Skip to content

Commit

Permalink
adding multi-tenancy + sdk client related changes to model, model gro…
Browse files Browse the repository at this point in the history
…up and connector update (#3399)

* adding multi-tenancy + sdk client related changes to model, model group and connector update

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* addressed comments

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* addressed more comments + refactored few codes

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

---------

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
(cherry picked from commit f63b961)
  • Loading branch information
dhrubo-os authored and github-actions[bot] committed Jan 21, 2025
1 parent eb225b9 commit 4dca0c0
Show file tree
Hide file tree
Showing 92 changed files with 5,379 additions and 1,960 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,18 @@ default ActionFuture<MLModel> getModel(String modelId) {
* @param modelId id of the model
* @param listener action listener
*/
void getModel(String modelId, ActionListener<MLModel> listener);
default void getModel(String modelId, ActionListener<MLModel> listener) {
getModel(modelId, null, listener);
}

/**
* Get MLModel and return model in listener
* For more info on get model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-model-information
* @param modelId id of the model
* @param tenantId id of the tenant
* @param listener action listener
*/
void getModel(String modelId, String tenantId, ActionListener<MLModel> listener);

/**
* Get MLTask and return ActionFuture.
Expand Down Expand Up @@ -182,7 +193,18 @@ default ActionFuture<DeleteResponse> deleteModel(String modelId) {
* @param modelId id of the model
* @param listener action listener
*/
void deleteModel(String modelId, ActionListener<DeleteResponse> listener);
default void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
deleteModel(modelId, null, listener);
}

/**
* Delete MLModel
* For more info on delete model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#delete-model
* @param modelId id of the model
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener action listener
*/
void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener);

/**
* Delete the task with taskId.
Expand Down Expand Up @@ -323,19 +345,10 @@ default ActionFuture<DeleteResponse> deleteConnector(String connectorId) {
return actionFuture;
}

/**
* Delete connector for remote model
* @param connectorId The id of the connector to delete
* @return the result future
*/
default ActionFuture<DeleteResponse> deleteConnector(String connectorId, String tenantId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteConnector(connectorId, tenantId, actionFuture);
return actionFuture;
default void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
deleteConnector(connectorId, null, listener);
}

void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener);

void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
}

@Override
public void getModel(String modelId, ActionListener<MLModel> listener) {
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build();
public void getModel(String modelId, String tenantId, ActionListener<MLModel> listener) {
MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).tenantId(tenantId).build();

client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener));
}
Expand All @@ -178,8 +178,8 @@ private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(A
}

@Override
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build();
public void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener) {
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).tenantId(tenantId).build();

client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure));
}
Expand Down Expand Up @@ -259,17 +259,6 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio
client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener));
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId);
client
.execute(
MLConnectorDeleteAction.INSTANCE,
connectorDeleteRequest,
ActionListener.wrap(listener::onResponse, listener::onFailure)
);
}

@Override
public void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> listener) {
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId, tenantId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ public class MachineLearningClientTest {
@Mock
ActionListener<MLOutput> dataFrameActionListener;

@Mock
ActionListener<MLModel> mlModelActionListener;

@Mock
DeleteResponse deleteResponse;

Expand Down Expand Up @@ -166,11 +169,21 @@ public void getModel(String modelId, ActionListener<MLModel> listener) {
listener.onResponse(mlModel);
}

@Override
public void getModel(String modelId, String tenantId, ActionListener<MLModel> listener) {
listener.onResponse(mlModel);
}

@Override
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void searchModel(SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
listener.onResponse(searchResponse);
Expand Down Expand Up @@ -352,6 +365,22 @@ public void getModel() {
assertEquals(mlModel, machineLearningClient.getModel("modelId").actionGet());
}

@Test
public void getModelActionListener() {
ArgumentCaptor<MLModel> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLModel.class);
machineLearningClient.getModel("modelId", mlModelActionListener);
verify(mlModelActionListener).onResponse(dataFrameArgumentCaptor.capture());
assertEquals(mlModel, dataFrameArgumentCaptor.getValue());
assertEquals(mlModel.getTenantId(), dataFrameArgumentCaptor.getValue().getTenantId());
}

@Test
public void undeploy_WithSpecificNodes() {
String[] modelIds = new String[] { "model1", "model2" };
String[] nodeIds = new String[] { "node1", "node2" };
assertEquals(undeployModelsResponse, machineLearningClient.undeploy(modelIds, nodeIds).actionGet());
}

@Test
public void deleteModel() {
assertEquals(deleteResponse, machineLearningClient.deleteModel("modelId").actionGet());
Expand All @@ -362,6 +391,11 @@ public void searchModel() {
assertEquals(searchResponse, machineLearningClient.searchModel(new SearchRequest()).actionGet());
}

@Test
public void deleteConnector_WithTenantId() {
assertEquals(deleteResponse, machineLearningClient.deleteConnector("connectorId").actionGet());
}

@Test
public void registerModelGroup() {
List<String> backendRoles = Arrays.asList("IT", "HR");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand Down Expand Up @@ -325,6 +326,64 @@ public void train() {
assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus());
}

@Test
public void getModel_withTenantId() {
String modelContent = "test content";
String tenantId = "tenantId";
doAnswer(invocation -> {
ActionListener<MLModelGetResponse> actionListener = invocation.getArgument(2);
MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("test").content(modelContent).build();
MLModelGetResponse output = MLModelGetResponse.builder().mlModel(mlModel).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLModelGetAction.INSTANCE), any(), any());

ArgumentCaptor<MLModel> argumentCaptor = ArgumentCaptor.forClass(MLModel.class);
machineLearningNodeClient.getModel("modelId", tenantId, getModelActionListener);

verify(client).execute(eq(MLModelGetAction.INSTANCE), isA(MLModelGetRequest.class), any());
verify(getModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(FunctionName.KMEANS, argumentCaptor.getValue().getAlgorithm());
assertEquals(modelContent, argumentCaptor.getValue().getContent());
}

@Test
public void undeployModels_withNullNodeIds() {
doAnswer(invocation -> {
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(2);
MLUndeployModelsResponse output = new MLUndeployModelsResponse(
new MLUndeployModelNodesResponse(ClusterName.DEFAULT, Collections.emptyList(), Collections.emptyList())
);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLUndeployModelsAction.INSTANCE), any(), any());

machineLearningNodeClient.undeploy(new String[] { "model1" }, null, undeployModelsActionListener);
verify(client).execute(eq(MLUndeployModelsAction.INSTANCE), isA(MLUndeployModelsRequest.class), any());
}

@Test
public void createConnector_withValidInput() {
doAnswer(invocation -> {
ActionListener<MLCreateConnectorResponse> actionListener = invocation.getArgument(2);
MLCreateConnectorResponse output = new MLCreateConnectorResponse("connectorId");
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLCreateConnectorAction.INSTANCE), any(), any());

MLCreateConnectorInput input = MLCreateConnectorInput
.builder()
.name("testConnector")
.protocol("http")
.version("1")
.credential(Map.of("TEST_CREDENTIAL_KEY", "TEST_CREDENTIAL_VALUE"))
.parameters(Map.of("endpoint", "https://example.com"))
.build();

machineLearningNodeClient.createConnector(input, createConnectorActionListener);
verify(client).execute(eq(MLCreateConnectorAction.INSTANCE), isA(MLCreateConnectorRequest.class), any());
}

@Test
public void registerModelGroup_withValidInput() {
doAnswer(invocation -> {
Expand All @@ -346,6 +405,146 @@ public void registerModelGroup_withValidInput() {
verify(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), isA(MLRegisterModelGroupRequest.class), any());
}

@Test
public void listTools_withValidRequest() {
doAnswer(invocation -> {
ActionListener<MLToolsListResponse> actionListener = invocation.getArgument(2);
MLToolsListResponse output = MLToolsListResponse
.builder()
.toolMetadata(
Arrays
.asList(
ToolMetadata.builder().name("tool1").description("description1").build(),
ToolMetadata.builder().name("tool2").description("description2").build()
)
)
.build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any());

machineLearningNodeClient.listTools(listToolsActionListener);
verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any());
}

@Test
public void listTools_withEmptyResponse() {
doAnswer(invocation -> {
ActionListener<MLToolsListResponse> actionListener = invocation.getArgument(2);
MLToolsListResponse output = MLToolsListResponse.builder().toolMetadata(Collections.emptyList()).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any());

ArgumentCaptor<List<ToolMetadata>> argumentCaptor = ArgumentCaptor.forClass(List.class);
machineLearningNodeClient.listTools(listToolsActionListener);

verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any());
verify(listToolsActionListener).onResponse(argumentCaptor.capture());

List<ToolMetadata> capturedTools = argumentCaptor.getValue();
assertTrue(capturedTools.isEmpty());
}

@Test
public void getTool_withValidToolName() {
doAnswer(invocation -> {
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
MLToolGetResponse output = MLToolGetResponse
.builder()
.toolMetadata(ToolMetadata.builder().name("tool1").description("description1").build())
.build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());

machineLearningNodeClient.getTool("tool1", getToolActionListener);
verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
}

@Test
public void getTool_withValidRequest() {
ToolMetadata toolMetadata = ToolMetadata
.builder()
.name("MathTool")
.description("Use this tool to calculate any math problem.")
.build();

doAnswer(invocation -> {
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
MLToolGetResponse output = MLToolGetResponse.builder().toolMetadata(toolMetadata).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());

ArgumentCaptor<ToolMetadata> argumentCaptor = ArgumentCaptor.forClass(ToolMetadata.class);
machineLearningNodeClient.getTool("MathTool", getToolActionListener);

verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
verify(getToolActionListener).onResponse(argumentCaptor.capture());

ToolMetadata capturedTool = argumentCaptor.getValue();
assertEquals("MathTool", capturedTool.getName());
assertEquals("Use this tool to calculate any math problem.", capturedTool.getDescription());
}

@Test
public void getTool_withFailureResponse() {
doAnswer(invocation -> {
ActionListener<MLToolGetResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new RuntimeException("Test exception"));
return null;
}).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any());

machineLearningNodeClient.getTool("MathTool", new ActionListener<>() {
@Override
public void onResponse(ToolMetadata toolMetadata) {
fail("Expected failure but got response");
}

@Override
public void onFailure(Exception e) {
assertEquals("Test exception", e.getMessage());
}
});

verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any());
}

@Test
public void train_withAsync() {
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLTrainingOutput output = MLTrainingOutput.builder().status("InProgress").modelId("modelId").build();
actionListener.onResponse(MLTaskResponse.builder().output(output).build());
return null;
}).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any());

MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.train(mlInput, true, trainingActionListener);
verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any());
}

@Test
public void deleteModel_withTenantId() {
String modelId = "testModelId";
String tenantId = "tenantId";
doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, modelId, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLModelDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
machineLearningNodeClient.deleteModel(modelId, tenantId, deleteModelActionListener);

verify(client).execute(eq(MLModelDeleteAction.INSTANCE), isA(MLModelDeleteRequest.class), any());
verify(deleteModelActionListener).onResponse(argumentCaptor.capture());
assertEquals(modelId, argumentCaptor.getValue().getId());
}

@Test
public void train_Exception_WithNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
Expand Down
Loading

0 comments on commit 4dca0c0

Please sign in to comment.