From 4dca0c046be63f20e81aad2c38303e2ea448a190 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Tue, 21 Jan 2025 14:46:06 -0800 Subject: [PATCH] adding multi-tenancy + sdk client related changes to model, model group and connector update (#3399) * adding multi-tenancy + sdk client related changes to model, model group and connector update Signed-off-by: Dhrubo Saha * addressed comments Signed-off-by: Dhrubo Saha * addressed more comments + refactored few codes Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha (cherry picked from commit f63b961bcdadefba237180f3110027bcfe7380d3) --- .../ml/client/MachineLearningClient.java | 39 +- .../ml/client/MachineLearningNodeClient.java | 19 +- .../ml/client/MachineLearningClientTest.java | 34 + .../client/MachineLearningNodeClientTest.java | 199 +++++ .../org/opensearch/ml/common/MLModel.java | 24 +- .../opensearch/ml/common/MLModelGroup.java | 25 +- .../ml/common/connector/HttpConnector.java | 2 +- .../connector/MLConnectorGetResponse.java | 6 +- .../connector/MLCreateConnectorInput.java | 2 +- .../connector/MLUpdateConnectorRequest.java | 4 +- .../transport/model/MLModelDeleteRequest.java | 18 +- .../transport/model/MLModelGetRequest.java | 13 +- .../transport/model/MLUpdateModelInput.java | 24 +- .../MLModelGroupDeleteRequest.java | 13 +- .../model_group/MLModelGroupGetRequest.java | 12 +- .../MLRegisterModelGroupInput.java | 26 +- .../model_group/MLUpdateModelGroupInput.java | 24 +- .../register/MLRegisterModelInput.java | 28 +- .../ml/common/MLModelGroupTest.java | 91 ++ .../opensearch/ml/common/MLModelTests.java | 42 + .../common/connector/HttpConnectorTest.java | 82 ++ .../MLUpdateConnectorRequestTests.java | 53 +- .../model/MLModelDeleteRequestTest.java | 106 +++ .../model/MLModelGetRequestTest.java | 79 ++ .../model/MLUpdateModelInputTest.java | 60 +- .../MLModelGroupDeleteRequestTest.java | 62 +- .../MLModelGroupGetRequestTest.java | 44 + .../MLRegisterModelGroupInputTest.java | 103 +++ .../MLUpdateModelGroupInputTest.java | 93 +- .../register/MLRegisterModelInputTest.java | 48 + .../MetricsCorrelation.java | 2 +- .../ml/engine/indices/MLIndicesHandler.java | 10 +- .../DeleteConnectorTransportAction.java | 85 +- .../TransportCreateConnectorAction.java | 4 +- .../UpdateConnectorTransportAction.java | 211 +++-- .../DeleteModelGroupTransportAction.java | 217 +++-- .../GetModelGroupTransportAction.java | 185 ++-- .../TransportRegisterModelGroupAction.java | 14 +- .../TransportUpdateModelGroupAction.java | 207 +++-- .../models/DeleteModelTransportAction.java | 247 ++++-- .../models/GetModelTransportAction.java | 161 ++-- .../models/UpdateModelTransportAction.java | 374 ++++---- .../TransportRegisterModelAction.java | 193 +++-- .../TransportRegisterModelMetaAction.java | 4 +- .../helper/ConnectorAccessControlHelper.java | 84 +- .../ml/helper/ModelAccessControlHelper.java | 137 ++- .../ml/model/MLModelGroupManager.java | 206 +++-- .../opensearch/ml/model/MLModelManager.java | 819 ++++++++++++++---- .../ml/plugin/MachineLearningPlugin.java | 37 +- .../ml/rest/RestMLDeleteModelAction.java | 12 +- .../ml/rest/RestMLDeleteModelGroupAction.java | 11 +- .../ml/rest/RestMLGetModelAction.java | 12 +- .../ml/rest/RestMLGetModelGroupAction.java | 11 +- .../ml/rest/RestMLRegisterModelAction.java | 3 + .../rest/RestMLRegisterModelGroupAction.java | 9 +- .../ml/rest/RestMLSearchModelGroupAction.java | 6 +- .../ml/rest/RestMLUpdateConnectorAction.java | 4 +- .../ml/rest/RestMLUpdateModelAction.java | 9 + .../ml/rest/RestMLUpdateModelGroupAction.java | 13 + .../plugin-metadata/plugin-security.policy | 3 + .../ml/action/MLCommonsIntegTestCase.java | 4 +- .../DeleteConnectorTransportActionTests.java | 232 ++--- .../GetConnectorTransportActionTests.java | 55 +- .../TransportCreateConnectorActionTests.java | 114 +-- .../UpdateConnectorTransportActionTests.java | 114 +-- .../DeleteModelGroupTransportActionTests.java | 191 +++- .../GetModelGroupTransportActionTests.java | 23 +- .../RegisterModelGroupITTests.java | 19 +- .../model_group/SearchModelGroupITTests.java | 9 +- ...ransportRegisterModelGroupActionTests.java | 35 +- .../TransportUpdateModelGroupActionTests.java | 29 +- .../model_group/UpdateModelGroupITTests.java | 20 +- .../DeleteModelTransportActionTests.java | 161 ++-- .../ml/action/models/GetModelITTests.java | 4 +- .../models/GetModelTransportActionTests.java | 67 +- .../ml/action/models/SearchModelITTests.java | 9 +- .../UpdateModelTransportActionTests.java | 693 ++++++++++----- .../TransportRegisterModelActionTests.java | 92 +- ...TransportRegisterModelMetaActionTests.java | 28 +- .../ConnectorAccessControlHelperTests.java | 79 +- .../helper/ModelAccessControlHelperTests.java | 186 +++- .../ml/model/MLModelGroupManagerTests.java | 111 ++- .../ml/model/MLModelManagerTests.java | 79 +- .../ml/rest/RestMLDeleteModelActionTests.java | 70 +- .../RestMLDeleteModelGroupActionTests.java | 56 +- .../ml/rest/RestMLGetModelActionTests.java | 52 +- .../ml/rest/RestMLGetModelGroupActionIT.java | 2 +- .../rest/RestMLGetModelGroupActionTests.java | 51 +- .../RestMLRegisterModelGroupActionTests.java | 29 +- .../RestMLSearchModelGroupActionTests.java | 8 +- .../ml/rest/RestMLUpdateModelActionTests.java | 31 +- .../RestMLUpdateModelGroupActionTests.java | 21 +- 92 files changed, 5379 insertions(+), 1960 deletions(-) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 8d98f86b6a..e50c7dc609 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -142,7 +142,18 @@ default ActionFuture getModel(String modelId) { * @param modelId id of the model * @param listener action listener */ - void getModel(String modelId, ActionListener listener); + default void getModel(String modelId, ActionListener listener) { + getModel(modelId, null, listener); + } + + /** + * Get MLModel and return model in listener + * For more info on get model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#get-model-information + * @param modelId id of the model + * @param tenantId id of the tenant + * @param listener action listener + */ + void getModel(String modelId, String tenantId, ActionListener listener); /** * Get MLTask and return ActionFuture. @@ -182,7 +193,18 @@ default ActionFuture deleteModel(String modelId) { * @param modelId id of the model * @param listener action listener */ - void deleteModel(String modelId, ActionListener listener); + default void deleteModel(String modelId, ActionListener listener) { + deleteModel(modelId, null, listener); + } + + /** + * Delete MLModel + * For more info on delete model, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#delete-model + * @param modelId id of the model + * @param tenantId the tenant id. This is necessary for multi-tenancy. + * @param listener action listener + */ + void deleteModel(String modelId, String tenantId, ActionListener listener); /** * Delete the task with taskId. @@ -323,19 +345,10 @@ default ActionFuture deleteConnector(String connectorId) { return actionFuture; } - /** - * Delete connector for remote model - * @param connectorId The id of the connector to delete - * @return the result future - */ - default ActionFuture deleteConnector(String connectorId, String tenantId) { - PlainActionFuture actionFuture = PlainActionFuture.newFuture(); - deleteConnector(connectorId, tenantId, actionFuture); - return actionFuture; + default void deleteConnector(String connectorId, ActionListener listener) { + deleteConnector(connectorId, null, listener); } - void deleteConnector(String connectorId, ActionListener listener); - void deleteConnector(String connectorId, String tenantId, ActionListener listener); /** diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 288a9f2e3a..56bc5e1425 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -164,8 +164,8 @@ public void run(MLInput mlInput, Map args, ActionListener listener) { - MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build(); + public void getModel(String modelId, String tenantId, ActionListener listener) { + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).tenantId(tenantId).build(); client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener)); } @@ -178,8 +178,8 @@ private ActionListener getMlGetModelResponseActionListener(A } @Override - public void deleteModel(String modelId, ActionListener listener) { - MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); + public void deleteModel(String modelId, String tenantId, ActionListener listener) { + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).tenantId(tenantId).build(); client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, ActionListener.wrap(listener::onResponse, listener::onFailure)); } @@ -259,17 +259,6 @@ public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, Actio client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener)); } - @Override - public void deleteConnector(String connectorId, ActionListener listener) { - MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId); - client - .execute( - MLConnectorDeleteAction.INSTANCE, - connectorDeleteRequest, - ActionListener.wrap(listener::onResponse, listener::onFailure) - ); - } - @Override public void deleteConnector(String connectorId, String tenantId, ActionListener listener) { MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId, tenantId); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index b0bdb80db8..01b51bd9f0 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -76,6 +76,9 @@ public class MachineLearningClientTest { @Mock ActionListener dataFrameActionListener; + @Mock + ActionListener mlModelActionListener; + @Mock DeleteResponse deleteResponse; @@ -166,11 +169,21 @@ public void getModel(String modelId, ActionListener listener) { listener.onResponse(mlModel); } + @Override + public void getModel(String modelId, String tenantId, ActionListener listener) { + listener.onResponse(mlModel); + } + @Override public void deleteModel(String modelId, ActionListener listener) { listener.onResponse(deleteResponse); } + @Override + public void deleteModel(String modelId, String tenantId, ActionListener listener) { + listener.onResponse(deleteResponse); + } + @Override public void searchModel(SearchRequest searchRequest, ActionListener listener) { listener.onResponse(searchResponse); @@ -352,6 +365,22 @@ public void getModel() { assertEquals(mlModel, machineLearningClient.getModel("modelId").actionGet()); } + @Test + public void getModelActionListener() { + ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLModel.class); + machineLearningClient.getModel("modelId", mlModelActionListener); + verify(mlModelActionListener).onResponse(dataFrameArgumentCaptor.capture()); + assertEquals(mlModel, dataFrameArgumentCaptor.getValue()); + assertEquals(mlModel.getTenantId(), dataFrameArgumentCaptor.getValue().getTenantId()); + } + + @Test + public void undeploy_WithSpecificNodes() { + String[] modelIds = new String[] { "model1", "model2" }; + String[] nodeIds = new String[] { "node1", "node2" }; + assertEquals(undeployModelsResponse, machineLearningClient.undeploy(modelIds, nodeIds).actionGet()); + } + @Test public void deleteModel() { assertEquals(deleteResponse, machineLearningClient.deleteModel("modelId").actionGet()); @@ -362,6 +391,11 @@ public void searchModel() { assertEquals(searchResponse, machineLearningClient.searchModel(new SearchRequest()).actionGet()); } + @Test + public void deleteConnector_WithTenantId() { + assertEquals(deleteResponse, machineLearningClient.deleteConnector("connectorId").actionGet()); + } + @Test public void registerModelGroup() { List backendRoles = Arrays.asList("IT", "HR"); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 0f4904e20c..7d77d2132d 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Answers.RETURNS_DEEP_STUBS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -325,6 +326,64 @@ public void train() { assertEquals(status, ((MLTrainingOutput) argumentCaptor.getValue()).getStatus()); } + @Test + public void getModel_withTenantId() { + String modelContent = "test content"; + String tenantId = "tenantId"; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("test").content(modelContent).build(); + MLModelGetResponse output = MLModelGetResponse.builder().mlModel(mlModel).build(); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLModelGetAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModel.class); + machineLearningNodeClient.getModel("modelId", tenantId, getModelActionListener); + + verify(client).execute(eq(MLModelGetAction.INSTANCE), isA(MLModelGetRequest.class), any()); + verify(getModelActionListener).onResponse(argumentCaptor.capture()); + assertEquals(FunctionName.KMEANS, argumentCaptor.getValue().getAlgorithm()); + assertEquals(modelContent, argumentCaptor.getValue().getContent()); + } + + @Test + public void undeployModels_withNullNodeIds() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLUndeployModelsResponse output = new MLUndeployModelsResponse( + new MLUndeployModelNodesResponse(ClusterName.DEFAULT, Collections.emptyList(), Collections.emptyList()) + ); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLUndeployModelsAction.INSTANCE), any(), any()); + + machineLearningNodeClient.undeploy(new String[] { "model1" }, null, undeployModelsActionListener); + verify(client).execute(eq(MLUndeployModelsAction.INSTANCE), isA(MLUndeployModelsRequest.class), any()); + } + + @Test + public void createConnector_withValidInput() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLCreateConnectorResponse output = new MLCreateConnectorResponse("connectorId"); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLCreateConnectorAction.INSTANCE), any(), any()); + + MLCreateConnectorInput input = MLCreateConnectorInput + .builder() + .name("testConnector") + .protocol("http") + .version("1") + .credential(Map.of("TEST_CREDENTIAL_KEY", "TEST_CREDENTIAL_VALUE")) + .parameters(Map.of("endpoint", "https://example.com")) + .build(); + + machineLearningNodeClient.createConnector(input, createConnectorActionListener); + verify(client).execute(eq(MLCreateConnectorAction.INSTANCE), isA(MLCreateConnectorRequest.class), any()); + } + @Test public void registerModelGroup_withValidInput() { doAnswer(invocation -> { @@ -346,6 +405,146 @@ public void registerModelGroup_withValidInput() { verify(client).execute(eq(MLRegisterModelGroupAction.INSTANCE), isA(MLRegisterModelGroupRequest.class), any()); } + @Test + public void listTools_withValidRequest() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLToolsListResponse output = MLToolsListResponse + .builder() + .toolMetadata( + Arrays + .asList( + ToolMetadata.builder().name("tool1").description("description1").build(), + ToolMetadata.builder().name("tool2").description("description2").build() + ) + ) + .build(); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any()); + + machineLearningNodeClient.listTools(listToolsActionListener); + verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any()); + } + + @Test + public void listTools_withEmptyResponse() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLToolsListResponse output = MLToolsListResponse.builder().toolMetadata(Collections.emptyList()).build(); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLListToolsAction.INSTANCE), any(), any()); + + ArgumentCaptor> argumentCaptor = ArgumentCaptor.forClass(List.class); + machineLearningNodeClient.listTools(listToolsActionListener); + + verify(client).execute(eq(MLListToolsAction.INSTANCE), isA(MLToolsListRequest.class), any()); + verify(listToolsActionListener).onResponse(argumentCaptor.capture()); + + List capturedTools = argumentCaptor.getValue(); + assertTrue(capturedTools.isEmpty()); + } + + @Test + public void getTool_withValidToolName() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLToolGetResponse output = MLToolGetResponse + .builder() + .toolMetadata(ToolMetadata.builder().name("tool1").description("description1").build()) + .build(); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any()); + + machineLearningNodeClient.getTool("tool1", getToolActionListener); + verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any()); + } + + @Test + public void getTool_withValidRequest() { + ToolMetadata toolMetadata = ToolMetadata + .builder() + .name("MathTool") + .description("Use this tool to calculate any math problem.") + .build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLToolGetResponse output = MLToolGetResponse.builder().toolMetadata(toolMetadata).build(); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(ToolMetadata.class); + machineLearningNodeClient.getTool("MathTool", getToolActionListener); + + verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any()); + verify(getToolActionListener).onResponse(argumentCaptor.capture()); + + ToolMetadata capturedTool = argumentCaptor.getValue(); + assertEquals("MathTool", capturedTool.getName()); + assertEquals("Use this tool to calculate any math problem.", capturedTool.getDescription()); + } + + @Test + public void getTool_withFailureResponse() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Test exception")); + return null; + }).when(client).execute(eq(MLGetToolAction.INSTANCE), any(), any()); + + machineLearningNodeClient.getTool("MathTool", new ActionListener<>() { + @Override + public void onResponse(ToolMetadata toolMetadata) { + fail("Expected failure but got response"); + } + + @Override + public void onFailure(Exception e) { + assertEquals("Test exception", e.getMessage()); + } + }); + + verify(client).execute(eq(MLGetToolAction.INSTANCE), isA(MLToolGetRequest.class), any()); + } + + @Test + public void train_withAsync() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLTrainingOutput output = MLTrainingOutput.builder().status("InProgress").modelId("modelId").build(); + actionListener.onResponse(MLTaskResponse.builder().output(output).build()); + return null; + }).when(client).execute(eq(MLTrainingTaskAction.INSTANCE), any(), any()); + + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build(); + machineLearningNodeClient.train(mlInput, true, trainingActionListener); + verify(client).execute(eq(MLTrainingTaskAction.INSTANCE), isA(MLTrainingTaskRequest.class), any()); + } + + @Test + public void deleteModel_withTenantId() { + String modelId = "testModelId"; + String tenantId = "tenantId"; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, modelId, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLModelDeleteAction.INSTANCE), any(), any()); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + machineLearningNodeClient.deleteModel(modelId, tenantId, deleteModelActionListener); + + verify(client).execute(eq(MLModelDeleteAction.INSTANCE), isA(MLModelDeleteRequest.class), any()); + verify(deleteModelActionListener).onResponse(argumentCaptor.capture()); + assertEquals(modelId, argumentCaptor.getValue().getId()); + } + @Test public void train_Exception_WithNullDataSet() { exceptionRule.expect(IllegalArgumentException.class); diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index a742f542c5..d6d0e03fdd 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -6,7 +6,9 @@ package org.opensearch.ml.common; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.CommonValue.USER; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import static org.opensearch.ml.common.connector.Connector.createConnector; import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; @@ -20,6 +22,7 @@ import java.util.Map; import java.util.Set; +import org.opensearch.Version; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -142,6 +145,7 @@ public class MLModel implements ToXContentObject { private Connector connector; private String connectorId; private Guardrails guardrails; + private String tenantId; /** * Model interface is a map that contains the input and output fields of the model, with JSON schema as the value. @@ -208,7 +212,8 @@ public MLModel( Connector connector, String connectorId, Guardrails guardrails, - Map modelInterface + Map modelInterface, + String tenantId ) { this.name = name; this.modelGroupId = modelGroupId; @@ -244,9 +249,11 @@ public MLModel( this.connectorId = connectorId; this.guardrails = guardrails; this.modelInterface = modelInterface; + this.tenantId = tenantId; } public MLModel(StreamInput input) throws IOException { + Version streamInputVersion = input.getVersion(); name = input.readOptionalString(); algorithm = input.readEnum(FunctionName.class); version = input.readString(); @@ -308,10 +315,14 @@ public MLModel(StreamInput input) throws IOException { if (input.readBoolean()) { modelInterface = input.readMap(StreamInput::readString, StreamInput::readString); } + if (streamInputVersion.onOrAfter(VERSION_2_19_0)) { + tenantId = input.readOptionalString(); + } } } public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeOptionalString(name); out.writeEnum(algorithm); out.writeString(version); @@ -391,6 +402,9 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } @Override @@ -498,6 +512,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (modelInterface != null) { builder.field(INTERFACE_FIELD, modelInterface); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } builder.endObject(); return builder; } @@ -543,6 +560,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws String connectorId = null; Guardrails guardrails = null; Map modelInterface = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -677,6 +695,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws case INTERFACE_FIELD: modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); break; + case TENANT_ID_FIELD: + tenantId = parser.textOrNull(); + break; default: parser.skipChildren(); break; @@ -718,6 +739,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws .connectorId(connectorId) .guardrails(guardrails) .modelInterface(modelInterface) + .tenantId(tenantId) .build(); } diff --git a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java index 91b21131d4..498f0127a8 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java @@ -6,6 +6,8 @@ package org.opensearch.ml.common; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; import java.time.Instant; @@ -13,6 +15,7 @@ import java.util.List; import java.util.Objects; +import org.opensearch.Version; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -51,6 +54,7 @@ public class MLModelGroup implements ToXContentObject { private Instant createdTime; private Instant lastUpdatedTime; + private String tenantId; @Builder(toBuilder = true) public MLModelGroup( @@ -62,7 +66,8 @@ public MLModelGroup( String access, String modelGroupId, Instant createdTime, - Instant lastUpdatedTime + Instant lastUpdatedTime, + String tenantId ) { this.name = Objects.requireNonNull(name, "model group name must not be null"); this.description = description; @@ -73,9 +78,11 @@ public MLModelGroup( this.modelGroupId = modelGroupId; this.createdTime = createdTime; this.lastUpdatedTime = lastUpdatedTime; + this.tenantId = tenantId; } public MLModelGroup(StreamInput input) throws IOException { + Version streamInputVersion = input.getVersion(); name = input.readString(); description = input.readOptionalString(); latestVersion = input.readInt(); @@ -91,9 +98,11 @@ public MLModelGroup(StreamInput input) throws IOException { modelGroupId = input.readOptionalString(); createdTime = input.readOptionalInstant(); lastUpdatedTime = input.readOptionalInstant(); + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null; } public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeString(name); out.writeOptionalString(description); out.writeInt(latestVersion); @@ -113,6 +122,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(modelGroupId); out.writeOptionalInstant(createdTime); out.writeOptionalInstant(lastUpdatedTime); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } @Override @@ -141,6 +153,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (lastUpdatedTime != null) { builder.field(LAST_UPDATED_TIME_FIELD, lastUpdatedTime.toEpochMilli()); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } builder.endObject(); return builder; } @@ -155,6 +170,7 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { String modelGroupId = null; Instant createdTime = null; Instant lastUpdateTime = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -193,6 +209,9 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { case LAST_UPDATED_TIME_FIELD: lastUpdateTime = Instant.ofEpochMilli(parser.longValue()); break; + case TENANT_ID_FIELD: + tenantId = parser.textOrNull(); + break; default: parser.skipChildren(); break; @@ -209,11 +228,11 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { .modelGroupId(modelGroupId) .createdTime(createdTime) .lastUpdatedTime(lastUpdateTime) + .tenantId(tenantId) .build(); } public static MLModelGroup fromStream(StreamInput in) throws IOException { - MLModelGroup mlModel = new MLModelGroup(in); - return mlModel; + return new MLModelGroup(in); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index c7d23bc8f5..74f63f2260 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -144,7 +144,7 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException connectorClientConfig = ConnectorClientConfig.parse(parser); break; case TENANT_ID_FIELD: - tenantId = parser.text(); + tenantId = parser.textOrNull(); break; default: parser.skipChildren(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java index 3265db4ff2..680558a8a8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java @@ -21,7 +21,9 @@ import org.opensearch.ml.common.connector.Connector; import lombok.Builder; +import lombok.Getter; +@Getter public class MLConnectorGetResponse extends ActionResponse implements ToXContentObject { Connector mlConnector; @@ -35,10 +37,6 @@ public MLConnectorGetResponse(StreamInput in) throws IOException { mlConnector = Connector.fromStream(in); } - public Connector getMlConnector() { - return mlConnector; - } - @Override public void writeTo(StreamOutput out) throws IOException { mlConnector.writeTo(out); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 99dc51ab99..d6442d70c7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -190,7 +190,7 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update connectorClientConfig = ConnectorClientConfig.parse(parser); break; case TENANT_ID_FIELD: - tenantId = parser.text(); + tenantId = parser.textOrNull(); break; default: parser.skipChildren(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java index 8a365140de..d496004db5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java @@ -62,9 +62,9 @@ public ActionRequestValidationException validate() { return exception; } - public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException { + public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId, String tenantId) throws IOException { MLCreateConnectorInput updateContent = MLCreateConnectorInput.parse(parser, true); - + updateContent.setTenantId(tenantId); return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(updateContent).build(); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java index 4c57c5912c..ea7019788b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java @@ -6,12 +6,14 @@ package org.opensearch.ml.common.transport.model; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -26,20 +28,34 @@ public class MLModelDeleteRequest extends ActionRequest { @Getter String modelId; + @Getter + String tenantId; + @Builder - public MLModelDeleteRequest(String modelId) { + public MLModelDeleteRequest(String modelId, String tenantId) { this.modelId = modelId; + this.tenantId = tenantId; } public MLModelDeleteRequest(StreamInput input) throws IOException { super(input); + Version streamInputVersion = input.getVersion(); this.modelId = input.readString(); + if (streamInputVersion.onOrAfter(VERSION_2_19_0)) { + this.tenantId = input.readOptionalString(); + } else { + this.tenantId = null; + } } @Override public void writeTo(StreamOutput output) throws IOException { super.writeTo(output); + Version streamOutputVersion = output.getVersion(); output.writeString(modelId); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + output.writeOptionalString(tenantId); + } } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java index 8a6e4c2c1e..c558b07a91 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java @@ -6,12 +6,14 @@ package org.opensearch.ml.common.transport.model; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -36,27 +38,36 @@ public class MLModelGetRequest extends ActionRequest { // delete/update options, we also perform get operation. This field is to distinguish between // these two situations. boolean isUserInitiatedGetRequest; + String tenantId; @Builder - public MLModelGetRequest(String modelId, boolean returnContent, boolean isUserInitiatedGetRequest) { + public MLModelGetRequest(String modelId, boolean returnContent, boolean isUserInitiatedGetRequest, String tenantId) { this.modelId = modelId; this.returnContent = returnContent; this.isUserInitiatedGetRequest = isUserInitiatedGetRequest; + this.tenantId = tenantId; } public MLModelGetRequest(StreamInput in) throws IOException { super(in); + Version streamInputVersion = in.getVersion(); this.modelId = in.readString(); this.returnContent = in.readBoolean(); this.isUserInitiatedGetRequest = in.readBoolean(); + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null; + } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + Version streamOutputVersion = out.getVersion(); out.writeString(this.modelId); out.writeBoolean(returnContent); out.writeBoolean(isUserInitiatedGetRequest); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index 81dae3a903..4dc54bb23c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -6,6 +6,8 @@ package org.opensearch.ml.common.transport.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; @@ -70,6 +72,7 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { private MLCreateConnectorInput connector; private Instant lastUpdateTime; private Guardrails guardrails; + private String tenantId; private Map modelInterface; @@ -89,7 +92,8 @@ public MLUpdateModelInput( MLCreateConnectorInput connector, Instant lastUpdateTime, Guardrails guardrails, - Map modelInterface + Map modelInterface, + String tenantId ) { this.modelId = modelId; this.description = description; @@ -106,6 +110,7 @@ public MLUpdateModelInput( this.lastUpdateTime = lastUpdateTime; this.guardrails = guardrails; this.modelInterface = modelInterface; + this.tenantId = tenantId; } public MLUpdateModelInput(StreamInput in) throws IOException { @@ -143,6 +148,7 @@ public MLUpdateModelInput(StreamInput in) throws IOException { modelInterface = in.readMap(StreamInput::readString, StreamInput::readString); } } + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null; } @Override @@ -190,6 +196,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelInterface != null) { builder.field(MLModel.INTERFACE_FIELD, modelInterface); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } builder.endObject(); return builder; } @@ -251,6 +260,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } public static MLUpdateModelInput parse(XContentParser parser) throws IOException { @@ -269,12 +281,16 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException Instant lastUpdateTime = null; Guardrails guardrails = null; Map modelInterface = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); parser.nextToken(); switch (fieldName) { + case MODEL_ID_FIELD: + modelId = parser.text(); + break; case DESCRIPTION_FIELD: description = parser.text(); break; @@ -308,6 +324,9 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException case MLModel.INTERFACE_FIELD: modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); break; + case TENANT_ID_FIELD: + tenantId = parser.textOrNull(); + break; default: parser.skipChildren(); break; @@ -330,7 +349,8 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException connector, lastUpdateTime, guardrails, - modelInterface + modelInterface, + tenantId ); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java index 86a1d093ee..cd2be26209 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java @@ -6,12 +6,14 @@ package org.opensearch.ml.common.transport.model_group; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -25,21 +27,30 @@ public class MLModelGroupDeleteRequest extends ActionRequest { @Getter String modelGroupId; + @Getter + String tenantId; @Builder - public MLModelGroupDeleteRequest(String modelGroupId) { + public MLModelGroupDeleteRequest(String modelGroupId, String tenantId) { this.modelGroupId = modelGroupId; + this.tenantId = tenantId; } public MLModelGroupDeleteRequest(StreamInput input) throws IOException { super(input); + Version streamInputVersion = input.getVersion(); this.modelGroupId = input.readString(); + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null; } @Override public void writeTo(StreamOutput output) throws IOException { super.writeTo(output); + Version streamOutputVersion = output.getVersion(); output.writeString(modelGroupId); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + output.writeOptionalString(tenantId); + } } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java index a3a3d0fa57..cc9dbcd444 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java @@ -6,12 +6,14 @@ package org.opensearch.ml.common.transport.model_group; import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -31,21 +33,29 @@ public class MLModelGroupGetRequest extends ActionRequest { String modelGroupId; + String tenantId; @Builder - public MLModelGroupGetRequest(String modelGroupId) { + public MLModelGroupGetRequest(String modelGroupId, String tenantId) { this.modelGroupId = modelGroupId; + this.tenantId = tenantId; } public MLModelGroupGetRequest(StreamInput in) throws IOException { super(in); + Version streamInputVersion = in.getVersion(); this.modelGroupId = in.readString(); + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + Version streamOutputVersion = out.getVersion(); out.writeString(this.modelGroupId); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java index 8f4162f11f..c75a0e7c37 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java @@ -6,6 +6,8 @@ package org.opensearch.ml.common.transport.model_group; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; import java.util.ArrayList; @@ -13,6 +15,7 @@ import java.util.Locale; import java.util.Objects; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -38,6 +41,7 @@ public class MLRegisterModelGroupInput implements ToXContentObject, Writeable { private List backendRoles; private AccessMode modelAccessMode; private Boolean isAddAllBackendRoles; + private String tenantId; @Builder(toBuilder = true) public MLRegisterModelGroupInput( @@ -45,16 +49,19 @@ public MLRegisterModelGroupInput( String description, List backendRoles, AccessMode modelAccessMode, - Boolean isAddAllBackendRoles + Boolean isAddAllBackendRoles, + String tenantId ) { this.name = Objects.requireNonNull(name, "model group name must not be null"); this.description = description; this.backendRoles = backendRoles; this.modelAccessMode = modelAccessMode; this.isAddAllBackendRoles = isAddAllBackendRoles; + this.tenantId = tenantId; } public MLRegisterModelGroupInput(StreamInput in) throws IOException { + Version streamInputVersion = in.getVersion(); this.name = in.readString(); this.description = in.readOptionalString(); this.backendRoles = in.readOptionalStringList(); @@ -62,10 +69,14 @@ public MLRegisterModelGroupInput(StreamInput in) throws IOException { modelAccessMode = in.readEnum(AccessMode.class); } this.isAddAllBackendRoles = in.readOptionalBoolean(); + if (streamInputVersion.onOrAfter(VERSION_2_19_0)) { + this.tenantId = in.readOptionalString(); + } } @Override public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeString(name); out.writeOptionalString(description); if (backendRoles != null) { @@ -81,6 +92,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalBoolean(isAddAllBackendRoles); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } @Override @@ -99,6 +113,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isAddAllBackendRoles != null) { builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } builder.endObject(); return builder; } @@ -109,6 +126,7 @@ public static MLRegisterModelGroupInput parse(XContentParser parser) throws IOEx List backendRoles = null; AccessMode modelAccessMode = null; Boolean isAddAllBackendRoles = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -134,12 +152,14 @@ public static MLRegisterModelGroupInput parse(XContentParser parser) throws IOEx case ADD_ALL_BACKEND_ROLES: isAddAllBackendRoles = parser.booleanValue(); break; + case TENANT_ID_FIELD: + tenantId = parser.textOrNull(); + break; default: parser.skipChildren(); break; } } - return new MLRegisterModelGroupInput(name, description, backendRoles, modelAccessMode, isAddAllBackendRoles); + return new MLRegisterModelGroupInput(name, description, backendRoles, modelAccessMode, isAddAllBackendRoles, tenantId); } - } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java index 3dd92082c8..2851c9188d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java @@ -6,12 +6,15 @@ package org.opensearch.ml.common.transport.model_group; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Locale; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -39,6 +42,7 @@ public class MLUpdateModelGroupInput implements ToXContentObject, Writeable { private List backendRoles; private AccessMode modelAccessMode; private Boolean isAddAllBackendRoles; + private String tenantId; @Builder(toBuilder = true) public MLUpdateModelGroupInput( @@ -47,7 +51,8 @@ public MLUpdateModelGroupInput( String description, List backendRoles, AccessMode modelAccessMode, - Boolean isAddAllBackendRoles + Boolean isAddAllBackendRoles, + String tenantId ) { this.modelGroupID = modelGroupID; this.name = name; @@ -55,9 +60,11 @@ public MLUpdateModelGroupInput( this.backendRoles = backendRoles; this.modelAccessMode = modelAccessMode; this.isAddAllBackendRoles = isAddAllBackendRoles; + this.tenantId = tenantId; } public MLUpdateModelGroupInput(StreamInput in) throws IOException { + Version streamInputVersion = in.getVersion(); this.modelGroupID = in.readString(); this.name = in.readOptionalString(); this.description = in.readOptionalString(); @@ -66,6 +73,8 @@ public MLUpdateModelGroupInput(StreamInput in) throws IOException { modelAccessMode = in.readEnum(AccessMode.class); } this.isAddAllBackendRoles = in.readOptionalBoolean(); + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null; + } @Override @@ -87,12 +96,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isAddAllBackendRoles != null) { builder.field(ADD_ALL_BACKEND_ROLES_FIELD, isAddAllBackendRoles); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } builder.endObject(); return builder; } @Override public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeString(modelGroupID); out.writeOptionalString(name); out.writeOptionalString(description); @@ -109,6 +122,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalBoolean(isAddAllBackendRoles); + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } public static MLUpdateModelGroupInput parse(XContentParser parser) throws IOException { @@ -118,6 +134,7 @@ public static MLUpdateModelGroupInput parse(XContentParser parser) throws IOExce List backendRoles = null; AccessMode modelAccessMode = null; Boolean isAddAllBackendRoles = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -146,11 +163,14 @@ public static MLUpdateModelGroupInput parse(XContentParser parser) throws IOExce case ADD_ALL_BACKEND_ROLES_FIELD: isAddAllBackendRoles = parser.booleanValue(); break; + case TENANT_ID_FIELD: + tenantId = parser.textOrNull(); + break; default: parser.skipChildren(); break; } } - return new MLUpdateModelGroupInput(modelGroupID, name, description, backendRoles, modelAccessMode, isAddAllBackendRoles); + return new MLUpdateModelGroupInput(modelGroupID, name, description, backendRoles, modelAccessMode, isAddAllBackendRoles, tenantId); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index 2db003f005..2490246023 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -6,6 +6,8 @@ package org.opensearch.ml.common.transport.register; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; import static org.opensearch.ml.common.connector.Connector.createConnector; import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; @@ -102,6 +104,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private Guardrails guardrails; private Map modelInterface; + private String tenantId; @Builder(toBuilder = true) public MLRegisterModelInput( @@ -127,7 +130,8 @@ public MLRegisterModelInput( Boolean doesVersionCreateModelGroup, Boolean isHidden, Guardrails guardrails, - Map modelInterface + Map modelInterface, + String tenantId ) { this.functionName = Objects.requireNonNullElse(functionName, FunctionName.TEXT_EMBEDDING); if (modelName == null) { @@ -169,6 +173,7 @@ public MLRegisterModelInput( this.isHidden = isHidden; this.guardrails = guardrails; this.modelInterface = modelInterface; + this.tenantId = tenantId; } public MLRegisterModelInput(StreamInput in) throws IOException { @@ -228,6 +233,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.modelInterface = in.readMap(StreamInput::readString, StreamInput::readString); } } + this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null; } @Override @@ -309,6 +315,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } + if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { + out.writeOptionalString(tenantId); + } } @Override @@ -377,6 +386,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelInterface != null) { builder.field(MLModel.INTERFACE_FIELD, modelInterface); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } builder.endObject(); return builder; } @@ -403,6 +415,7 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName Boolean isHidden = null; Guardrails guardrails = null; Map modelInterface = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -479,6 +492,9 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case MLModel.INTERFACE_FIELD: modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); break; + case TENANT_ID_FIELD: + tenantId = parser.textOrNull(); + break; default: parser.skipChildren(); break; @@ -507,7 +523,8 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName doesVersionCreateModelGroup, isHidden, guardrails, - modelInterface + modelInterface, + tenantId ); } @@ -534,6 +551,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo Boolean isHidden = null; Guardrails guardrails = null; Map modelInterface = null; + String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -617,6 +635,9 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case MLModel.INTERFACE_FIELD: modelInterface = filteredParameterMap(parser.map(), allowedInterfaceFieldKeys); break; + case TENANT_ID_FIELD: + tenantId = parser.textOrNull(); + break; default: parser.skipChildren(); break; @@ -645,7 +666,8 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo doesVersionCreateModelGroup, isHidden, guardrails, - modelInterface + modelInterface, + tenantId ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java index c1abe07297..cdad6f21ca 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java @@ -145,4 +145,95 @@ public void writeTo_Empty() throws IOException { Assert.assertNull(modelGroup.getAccess()); Assert.assertNull(modelGroup.getOwner()); } + + @Test + public void toXContent_WithTenantId() throws IOException { + MLModelGroup modelGroup = MLModelGroup + .builder() + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .tenantId("test_tenant") + .build(); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + Assert + .assertEquals( + "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"access\":\"PUBLIC\",\"tenant_id\":\"test_tenant\"}", + content + ); + } + + @Test + public void parse_WithTenantId() throws IOException { + String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"access\":\"PUBLIC\",\"tenant_id\":\"test_tenant\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLModelGroup modelGroup = MLModelGroup.parse(parser); + Assert.assertEquals("test", modelGroup.getName()); + Assert.assertEquals("this is test group", modelGroup.getDescription()); + Assert.assertEquals("PUBLIC", modelGroup.getAccess()); + Assert.assertEquals("test_tenant", modelGroup.getTenantId()); + Assert.assertEquals(2, modelGroup.getBackendRoles().size()); + Assert.assertEquals("role1", modelGroup.getBackendRoles().get(0)); + Assert.assertEquals("role2", modelGroup.getBackendRoles().get(1)); + } + + @Test + public void parse_WithoutTenantId() throws IOException { + String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"access\":\"PUBLIC\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLModelGroup modelGroup = MLModelGroup.parse(parser); + Assert.assertEquals("test", modelGroup.getName()); + Assert.assertEquals("this is test group", modelGroup.getDescription()); + Assert.assertEquals("PUBLIC", modelGroup.getAccess()); + Assert.assertNull(modelGroup.getTenantId()); + Assert.assertEquals(2, modelGroup.getBackendRoles().size()); + Assert.assertEquals("role1", modelGroup.getBackendRoles().get(0)); + Assert.assertEquals("role2", modelGroup.getBackendRoles().get(1)); + } + + @Test + public void toBuilder_WithTenantId() { + MLModelGroup originalModelGroup = MLModelGroup + .builder() + .name("test") + .description("this is test group") + .latestVersion(1) + .tenantId("test_tenant") + .build(); + + MLModelGroup newModelGroup = originalModelGroup.toBuilder().build(); + Assert.assertEquals("test_tenant", newModelGroup.getTenantId()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelTests.java b/common/src/test/java/org/opensearch/ml/common/MLModelTests.java index 77a3dd6a25..c12069888c 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelTests.java @@ -136,4 +136,46 @@ public void readInputStream(MLModel mlModel) throws IOException { assertEquals(mlModel.getChunkNumber(), parsedMLModel.getChunkNumber()); assertEquals(mlModel.getTotalChunks(), parsedMLModel.getTotalChunks()); } + + @Test + public void toXContent_WithTenantId() throws IOException { + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.KMEANS) + .name("model_name") + .version("1.0.0") + .content("test_content") + .isHidden(true) + .tenantId("test_tenant") + .build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlModel.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + assertEquals( + "{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"model_version\":\"1.0.0\",\"model_content\":\"test_content\",\"is_hidden\":true,\"tenant_id\":\"test_tenant\"}", + mlModelContent + ); + } + + @Test + public void parse_WithTenantId() throws IOException { + String modelJson = """ + { + "name": "model_name", + "algorithm": "KMEANS", + "model_version": "1.0.0", + "model_content": "test_content", + "is_hidden": true, + "tenant_id": "test_tenant" + } + """; + TestHelper.testParseFromString(config, modelJson, function); + } + + @Test + public void toBuilder_WithTenantId() { + MLModel mlModelWithTenantId = mlModel.toBuilder().tenantId("test_tenant").build(); + assertEquals("test_tenant", mlModelWithTenantId.getTenantId()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 16bbc76bfa..7f83444c66 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -325,4 +325,86 @@ public static HttpConnector createHttpConnectorWithRequestBody(String requestBod return connector; } + @Test + public void writeToAndReadFrom_WithTenantId() throws IOException { + HttpConnector originalConnector = HttpConnector + .builder() + .name("test_connector_name") + .description("this is a test connector") + .protocol("http") + .tenantId("test_tenant") + .build(); + + BytesStreamOutput output = new BytesStreamOutput(); + originalConnector.writeTo(output); + + HttpConnector deserializedConnector = new HttpConnector(output.bytes().streamInput()); + Assert.assertEquals("test_tenant", deserializedConnector.getTenantId()); + Assert.assertEquals(originalConnector, deserializedConnector); + } + + @Test + public void writeToAndReadFrom_WithoutTenantId() throws IOException { + HttpConnector originalConnector = HttpConnector + .builder() + .name("test_connector_name") + .description("this is a test connector") + .protocol("http") + .build(); + + BytesStreamOutput output = new BytesStreamOutput(); + originalConnector.writeTo(output); + + HttpConnector deserializedConnector = new HttpConnector(output.bytes().streamInput()); + Assert.assertNull(deserializedConnector.getTenantId()); + Assert.assertEquals(originalConnector, deserializedConnector); + } + + @Test + public void toXContent_WithTenantId() throws IOException { + HttpConnector connector = HttpConnector + .builder() + .name("test_connector_name") + .description("this is a test connector") + .protocol("http") + .tenantId("test_tenant") + .build(); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + connector.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertTrue(content.contains("\"tenant_id\":\"test_tenant\"")); + } + + @Test + public void constructor_WithTenantId() { + HttpConnector connector = HttpConnector + .builder() + .name("test_connector_name") + .description("this is a test connector") + .protocol("http") + .tenantId("test_tenant") + .build(); + + Assert.assertEquals("test_tenant", connector.getTenantId()); + } + + @Test + public void parse_WithTenantId() throws IOException { + String jsonStr = "{\"name\":\"test_connector_name\",\"protocol\":\"http\",\"tenant_id\":\"test_tenant\"}"; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + + HttpConnector connector = new HttpConnector("http", parser); + Assert.assertEquals("test_tenant", connector.getTenantId()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java index 49b013cdf2..9fab57d545 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java @@ -77,7 +77,7 @@ public void parse_success() throws IOException { jsonStr ); parser.nextToken(); - MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId); + MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId, null); assertEquals(updateConnectorRequest.getConnectorId(), connectorId); assertTrue(updateConnectorRequest.getUpdateContent().isUpdateConnector()); assertEquals("new version", updateConnectorRequest.getUpdateContent().getVersion()); @@ -133,4 +133,55 @@ public void writeTo(StreamOutput out) throws IOException { }; MLUpdateConnectorRequest.fromActionRequest(actionRequest); } + + @Test + public void parse_withTenantId_success() throws IOException { + String tenantId = "test-tenant"; + String jsonStr = "{\"version\":\"new version\",\"description\":\"new description\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId, tenantId); + assertEquals(updateConnectorRequest.getConnectorId(), connectorId); + assertEquals(tenantId, updateConnectorRequest.getUpdateContent().getTenantId()); + assertEquals("new version", updateConnectorRequest.getUpdateContent().getVersion()); + assertEquals("new description", updateConnectorRequest.getUpdateContent().getDescription()); + } + + @Test + public void parse_withoutTenantId_success() throws IOException { + String jsonStr = "{\"version\":\"new version\",\"description\":\"new description\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId, null); + assertEquals(updateConnectorRequest.getConnectorId(), connectorId); + assertNull(updateConnectorRequest.getUpdateContent().getTenantId()); + assertEquals("new version", updateConnectorRequest.getUpdateContent().getVersion()); + assertEquals("new description", updateConnectorRequest.getUpdateContent().getDescription()); + } + + @Test + public void writeTo_withTenantId_Success() throws IOException { + updateContent.setTenantId("tenant-1"); + MLUpdateConnectorRequest request = MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(updateContent).build(); + + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLUpdateConnectorRequest parsedRequest = new MLUpdateConnectorRequest(bytesStreamOutput.bytes().streamInput()); + + assertEquals("tenant-1", parsedRequest.getUpdateContent().getTenantId()); + assertEquals(connectorId, parsedRequest.getConnectorId()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java index 109eafde95..be39c35edd 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java @@ -15,9 +15,11 @@ import org.junit.Before; import org.junit.Test; +import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; public class MLModelDeleteRequestTest { @@ -94,4 +96,108 @@ public void fromActionRequestWithModelDeleteRequest_Success() { assertSame(mlModelDeleteRequest, mlModelDeleteRequestFromActionRequest); assertEquals(mlModelDeleteRequest.getModelId(), mlModelDeleteRequestFromActionRequest.getModelId()); } + + @Test + public void writeTo_withTenantId_Success() throws IOException { + String tenantId = "tenant-1"; + MLModelDeleteRequest request = MLModelDeleteRequest.builder().modelId(modelId).tenantId(tenantId).build(); + + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + MLModelDeleteRequest parsedRequest = new MLModelDeleteRequest(out.bytes().streamInput()); + + assertEquals(modelId, parsedRequest.getModelId()); + assertEquals(tenantId, parsedRequest.getTenantId()); + } + + @Test + public void writeTo_withoutTenantId_Success() throws IOException { + MLModelDeleteRequest request = MLModelDeleteRequest.builder().modelId(modelId).build(); + + BytesStreamOutput out = new BytesStreamOutput(); + request.writeTo(out); + MLModelDeleteRequest parsedRequest = new MLModelDeleteRequest(out.bytes().streamInput()); + + assertEquals(modelId, parsedRequest.getModelId()); + assertNull(parsedRequest.getTenantId()); + } + + @Test + public void serialization_withOlderVersion_Success() throws IOException { + MLModelDeleteRequest request = MLModelDeleteRequest.builder().modelId(modelId).tenantId("tenant-1").build(); + + // Serialize with an older version + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_2_18_0); // Older version without tenantId support + request.writeTo(out); + + // Deserialize with the same older version + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.V_2_18_0); // Ensure the version matches + MLModelDeleteRequest parsedRequest = new MLModelDeleteRequest(in); + + // Validate + assertEquals(modelId, parsedRequest.getModelId()); + assertNull(parsedRequest.getTenantId()); // tenantId should not be read + } + + @Test + public void serialization_withNewVersion_Success() throws IOException { + String tenantId = "tenant-1"; + MLModelDeleteRequest request = MLModelDeleteRequest.builder().modelId(modelId).tenantId(tenantId).build(); + + // Serialize with a newer version + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_2_19_0); + request.writeTo(out); + + // Deserialize with the same newer version + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.V_2_19_0); + MLModelDeleteRequest parsedRequest = new MLModelDeleteRequest(in); + + // Validate + assertEquals(modelId, parsedRequest.getModelId()); + assertEquals(tenantId, parsedRequest.getTenantId()); // tenantId should be preserved + } + + @Test + public void fromActionRequest_withTenantId_Success() { + MLModelDeleteRequest originalRequest = MLModelDeleteRequest.builder().modelId(modelId).tenantId("tenant-1").build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + originalRequest.writeTo(out); + } + }; + + MLModelDeleteRequest parsedRequest = MLModelDeleteRequest.fromActionRequest(actionRequest); + assertNotSame(originalRequest, parsedRequest); + assertEquals(originalRequest.getModelId(), parsedRequest.getModelId()); + assertEquals(originalRequest.getTenantId(), parsedRequest.getTenantId()); + } + + @Test + public void writeTo_withOlderVersion_withoutTenantId_Success() throws IOException { + MLModelDeleteRequest request = MLModelDeleteRequest.builder().modelId(modelId).tenantId("xyz").build(); + + // Serialize with an older version + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.V_2_19_0); + request.writeTo(out); + + // Deserialize + StreamInput in = out.bytes().streamInput(); + in.setVersion(Version.V_2_18_0); + MLModelDeleteRequest parsedRequest = new MLModelDeleteRequest(in); + + assertEquals(modelId, parsedRequest.getModelId()); + assertNull(parsedRequest.getTenantId()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java index 4a16bf9347..1174c4cca3 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java @@ -6,9 +6,12 @@ package org.opensearch.ml.common.transport.model; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.opensearch.ml.common.CommonValue.VERSION_2_18_0; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; import java.io.UncheckedIOException; @@ -18,6 +21,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; public class MLModelGetRequestTest { @@ -94,4 +98,79 @@ public void fromActionRequestWithMLModelGetRequest_Success() { assertSame(mlModelGetRequest, mlModelGetRequestFromActionRequest); assertEquals(mlModelGetRequest.getModelId(), mlModelGetRequestFromActionRequest.getModelId()); } + + @Test + public void writeTo_withTenantId_Success() throws IOException { + String tenantId = "tenant-1"; + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).tenantId(tenantId).build(); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); // Newer version supporting tenantId + mlModelGetRequest.writeTo(out); + + MLModelGetRequest parsedRequest = new MLModelGetRequest(out.bytes().streamInput()); + assertEquals(modelId, parsedRequest.getModelId()); + assertEquals(tenantId, parsedRequest.getTenantId()); + } + + @Test + public void writeTo_withoutTenantId_Success() throws IOException { + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build(); // No tenantId set + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); // Newer version supporting tenantId + mlModelGetRequest.writeTo(out); + + MLModelGetRequest parsedRequest = new MLModelGetRequest(out.bytes().streamInput()); + assertEquals(modelId, parsedRequest.getModelId()); + assertNull(parsedRequest.getTenantId()); // TenantId should be null + } + + @Test + public void writeTo_withOlderVersion_Success() throws IOException { + String tenantId = "tenant-1"; + MLModelGetRequest mlModelGetRequest = MLModelGetRequest + .builder() + .modelId(modelId) + .tenantId(tenantId) // Tenant ID is set, but won't be written for older versions + .build(); + + // Serialize with an older version + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_18_0); // Older version without tenantId support + mlModelGetRequest.writeTo(out); + + // Deserialize with the same older version + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_18_0); // Set version explicitly + MLModelGetRequest parsedRequest = new MLModelGetRequest(in); + + // Validate deserialization + assertEquals(modelId, parsedRequest.getModelId()); + assertNull(parsedRequest.getTenantId()); // TenantId should not be present for older versions + assertFalse(parsedRequest.isReturnContent()); // Default value for boolean fields + assertFalse(parsedRequest.isUserInitiatedGetRequest()); // Default value for boolean fields + } + + @Test + public void fromActionRequest_withTenantId_Success() { + String tenantId = "tenant-1"; + MLModelGetRequest originalRequest = MLModelGetRequest.builder().modelId(modelId).tenantId(tenantId).build(); + + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + originalRequest.writeTo(out); + } + }; + + MLModelGetRequest parsedRequest = MLModelGetRequest.fromActionRequest(actionRequest); + assertEquals(originalRequest.getModelId(), parsedRequest.getModelId()); + assertEquals(originalRequest.getTenantId(), parsedRequest.getTenantId()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index 46e89d0aa6..5591e8d273 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -6,8 +6,11 @@ package org.opensearch.ml.common.transport.model; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; import java.time.Instant; @@ -202,13 +205,68 @@ public void parseWithIllegalFieldWithoutModel() throws Exception { + "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1,\"illegal_field\":\"This field need to be skipped.\"}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { - assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + String jsonStr = serializationWithToXContent(parsedInput); + assertTrue(jsonStr.contains("\"model_id\":\"test-model_id\"")); // Validate expected content + assertFalse(jsonStr.contains("\"illegal_field\"")); // Ensure illegal fields are skipped } catch (IOException e) { throw new RuntimeException(e); } }); } + @Test + public void serializationWithTenantId_Success() throws IOException { + MLUpdateModelInput input = updateModelInput.toBuilder().tenantId("tenant-1").build(); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); // Version with tenantId support + input.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_19_0); + MLUpdateModelInput parsedInput = new MLUpdateModelInput(in); + + assertEquals(input.getTenantId(), parsedInput.getTenantId()); + } + + @Test + public void serializationWithoutTenantId_Success() throws IOException { + MLUpdateModelInput input = updateModelInput.toBuilder().tenantId(null).build(); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); // Version with tenantId support + input.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_19_0); + MLUpdateModelInput parsedInput = new MLUpdateModelInput(in); + + assertNull(parsedInput.getTenantId()); + } + + @Test + public void parseWithTenantId_Success() throws Exception { + String jsonWithTenantId = + "{\"model_id\":\"test-model_id\",\"tenant_id\":\"tenant-1\",\"name\":\"name\",\"description\":\"description\"}"; + testParseFromJsonString(jsonWithTenantId, parsedInput -> { + assertEquals("tenant-1", parsedInput.getTenantId()); + assertEquals("test-model_id", parsedInput.getModelId()); + }); + } + + @Test + public void toXContentWithTenantId_Success() throws IOException { + MLUpdateModelInput input = updateModelInput.toBuilder().tenantId("tenant-1").build(); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonOutput = builder.toString(); + + // Validate that tenantId is present in the serialized JSON + assertTrue(jsonOutput.contains("\"tenant_id\":\"tenant-1\"")); + } + private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { XContentParser parser = XContentType.JSON .xContent() diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java index a33e169668..a63a8f3b6c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java @@ -4,6 +4,8 @@ import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.opensearch.ml.common.CommonValue.VERSION_2_18_0; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; import java.io.UncheckedIOException; @@ -13,12 +15,13 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; public class MLModelGroupDeleteRequestTest { private String modelGroupId; - + private String tenantId; private MLModelGroupDeleteRequest request; @Before @@ -88,4 +91,61 @@ public void writeTo(StreamOutput out) throws IOException { MLModelGroupDeleteRequest.fromActionRequest(actionRequest); } + @Test + public void writeToAndReadFrom_withTenantId_Success() throws IOException { + tenantId = "tenant-1"; + request = MLModelGroupDeleteRequest.builder().modelGroupId(modelGroupId).tenantId(tenantId).build(); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); // Newer version supporting tenantId + request.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_19_0); // Ensure version alignment + MLModelGroupDeleteRequest parsedRequest = new MLModelGroupDeleteRequest(in); + + assertEquals(modelGroupId, parsedRequest.getModelGroupId()); + assertEquals(tenantId, parsedRequest.getTenantId()); + } + + @Test + public void fromActionRequest_withTenantId_Success() { + tenantId = "tenant-1"; + request = MLModelGroupDeleteRequest.builder().modelGroupId(modelGroupId).tenantId(tenantId).build(); + + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + + MLModelGroupDeleteRequest result = MLModelGroupDeleteRequest.fromActionRequest(actionRequest); + assertNotSame(result, request); + assertEquals(result.getModelGroupId(), request.getModelGroupId()); + assertEquals(result.getTenantId(), request.getTenantId()); + } + + @Test + public void writeToAndReadFrom_withOlderVersion_TenantIdIgnored() throws IOException { + tenantId = "tenant-1"; + request = MLModelGroupDeleteRequest.builder().modelGroupId(modelGroupId).tenantId(tenantId).build(); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); // Serialize with newer version + request.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_18_0); // Older version without tenantId support + MLModelGroupDeleteRequest parsedRequest = new MLModelGroupDeleteRequest(in); + + assertEquals(modelGroupId, parsedRequest.getModelGroupId()); + assertNull(parsedRequest.getTenantId()); // tenantId should not be deserialized + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java index 7a463f28bc..9529b74bee 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java @@ -9,6 +9,8 @@ import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.opensearch.ml.common.CommonValue.VERSION_2_18_0; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; import java.io.UncheckedIOException; @@ -18,6 +20,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; public class MLModelGroupGetRequestTest { @@ -94,4 +97,45 @@ public void fromActionRequestWithMLModelGroupGetRequest_Success() { assertSame(mlModelGroupGetRequest, mlModelGroupGetRequestFromActionRequest); assertEquals(mlModelGroupGetRequest.getModelGroupId(), mlModelGroupGetRequestFromActionRequest.getModelGroupId()); } + + @Test + public void writeToAndReadFrom_withOlderVersion_TenantIdIgnored() throws IOException { + String tenantId = "tenant-1"; + MLModelGroupGetRequest request = MLModelGroupGetRequest.builder().modelGroupId(modelGroupId).tenantId(tenantId).build(); + + // Serialize with newer version + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); // Newer version with tenantId support + request.writeTo(out); + + // Deserialize with older version + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_18_0); // Older version without tenantId support + MLModelGroupGetRequest parsedRequest = new MLModelGroupGetRequest(in); + + // Validate + assertEquals(modelGroupId, parsedRequest.getModelGroupId()); + assertNull(parsedRequest.getTenantId()); // tenantId should not be deserialized + } + + @Test + public void writeToAndReadFrom_withNewerVersion_TenantIdIncluded() throws IOException { + String tenantId = "tenant-1"; + MLModelGroupGetRequest request = MLModelGroupGetRequest.builder().modelGroupId(modelGroupId).tenantId(tenantId).build(); + + // Serialize with newer version + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); // Newer version with tenantId support + request.writeTo(out); + + // Deserialize with newer version + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_19_0); // Newer version with tenantId support + MLModelGroupGetRequest parsedRequest = new MLModelGroupGetRequest(in); + + // Validate + assertEquals(modelGroupId, parsedRequest.getModelGroupId()); + assertEquals(tenantId, parsedRequest.getTenantId()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java index 68cd72836e..ef69e82bf7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java @@ -1,6 +1,9 @@ package org.opensearch.ml.common.transport.model_group; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.opensearch.ml.common.CommonValue.VERSION_2_18_0; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; import java.util.Arrays; @@ -8,7 +11,11 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; public class MLRegisterModelGroupInputTest { @@ -36,4 +43,100 @@ public void readInputStream() throws IOException { MLRegisterModelGroupInput parsedInput = new MLRegisterModelGroupInput(streamInput); assertEquals(mlRegisterModelGroupInput.getName(), parsedInput.getName()); } + + @Test + public void writeToAndReadFrom_withTenantId_Success() throws IOException { + MLRegisterModelGroupInput input = MLRegisterModelGroupInput + .builder() + .name("name") + .description("description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .tenantId("tenant-1") + .build(); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); // Serialize with newer version + input.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_19_0); // Deserialize with the same version + MLRegisterModelGroupInput parsedInput = new MLRegisterModelGroupInput(in); + + assertEquals("name", parsedInput.getName()); + assertEquals("tenant-1", parsedInput.getTenantId()); + } + + @Test + public void writeToAndReadFrom_withOlderVersion_TenantIdIgnored() throws IOException { + MLRegisterModelGroupInput input = MLRegisterModelGroupInput + .builder() + .name("name") + .description("description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .tenantId("tenant-1") + .build(); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(VERSION_2_19_0); // Serialize with newer version + input.writeTo(out); + + StreamInput in = out.bytes().streamInput(); + in.setVersion(VERSION_2_18_0); // Deserialize with older version + MLRegisterModelGroupInput parsedInput = new MLRegisterModelGroupInput(in); + + assertEquals("name", parsedInput.getName()); + assertNull(parsedInput.getTenantId()); // tenantId should not be deserialized in older versions + } + + @Test + public void parse_withTenantId_Success() throws IOException { + String jsonWithTenantId = """ + { + "name": "name", + "description": "description", + "backend_roles": ["IT"], + "access_mode": "restricted", + "add_all_backend_roles": true, + "tenant_id": "tenant-1" + } + """; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, jsonWithTenantId); + + parser.nextToken(); // Start parsing + MLRegisterModelGroupInput parsedInput = MLRegisterModelGroupInput.parse(parser); + + assertEquals("name", parsedInput.getName()); + assertEquals("tenant-1", parsedInput.getTenantId()); + } + + @Test + public void parse_withoutTenantId_Success() throws IOException { + String jsonWithoutTenantId = """ + { + "name": "name", + "description": "description", + "backend_roles": ["IT"], + "access_mode": "restricted", + "add_all_backend_roles": true + } + """; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, jsonWithoutTenantId); + + parser.nextToken(); // Start parsing + MLRegisterModelGroupInput parsedInput = MLRegisterModelGroupInput.parse(parser); + + assertEquals("name", parsedInput.getName()); + assertNull(parsedInput.getTenantId()); // tenantId is not provided in the JSON + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java index 96d6b36a45..17680a621a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java @@ -1,14 +1,21 @@ package org.opensearch.ml.common.transport.model_group; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.opensearch.ml.common.CommonValue.VERSION_2_18_0; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; -import java.util.Arrays; +import java.util.List; import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; public class MLUpdateModelGroupInputTest { @@ -23,7 +30,7 @@ public void setUp() throws Exception { .modelGroupID("modelGroupId") .name("name") .description("description") - .backendRoles(Arrays.asList("IT")) + .backendRoles(List.of("IT")) .modelAccessMode(AccessMode.RESTRICTED) .isAddAllBackendRoles(true) .build(); @@ -37,4 +44,86 @@ public void readInputStream() throws IOException { MLUpdateModelGroupInput parsedInput = new MLUpdateModelGroupInput(streamInput); assertEquals(mlUpdateModelGroupInput.getName(), parsedInput.getName()); } + + @Test + public void readInputStream_withTenantId_Success() throws IOException { + // Ensure tenantId is included in the test setup + MLUpdateModelGroupInput inputWithTenantId = mlUpdateModelGroupInput.toBuilder().tenantId("tenant-1").build(); + + // Serialize with a newer version that supports tenantId + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + bytesStreamOutput.setVersion(VERSION_2_19_0); + inputWithTenantId.writeTo(bytesStreamOutput); + + // Deserialize and verify + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + streamInput.setVersion(VERSION_2_19_0); + MLUpdateModelGroupInput parsedInput = new MLUpdateModelGroupInput(streamInput); + + assertEquals("modelGroupId", parsedInput.getModelGroupID()); + assertEquals("tenant-1", parsedInput.getTenantId()); + } + + @Test + public void writeToAndReadFrom_withOlderVersion_TenantIdIgnored() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + bytesStreamOutput.setVersion(VERSION_2_19_0); // Serialize with newer version + mlUpdateModelGroupInput.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + streamInput.setVersion(VERSION_2_18_0); // Deserialize with older version + MLUpdateModelGroupInput parsedInput = new MLUpdateModelGroupInput(streamInput); + + assertEquals("modelGroupId", parsedInput.getModelGroupID()); + assertNull(parsedInput.getTenantId()); // tenantId should not be deserialized in older versions + } + + @Test + public void parse_withTenantId_Success() throws IOException { + String jsonWithTenantId = """ + { + "model_group_id": "modelGroupId", + "name": "name", + "description": "description", + "backend_roles": ["IT"], + "access_mode": "restricted", + "add_all_backend_roles": true, + "tenant_id": "tenant-1" + } + """; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, jsonWithTenantId); + + parser.nextToken(); // Start parsing + MLUpdateModelGroupInput parsedInput = MLUpdateModelGroupInput.parse(parser); + + assertEquals("modelGroupId", parsedInput.getModelGroupID()); + assertEquals("tenant-1", parsedInput.getTenantId()); + } + + @Test + public void parse_withoutTenantId_Success() throws IOException { + String jsonWithoutTenantId = """ + { + "model_group_id": "modelGroupId", + "name": "name", + "description": "description", + "backend_roles": ["IT"], + "access_mode": "restricted", + "add_all_backend_roles": true + } + """; + + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, jsonWithoutTenantId); + + parser.nextToken(); // Start parsing + MLUpdateModelGroupInput parsedInput = MLUpdateModelGroupInput.parse(parser); + + assertEquals("modelGroupId", parsedInput.getModelGroupID()); + assertNull(parsedInput.getTenantId()); // tenantId is not provided in the JSON + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index 8caed811b5..119e22e1f7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -4,6 +4,8 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; import java.io.IOException; import java.util.Collections; @@ -363,6 +365,52 @@ public void readInputStream_MCorr() throws IOException { }); } + @Test + public void readInputStream_withTenantId_Success() throws IOException { + // Add tenantId to input + input = input.toBuilder().tenantId("tenant-1").build(); + + // Serialize with newer version + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + bytesStreamOutput.setVersion(VERSION_2_19_0); + input.writeTo(bytesStreamOutput); + + // Deserialize and verify tenantId + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + streamInput.setVersion(VERSION_2_19_0); + MLRegisterModelInput parsedInput = new MLRegisterModelInput(streamInput); + + assertEquals("tenant-1", parsedInput.getTenantId()); + } + + @Test + public void toXContent_withTenantId_Success() throws IOException { + // Add tenantId to input + input = input.toBuilder().tenantId("tenant-1").build(); + + // Convert to XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + + // Verify tenantId is serialized correctly + assertTrue(jsonStr.contains("\"tenant_id\":\"tenant-1\"")); + } + + @Test + public void toXContent_withoutTenantId_Success() throws IOException { + // Ensure input does not have tenantId + input = input.toBuilder().tenantId(null).build(); + + // Convert to XContent + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + + // Verify tenantId is not present + assertFalse(jsonStr.contains("\"tenant_id\"")); + } + private void readInputStream(MLRegisterModelInput input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index 9efe9372b8..ffce3e0d1f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -364,7 +364,7 @@ public MLTask getTask(String taskId) { } public MLModel getModel(String modelId) { - MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, false); + MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, false, null); ActionFuture future = client.execute(MLModelGetAction.INSTANCE, getRequest); MLModelGetResponse response = future.actionGet(5000); return response.getMlModel(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index c44846fa98..9dd4cb25b3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -107,7 +107,7 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) log.info("Skip creating the Index:{} that is already created by another parallel request", indexName); internalListener.onResponse(true); } else { - log.error("Failed to create index " + indexName, e); + log.error("Failed to create index {}", indexName, e); internalListener.onFailure(e); } }); @@ -148,14 +148,14 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) .onFailure(new MLException("Failed to update index setting for: " + indexName)); } }, exception -> { - log.error("Failed to update index setting for: " + indexName, exception); + log.error("Failed to update index setting for: {}", indexName, exception); internalListener.onFailure(exception); })); } else { internalListener.onFailure(new MLException("Failed to update index: " + indexName)); } }, exception -> { - log.error("Failed to update index " + indexName, exception); + log.error("Failed to update index {}", indexName, exception); internalListener.onFailure(exception); }) ); @@ -175,7 +175,7 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) } } } catch (Exception e) { - log.error("Failed to init index " + indexName, e); + log.error("Failed to init index {}", indexName, e); listener.onFailure(e); } } @@ -197,7 +197,7 @@ public void shouldUpdateIndex(String indexName, Integer newVersion, ActionListen Integer oldVersion = CommonValue.NO_SCHEMA_VERSION; Map indexMapping = indexMetaData.mapping().getSourceAsMap(); Object meta = indexMapping.get(META); - if (meta != null && meta instanceof Map) { + if (meta instanceof Map) { @SuppressWarnings("unchecked") Map metaMapping = (Map) meta; Object schemaVersion = metaMapping.get(SCHEMA_VERSION_FIELD); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index e9552cc650..ae5521da9f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -8,13 +8,13 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; -import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.delete.DeleteRequest; @@ -41,6 +41,7 @@ import org.opensearch.remote.metadata.client.DeleteDataObjectResponse; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.client.SearchDataObjectResponse; import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -120,48 +121,57 @@ private void handleConnectorAccessValidationFailure(String connectorId, Exceptio private void checkForModelsUsingConnector(String connectorId, String tenantId, ActionListener actionListener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener restoringListener = ActionListener.runBefore(actionListener, context::restore); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); - if (mlFeatureEnabledSetting.isMultiTenancyEnabled()) { - sourceBuilder.query(QueryBuilders.matchQuery(TENANT_ID_FIELD, tenantId)); - } + SearchDataObjectRequest searchRequest = buildModelSearchRequest(connectorId, tenantId); - SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest - .builder() - .indices(ML_MODEL_INDEX) - .tenantId(tenantId) - .searchSourceBuilder(sourceBuilder) - .build(); sdkClient - .searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) - .whenComplete((sr, st) -> { - if (sr != null) { - try { - SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser()); - SearchHit[] searchHits = searchResponse.getHits().getHits(); - if (searchHits.length == 0) { - deleteConnector(connectorId, tenantId, restoringListener); - } else { - handleModelsUsingConnector(searchHits, connectorId, restoringListener); - } - } catch (Exception e) { - log.error("Failed to parse search response", e); - restoringListener - .onFailure( - new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR) - ); - } - } else { - Exception cause = SdkClientUtils.unwrapAndConvertToException(st); - handleSearchFailure(connectorId, tenantId, cause, restoringListener); - } - }); + .searchDataObjectAsync(searchRequest) + .whenComplete( + (searchResponse, throwable) -> handleSearchResponse(connectorId, tenantId, restoringListener, searchResponse, throwable) + ); } catch (Exception e) { log.error("Failed to check for models using connector: {}", connectorId, e); actionListener.onFailure(e); } } + private SearchDataObjectRequest buildModelSearchRequest(String connectorId, String tenantId) { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); + if (mlFeatureEnabledSetting.isMultiTenancyEnabled()) { + sourceBuilder.query(QueryBuilders.matchQuery(TENANT_ID_FIELD, tenantId)); + } + + return SearchDataObjectRequest.builder().indices(ML_MODEL_INDEX).tenantId(tenantId).searchSourceBuilder(sourceBuilder).build(); + } + + private void handleSearchResponse( + String connectorId, + String tenantId, + ActionListener restoringListener, + SearchDataObjectResponse searchResponse, + Throwable throwable + ) { + if (searchResponse == null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + handleSearchFailure(connectorId, tenantId, cause, restoringListener); + return; + } + + try { + SearchResponse response = SearchResponse.fromXContent(searchResponse.parser()); + SearchHit[] searchHits = response.getHits().getHits(); + + if (searchHits.length == 0) { + deleteConnector(connectorId, tenantId, restoringListener); + } else { + handleModelsUsingConnector(searchHits, connectorId, restoringListener); + } + } catch (Exception e) { + log.error("Failed to parse search response", e); + restoringListener.onFailure(new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR)); + } + } + private void handleModelsUsingConnector(SearchHit[] searchHits, String connectorId, ActionListener actionListener) { log.error("{} models are still using this connector, please delete or update the models first!", searchHits.length); List modelIds = new ArrayList<>(); @@ -180,7 +190,7 @@ private void handleModelsUsingConnector(SearchHit[] searchHits, String connector } private void handleSearchFailure(String connectorId, String tenantId, Exception cause, ActionListener actionListener) { - if (cause instanceof IndexNotFoundException) { + if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { deleteConnector(connectorId, tenantId, actionListener); return; } @@ -193,8 +203,7 @@ private void deleteConnector(String connectorId, String tenantId, ActionListener try { sdkClient .deleteDataObjectAsync( - DeleteDataObjectRequest.builder().index(deleteRequest.index()).id(deleteRequest.id()).tenantId(tenantId).build(), - client.threadPool().executor(GENERAL_THREAD_POOL) + DeleteDataObjectRequest.builder().index(deleteRequest.index()).id(deleteRequest.id()).tenantId(tenantId).build() ) .whenComplete((response, throwable) -> handleDeleteResponse(response, throwable, connectorId, actionListener)); } catch (Exception e) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index 1c2aa3a2d5..5e766b093d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -6,7 +6,6 @@ package org.opensearch.ml.action.connector; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import java.io.IOException; @@ -153,8 +152,7 @@ private void indexConnector(Connector connector, ActionListener { context.restore(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index 328bbd7e83..e646cadb83 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -9,41 +9,49 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; +import java.io.IOException; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; +import org.opensearch.remote.metadata.client.UpdateDataObjectResponse; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -54,11 +62,13 @@ @Log4j2 @FieldDefaults(level = AccessLevel.PRIVATE) public class UpdateConnectorTransportAction extends HandledTransportAction { - Client client; + final Client client; + final SdkClient sdkClient; - ConnectorAccessControlHelper connectorAccessControlHelper; - MLModelManager mlModelManager; - MLEngine mlEngine; + final ConnectorAccessControlHelper connectorAccessControlHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + final MLModelManager mlModelManager; + final MLEngine mlEngine; volatile List trustedConnectorEndpointsRegex; @Inject @@ -66,17 +76,21 @@ public UpdateConnectorTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, ConnectorAccessControlHelper connectorAccessControlHelper, MLModelManager mlModelManager, Settings settings, ClusterService clusterService, - MLEngine mlEngine + MLEngine mlEngine, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new); this.client = client; + this.sdkClient = sdkClient; this.connectorAccessControlHelper = connectorAccessControlHelper; this.mlModelManager = mlModelManager; this.mlEngine = mlEngine; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); clusterService .getClusterSettings() @@ -86,31 +100,57 @@ public UpdateConnectorTransportAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLUpdateConnectorRequest mlUpdateConnectorAction = MLUpdateConnectorRequest.fromActionRequest(request); + MLCreateConnectorInput mlCreateConnectorInput = mlUpdateConnectorAction.getUpdateContent(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, mlCreateConnectorInput.getTenantId(), listener)) { + return; + } String connectorId = mlUpdateConnectorAction.getConnectorId(); - + String tenantId = mlCreateConnectorInput.getTenantId(); + FetchSourceContext fetchSourceContext = new FetchSourceContext(true, Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY); + GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest + .builder() + .index(ML_CONNECTOR_INDEX) + .id(connectorId) + .tenantId(tenantId) + .fetchSourceContext(fetchSourceContext) + .build(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - connectorAccessControlHelper.getConnector(client, connectorId, ActionListener.wrap(connector -> { - boolean hasPermission = connectorAccessControlHelper.validateConnectorAccess(client, connector); - if (Boolean.TRUE.equals(hasPermission)) { - connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt); - connector.validateConnectorURL(trustedConnectorEndpointsRegex); - - connector.setLastUpdateTime(Instant.now()); - - UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - updateRequest.doc(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); - updateUndeployedConnector(connectorId, updateRequest, listener, context); - } else { - listener - .onFailure( - new IllegalArgumentException("You don't have permission to update the connector, connector id: " + connectorId) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", connectorId, exception); - listener.onFailure(exception); - })); + connectorAccessControlHelper + .getConnector(sdkClient, client, context, getDataObjectRequest, connectorId, ActionListener.wrap(connector -> { + // context is already restored here + if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, connector.getTenantId(), listener)) { + boolean hasPermission = connectorAccessControlHelper.validateConnectorAccess(client, connector); + if (hasPermission) { + connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt); + connector.validateConnectorURL(trustedConnectorEndpointsRegex); + connector.setLastUpdateTime(Instant.now()); + UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest + .builder() + .index(ML_CONNECTOR_INDEX) + .id(connectorId) + .tenantId(tenantId) + .dataObject(connector) + .build(); + try (ThreadContext.StoredContext innerContext = client.threadPool().getThreadContext().stashContext()) { + updateUndeployedConnector( + connectorId, + updateDataObjectRequest, + ActionListener.runBefore(listener, innerContext::restore) + ); + } + } else { + listener + .onFailure( + new IllegalArgumentException( + "You don't have permission to update the connector, connector id: " + connectorId + ) + ); + } + } + }, exception -> { + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", connectorId, exception); + listener.onFailure(exception); + })); } catch (Exception e) { log.error("Failed to update ML connector for connector id {}. Details {}:", connectorId, e); listener.onFailure(e); @@ -119,55 +159,84 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener, - ThreadContext.StoredContext context + UpdateDataObjectRequest updateDataObjectRequest, + ActionListener listener ) { - SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); boolQueryBuilder.must(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); boolQueryBuilder.must(QueryBuilders.idsQuery().addIds(mlModelManager.getAllModelIds())); sourceBuilder.query(boolQueryBuilder); - searchRequest.source(sourceBuilder); - client.search(searchRequest, ActionListener.wrap(searchResponse -> { - SearchHit[] searchHits = searchResponse.getHits().getHits(); - if (searchHits.length == 0) { - client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); + SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest + .builder() + .indices(ML_MODEL_INDEX) + .tenantId(updateDataObjectRequest.tenantId()) + .searchSourceBuilder(sourceBuilder) + .build(); + sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((sr, st) -> { + if (sr != null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser()); + SearchHit[] searchHits = searchResponse.getHits().getHits(); + if (searchHits.length == 0) { + sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((r, throwable) -> { + handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(connectorId, listener)); + }); + } else { + log.error("{} models are still using this connector, please undeploy the models first!", searchHits.length); + List modelIds = new ArrayList<>(); + for (SearchHit hit : searchHits) { + modelIds.add(hit.getId()); + } + listener + .onFailure( + new OpenSearchStatusException( + searchHits.length + + " models are still using this connector, please undeploy the models first: " + + Arrays.toString(modelIds.toArray(new String[0])), + RestStatus.BAD_REQUEST + ) + ); + } + } catch (Exception e) { + log.error("Failed to parse search response", e); + listener.onFailure(new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR)); + } } else { - log.error(searchHits.length + " models are still using this connector, please undeploy the models first!"); - List modelIds = new ArrayList<>(); - for (SearchHit hit : searchHits) { - modelIds.add(hit.getId()); + Exception cause = SdkClientUtils.unwrapAndConvertToException(st); + if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { + sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((r, throwable) -> { + handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(connectorId, listener)); + }); + } else { + log.error("Failed to update ML connector: {}", connectorId, cause); + listener.onFailure(cause); } - listener - .onFailure( - new OpenSearchStatusException( - searchHits.length - + " models are still using this connector, please undeploy the models first: " - + Arrays.toString(modelIds.toArray(new String[0])), - RestStatus.BAD_REQUEST - ) - ); } - }, e -> { - if (e instanceof IndexNotFoundException) { - client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); - return; - } - log.error("Failed to update ML connector: " + connectorId, e); - listener.onFailure(e); - - })); + }); } - private ActionListener getUpdateResponseListener( - String connectorId, - ActionListener actionListener, - ThreadContext.StoredContext context + private void handleUpdateDataObjectCompletionStage( + UpdateDataObjectResponse r, + Throwable throwable, + ActionListener updateListener ) { - return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + updateListener.onFailure(cause); + } else { + try { + UpdateResponse updateResponse = r.parser() == null ? null : UpdateResponse.fromXContent(r.parser()); + updateListener.onResponse(updateResponse); + } catch (IOException e) { + updateListener.onFailure(e); + } + } + } + + private ActionListener getUpdateResponseListener(String connectorId, ActionListener actionListener) { + return ActionListener.wrap(updateResponse -> { if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { log.error("Failed to update the connector with ID: {}", connectorId); actionListener.onResponse(updateResponse); @@ -178,6 +247,6 @@ private ActionListener getUpdateResponseListener( }, exception -> { log.error("Failed to update ML connector with ID {}. Details: {}", connectorId, exception); actionListener.onFailure(exception); - }), context::restore); + }); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index d75a5668dc..b7642ba252 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -9,19 +9,23 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; +import java.io.IOException; + +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; @@ -30,7 +34,15 @@ import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.DeleteDataObjectRequest; +import org.opensearch.remote.metadata.client.DeleteDataObjectResponse; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.client.SearchDataObjectResponse; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -43,82 +55,183 @@ @FieldDefaults(level = AccessLevel.PRIVATE) public class DeleteModelGroupTransportAction extends HandledTransportAction { - Client client; - NamedXContentRegistry xContentRegistry; - ClusterService clusterService; + final Client client; + final SdkClient sdkClient; + final NamedXContentRegistry xContentRegistry; + final ClusterService clusterService; - ModelAccessControlHelper modelAccessControlHelper; + final ModelAccessControlHelper modelAccessControlHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject public DeleteModelGroupTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, NamedXContentRegistry xContentRegistry, ClusterService clusterService, - ModelAccessControlHelper modelAccessControlHelper + ModelAccessControlHelper modelAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLModelGroupDeleteAction.NAME, transportService, actionFilters, MLModelGroupDeleteRequest::new); this.client = client; + this.sdkClient = sdkClient; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { - MLModelGroupDeleteRequest mlModelGroupDeleteRequest = MLModelGroupDeleteRequest.fromActionRequest(request); - String modelGroupId = mlModelGroupDeleteRequest.getModelGroupId(); - DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - User user = RestActionUtils.getUserContext(client); + MLModelGroupDeleteRequest deleteRequest = MLModelGroupDeleteRequest.fromActionRequest(request); + String modelGroupId = deleteRequest.getModelGroupId(); + String tenantId = deleteRequest.getTenantId(); + + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); - modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { - if (!access) { - wrappedListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group")); - } else { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap(mlModels -> { - if (mlModels == null || mlModels.getHits().getTotalHits() == null || mlModels.getHits().getTotalHits().value == 0) { - deleteModelGroup(deleteRequest, modelGroupId, wrappedListener); - } else { - throw new MLValidationException("Cannot delete the model group when it has associated model versions"); - } - - }, e -> { - if (e instanceof IndexNotFoundException) { - deleteModelGroup(deleteRequest, modelGroupId, wrappedListener); - } else { - log.error("Failed to search models with the specified Model Group Id " + modelGroupId, e); - wrappedListener.onFailure(e); - } - })); - } - }, e -> { - log.error("Failed to validate Access for Model Group " + modelGroupId, e); - wrappedListener.onFailure(e); - })); + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + validateAndDeleteModelGroup(modelGroupId, tenantId, wrappedListener); } } - private void deleteModelGroup(DeleteRequest deleteRequest, String modelGroupId, ActionListener actionListener) { - client.delete(deleteRequest, new ActionListener() { - @Override - public void onResponse(DeleteResponse deleteResponse) { - log.debug("Completed Delete Model Group Request, task id:{} deleted", modelGroupId); - actionListener.onResponse(deleteResponse); + private void validateAndDeleteModelGroup(String modelGroupId, String tenantId, ActionListener listener) { + User user = RestActionUtils.getUserContext(client); + modelAccessControlHelper + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + tenantId, + modelGroupId, + client, + sdkClient, + ActionListener + .wrap( + hasAccess -> handleAccessValidation(hasAccess, modelGroupId, tenantId, listener), + error -> handleValidationError(error, modelGroupId, listener) + ) + ); + } + + private void handleAccessValidation(boolean hasAccess, String modelGroupId, String tenantId, ActionListener listener) { + if (!hasAccess) { + listener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group")); + return; + } + checkForAssociatedModels(modelGroupId, tenantId, listener); + } + + private void checkForAssociatedModels(String modelGroupId, String tenantId, ActionListener listener) { + SearchDataObjectRequest searchRequest = buildModelSearchRequest(modelGroupId, tenantId); + + sdkClient + .searchDataObjectAsync(searchRequest) + .whenComplete( + (searchResponse, throwable) -> handleModelSearchResponse(searchResponse, throwable, modelGroupId, tenantId, listener) + ); + } + + private SearchDataObjectRequest buildModelSearchRequest(String modelGroupId, String tenantId) { + BoolQueryBuilder query = new BoolQueryBuilder().filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId)); + SearchSourceBuilder searchSource = new SearchSourceBuilder().query(query); + + return SearchDataObjectRequest.builder().indices(ML_MODEL_INDEX).tenantId(tenantId).searchSourceBuilder(searchSource).build(); + } + + private void handleModelSearchResponse( + SearchDataObjectResponse searchResponse, + Throwable throwable, + String modelGroupId, + String tenantId, + ActionListener listener + ) { + if (searchResponse == null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + handleModelSearchFailure(modelGroupId, tenantId, cause, listener); + return; + } + + try { + SearchResponse response = SearchResponse.fromXContent(searchResponse.parser()); + if (response.getHits().getHits().length == 0) { + DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId); + deleteModelGroup(deleteRequest, tenantId, listener); + } else { + listener + .onFailure( + new OpenSearchStatusException( + "Cannot delete the model group when it has associated model versions", + RestStatus.CONFLICT + ) + ); } + } catch (Exception e) { + log.error("Failed to parse search response", e); + listener.onFailure(new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + private void deleteModelGroup(DeleteRequest deleteRequest, String tenantId, ActionListener listener) { + try { + DeleteDataObjectRequest request = DeleteDataObjectRequest + .builder() + .index(deleteRequest.index()) + .id(deleteRequest.id()) + .tenantId(tenantId) + .build(); + + sdkClient + .deleteDataObjectAsync(request) + .whenComplete((response, throwable) -> handleDeleteResponse(response, throwable, deleteRequest.id(), listener)); + } catch (Exception e) { + log.error("Failed to delete Model group : {}", deleteRequest.id(), e); + listener.onFailure(e); + } + } + + private void handleValidationError(Exception error, String modelGroupId, ActionListener listener) { + log.error("Failed to validate Access for Model Group {}", modelGroupId, error); + listener.onFailure(error); + } - @Override - public void onFailure(Exception e) { - log.error("Failed to delete ML Model Group " + modelGroupId, e); + private void handleDeleteResponse( + DeleteDataObjectResponse response, + Throwable throwable, + String modelGroupId, + ActionListener actionListener + ) { + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + log.error("Failed to delete ML Model Group {}", modelGroupId, cause); + actionListener.onFailure(cause); + } else { + try { + DeleteResponse deleteResponse = DeleteResponse.fromXContent(response.parser()); + log.debug("Completed Delete Model Group Request, model group id:{} deleted", response.id()); + actionListener.onResponse(deleteResponse); + } catch (IOException e) { actionListener.onFailure(e); } - }); + } + } + + private void handleModelSearchFailure( + String modelGroupId, + String tenantId, + Exception cause, + ActionListener actionListener + ) { + if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { + DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId); + deleteModelGroup(deleteRequest, tenantId, actionListener); + return; + } + + log.error("Failed to search for models using model group id: {}", modelGroupId, cause); + actionListener.onFailure(cause); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java index a846c9c0f6..b0939d91fe 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java @@ -5,32 +5,41 @@ package org.opensearch.ml.action.model_group; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; -import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; -import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.MLModelGroup; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.GetDataObjectResponse; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.common.SdkClientUtils; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -42,85 +51,153 @@ @FieldDefaults(level = AccessLevel.PRIVATE) public class GetModelGroupTransportAction extends HandledTransportAction { - Client client; - NamedXContentRegistry xContentRegistry; - ClusterService clusterService; - - ModelAccessControlHelper modelAccessControlHelper; + final Client client; + final SdkClient sdkClient; + final NamedXContentRegistry xContentRegistry; + final ClusterService clusterService; + final ModelAccessControlHelper modelAccessControlHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject public GetModelGroupTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, NamedXContentRegistry xContentRegistry, ClusterService clusterService, - ModelAccessControlHelper modelAccessControlHelper + ModelAccessControlHelper modelAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLModelGroupGetAction.NAME, transportService, actionFilters, MLModelGroupGetRequest::new); this.client = client; + this.sdkClient = sdkClient; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.fromActionRequest(request); String modelGroupId = mlModelGroupGetRequest.getModelGroupId(); - GetRequest getRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); + String tenantId = mlModelGroupGetRequest.getTenantId(); + + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + + FetchSourceContext fetchSourceContext = new FetchSourceContext(true, Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY); + GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest + .builder() + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + .tenantId(tenantId) + .fetchSourceContext(fetchSourceContext) + .build(); + User user = RestActionUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - - MLModelGroup mlModelGroup = MLModelGroup.parse(parser); - modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model group", - RestStatus.FORBIDDEN - ) - ); - } else { - wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build()); - } - }, e -> { - log.error("Failed to validate access for Model Group " + modelGroupId, e); - wrappedListener.onFailure(e); - })); - - } catch (Exception e) { - log.error("Failed to parse ml model group" + r.getId(), e); - wrappedListener.onFailure(e); + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + + sdkClient + .getDataObjectAsync(getDataObjectRequest) + .whenComplete((r, throwable) -> handleResponse(r, throwable, modelGroupId, tenantId, user, wrappedListener)); + } catch (Exception e) { + log.error("Failed to get ML model group {}", modelGroupId, e); + actionListener.onFailure(e); + } + } + + private void handleResponse( + GetDataObjectResponse getDataObjectResponse, + Throwable throwable, + String modelGroupId, + String tenantId, + User user, + ActionListener wrappedListener + ) { + log.debug("Completed Get Model group Request, id:{}", modelGroupId); + if (throwable != null) { + handleThrowable(throwable, modelGroupId, wrappedListener); + } else { + processResponse(getDataObjectResponse, modelGroupId, tenantId, user, wrappedListener); + } + } + + private void handleThrowable(Throwable throwable, String modelGroupId, ActionListener wrappedListener) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { + log.error("Failed to find model group index", cause); + wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group index", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML group {}", modelGroupId, cause); + wrappedListener.onFailure(cause); + } + } + + private void processResponse( + GetDataObjectResponse getDataObjectResponse, + String modelGroupId, + String tenantId, + User user, + ActionListener wrappedListener + ) { + try { + GetResponse gr = getDataObjectResponse.parser() == null ? null : GetResponse.fromXContent(getDataObjectResponse.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + + if (TenantAwareHelper + .validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModelGroup.getTenantId(), wrappedListener)) { + validateModelGroupAccess(user, modelGroupId, mlModelGroup, wrappedListener); } - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Failed to find model group with the provided model group id: " + modelGroupId, - RestStatus.NOT_FOUND - ) - ); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group index")); - } else { - log.error("Failed to get ML model group" + modelGroupId, e); + } catch (Exception e) { + log.error("Failed to parse ml connector {}", getDataObjectResponse.id(), e); wrappedListener.onFailure(e); } - })); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model group with the provided model group id: " + modelGroupId, + RestStatus.NOT_FOUND + ) + ); + } } catch (Exception e) { - log.error("Failed to get ML model group " + modelGroupId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); } + } + private void validateModelGroupAccess( + User user, + String modelGroupId, + MLModelGroup mlModelGroup, + ActionListener wrappedListener + ) { + modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model group", + RestStatus.FORBIDDEN + ) + ); + } else { + wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build()); + } + }, e -> { + log.error("Failed to validate access for Model Group {}", modelGroupId, e); + wrappedListener.onFailure(e); + })); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java index 4e29db680d..464becdace 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupAction.java @@ -20,6 +20,9 @@ import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -34,10 +37,12 @@ public class TransportRegisterModelGroupAction extends HandledTransportAction listener) { MLRegisterModelGroupRequest createModelGroupRequest = MLRegisterModelGroupRequest.fromActionRequest(request); MLRegisterModelGroupInput createModelGroupInput = createModelGroupRequest.getRegisterModelGroupInput(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, createModelGroupInput.getTenantId(), listener)) { + return; + } mlModelGroupManager.createModelGroup(createModelGroupInput, ActionListener.wrap(modelGroupId -> { listener.onResponse(new MLRegisterModelGroupResponse(modelGroupId, MLTaskState.CREATED.name())); }, ex -> { diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index f3793fe59c..6bdf091665 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -5,28 +5,29 @@ package org.opensearch.ml.action.model_group; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.utils.MLExceptionUtils.logException; import java.util.HashSet; -import java.util.Iterator; import java.util.Map; import org.apache.commons.lang3.StringUtils; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; -import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -34,7 +35,6 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; @@ -42,9 +42,15 @@ import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; -import org.opensearch.ml.utils.MLNodeUtils; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -58,30 +64,36 @@ public class TransportUpdateModelGroupAction extends HandledTransportAction wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); - client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { - if (modelGroup.isExists()) { - try ( - XContentParser parser = MLNodeUtils - .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelGroup mlModelGroup = MLModelGroup.parse(parser); - if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { - validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((r, throwable) -> { + log.debug("Completed Get Model group Request, id:{}", modelGroupId); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { + log.error("Failed to get model group index", cause); + wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML group {}", modelGroupId, cause); + wrappedListener.onFailure(cause); + } + } else { + try { + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + if (TenantAwareHelper + .validateTenantResource( + mlFeatureEnabledSetting, + tenantId, + mlModelGroup.getTenantId(), + wrappedListener + )) { + if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { + validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup); + } else { + validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); + } + updateModelGroup(modelGroupId, r.source(), updateModelGroupInput, wrappedListener, user); + } + } catch (Exception e) { + log.error("Failed to parse ml connector {}", r.id(), e); + wrappedListener.onFailure(e); + } } else { - validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model group with the provided model group id: " + modelGroupId, + RestStatus.NOT_FOUND + ) + ); } - updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user); } catch (Exception e) { - log.error("Failed to parse ml model group" + modelGroup.getId(), e); wrappedListener.onFailure(e); } - } else { - wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND)); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); - } else { - logException("Failed to get model group", e, log); - wrappedListener.onFailure(e); } - })); + }); } catch (Exception e) { logException("Failed to Update model group", e, log); listener.onFailure(e); @@ -155,54 +203,63 @@ private void updateModelGroup( source.put(MLModelGroup.DESCRIPTION_FIELD, updateModelGroupInput.getDescription()); } if (StringUtils.isNotBlank(updateModelGroupInput.getName()) && !updateModelGroupInput.getName().equals(modelGroupName)) { - mlModelGroupManager.validateUniqueModelGroupName(updateModelGroupInput.getName(), ActionListener.wrap(modelGroups -> { - if (modelGroups != null - && modelGroups.getHits().getTotalHits() != null - && modelGroups.getHits().getTotalHits().value != 0) { - Iterator iterator = modelGroups.getHits().iterator(); - while (iterator.hasNext()) { - String id = iterator.next().getId(); - listener - .onFailure( - new IllegalArgumentException( - "The name you provided is already being used by another model with ID: " - + id - + ". Please provide a different name" - ) - ); - } - } else { - source.put(MLModelGroup.MODEL_GROUP_NAME_FIELD, updateModelGroupInput.getName()); - updateModelGroup(modelGroupId, source, listener); - } - }, e -> { - log.error("Failed to search model group index", e); - listener.onFailure(e); - })); + mlModelGroupManager + .validateUniqueModelGroupName( + updateModelGroupInput.getName(), + updateModelGroupInput.getTenantId(), + ActionListener.wrap(modelGroups -> { + if (modelGroups != null + && modelGroups.getHits().getTotalHits() != null + && modelGroups.getHits().getTotalHits().value != 0) { + for (SearchHit documentFields : modelGroups.getHits()) { + String id = documentFields.getId(); + listener + .onFailure( + new IllegalArgumentException( + "The name you provided is already being used by another model with ID: " + + id + + ". Please provide a different name" + ) + ); + } + } else { + source.put(MLModelGroup.MODEL_GROUP_NAME_FIELD, updateModelGroupInput.getName()); + updateModelGroup(modelGroupId, updateModelGroupInput.getTenantId(), source, listener); + } + }, e -> { + log.error("Failed to search model group index", e); + listener.onFailure(e); + }) + ); } else { - updateModelGroup(modelGroupId, source, listener); + updateModelGroup(modelGroupId, updateModelGroupInput.getTenantId(), source, listener); } - } - private void updateModelGroup(String modelGroupId, Map source, ActionListener listener) { - UpdateRequest updateModelGroupRequest = new UpdateRequest(); - updateModelGroupRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - updateModelGroupRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId).doc(source); + private void updateModelGroup( + String modelGroupId, + String tenantId, + Map source, + ActionListener listener + ) { + UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest + .builder() + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + .tenantId(tenantId) + .dataObject(source) + .build(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); - client - .update( - updateModelGroupRequest, - ActionListener.wrap(r -> { wrappedListener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> { - if (e instanceof IndexNotFoundException) { - wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); - } else { - log.error("Failed to update model group", e, log); - wrappedListener.onFailure(new MLValidationException("Failed to update Model Group")); - } - }) - ); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((ur, ut) -> { + if (ut == null) { + wrappedListener.onResponse(new MLUpdateModelGroupResponse("Updated")); + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(ut); + log.error("Failed to update model group {}", modelGroupId, e); + wrappedListener.onFailure(new MLValidationException("Failed to update Model Group")); + } + }); } catch (Exception e) { logException("Failed to Update model group ", e, log); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 055d5556d1..b65940399d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -6,50 +6,62 @@ package org.opensearch.ml.action.models; import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; +import static org.opensearch.ml.common.MLModel.FUNCTION_NAME_FIELD; import static org.opensearch.ml.common.MLModel.IS_HIDDEN_FIELD; import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; -import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import java.util.Map; +import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.DeleteDataObjectRequest; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -68,81 +80,87 @@ public class DeleteModelTransportAction extends HandledTransportAction actionListener) { MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.fromActionRequest(request); String modelId = mlModelDeleteRequest.getModelId(); - MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, false, false); + String tenantId = mlModelDeleteRequest.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } + MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, false, false, tenantId); FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGetRequest.isReturnContent()); - GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext); + GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest + .builder() + .index(ML_MODEL_INDEX) + .id(modelId) + .tenantId(tenantId) + .fetchSourceContext(fetchSourceContext) + .build(); User user = RestActionUtils.getUserContext(client); boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - GetResponse getResponse = r; - String algorithmName = ""; - if (getResponse.getSource() != null && getResponse.getSource().get(ALGORITHM_FIELD) != null) { - algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString(); - } - MLModel mlModel = MLModel.parse(parser, algorithmName); - Boolean isHidden = (Boolean) r.getSource().get(IS_HIDDEN_FIELD); - MLModelState mlModelState = mlModel.getModelState(); - if (isHidden != null && isHidden) { - if (!isSuperAdmin) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } else { - if (isModelNotDeployed(mlModelState)) { - deleteModel(modelId, isHidden, actionListener); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete", - RestStatus.BAD_REQUEST - ) - ); + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + String algorithmName = ""; + Map source = r.source(); + if (source != null) { + if (source.get(FUNCTION_NAME_FIELD) != null) { + algorithmName = source.get(FUNCTION_NAME_FIELD).toString(); + } else if (source.get(ALGORITHM_FIELD) != null) { + algorithmName = source.get(ALGORITHM_FIELD).toString(); + } } - } - } else { - modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { + MLModel mlModel = MLModel.parse(parser, algorithmName); + if (!TenantAwareHelper + .validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), actionListener)) { + return; + } + Boolean isHidden = (Boolean) r.source().get(IS_HIDDEN_FIELD); + MLModelState mlModelState = mlModel.getModelState(); + if (isHidden != null && isHidden) { + if (!isSuperAdmin) { wrappedListener .onFailure( new OpenSearchStatusException( @@ -150,34 +168,66 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - log.error(getErrorMessage("Failed to validate Access", modelId, isHidden), e); - wrappedListener.onFailure(e); - })); + } else { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else if (isModelNotDeployed(mlModelState)) { + deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, actionListener); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete", + RestStatus.BAD_REQUEST + ) + ); + } + }, e -> { + log.error(getErrorMessage("Failed to validate Access", modelId, isHidden), e); + wrappedListener.onFailure(e); + })); + } + } catch (Exception e) { + log.error("Failed to parse ml model {}", r.id(), e); + wrappedListener.onFailure(e); + } + } else { + // when model metadata is not found, model chunk and controller might still there, delete them here and + // return success response as we can't see the metadata we are providing functionName as null. In this way, + // code will try to remove model chunks for any models other than remote. As remote + // model doesn't have any model chunks. + deleteModelChunksAndController(wrappedListener, modelId, null, false, null); } } catch (Exception e) { - log.error("Failed to parse ml model " + r.getId(), e); wrappedListener.onFailure(e); } } else { - // when model metadata is not found, model chunk and controller might still there, delete them here and return success - // response - deleteModelChunksAndController(wrappedListener, modelId, false, null); + wrappedListener.onFailure((new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND))); } - }, e -> { wrappedListener.onFailure((new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND))); })); + }); } catch (Exception e) { - log.error("Failed to delete ML model " + modelId, e); + log.error("Failed to delete ML model {}", modelId, e); actionListener.onFailure(e); } } @@ -185,11 +235,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX); - deleteModelsRequest.setQuery(new TermsQueryBuilder(MODEL_ID_FIELD, modelId)).setRefresh(true); - + deleteModelsRequest + .setQuery( + new BoolQueryBuilder() + .must(new TermsQueryBuilder(MODEL_ID_FIELD, modelId)) // Match documents with the same model_id + // The just deleted model document does not have the same fields as the model chunks and can result in parsing errors if + // it is read. OpenSearch is eventually consistent on search, so a search may return deleted documents until the next + // merge. A force merge between deletions would have performance impact. A more robust solution is just to make sure the + // model document does not appear in the search results. + .mustNot(new TermQueryBuilder("_id", modelId)) // exclude the document just deleted + ); client.execute(DeleteByQueryAction.INSTANCE, deleteModelsRequest, ActionListener.wrap(r -> { - if ((r.getBulkFailures() == null || r.getBulkFailures().size() == 0) - && (r.getSearchFailures() == null || r.getSearchFailures().size() == 0)) { + if ((r.getBulkFailures() == null || r.getBulkFailures().isEmpty()) + && (r.getSearchFailures() == null || r.getSearchFailures().isEmpty())) { log.debug(getErrorMessage("All model chunks are deleted for the provided model.", modelId, isHidden)); actionListener.onResponse(true); } else { @@ -202,7 +260,7 @@ void deleteModelChunks(String modelId, Boolean isHidden, ActionListener } private void returnFailure(BulkByScrollResponse response, String modelId, ActionListener actionListener) { - String errorMessage = ""; + String errorMessage; if (response.isTimedOut()) { errorMessage = OS_STATUS_EXCEPTION_MESSAGE + ", " + TIMEOUT_MSG + modelId; } else if (!response.getBulkFailures().isEmpty()) { @@ -214,18 +272,31 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR)); } - private void deleteModel(String modelId, Boolean isHidden, ActionListener actionListener) { - DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.delete(deleteRequest, new ActionListener<>() { - @Override - public void onResponse(DeleteResponse deleteResponse) { - deleteModelChunksAndController(actionListener, modelId, isHidden, deleteResponse); - } - - @Override - public void onFailure(Exception e) { - if (e instanceof ResourceNotFoundException) { - deleteModelChunksAndController(actionListener, modelId, isHidden, null); + private void deleteModel( + String modelId, + String tenantId, + String functionName, + Boolean isHidden, + ActionListener actionListener + ) { + DeleteDataObjectRequest deleteDataObjectRequest = DeleteDataObjectRequest + .builder() + .index(ML_MODEL_INDEX) + .id(modelId) + .tenantId(tenantId) + .build(); + sdkClient.deleteDataObjectAsync(deleteDataObjectRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + DeleteResponse deleteResponse = DeleteResponse.fromXContent(r.parser()); + deleteModelChunksAndController(actionListener, modelId, functionName, isHidden, deleteResponse); + } catch (Exception e) { + actionListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + if (ExceptionsHelper.unwrap(e, ResourceNotFoundException.class) != null) { + deleteModelChunksAndController(actionListener, modelId, functionName, isHidden, null); } else { log.error(getErrorMessage("Model is not all cleaned up, please try again.", modelId, isHidden), e); actionListener.onFailure(e); @@ -237,6 +308,7 @@ public void onFailure(Exception e) { private void deleteModelChunksAndController( ActionListener actionListener, String modelId, + String functionName, Boolean isHidden, DeleteResponse deleteResponse ) { @@ -277,7 +349,12 @@ private void deleteModelChunksAndController( ); } }); - deleteModelChunks(modelId, isHidden, countDownActionListener); + if (!Objects.equals(functionName, FunctionName.REMOTE.name())) { + deleteModelChunks(modelId, isHidden, countDownActionListener); + } else { + // for remote model we don't need to delete model chunks so reducing one latch countdown. + countDownLatch.countDown(); + } deleteController(modelId, isHidden, countDownActionListener); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index 0d5fc5f659..db6d3329ea 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -5,16 +5,17 @@ package org.opensearch.ml.action.models; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; import static org.opensearch.ml.common.MLModel.IS_HIDDEN_FIELD; -import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; -import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -22,6 +23,7 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; @@ -30,12 +32,16 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -50,11 +56,14 @@ @FieldDefaults(level = AccessLevel.PRIVATE) public class GetModelTransportAction extends HandledTransportAction { - Client client; - NamedXContentRegistry xContentRegistry; - ClusterService clusterService; + final Client client; + final SdkClient sdkClient; + final NamedXContentRegistry xContentRegistry; + final ClusterService clusterService; - ModelAccessControlHelper modelAccessControlHelper; + final ModelAccessControlHelper modelAccessControlHelper; + + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; Settings settings; @@ -63,53 +72,65 @@ public GetModelTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, Settings settings, NamedXContentRegistry xContentRegistry, ClusterService clusterService, - ModelAccessControlHelper modelAccessControlHelper + ModelAccessControlHelper modelAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLModelGetAction.NAME, transportService, actionFilters, MLModelGetRequest::new); this.client = client; + this.sdkClient = sdkClient; this.settings = settings; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLModelGetRequest mlModelGetRequest = MLModelGetRequest.fromActionRequest(request); String modelId = mlModelGetRequest.getModelId(); + String tenantId = mlModelGetRequest.getTenantId(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { + return; + } FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGetRequest.isReturnContent()); - GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext); + GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest + .builder() + .index(ML_MODEL_INDEX) + .id(modelId) + .tenantId(tenantId) + .fetchSourceContext(fetchSourceContext) + .build(); User user = RestActionUtils.getUserContext(client); boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - String algorithmName = r.getSource().get(ALGORITHM_FIELD).toString(); - Boolean isHidden = (Boolean) r.getSource().get(IS_HIDDEN_FIELD); - MLModel mlModel = MLModel.parse(parser, algorithmName); - if (isHidden != null && isHidden) { - if (isSuperAdmin || !mlModelGetRequest.isUserInitiatedGetRequest()) { - wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } - } else { - modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + String algorithmName = r.source().get(ALGORITHM_FIELD).toString(); + Boolean isHidden = (Boolean) r.source().get(IS_HIDDEN_FIELD); + MLModel mlModel = MLModel.parse(parser, algorithmName); + if (!TenantAwareHelper + .validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), actionListener)) { + return; + } + if (isHidden != null && isHidden) { + if (isSuperAdmin || !mlModelGetRequest.isUserInitiatedGetRequest()) { + wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); + } else { wrappedListener .onFailure( new OpenSearchStatusException( @@ -117,45 +138,61 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - log.error("Failed to validate Access for Model Id " + modelId, e); - wrappedListener.onFailure(e); - })); + } else { + modelAccessControlHelper + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else { + log.debug("Completed Get Model Request, id:{}", modelId); + Connector connector = mlModel.getConnector(); + if (connector != null) { + connector.removeCredential(); + } + wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); + } + }, e -> { + log.error("Failed to validate Access for Model Id {}", modelId, e); + wrappedListener.onFailure(e); + })); + } + } catch (Exception e) { + log.error("Failed to parse ml model {}", r.id(), e); + wrappedListener.onFailure(e); + } + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Failed to find model with the provided model id: " + modelId, + RestStatus.NOT_FOUND + ) + ); } } catch (Exception e) { - log.error("Failed to parse ml model " + r.getId(), e); wrappedListener.onFailure(e); } } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Failed to find model with the provided model id: " + modelId, - RestStatus.NOT_FOUND - ) - ); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model")); - } else { - log.error("Failed to get ML model " + modelId, e); - wrappedListener.onFailure(e); + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + if (ExceptionsHelper.unwrap(e, IndexNotFoundException.class) != null) { + wrappedListener.onFailure(new OpenSearchStatusException("Fail to find model", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML model {}", modelId, e); + wrappedListener.onFailure(e); + } } - })); + }); } catch (Exception e) { - log.error("Failed to get ML model " + modelId, e); + log.error("Failed to get ML model {}", modelId, e); actionListener.onFailure(e); } - } // this method is only to stub static method. diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index c4ecd57805..0dfa5a3b1a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -18,6 +18,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Objects; import org.opensearch.OpenSearchStatusException; @@ -27,7 +28,6 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -36,12 +36,12 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; @@ -59,7 +59,12 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -72,15 +77,16 @@ @Log4j2 @FieldDefaults(level = AccessLevel.PRIVATE) public class UpdateModelTransportAction extends HandledTransportAction { - Client client; - - Settings settings; - ClusterService clusterService; - ModelAccessControlHelper modelAccessControlHelper; - ConnectorAccessControlHelper connectorAccessControlHelper; - MLModelManager mlModelManager; - MLModelGroupManager mlModelGroupManager; - MLEngine mlEngine; + final Client client; + private final SdkClient sdkClient; + final Settings settings; + final ClusterService clusterService; + final ModelAccessControlHelper modelAccessControlHelper; + final ConnectorAccessControlHelper connectorAccessControlHelper; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + final MLModelManager mlModelManager; + final MLModelGroupManager mlModelGroupManager; + final MLEngine mlEngine; volatile List trustedConnectorEndpointsRegex; @Inject @@ -88,16 +94,19 @@ public UpdateModelTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + SdkClient sdkClient, ConnectorAccessControlHelper connectorAccessControlHelper, ModelAccessControlHelper modelAccessControlHelper, MLModelManager mlModelManager, MLModelGroupManager mlModelGroupManager, Settings settings, ClusterService clusterService, - MLEngine mlEngine + MLEngine mlEngine, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { super(MLUpdateModelAction.NAME, transportService, actionFilters, MLUpdateModelRequest::new); this.client = client; + this.sdkClient = sdkClient; this.modelAccessControlHelper = modelAccessControlHelper; this.connectorAccessControlHelper = connectorAccessControlHelper; this.mlModelManager = mlModelManager; @@ -105,6 +114,7 @@ public UpdateModelTransportAction( this.clusterService = clusterService; this.mlEngine = mlEngine; this.settings = settings; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); clusterService .getClusterSettings() @@ -116,6 +126,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { - if (!isModelDeploying(mlModel.getModelState())) { - FunctionName functionName = mlModel.getAlgorithm(); - // TODO: Support update as well as model/user level throttling in all other DLModel categories - if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { - if (mlModel.getIsHidden() != null && mlModel.getIsHidden()) { - if (isSuperAdmin) { - updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, wrappedListener); + mlModelManager.getModel(modelId, tenantId, null, excludes, ActionListener.wrap(mlModel -> { + if (TenantAwareHelper.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModel.getTenantId(), actionListener)) { + if (!isModelDeploying(mlModel.getModelState())) { + FunctionName functionName = mlModel.getAlgorithm(); + // TODO: Support update as well as model/user level throttling in all other DLModel categories + if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { + if (mlModel.getIsHidden() != null && mlModel.getIsHidden()) { + if (isSuperAdmin) { + updateRemoteOrTextEmbeddingModel(modelId, tenantId, updateModelInput, mlModel, user, wrappedListener); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) + modelAccessControlHelper + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + tenantId, + mlModel.getModelGroupId(), + client, + sdkClient, + ActionListener.wrap(hasPermission -> { + if (hasPermission) { + updateRemoteOrTextEmbeddingModel( + modelId, + tenantId, + updateModelInput, + mlModel, + user, + wrappedListener + ); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model, model ID " + + modelId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log + .error( + "Permission denied: Unable to update the model with ID {}. Details: {}", + modelId, + exception + ); + wrappedListener.onFailure(exception); + }) ); } + } else { - modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { - if (hasPermission) { - updateRemoteOrTextEmbeddingModel(modelId, updateModelInput, mlModel, user, wrappedListener); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model, model ID " - + modelId, - RestStatus.FORBIDDEN - ) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the model with ID {}. Details: {}", modelId, exception); - wrappedListener.onFailure(exception); - })); + wrappedListener + .onFailure( + new OpenSearchStatusException( + "The function category " + functionName.toString() + " is not supported at this time.", + RestStatus.FORBIDDEN + ) + ); } - } else { wrappedListener .onFailure( new OpenSearchStatusException( - "The function category " + functionName.toString() + " is not supported at this time.", - RestStatus.FORBIDDEN + "Model is deploying. Please wait for the model to complete deployment. model ID " + modelId, + RestStatus.CONFLICT ) ); } - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Model is deploying. Please wait for the model to complete deployment. model ID " + modelId, - RestStatus.CONFLICT - ) - ); } }, e -> wrappedListener @@ -189,13 +225,14 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (hasNewConnectorPermission) { - updateModelWithRegisteringToAnotherModelGroup( - modelId, - newModelGroupId, - user, - updateModelInput, - wrappedListener, - isUpdateModelCache - ); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "You don't have permission to update the connector, connector id: " + newConnectorId, - RestStatus.FORBIDDEN - ) - ); - } - }, exception -> { - log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", newConnectorId, exception); - wrappedListener.onFailure(exception); - })); + connectorAccessControlHelper + .validateConnectorAccess( + sdkClient, + client, + newConnectorId, + tenantId, + mlFeatureEnabledSetting, + ActionListener.wrap(hasNewConnectorPermission -> { + if (hasNewConnectorPermission) { + updateModelWithRegisteringToAnotherModelGroup( + modelId, + newModelGroupId, + tenantId, + user, + updateModelInput, + wrappedListener, + isUpdateModelCache + ); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "You don't have permission to update the connector, connector id: " + newConnectorId, + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", newConnectorId, exception); + wrappedListener.onFailure(exception); + }) + ); } else { wrappedListener .onFailure( @@ -321,17 +371,18 @@ private void updateModelWithNewStandAloneConnector( private void updateModelWithRegisteringToAnotherModelGroup( String modelId, String newModelGroupId, + String tenantId, User user, MLUpdateModelInput updateModelInput, ActionListener wrappedListener, boolean isUpdateModelCache ) { - UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); if (newModelGroupId != null) { modelAccessControlHelper .validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> { if (hasNewModelGroupPermission) { - mlModelGroupManager.getModelGroupResponse(newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { + mlModelGroupManager.getModelGroupResponse(sdkClient, newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { buildUpdateRequest( modelId, newModelGroupId, @@ -366,37 +417,49 @@ private void updateModelWithRegisteringToAnotherModelGroup( wrappedListener.onFailure(exception); })); } else { - buildUpdateRequest(modelId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache); + buildUpdateRequest(modelId, tenantId, updateRequest, updateModelInput, wrappedListener, isUpdateModelCache); } } private void buildUpdateRequest( String modelId, + String tenantId, UpdateRequest updateRequest, MLUpdateModelInput updateModelInput, ActionListener wrappedListener, boolean isUpdateModelCache ) { - try { - updateModelInput.setLastUpdateTime(Instant.now()); - updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); - updateRequest.docAsUpsert(true); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - if (isUpdateModelCache) { - String[] targetNodeIds = getAllNodes(); - MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(targetNodeIds, modelId); - client - .update( - updateRequest, - getUpdateResponseListenerWithUpdateModelCache(modelId, wrappedListener, mlUpdateModelCacheNodesRequest) - ); + updateModelInput.setLastUpdateTime(Instant.now()); + UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest + .builder() + .index(updateRequest.index()) + .id(updateRequest.id()) + .tenantId(tenantId) + .dataObject(updateModelInput) + .build(); + // TODO: This should probably be default on update data object: + // updateRequest.docAsUpsert(true); + ActionListener updateListener; + if (isUpdateModelCache) { + String[] targetNodeIds = getAllNodes(); + MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(targetNodeIds, modelId); + updateListener = getUpdateResponseListenerWithUpdateModelCache(modelId, wrappedListener, mlUpdateModelCacheNodesRequest); + } else { + updateListener = getUpdateResponseListener(modelId, wrappedListener); + } + sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((ur, ut) -> { + if (ut == null) { + try { + UpdateResponse updateResponse = ur.parser() == null ? null : UpdateResponse.fromXContent(ur.parser()); + updateListener.onResponse(updateResponse); + } catch (Exception e) { + updateListener.onFailure(e); + } } else { - client.update(updateRequest, getUpdateResponseListener(modelId, wrappedListener)); + Exception e = SdkClientUtils.unwrapAndConvertToException(ut); + updateListener.onFailure(e); } - } catch (IOException e) { - log.error("Failed to build update request.", e); - wrappedListener.onFailure(e); - } + }); } private void buildUpdateRequest( @@ -412,54 +475,56 @@ private void buildUpdateRequest( String updatedVersion = incrementLatestVersion(newModelGroupSourceMap); updateModelInput.setVersion(updatedVersion); updateModelInput.setLastUpdateTime(Instant.now()); - UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( + UpdateDataObjectRequest updateModelGroupRequest = createUpdateModelGroupRequest( newModelGroupSourceMap, newModelGroupId, newModelGroupResponse.getSeqNo(), newModelGroupResponse.getPrimaryTerm(), Integer.parseInt(updatedVersion) ); - try { - updateRequest.doc(updateModelInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); - updateRequest.docAsUpsert(true); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - if (isUpdateModelCache) { - String[] targetNodeIds = getAllNodes(); - MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(targetNodeIds, modelId); - client.update(updateModelGroupRequest, ActionListener.wrap(r -> { - client - .update( - updateRequest, - getUpdateResponseListenerWithUpdateModelCache(modelId, wrappedListener, mlUpdateModelCacheNodesRequest) - ); - }, e -> { - log - .error( - "Failed to register ML model with model ID {} to the new model group with model group ID {}", - modelId, - newModelGroupId, - e - ); - wrappedListener.onFailure(e); - })); + UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest + .builder() + .index(updateRequest.index()) + .id(updateRequest.id()) + .dataObject(updateModelInput) + .build(); + // TODO: This should probably be default on update data object: + // updateRequest.docAsUpsert(true); + ActionListener updateListener; + if (isUpdateModelCache) { + String[] targetNodeIds = getAllNodes(); + MLUpdateModelCacheNodesRequest mlUpdateModelCacheNodesRequest = new MLUpdateModelCacheNodesRequest(targetNodeIds, modelId); + updateListener = getUpdateResponseListenerWithUpdateModelCache(modelId, wrappedListener, mlUpdateModelCacheNodesRequest); + } else { + updateListener = getUpdateResponseListener(modelId, wrappedListener); + } + sdkClient.updateDataObjectAsync(updateModelGroupRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((ur, ut) -> { + if (ut == null) { + try { + UpdateResponse updateResponse = ur.parser() == null ? null : UpdateResponse.fromXContent(ur.parser()); + updateListener.onResponse(updateResponse); + } catch (Exception e) { + updateListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(ut); + updateListener.onFailure(e); + } + }); } else { - client.update(updateModelGroupRequest, ActionListener.wrap(r -> { - client.update(updateRequest, getUpdateResponseListener(modelId, wrappedListener)); - }, e -> { - log - .error( - "Failed to register ML model with model ID {} to the new model group with model group ID {}", - modelId, - newModelGroupId, - e - ); - wrappedListener.onFailure(e); - })); + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + log + .error( + "Failed to register ML model with model ID {} to the new model group with model group ID {}", + modelId, + newModelGroupId, + e + ); + wrappedListener.onFailure(e); } - } catch (IOException e) { - log.error("Failed to build update request."); - wrappedListener.onFailure(e); - } + }); } private ActionListener getUpdateResponseListenerWithUpdateModelCache( @@ -542,7 +607,7 @@ private String incrementLatestVersion(Map modelGroupSourceMap) { return Integer.toString((int) modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD) + 1); } - private UpdateRequest createUpdateModelGroupRequest( + private UpdateDataObjectRequest createUpdateModelGroupRequest( Map modelGroupSourceMap, String modelGroupId, long seqNo, @@ -551,17 +616,24 @@ private UpdateRequest createUpdateModelGroupRequest( ) { modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); - UpdateRequest updateModelGroupRequest = new UpdateRequest(); - - updateModelGroupRequest + ToXContentObject dataObject = new ToXContentObject() { + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + for (Entry e : modelGroupSourceMap.entrySet()) { + builder.field(e.getKey(), e.getValue()); + } + return builder.endObject(); + } + }; + return UpdateDataObjectRequest + .builder() .index(ML_MODEL_GROUP_INDEX) .id(modelGroupId) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .doc(modelGroupSourceMap); - - return updateModelGroupRequest; + .ifSeqNo(seqNo) + .ifPrimaryTerm(primaryTerm) + .dataObject(dataObject) + .build(); } private Boolean isModelDeployed(MLModelState mlModelState) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 4ae39defe6..2a555e22bd 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -65,6 +65,8 @@ import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.MLExceptionUtils; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -84,6 +86,7 @@ public class TransportRegisterModelAction extends HandledTransportAction listener) { MLRegisterModelRequest registerModelRequest = MLRegisterModelRequest.fromActionRequest(request); MLRegisterModelInput registerModelInput = registerModelRequest.getRegisterModelInput(); + if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, registerModelInput.getTenantId(), listener)) { + return; + } if (FunctionName.isDLModel(registerModelInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); } @@ -164,20 +172,25 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - if (modelGroups != null - && modelGroups.getHits().getTotalHits() != null - && modelGroups.getHits().getTotalHits().value != 0) { - String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId(); - registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided); - checkUserAccess(registerModelInput, listener, true); - } else { - doRegister(registerModelInput, listener); - } - }, e -> { - log.error("Failed to search model group index", e); - listener.onFailure(e); - })); + mlModelGroupManager + .validateUniqueModelGroupName( + registerModelInput.getModelName(), + registerModelInput.getTenantId(), + ActionListener.wrap(modelGroups -> { + if (modelGroups != null + && modelGroups.getHits().getTotalHits() != null + && modelGroups.getHits().getTotalHits().value != 0) { + String modelGroupIdOfTheNameProvided = modelGroups.getHits().getAt(0).getId(); + registerModelInput.setModelGroupId(modelGroupIdOfTheNameProvided); + checkUserAccess(registerModelInput, listener, true); + } else { + doRegister(registerModelInput, listener); + } + }, e -> { + log.error("Failed to search model group index", e); + listener.onFailure(e); + }) + ); } else { checkUserAccess(registerModelInput, listener, false); } @@ -190,81 +203,93 @@ private void checkUserAccess( ) { User user = RestActionUtils.getUserContext(client); modelAccessControlHelper - .validateModelGroupAccess(user, registerModelInput.getModelGroupId(), client, ActionListener.wrap(access -> { - if (access) { - doRegister(registerModelInput, listener); - return; - } - // if the user does not have access, we need to check three more conditions before throwing exception. - // if we are checking the access based on the name provided in the input, we let user know the name is already used by a - // model group they do not have access to. - if (isModelNameAlreadyExisting) { - // This case handles when user is using the same pre-trained model already registered by another user on the cluster. - // The only way here is for the user to first create model group and use its ID in the request - if (registerModelInput.getUrl() == null - && registerModelInput.getFunctionName() != FunctionName.REMOTE - && registerModelInput.getConnectorId() == null) { - listener - .onFailure( - new IllegalArgumentException( - "Without a model group ID, the system will use the model name {" - + registerModelInput.getModelName() - + "} to create a new model group. However, this name is taken by another group with id {" - + registerModelInput.getModelGroupId() - + "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request." - ) - ); - } else { - listener - .onFailure( - new IllegalArgumentException( - "The name {" - + registerModelInput.getModelName() - + "} you provided is unavailable because it is used by another model group with id {" - + registerModelInput.getModelGroupId() - + "} to which you do not have access. Please provide a different name." - ) - ); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + registerModelInput.getTenantId(), + registerModelInput.getModelGroupId(), + client, + sdkClient, + ActionListener.wrap(access -> { + if (access) { + doRegister(registerModelInput, listener); + return; } - return; - } - // if user does not have access to the model group ID provided in the input, we let user know they do not have access to the - // specified model group - listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); - }, listener::onFailure)); + // if the user does not have access, we need to check three more conditions before throwing exception. + // if we are checking the access based on the name provided in the input, we let user know the name is already used by a + // model group they do not have access to. + if (isModelNameAlreadyExisting) { + // This case handles when user is using the same pre-trained model already registered by another user on the + // cluster. + // The only way here is for the user to first create model group and use its ID in the request + if (registerModelInput.getUrl() == null + && registerModelInput.getFunctionName() != FunctionName.REMOTE + && registerModelInput.getConnectorId() == null) { + listener + .onFailure( + new IllegalArgumentException( + "Without a model group ID, the system will use the model name {" + + registerModelInput.getModelName() + + "} to create a new model group. However, this name is taken by another group with id {" + + registerModelInput.getModelGroupId() + + "} you can't access. To register this pre-trained model, create a new model group and use its ID in your request." + ) + ); + } else { + listener + .onFailure( + new IllegalArgumentException( + "The name {" + + registerModelInput.getModelName() + + "} you provided is unavailable because it is used by another model group with id {" + + registerModelInput.getModelGroupId() + + "} to which you do not have access. Please provide a different name." + ) + ); + } + return; + } + // if user does not have access to the model group ID provided in the input, we let user know they do not have access to + // the + // specified model group + listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); + }, listener::onFailure) + ); } private void doRegister(MLRegisterModelInput registerModelInput, ActionListener listener) { FunctionName functionName = registerModelInput.getFunctionName(); if (FunctionName.REMOTE == functionName) { if (Strings.isNotBlank(registerModelInput.getConnectorId())) { - connectorAccessControlHelper.validateConnectorAccess(client, registerModelInput.getConnectorId(), ActionListener.wrap(r -> { - if (Boolean.TRUE.equals(r)) { - if (registerModelInput.getModelInterface() == null) { - mlModelManager.getConnector(registerModelInput.getConnectorId(), ActionListener.wrap(connector -> { - updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInput, connector); + connectorAccessControlHelper + .validateConnectorAccess( + sdkClient, + client, + registerModelInput.getConnectorId(), + registerModelInput.getTenantId(), + mlFeatureEnabledSetting, + ActionListener.wrap(r -> { + if (Boolean.TRUE.equals(r)) { createModelGroup(registerModelInput, listener); - }, listener::onFailure)); - } else { - createModelGroup(registerModelInput, listener); - } - } else { - listener - .onFailure( - new IllegalArgumentException( - "You don't have permission to use the connector provided, connector id: " - + registerModelInput.getConnectorId() - ) - ); - } - }, e -> { - log - .error( - "You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId(), - e - ); - listener.onFailure(e); - })); + } else { + listener + .onFailure( + new IllegalArgumentException( + "You don't have permission to use the connector provided, connector id: " + + registerModelInput.getConnectorId() + ) + ); + } + }, e -> { + log + .error( + "You don't have permission to use the connector provided, connector id: {}", + registerModelInput.getConnectorId(), + e + ); + listener.onFailure(e); + }) + ); } else { validateInternalConnector(registerModelInput); ActionListener dryRunResultListener = ActionListener.wrap(res -> { @@ -347,7 +372,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); mlTask.setTaskId(taskId); - mlModelManager.registerMLRemoteModel(registerModelInput, mlTask, listener); + mlModelManager.registerMLRemoteModel(sdkClient, registerModelInput, mlTask, listener); }, e -> { logException("Failed to register model", e, log); listener.onFailure(e); @@ -364,7 +389,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name())); ActionListener forwardActionListener = ActionListener.wrap(res -> { - log.debug("Register model response: " + res); + log.debug("Register model response: {}", res); if (!clusterService.localNode().getId().equals(nodeId)) { mlTaskManager.remove(taskId); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java index a730de712f..b39521ab30 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java @@ -64,7 +64,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + + // Local models are out of scope for multi-tenancy. Therefore, null is used as the default tenant for single tenancy. + mlModelGroupManager.validateUniqueModelGroupName(mlUploadInput.getName(), null, ActionListener.wrap(modelGroups -> { if (modelGroups != null && modelGroups.getHits().getTotalHits() != null && modelGroups.getHits().getTotalHits().value != 0) { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index 7b7f68e8c0..84e92bf52f 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -10,11 +10,11 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; @@ -85,7 +85,7 @@ public void validateConnectorAccess(Client client, String connectorId, ActionLis wrappedListener.onResponse(hasPermission); }, wrappedListener::onFailure)); } catch (Exception e) { - log.error("Failed to validate Access for connector:" + connectorId, e); + log.error("Failed to validate Access for connector:{}", connectorId, e); listener.onFailure(e); } } @@ -150,7 +150,7 @@ public void getConnector(Client client, String connectorId, ActionListener listener ) { - sdkClient - .getDataObjectAsync(getDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) - .whenComplete((r, throwable) -> { - context.restore(); - log.debug("Completed Get Connector Request, id:{}", connectorId); - if (throwable != null) { - Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); - if (cause instanceof IndexNotFoundException) { - log.error("Failed to get connector index", cause); - listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); - } else { - log.error("Failed to get ML connector " + connectorId, cause); - listener.onFailure(cause); - } + sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((r, throwable) -> { + context.restore(); + log.debug("Completed Get Connector Request, id:{}", connectorId); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (ExceptionsHelper.unwrap(throwable, IndexNotFoundException.class) != null) { + log.error("Failed to get connector index", cause); + listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); } else { - try { - GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); - if (gr != null && gr.isExists()) { - try ( - XContentParser parser = jsonXContent - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector mlConnector = Connector.createConnector(parser); - mlConnector.removeCredential(); - listener.onResponse(mlConnector); - } catch (Exception e) { - log.error("Failed to parse ml connector {}", r.id(), e); - listener.onFailure(e); - } - } else { - listener - .onFailure( - new OpenSearchStatusException( - "Failed to find connector with the provided connector id: " + connectorId, - RestStatus.NOT_FOUND - ) - ); + log.error("Failed to get ML connector {}", connectorId, cause); + listener.onFailure(cause); + } + } else { + try { + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector mlConnector = Connector.createConnector(parser); + mlConnector.removeCredential(); + listener.onResponse(mlConnector); + } catch (Exception e) { + log.error("Failed to parse ml connector {}", r.id(), e); + listener.onFailure(e); } - } catch (Exception e) { - listener.onFailure(e); + } else { + listener + .onFailure( + new OpenSearchStatusException( + "Failed to find connector with the provided connector id: " + connectorId, + RestStatus.NOT_FOUND + ) + ); } + } catch (Exception e) { + listener.onFailure(e); } - }); + } + }); } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index 0967544cba..7b5c218cfd 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -7,20 +7,24 @@ package org.opensearch.ml.helper; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Optional; import org.apache.lucene.search.join.ScoreMode; import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; @@ -43,7 +47,12 @@ import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.MLNodeUtils; +import org.opensearch.ml.utils.TenantAwareHelper; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; import com.google.common.collect.ImmutableList; @@ -74,17 +83,17 @@ public ModelAccessControlHelper(ClusterService clusterService, Settings settings RangeQueryBuilder.class ); + // TODO Eventually remove this when all usages of it have been migrated to the SdkClient version public void validateModelGroupAccess(User user, String modelGroupId, Client client, ActionListener listener) { if (modelGroupId == null || isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)) { listener.onResponse(true); return; } - List userBackendRoles = user.getBackendRoles(); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); client.get(getModelGroupRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try ( @@ -93,31 +102,7 @@ public void validateModelGroupAccess(User user, String modelGroupId, Client clie ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLModelGroup mlModelGroup = MLModelGroup.parse(parser); - AccessMode modelAccessMode = AccessMode.from(mlModelGroup.getAccess()); - if (mlModelGroup.getOwner() == null) { - // previous security plugin not enabled, model defaults to public. - wrappedListener.onResponse(true); - } else if (AccessMode.RESTRICTED == modelAccessMode) { - if (mlModelGroup.getBackendRoles() == null || mlModelGroup.getBackendRoles().size() == 0) { - throw new IllegalStateException("Backend roles shouldn't be null"); - } else { - wrappedListener - .onResponse( - Optional - .ofNullable(userBackendRoles) - .orElse(ImmutableList.of()) - .stream() - .anyMatch(mlModelGroup.getBackendRoles()::contains) - ); - } - } else if (AccessMode.PUBLIC == modelAccessMode) { - wrappedListener.onResponse(true); - } else if (AccessMode.PRIVATE == modelAccessMode) { - if (isOwner(mlModelGroup.getOwner(), user)) - wrappedListener.onResponse(true); - else - wrappedListener.onResponse(false); - } + checkModelGroupPermission(mlModelGroup, user, wrappedListener); } catch (Exception e) { log.error("Failed to parse ml model group"); wrappedListener.onFailure(e); @@ -139,6 +124,104 @@ public void validateModelGroupAccess(User user, String modelGroupId, Client clie } } + public void validateModelGroupAccess( + User user, + MLFeatureEnabledSetting mlFeatureEnabledSetting, + String tenantId, + String modelGroupId, + Client client, + SdkClient sdkClient, + ActionListener listener + ) { + if (modelGroupId == null + || (!mlFeatureEnabledSetting.isMultiTenancyEnabled() + && (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)))) { + listener.onResponse(true); + return; + } + GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest + .builder() + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + .tenantId(tenantId) + .build(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + sdkClient.getDataObjectAsync(getModelGroupRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + if (TenantAwareHelper + .validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModelGroup.getTenantId(), listener)) { + if (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)) { + listener.onResponse(true); + return; + } + checkModelGroupPermission(mlModelGroup, user, wrappedListener); + } + } catch (Exception e) { + log.error("Failed to parse ml model group"); + wrappedListener.onFailure(e); + } + } else { + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + } + } catch (Exception e) { + listener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + if (e instanceof IndexNotFoundException) { + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + } else { + log.error("Fail to get model group", e); + wrappedListener.onFailure(new MLValidationException("Fail to get model group")); + } + } + }); + } catch (Exception e) { + log.error("Failed to validate Access", e); + listener.onFailure(e); + } + } + + public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, ActionListener wrappedListener) { + AccessMode modelAccessMode = AccessMode.from(mlModelGroup.getAccess()); + if (mlModelGroup.getOwner() == null) { + // previous security plugin not enabled, model defaults to public. + wrappedListener.onResponse(true); + } else { + switch (modelAccessMode) { + case RESTRICTED: + if (mlModelGroup.getBackendRoles() == null || mlModelGroup.getBackendRoles().isEmpty()) { + throw new IllegalStateException("Backend roles shouldn't be null"); + } else { + wrappedListener + .onResponse( + Optional + .ofNullable(user.getBackendRoles()) + .orElse(Collections.emptyList()) + .stream() + .anyMatch(mlModelGroup.getBackendRoles()::contains) + ); + } + break; + case PRIVATE: + wrappedListener.onResponse(isOwner(mlModelGroup.getOwner(), user)); + break; + default: // PUBLIC + wrappedListener.onResponse(true); + } + } + } + public boolean skipModelAccessControl(User user) { // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin // Case 2: If Security is enabled and filter is disabled, proceed with search as diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 73c48e776a..e76ae8e306 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -5,38 +5,50 @@ package org.opensearch.ml.model; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import java.io.IOException; import java.time.Instant; import java.util.HashSet; -import java.util.Iterator; -import org.opensearch.action.get.GetRequest; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.util.CollectionUtils; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.GetDataObjectResponse; +import org.opensearch.remote.metadata.client.PutDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -46,21 +58,27 @@ public class MLModelGroupManager { private final MLIndicesHandler mlIndicesHandler; private final Client client; + private final SdkClient sdkClient; ClusterService clusterService; ModelAccessControlHelper modelAccessControlHelper; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; @Inject public MLModelGroupManager( MLIndicesHandler mlIndicesHandler, Client client, + SdkClient sdkClient, ClusterService clusterService, - ModelAccessControlHelper modelAccessControlHelper + ModelAccessControlHelper modelAccessControlHelper, + MLFeatureEnabledSetting mlFeatureEnabledSetting ) { this.mlIndicesHandler = mlIndicesHandler; this.client = client; + this.sdkClient = sdkClient; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } public void createModelGroup(MLRegisterModelGroupInput input, ActionListener listener) { @@ -68,14 +86,13 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); - validateUniqueModelGroupName(input.getName(), ActionListener.wrap(modelGroups -> { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + validateUniqueModelGroupName(input.getName(), input.getTenantId(), ActionListener.wrap(modelGroups -> { if (modelGroups != null && modelGroups.getHits().getTotalHits() != null && modelGroups.getHits().getTotalHits().value != 0) { - Iterator iterator = modelGroups.getHits().iterator(); - while (iterator.hasNext()) { - String id = iterator.next().getId(); + for (SearchHit documentFields : modelGroups.getHits()) { + String id = documentFields.getId(); wrappedListener .onFailure( new IllegalArgumentException( @@ -99,6 +116,7 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener { - if (!res) { - wrappedListener.onFailure(new RuntimeException("No response to create ML Model Group index")); - return; - } - IndexRequest indexRequest = new IndexRequest(ML_MODEL_GROUP_INDEX); - indexRequest - .source( - mlModelGroup.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS) - ); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - client.index(indexRequest, ActionListener.wrap(r -> { - log.debug("Indexed model group doc successfully {}", modelName); - wrappedListener.onResponse(r.getId()); - }, e -> { - log.error("Failed to index model group doc", e); - wrappedListener.onFailure(e); - })); + sdkClient + .putDataObjectAsync( + PutDataObjectRequest + .builder() + .tenantId(mlModelGroup.getTenantId()) + .index(ML_MODEL_GROUP_INDEX) + .dataObject(mlModelGroup) + .build() + ) + .whenComplete((r, throwable) -> { + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + log.error("Failed to index model group", cause); + wrappedListener.onFailure(cause); + } else { + try { + IndexResponse indexResponse = IndexResponse.fromXContent(r.parser()); + log + .info( + "Model group creation result: {}, model group id: {}", + indexResponse.getResult(), + indexResponse.getId() + ); + wrappedListener.onResponse(r.id()); + } catch (Exception e) { + wrappedListener.onFailure(e); + } + } + }); + }, ex -> { log.error("Failed to init model group index", ex); wrappedListener.onFailure(ex); @@ -188,26 +219,48 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us } } - public void validateUniqueModelGroupName(String name, ActionListener listener) throws IllegalArgumentException { + public void validateUniqueModelGroupName(String name, String tenantId, ActionListener listener) + throws IllegalArgumentException { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name)); + if (tenantId != null) { + query.filter(new TermQueryBuilder(CommonValue.TENANT_ID_FIELD, tenantId)); + } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder); - client - .search( - searchRequest, - ActionListener.runBefore(ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onResponse(null); - } else { - log.error("Failed to search model group index", e); - listener.onFailure(e); - } - }), () -> context.restore()) - ); + SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest + .builder() + .indices(searchRequest.indices()) + .tenantId(tenantId) + .searchSourceBuilder(searchRequest.source()) + .tenantId(tenantId) + .build(); + + sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> { + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (ExceptionsHelper.unwrap(throwable, IndexNotFoundException.class) != null) { + log.debug("Model group index does not exist"); + listener.onResponse(null); + } else { + log.error("Failed to search model group index", cause); + listener.onFailure(cause); + } + } else { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + log.info("Model group search complete: {}", searchResponse.getHits().getTotalHits()); + listener.onResponse(searchResponse); + } catch (IOException e) { + log.error("Failed to parse search response", e); + listener + .onFailure(new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR)); + } + } + }); } catch (Exception e) { log.error("Failed to search model group index", e); listener.onFailure(e); @@ -217,19 +270,64 @@ public void validateUniqueModelGroupName(String name, ActionListener listener) { - GetRequest getRequest = new GetRequest(); - getRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - listener.onResponse(r); - } else { + public void getModelGroupResponse(SdkClient sdkClient, String modelGroupId, ActionListener listener) { + GetDataObjectRequest getRequest = buildGetModelGroupRequest(modelGroupId); + + sdkClient.getDataObjectAsync(getRequest).whenComplete((response, throwable) -> { + if (throwable != null) { + handleError(throwable, listener); + return; + } + + processModelGroupResponse(response, modelGroupId, listener); + }); + } + + private GetDataObjectRequest buildGetModelGroupRequest(String modelGroupId) { + return GetDataObjectRequest.builder().index(ML_MODEL_GROUP_INDEX).id(modelGroupId).build(); + } + + private void handleError(Throwable throwable, ActionListener listener) { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + listener.onFailure(exception); + } + + private void processModelGroupResponse(GetDataObjectResponse response, String modelGroupId, ActionListener listener) { + try { + GetResponse getResponse = parseGetResponse(response); + if (getResponse == null || !getResponse.isExists()) { listener.onFailure(new MLResourceNotFoundException("Failed to find model group with ID: " + modelGroupId)); + return; } - }, e -> { listener.onFailure(e); })); + + parseAndRespond(getResponse, listener); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private GetResponse parseGetResponse(GetDataObjectResponse response) throws IOException { + return response.parser() == null ? null : GetResponse.fromXContent(response.parser()); + } + + private void parseAndRespond(GetResponse getResponse, ActionListener listener) { + try ( + XContentParser parser = jsonXContent + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + Strings.toString(MediaTypeRegistry.JSON, getResponse) + ) + ) { + listener.onResponse(GetResponse.fromXContent(parser)); + } catch (Exception e) { + log.error("Failed to parse model group response: {}", getResponse.getId(), e); + listener.onFailure(e); + } } private void validateSecurityDisabledOrModelAccessControlDisabled(MLRegisterModelGroupInput input) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index c04caaa711..db391c9179 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -6,8 +6,10 @@ package org.opensearch.ml.model; import static org.opensearch.common.xcontent.XContentType.JSON; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; @@ -53,10 +55,10 @@ import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import java.io.File; +import java.io.IOException; import java.nio.file.Path; import java.security.PrivilegedActionException; import java.time.Instant; -import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; @@ -76,9 +78,12 @@ import org.apache.commons.lang3.BooleanUtils; import org.apache.logging.log4j.util.Strings; +import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.IndicesOptions; @@ -92,6 +97,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.TokenBucket; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -103,7 +109,6 @@ import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; -import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLModelGroup; @@ -141,6 +146,13 @@ import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.MLExceptionUtils; import org.opensearch.ml.utils.MLNodeUtils; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.GetDataObjectResponse; +import org.opensearch.remote.metadata.client.PutDataObjectRequest; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.UpdateDataObjectRequest; +import org.opensearch.remote.metadata.client.UpdateDataObjectResponse; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.script.ScriptService; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.threadpool.ThreadPool; @@ -159,14 +171,15 @@ public class MLModelManager { public static final int TIMEOUT_IN_MILLIS = 5000; - public static final long MODEL_FILE_SIZE_LIMIT = 4l * 1024 * 1024 * 1024;// 4GB + public static final long MODEL_FILE_SIZE_LIMIT = 4L * 1024 * 1024 * 1024;// 4GB private final Client client; + private final SdkClient sdkClient; private final ClusterService clusterService; private final ScriptService scriptService; - private ThreadPool threadPool; - private NamedXContentRegistry xContentRegistry; - private ModelHelper modelHelper; + private final ThreadPool threadPool; + private final NamedXContentRegistry xContentRegistry; + private final ModelHelper modelHelper; private final MLModelCacheHelper modelCacheHelper; private final MLStats mlStats; @@ -183,7 +196,7 @@ public class MLModelManager { private volatile Integer maxBatchInferenceTasks; private volatile Integer maxBatchIngestionTasks; - public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet + public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet .of( MLModelState.TRAINED, MLModelState.REGISTERED, @@ -197,6 +210,7 @@ public MLModelManager( ClusterService clusterService, ScriptService scriptService, Client client, + SdkClient sdkClient, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, ModelHelper modelHelper, @@ -211,6 +225,7 @@ public MLModelManager( MLFeatureEnabledSetting mlFeatureEnabledSetting ) { this.client = client; + this.sdkClient = sdkClient; this.threadPool = threadPool; this.xContentRegistry = xContentRegistry; this.modelHelper = modelHelper; @@ -259,7 +274,7 @@ public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, uploadMLModelMeta(mlRegisterModelMetaInput, "1", listener); } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { @@ -305,7 +320,7 @@ public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, String version, ActionListener listener) { FunctionName functionName = mlRegisterModelMetaInput.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); String modelName = mlRegisterModelMetaInput.getName(); mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { if (!res) { @@ -365,12 +380,13 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput } /** - * + * @param sdkClient metadata client * @param mlRegisterModelInput register model input for remote models * @param mlTask ML task * @param listener action listener */ public void registerMLRemoteModel( + SdkClient sdkClient, MLRegisterModelInput mlRegisterModelInput, MLTask mlTask, ActionListener listener @@ -382,52 +398,83 @@ public void registerMLRemoteModel( mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); String modelGroupId = mlRegisterModelInput.getModelGroupId(); - GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); - client.get(getModelGroupRequest, ActionListener.wrap(getModelGroupResponse -> { - if (getModelGroupResponse.isExists()) { - Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); - int updatedVersion = incrementLatestVersion(modelGroupSourceMap); - UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest( - modelGroupSourceMap, - modelGroupId, - getModelGroupResponse.getSeqNo(), - getModelGroupResponse.getPrimaryTerm(), - updatedVersion - ); - client.update(updateModelGroupRequest, ActionListener.wrap(r -> { - indexRemoteModel(mlRegisterModelInput, mlTask, updatedVersion + "", listener); - }, e -> { - log.error("Failed to update model group " + modelGroupId, e); - handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e); + GetDataObjectRequest getModelGroupRequest = GetDataObjectRequest + .builder() + .index(ML_MODEL_GROUP_INDEX) + .tenantId(mlRegisterModelInput.getTenantId()) + .id(modelGroupId) + .build(); + + sdkClient.getDataObjectAsync(getModelGroupRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + GetResponse getModelGroupResponse = GetResponse.fromXContent(r.parser()); + if (getModelGroupResponse.isExists()) { + Map modelGroupSourceMap = getModelGroupResponse.getSourceAsMap(); + int updatedVersion = incrementLatestVersion(modelGroupSourceMap); + modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion); + modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + UpdateDataObjectRequest updateDataObjectRequest = UpdateDataObjectRequest + .builder() + .index(ML_MODEL_GROUP_INDEX) + .id(modelGroupId) + .tenantId(mlRegisterModelInput.getTenantId()) + .ifSeqNo(getModelGroupResponse.getSeqNo()) + .ifPrimaryTerm(getModelGroupResponse.getPrimaryTerm()) + .dataObject(modelGroupSourceMap) + .build(); + try (ThreadContext.StoredContext innerContext = client.threadPool().getThreadContext().stashContext()) { + + sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((ur, ut) -> { + if (ut == null) { + indexRemoteModel(sdkClient, mlRegisterModelInput, mlTask, updatedVersion + "", listener); + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(ut); + log.error("Failed to update model group {}", modelGroupId, e); + handleException( + mlRegisterModelInput.getFunctionName(), + mlTask.getTaskId(), + mlRegisterModelInput.getTenantId(), + e + ); + listener.onFailure(e); + } + }); + } + } else { + log.error("Model group response is empty"); + handleException( + mlRegisterModelInput.getFunctionName(), + mlTask.getTaskId(), + mlRegisterModelInput.getTenantId(), + new MLValidationException("Model group not found") + ); + listener.onFailure(new MLResourceNotFoundException("Model Group Response is empty for " + modelGroupId)); + } + } catch (Exception e) { listener.onFailure(e); - })); - } else { - log.error("Model group response is empty"); - handleException( - mlRegisterModelInput.getFunctionName(), - mlTask.getTaskId(), - new MLValidationException("Model group not found") - ); - listener.onFailure(new MLResourceNotFoundException("Model Group Response is empty for " + modelGroupId)); - } - }, error -> { - if (error instanceof IndexNotFoundException) { - log.error("Model group Index is missing"); - handleException( - mlRegisterModelInput.getFunctionName(), - mlTask.getTaskId(), - new MLResourceNotFoundException("Failed to get model group due to index missing") - ); - listener.onFailure(error); + } } else { - log.error("Failed to get model group", error); - handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), error); - listener.onFailure(error); + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + if (e instanceof IndexNotFoundException) { + log.error("Model group Index is missing"); + handleException( + mlRegisterModelInput.getFunctionName(), + mlTask.getTaskId(), + mlRegisterModelInput.getTenantId(), + new MLResourceNotFoundException("Failed to get model group due to index missing") + ); + listener.onFailure(e); + } else { + log.error("Failed to get model group", e); + handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), mlRegisterModelInput.getTenantId(), e); + listener.onFailure(e); + } } - })); + }); } catch (Exception e) { log.error("Failed to register remote model", e); - handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e); + handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), mlRegisterModelInput.getTenantId(), e); listener.onFailure(e); } finally { mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); @@ -469,7 +516,12 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa updateModelGroupRequest, ActionListener.wrap(r -> { uploadModel(registerModelInput, mlTask, updatedVersion + ""); }, e -> { log.error("Failed to update model group", e); - handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); + handleException( + registerModelInput.getFunctionName(), + mlTask.getTaskId(), + registerModelInput.getTenantId(), + e + ); }) ); } @@ -478,6 +530,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa handleException( registerModelInput.getFunctionName(), mlTask.getTaskId(), + registerModelInput.getTenantId(), new MLValidationException("Model group not found") ); } @@ -486,19 +539,20 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa handleException( registerModelInput.getFunctionName(), mlTask.getTaskId(), + registerModelInput.getTenantId(), new MLResourceNotFoundException("Failed to get model group") ); } else { log.error("Failed to get model group", e); - handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); + handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), registerModelInput.getTenantId(), e); } - }), () -> context.restore())); + }), context::restore)); } catch (Exception e) { log.error("Failed to register model", e); - handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); + handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), registerModelInput.getTenantId(), e); } } catch (Exception e) { - handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); + handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), registerModelInput.getTenantId(), e); } finally { mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); } @@ -527,10 +581,11 @@ private UpdateRequest createUpdateModelGroupRequest( } private int incrementLatestVersion(Map modelGroupSourceMap) { - return (int) modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD) + 1; + return Integer.parseInt(modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD).toString()) + 1; } private void indexRemoteModel( + SdkClient sdkClient, MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion, @@ -570,14 +625,16 @@ private void indexRemoteModel( .isHidden(registerModelInput.getIsHidden()) .guardrails(registerModelInput.getGuardrails()) .modelInterface(registerModelInput.getModelInterface()) + .tenantId(registerModelInput.getTenantId()) .build(); - IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); - if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) { - indexModelMetaRequest.id(modelName); - } - indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); - indexModelMetaRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + PutDataObjectRequest putModelMetaRequest = PutDataObjectRequest + .builder() + .index(ML_MODEL_INDEX) + .id(Boolean.TRUE.equals(registerModelInput.getIsHidden()) ? modelName : null) + .tenantId(registerModelInput.getTenantId()) + .dataObject(mlModelMeta) + .build(); // index remote model doc ActionListener indexListener = ActionListener.wrap(modelMetaRes -> { @@ -591,15 +648,28 @@ private void indexRemoteModel( listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name(), modelId)); }, e -> { log.error("Failed to index model meta doc", e); - handleException(functionName, taskId, e); + handleException(functionName, taskId, registerModelInput.getTenantId(), e); listener.onFailure(e); }); - client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, indexListener)); + ThreadedActionListener putListener = threadedActionListener(REGISTER_THREAD_POOL, indexListener); + sdkClient.putDataObjectAsync(putModelMetaRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + IndexResponse ir = IndexResponse.fromXContent(r.parser()); + putListener.onResponse(ir); + } catch (Exception e) { + putListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + putListener.onFailure(e); + } + }); }, error -> { // failed to initialize the model index log.error("Failed to init model index", error); - handleException(functionName, taskId, error); + handleException(functionName, taskId, registerModelInput.getTenantId(), error); listener.onFailure(error); })); } @@ -619,7 +689,12 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St } mlIndicesHandler.initModelIndexIfAbsent(ActionListener.runBefore(ActionListener.wrap(res -> { if (!res) { - handleException(functionName, taskId, new RuntimeException("No response to create ML Model index")); + handleException( + functionName, + taskId, + registerModelInput.getTenantId(), + new RuntimeException("No response to create ML Model index") + ); return; } MLModel mlModelMeta = MLModel @@ -641,7 +716,17 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St .isHidden(registerModelInput.getIsHidden()) .guardrails(registerModelInput.getGuardrails()) .modelInterface(registerModelInput.getModelInterface()) + .tenantId(registerModelInput.getTenantId()) + .build(); + + PutDataObjectRequest putModelMetaRequest = PutDataObjectRequest + .builder() + .index(ML_MODEL_INDEX) + .id(Boolean.TRUE.equals(registerModelInput.getIsHidden()) ? modelName : null) + .tenantId(registerModelInput.getTenantId()) + .dataObject(mlModelMeta) .build(); + IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); if (registerModelInput.getIsHidden() != null && registerModelInput.getIsHidden()) { indexModelMetaRequest.id(modelName); @@ -659,16 +744,31 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St } }, e -> { log.error("Failed to index model meta doc", e); - handleException(functionName, taskId, e); + handleException(functionName, taskId, registerModelInput.getTenantId(), e); + }); + + ThreadedActionListener putListener = threadedActionListener(REGISTER_THREAD_POOL, indexListener); + sdkClient.putDataObjectAsync(putModelMetaRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + IndexResponse ir = IndexResponse.fromXContent(r.parser()); + putListener.onResponse(ir); + } catch (Exception e) { + putListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + putListener.onFailure(e); + } }); - client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, indexListener)); + }, e -> { log.error("Failed to init model index", e); - handleException(functionName, taskId, e); - }), () -> context.restore())); + handleException(functionName, taskId, registerModelInput.getTenantId(), e); + }), context::restore)); } catch (Exception e) { logException("Failed to upload model", e, log); - handleException(functionName, taskId, e); + handleException(functionName, taskId, registerModelInput.getTenantId(), e); } } @@ -692,7 +792,12 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas Instant now = Instant.now(); mlIndicesHandler.initModelIndexIfAbsent(ActionListener.runBefore(ActionListener.wrap(res -> { if (!res) { - handleException(functionName, taskId, new RuntimeException("No response to create ML Model index")); + handleException( + functionName, + taskId, + registerModelInput.getTenantId(), + new RuntimeException("No response to create ML Model index") + ); return; } MLModel mlModelMeta = MLModel @@ -731,16 +836,16 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas registerModel(registerModelInput, taskId, functionName, modelName, version, modelId); }, e -> { log.error("Failed to index model meta doc", e); - handleException(functionName, taskId, e); + handleException(functionName, taskId, registerModelInput.getTenantId(), e); }); client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, listener)); }, e -> { log.error("Failed to init model index", e); - handleException(functionName, taskId, e); - }), () -> context.restore())); + handleException(functionName, taskId, registerModelInput.getTenantId(), e); + }), context::restore)); } catch (Exception e) { logException("Failed to register model", e, log); - handleException(functionName, taskId, e); + handleException(functionName, taskId, registerModelInput.getTenantId(), e); } } @@ -824,9 +929,9 @@ private void registerModel( } semaphore.release(); }, e -> { - log.error("Failed to index model chunk " + chunkId, e); + log.error("Failed to index model chunk {}", chunkId, e); failedToUploadChunk.set(true); - handleException(functionName, taskId, e); + handleException(functionName, taskId, registerModelInput.getTenantId(), e); deleteFileQuietly(file); // remove model doc as failed to upload model deleteModel(modelId, registerModelInput, version); @@ -838,7 +943,7 @@ private void registerModel( log.error("Failed to index chunk file", e); deleteFileQuietly(mlEngine.getRegisterModelPath(modelId)); deleteModel(modelId, registerModelInput, version); - handleException(functionName, taskId, e); + handleException(functionName, taskId, registerModelInput.getTenantId(), e); }) ); } @@ -857,7 +962,7 @@ private void registerPrebuiltModel(MLRegisterModelInput registerModelInput, MLTa registerModelFromUrl(mlRegisterModelInput, mlTask, modelVersion); }, e -> { log.error("Failed to register prebuilt model", e); - handleException(registerModelInput.getFunctionName(), taskId, e); + handleException(registerModelInput.getFunctionName(), taskId, registerModelInput.getTenantId(), e); })); } @@ -890,6 +995,16 @@ public void checkMaxBatchJobTask(MLTask mlTask, ActionListener listener mlTaskManager.checkMaxBatchJobTask(taskType, maxLimit, listener); } + /** + * Update model register state as done. This is only for local model. Not for remote model. + * @param registerModelInput model input for local model registration + * @param taskId id of the task + * @param modelId id of the model + * @param modelSizeInBytes size of the model in bytes + * @param chunkFiles list of chunk files + * @param hashValue model hash value + * @param version model version + */ private void updateModelRegisterStateAsDone( MLRegisterModelInput registerModelInput, String taskId, @@ -915,14 +1030,16 @@ private void updateModelRegisterStateAsDone( modelSizeInBytes ); log.info("Model registered successfully, model id: {}, task id: {}", modelId, taskId); - updateModel(modelId, updatedFields, ActionListener.wrap(updateResponse -> { + + // For local model we don't support multi-tenancy. So we are providing tenant Id null by default. + updateModel(modelId, null, updatedFields, ActionListener.wrap(updateResponse -> { mlTaskManager.updateMLTask(taskId, Map.of(STATE_FIELD, COMPLETED, MODEL_ID_FIELD, modelId), TIMEOUT_IN_MILLIS, true); if (registerModelInput.isDeployModel()) { deployModelAfterRegistering(registerModelInput, modelId); } }, e -> { log.error("Failed to update model", e); - handleException(functionName, taskId, e); + handleException(functionName, taskId, registerModelInput.getTenantId(), e); deleteModel(modelId, registerModelInput, version); })); } @@ -975,13 +1092,15 @@ private void updateLatestVersionInModelGroup(String modelGroupID, Integer latest updateRequest.doc(updatedFields); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); + client.update(updateRequest, ActionListener.runBefore(listener, context::restore)); } catch (Exception e) { listener.onFailure(e); } } - private void handleException(FunctionName functionName, String taskId, Exception e) { + // I will raise another PR where tenantId will be provided to the task manager. That time + // I will refactor updateMLTask method. + private void handleException(FunctionName functionName, String taskId, String tenantId, Exception e) { if (!(e instanceof MLLimitExceededException) && !(e instanceof MLResourceNotFoundException) && !(e instanceof IllegalArgumentException)) { @@ -992,8 +1111,15 @@ private void handleException(FunctionName functionName, String taskId, Exception mlTaskManager.updateMLTask(taskId, updated, TIMEOUT_IN_MILLIS, true); } + /** + * Get connector credential by connector id. + * This code is developed only for batch. And currently multi-tenancy isn't implemented in batch + * so by default, tenantId is provided null when we invoke getConnector method + * @param connectorId connector id + * @param connectorCredentialListener listener + */ public void getConnectorCredential(String connectorId, ActionListener> connectorCredentialListener) { - getConnector(connectorId, ActionListener.wrap(connector -> { + getConnector(connectorId, null, ActionListener.wrap(connector -> { Map credential = mlEngine.getConnectorCredential(connector); connectorCredentialListener.onResponse(credential); log.info("Completed loading credential in the connector {}", connectorId); @@ -1006,6 +1132,7 @@ public void getConnectorCredential(String connectorId, ActionListener { + this.getModel(modelId, tenantId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); modelCacheHelper.setModelInfo(modelId, mlModel); if (FunctionName.REMOTE == mlModel.getAlgorithm() @@ -1072,8 +1200,8 @@ public void deployModel( }, e -> { log .error( - "Trying to deploy remote model with exceptions in re-deploying its model controller. Model ID: " - + modelId, + "Trying to deploy remote model with exceptions in re-deploying its model controller. Model ID: {}", + modelId, e ); deployRemoteOrBuiltInModel(mlModel, eligibleNodeCount, wrappedListener); @@ -1135,11 +1263,11 @@ public void deployModel( } } }, e -> { - log.error("Failed to retrieve model " + modelId, e); + log.error("Failed to retrieve model {}", modelId, e); handleDeployModelException(modelId, functionName, wrappedListener, e); })); }, e -> { - log.error("Failed to deploy model " + modelId, e); + log.error("Failed to deploy model {}", modelId, e); handleDeployModelException(modelId, functionName, wrappedListener, e); }))); } catch (Exception e) { @@ -1149,20 +1277,154 @@ public void deployModel( } } - public void deployRemoteModelToLocal(String modelId, MLModel mlModel, ActionListener listener) { + /** + * Read model chunks from model index. Concat chunks into a whole model file, + * then load + * into memory. + * + * TODO: I'll remove this method later. Currently this method is being used in multiple classes. + * + * @param modelId model id + * @param modelContentHash model content hash value + * @param functionName function name + * @param mlTask ML task + * @param listener action listener + */ + public void deployModel( + String modelId, + String modelContentHash, + FunctionName functionName, + boolean deployToAllNodes, + boolean autoDeployModel, + MLTask mlTask, + ActionListener listener + ) { + mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); + mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); + mlStats.createModelCounterStatIfAbsent(modelId, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); + List workerNodes = mlTask.getWorkerNodes(); if (modelCacheHelper.isModelDeployed(modelId)) { - listener.onResponse("Success"); + if (!autoDeployModel && workerNodes != null && !workerNodes.isEmpty()) { + log.info("Set new target node ids {} for model {}", Arrays.toString(workerNodes.toArray(new String[0])), modelId); + modelCacheHelper.setDeployToAllNodes(modelId, deployToAllNodes); + modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes); + modelCacheHelper.refreshLastAccessTime(modelId); + } + listener.onResponse("successful"); + return; + } + if (functionName != FunctionName.REMOTE && modelCacheHelper.getLocalDeployedModels().length >= maxModelPerNode) { + listener.onFailure(new IllegalArgumentException("Exceed max local model per node limit")); return; } - modelCacheHelper - .initModelState(modelId, MLModelState.DEPLOYING, FunctionName.REMOTE, new ArrayList<>(), mlModel.isDeployToAllNodes()); + int eligibleNodeCount = workerNodes.size(); + if (!autoDeployModel) { + modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes); + } else { + modelCacheHelper.initModelStateAutoDeploy(modelId, MLModelState.DEPLOYING, functionName, workerNodes); + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); - modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); - deployRemoteOrBuiltInModel(mlModel, 1, wrappedListener); + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> { + context.restore(); + modelCacheHelper.removeAutoDeployModel(modelId); + modelCacheHelper.setIsAutoDeploying(modelId, false); + }); + if (!autoDeployModel) { + checkAndAddRunningTask(mlTask, maxDeployTasksPerNode); + } + this.getModel(modelId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { + modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); + modelCacheHelper.setModelInfo(modelId, mlModel); + if (FunctionName.REMOTE == mlModel.getAlgorithm() + || (!FunctionName.isDLModel(mlModel.getAlgorithm()) && mlModel.getAlgorithm() != FunctionName.METRICS_CORRELATION)) { + // deploy remote model or model trained by built-in algorithm like kmeans + // deploy remote model with internal connector or model trained by built-in + // algorithm like kmeans + if (BooleanUtils.isTrue(mlModel.getIsControllerEnabled())) { + getController(modelId, ActionListener.wrap(controller -> { + setupUserRateLimiterMap(modelId, eligibleNodeCount, controller.getUserRateLimiter()); + log.info("Successfully redeployed model controller for model " + modelId); + log.info("Trying to deploy remote model with model controller configured."); + deployRemoteOrBuiltInModel(mlModel, eligibleNodeCount, wrappedListener); + }, e -> { + log + .error( + "Trying to deploy remote model with exceptions in re-deploying its model controller. Model ID: " + + modelId, + e + ); + deployRemoteOrBuiltInModel(mlModel, eligibleNodeCount, wrappedListener); + })); + return; + } else { + log.info("Trying to deploy remote or built-in model without model controller configured."); + deployRemoteOrBuiltInModel(mlModel, eligibleNodeCount, wrappedListener); + } + return; + } + + setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); + setupMLGuard(modelId, mlModel.getGuardrails()); + setupModelInterface(modelId, mlModel.getModelInterface()); + deployControllerWithDeployingModel(mlModel, eligibleNodeCount); + // check circuit breaker before deploying custom model chunks + checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); + retrieveModelChunks(mlModel, ActionListener.wrap(modelZipFile -> {// read model chunks + String hash = calculateFileHash(modelZipFile); + if (modelContentHash != null && !modelContentHash.equals(hash)) { + log.error("Model content hash can't match original hash value"); + removeModel(modelId); + wrappedListener.onFailure(new IllegalArgumentException("model content changed")); + return; + } + log.debug("Model content matches original hash value, continue deploying"); + Map params = Map.of(MODEL_ZIP_FILE, modelZipFile, MODEL_HELPER, modelHelper, ML_ENGINE, mlEngine); + if (FunctionName.METRICS_CORRELATION.equals(mlModel.getAlgorithm())) { + MLExecutable mlExecutable = mlEngine.deployExecute(mlModel, params); + try { + modelCacheHelper.setMLExecutor(modelId, mlExecutable); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); + modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.refreshLastAccessTime(modelId); + wrappedListener.onResponse("successful"); + } catch (Exception e) { + log.error("Failed to add predictor to cache", e); + mlExecutable.close(); + wrappedListener.onFailure(e); + } + } else { + Predictable predictable = mlEngine.deploy(mlModel, params); + try { + modelCacheHelper.setPredictor(modelId, predictable); + mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); + modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.refreshLastAccessTime(modelId); + Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes(); + long contentSize = modelContentSizeInBytes == null + ? mlModel.getTotalChunks() * CHUNK_SIZE + : modelContentSizeInBytes; + modelCacheHelper.setMemSizeEstimation(modelId, mlModel.getModelFormat(), contentSize); + wrappedListener.onResponse("successful"); + } catch (Exception e) { + log.error("Failed to add predictor to cache", e); + predictable.close(); + wrappedListener.onFailure(e); + } + } + }, e -> { + log.error("Failed to retrieve model " + modelId, e); + handleDeployModelException(modelId, functionName, wrappedListener, e); + })); + }, e -> { + log.error("Failed to deploy model " + modelId, e); + handleDeployModelException(modelId, functionName, wrappedListener, e); + }))); } catch (Exception e) { - log.error("Failed to deploy model to local node" + modelId, e); - listener.onFailure(e); + handleDeployModelException(modelId, functionName, listener, e); + } finally { + mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement(); } } @@ -1180,7 +1442,7 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou return; } log.info("Set connector {} for the model: {}", mlModel.getConnectorId(), modelId); - getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); @@ -1258,7 +1520,7 @@ public synchronized void updateModelCache(String modelId, ActionListener wrappedListener.onResponse("Successfully updated model cache for the remote model " + modelId); log.info("Completed the model cache update for the remote model {}", modelId); } else { - getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); wrappedListener.onResponse("Successfully updated model cache for the remote model " + modelId); @@ -1270,7 +1532,7 @@ public synchronized void updateModelCache(String modelId, ActionListener log.info("Completed the model cache update for the model {}", modelId); }, wrappedListener::onFailure)); } catch (Exception e) { - log.error("Failed to updated model cache for the model " + modelId, e); + log.error("Failed to updated model cache for the model {}", modelId, e); listener.onFailure(e); } } @@ -1306,7 +1568,7 @@ public synchronized void deployControllerWithDeployedModel(String modelId, Actio wrappedListener.onResponse("Successfully deployed model controller for the remote model " + modelId); log.info("Deployed model controller for the remote model {}", modelId); } else { - getConnector(mlModel.getConnectorId(), ActionListener.wrap(connector -> { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); wrappedListener.onResponse("Successfully deployed model controller for the remote model " + modelId); @@ -1320,7 +1582,7 @@ public synchronized void deployControllerWithDeployedModel(String modelId, Actio }, wrappedListener::onFailure)); }, wrappedListener::onFailure)); } catch (Exception e) { - log.error("Failed to deploy model controller for the model " + modelId, e); + log.error("Failed to deploy model controller for the model {}", modelId, e); listener.onFailure(e); } } @@ -1344,7 +1606,7 @@ public synchronized void undeployController(String modelId, ActionListener { + getConnector(mlModel.getConnectorId(), mlModel.getTenantId(), ActionListener.wrap(connector -> { mlModel.setConnector(connector); setupParamsAndPredictable(modelId, mlModel); wrappedListener.onResponse("Successfully undeployed model controller for the remote model " + modelId); @@ -1357,14 +1619,14 @@ public synchronized void undeployController(String modelId, ActionListener log.error("Failed to re-deploy the model controller for model: " + mlModel.getModelId(), e))); + }, e -> log.error("Failed to re-deploy the model controller for model: {}", mlModel.getModelId(), e))); } private void setupRateLimiter(String modelId, Integer eligibleNodeCount, MLRateLimiter rateLimiter) { @@ -1579,8 +1841,8 @@ public Map getModelInterface(String modelId) { /** * Set up ML guard with model id. * - * @param modelId - * @param guardrails + * @param modelId model id + * @param guardrails guardrail for the model */ private void setupMLGuard(String modelId, Guardrails guardrails) { @@ -1613,38 +1875,136 @@ public MLGuard getMLGuard(String modelId) { * @param listener action listener */ public void getModel(String modelId, ActionListener listener) { - getModel(modelId, null, null, listener); + getModel(modelId, null, listener); + } + + /** + * Get model from model index. + * + * @param modelId model id + * @param tenantId tenant id + * @param listener action listener + */ + public void getModel(String modelId, String tenantId, ActionListener listener) { + getModel(modelId, tenantId, null, null, listener); } /** * Get model from model index with includes/excludes filter. * + * TODO: I will remove this method later. Currently other classes are invoking this method. * @param modelId model id * @param includes fields included * @param excludes fields excluded * @param listener action listener */ public void getModel(String modelId, String[] includes, String[] excludes, ActionListener listener) { - GetRequest getRequest = new GetRequest(); - FetchSourceContext fetchContext = new FetchSourceContext(true, includes, excludes); - getRequest.index(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchContext); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - String algorithmName = r.getSource().get(ALGORITHM_FIELD).toString(); - - MLModel mlModel = MLModel.parse(parser, algorithmName); - mlModel.setModelId(modelId); - listener.onResponse(mlModel); + GetDataObjectRequest getRequest = GetDataObjectRequest + .builder() + .index(ML_MODEL_INDEX) + .id(modelId) + .fetchSourceContext(new FetchSourceContext(true, includes, excludes)) + .build(); + sdkClient.getDataObjectAsync(getRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + String algorithmName = r.source().get(ALGORITHM_FIELD).toString(); + + MLModel mlModel = MLModel.parse(parser, algorithmName); + mlModel.setModelId(modelId); + listener.onResponse(mlModel); + } catch (Exception e) { + log.error("Failed to parse ml task{}", r.id(), e); + listener.onFailure(e); + } + } else { + listener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); + } } catch (Exception e) { - log.error("Failed to parse ml task" + r.getId(), e); listener.onFailure(e); } } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable); + listener.onFailure(e); + } + }); + } + + /** + * Get model from model index with includes/excludes filter. + * + * @param modelId model id + * @param tenantId tenant id + * @param includes fields included + * @param excludes fields excluded + * @param listener action listener + */ + public void getModel(String modelId, String tenantId, String[] includes, String[] excludes, ActionListener listener) { + GetDataObjectRequest getRequest = buildGetRequest(modelId, tenantId, includes, excludes); + + sdkClient.getDataObjectAsync(getRequest).whenComplete((response, throwable) -> { + if (throwable != null) { + handleError(throwable, listener); + return; + } + + processGetResponse(response, modelId, listener); + }); + } + + private GetDataObjectRequest buildGetRequest(String modelId, String tenantId, String[] includes, String[] excludes) { + return GetDataObjectRequest + .builder() + .index(ML_MODEL_INDEX) + .id(modelId) + .tenantId(tenantId) + .fetchSourceContext(new FetchSourceContext(true, includes, excludes)) + .build(); + } + + private void handleError(Throwable throwable, ActionListener listener) { + Exception exception = SdkClientUtils.unwrapAndConvertToException(throwable); + listener.onFailure(exception); + } + + private void processGetResponse(GetDataObjectResponse response, String modelId, ActionListener listener) { + try { + GetResponse getResponse = parseGetResponse(response); + if (getResponse == null || !getResponse.isExists()) { listener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); + return; } - }, listener::onFailure)); + + parseAndReturnModel(getResponse, response.source().get(ALGORITHM_FIELD).toString(), modelId, listener); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private GetResponse parseGetResponse(GetDataObjectResponse response) throws IOException { + return response.parser() == null ? null : GetResponse.fromXContent(response.parser()); + } + + private void parseAndReturnModel(GetResponse getResponse, String algorithmName, String modelId, ActionListener listener) { + try ( + XContentParser parser = jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModel mlModel = MLModel.parse(parser, algorithmName); + mlModel.setModelId(modelId); + listener.onResponse(mlModel); + } catch (Exception e) { + log.error("Failed to parse ml model {}", modelId, e); + listener.onFailure(e); + } } /** @@ -1663,7 +2023,7 @@ public void getController(String modelId, ActionListener listener) MLController controller = MLController.parse(parser); listener.onResponse(controller); } catch (Exception e) { - log.error("Failed to parse ml task" + r.getId(), e); + log.error("Failed to parse ml task{}", r.getId(), e); listener.onFailure(e); } } else { @@ -1676,37 +2036,51 @@ public void getController(String modelId, ActionListener listener) * Get connector from connector index. * * @param connectorId connector id + * @param tenantId tenant id * @param listener action listener */ - public void getConnector(String connectorId, ActionListener listener) { - GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try ( - XContentParser parser = MLNodeUtils - .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector connector = Connector.createConnector(parser); - wrappedListener.onResponse(connector); - } catch (Exception e) { - log.error("Failed to parse connector:" + connectorId); - wrappedListener.onFailure(e); - } + private void getConnector(String connectorId, String tenantId, ActionListener listener) { + GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest + .builder() + .index(ML_CONNECTOR_INDEX) + .id(connectorId) + .tenantId(tenantId) + .build(); + + sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((r, throwable) -> { + log.debug("Completed Get Connector Request, id:{}", connectorId); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { + log.error("Failed to get connector index", cause); + listener.onFailure(new OpenSearchStatusException("Failed to find connector", RestStatus.NOT_FOUND)); } else { - wrappedListener - .onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND)); + log.error("Failed to get ML connector {}", connectorId, cause); + listener.onFailure(cause); } - }, e -> { - log.error("Failed to get connector", e); - wrappedListener.onFailure(new OpenSearchStatusException("Failed to get connector:" + connectorId, RestStatus.NOT_FOUND)); - })); - } catch (Exception e) { - log.error("Failed to get connector", e); - listener.onFailure(e); - } + } else { + try { + GetResponse gr = r.parser() == null ? null : GetResponse.fromXContent(r.parser()); + if (gr != null && gr.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, gr.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector connector = Connector.createConnector(parser); + listener.onResponse(connector); + } catch (Exception e) { + log.error("Failed to parse connector:" + connectorId); + listener.onFailure(e); + } + } else { + listener.onFailure(new OpenSearchStatusException("Failed to find connector:" + connectorId, RestStatus.NOT_FOUND)); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + }); } /** @@ -1748,7 +2122,7 @@ private void retrieveModelChunks(MLModel mlModelMeta, ActionListener liste }, e -> { stopNow.set(true); semaphore.release(); - log.error("Failed to retrieve model chunk " + modelChunkId, e); + log.error("Failed to retrieve model chunk {}", modelChunkId, e); if (retrievedChunks.get() == totalChunks - 1) { listener.onFailure(new MLResourceNotFoundException("Fail to find model chunk " + modelChunkId)); } @@ -1760,6 +2134,70 @@ private void retrieveModelChunks(MLModel mlModelMeta, ActionListener liste * Update model with build-in listener. * * @param modelId model id + * @param tenantId tenant id + * @param updatedFields updated fields + */ + public void updateModel(String modelId, String tenantId, Boolean isHidden, Map updatedFields) { + updateModel(modelId, tenantId, updatedFields, ActionListener.wrap(response -> { + if (response.status() == RestStatus.OK) { + log.debug(getErrorMessage("Updated ML model successfully: {}", modelId, isHidden), response.status()); + } else { + log.error(getErrorMessage("Failed to update provided ML model, status: {}", modelId, isHidden), response.status()); + } + }, e -> { log.error(getErrorMessage("Failed to update the provided ML model", modelId, isHidden), e); })); + } + + /** + * Update model. + * + * @param modelId model id + * @param tenantId tenant id + * @param updatedFields updated fields + * @param listener action listener + */ + public void updateModel(String modelId, String tenantId, Map updatedFields, ActionListener listener) { + if (updatedFields == null || updatedFields.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Updated fields is null or empty")); + return; + } + Map newUpdatedFields = new HashMap<>(); + newUpdatedFields.putAll(updatedFields); + newUpdatedFields.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); + + UpdateDataObjectRequest.Builder requestBuilder = UpdateDataObjectRequest + .builder() + .index(ML_MODEL_INDEX) + .id(modelId) + .tenantId(tenantId) + .dataObject(newUpdatedFields); + + // Conditionally add retryOnConflict based on the provided condition + if (updatedFields.containsKey(MLModel.MODEL_STATE_FIELD) + && MODEL_DONE_STATES.contains(newUpdatedFields.get(MLModel.MODEL_STATE_FIELD))) { + requestBuilder.retryOnConflict(3); + } + + // Build the request + UpdateDataObjectRequest updateDataObjectRequest = requestBuilder.build(); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + + sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((r, throwable) -> { + context.restore(); // Restore the context once the operation is done + handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(modelId, listener)); + }); + + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Update model with build-in listener. + * + * TODO: Will remove this method later. Currently it's being used by multiple classes. + * + * @param modelId model id * @param updatedFields updated fields */ public void updateModel(String modelId, Boolean isHidden, Map updatedFields) { @@ -1775,32 +2213,81 @@ public void updateModel(String modelId, Boolean isHidden, Map up /** * Update model. * + * * TODO: Will remove this method later. Currently it's being used by multiple classes. + * * @param modelId model id * @param updatedFields updated fields * @param listener action listener */ public void updateModel(String modelId, Map updatedFields, ActionListener listener) { - if (updatedFields == null || updatedFields.size() == 0) { + if (updatedFields == null || updatedFields.isEmpty()) { listener.onFailure(new IllegalArgumentException("Updated fields is null or empty")); return; } Map newUpdatedFields = new HashMap<>(); newUpdatedFields.putAll(updatedFields); newUpdatedFields.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); - UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); - updateRequest.doc(newUpdatedFields); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - if (newUpdatedFields.containsKey(MLModel.MODEL_STATE_FIELD) + + UpdateDataObjectRequest.Builder requestBuilder = UpdateDataObjectRequest + .builder() + .index(ML_MODEL_INDEX) + .id(modelId) + .dataObject(newUpdatedFields); + + // Conditionally add retryOnConflict based on the provided condition + if (updatedFields.containsKey(MLModel.MODEL_STATE_FIELD) && MODEL_DONE_STATES.contains(newUpdatedFields.get(MLModel.MODEL_STATE_FIELD))) { - updateRequest.retryOnConflict(3); + requestBuilder.retryOnConflict(3); } + + // Build the request + UpdateDataObjectRequest updateDataObjectRequest = requestBuilder.build(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); + + sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((r, throwable) -> { + context.restore(); // Restore the context once the operation is done + handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(modelId, listener)); + }); + } catch (Exception e) { listener.onFailure(e); } } + private void handleUpdateDataObjectCompletionStage( + UpdateDataObjectResponse r, + Throwable throwable, + ActionListener updateListener + ) { + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + updateListener.onFailure(cause); + } else { + try { + UpdateResponse updateResponse = r.parser() == null ? null : UpdateResponse.fromXContent(r.parser()); + updateListener.onResponse(updateResponse); + } catch (IOException e) { + updateListener.onFailure(e); + } + } + } + + private ActionListener getUpdateResponseListener(String modelId, ActionListener actionListener) { + return ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + log.error("Failed to update the model with ID: {}", modelId); + actionListener.onResponse(updateResponse); + return; + } + log.info("Successfully updated the model with ID: {}", modelId); + actionListener.onResponse(updateResponse); + }, exception -> { + log.error("Failed to update ML model with ID {}. Details: {}", modelId, exception); + actionListener.onFailure(exception); + }); + } + /** * Get model chunk id. * diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index cffe57e381..4d4bc788f0 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -343,6 +343,7 @@ public class MachineLearningPlugin extends Plugin SystemIndexPlugin { public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons."; public static final String GENERAL_THREAD_POOL = "opensearch_ml_general"; + public static final String SDK_CLIENT_THREAD_POOL = "opensearch_ml_sdkclient"; public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute"; public static final String TRAIN_THREAD_POOL = "opensearch_ml_train"; public static final String PREDICT_THREAD_POOL = "opensearch_ml_predict"; @@ -511,7 +512,11 @@ public Collection createComponents( Map.entry(TENANT_AWARE_KEY, "true"), Map.entry(TENANT_ID_FIELD_KEY, TENANT_ID_FIELD) ) - : Collections.emptyMap() + : Collections.emptyMap(), + // For node client / local cluster it won't use this thread pool + // but we haven't update the ddbclient to async for which we are keeping it like this + // todo: need to update this when ddbclient async is going to be implemented. + client.threadPool().executor(ThreadPool.Names.GENERIC) ); mlEngine = new MLEngine(dataPath, encryptor); @@ -557,6 +562,7 @@ public Collection createComponents( clusterService, scriptService, client, + sdkClient, threadPool, xContentRegistry, modelHelper, @@ -753,8 +759,8 @@ public List getRestHandlers( RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction(); RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager, mlFeatureEnabledSetting); RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction(mlFeatureEnabledSetting); - RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(); - RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction(); + RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); + RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting); RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction(); RestMLGetTaskAction restMLGetTaskAction = new RestMLGetTaskAction(); RestMLDeleteTaskAction restMLDeleteTaskAction = new RestMLDeleteTaskAction(); @@ -770,12 +776,12 @@ public List getRestHandlers( RestMLUndeployModelAction restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings); RestMLRegisterModelMetaAction restMLRegisterModelMetaAction = new RestMLRegisterModelMetaAction(clusterService, settings); RestMLUploadModelChunkAction restMLUploadModelChunkAction = new RestMLUploadModelChunkAction(clusterService, settings); - RestMLRegisterModelGroupAction restMLCreateModelGroupAction = new RestMLRegisterModelGroupAction(); - RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(); - RestMLGetModelGroupAction restMLGetModelGroupAction = new RestMLGetModelGroupAction(); - RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(); - RestMLUpdateModelAction restMLUpdateModelAction = new RestMLUpdateModelAction(); - RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); + RestMLRegisterModelGroupAction restMLCreateModelGroupAction = new RestMLRegisterModelGroupAction(mlFeatureEnabledSetting); + RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(mlFeatureEnabledSetting); + RestMLGetModelGroupAction restMLGetModelGroupAction = new RestMLGetModelGroupAction(mlFeatureEnabledSetting); + RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(mlFeatureEnabledSetting); + RestMLUpdateModelAction restMLUpdateModelAction = new RestMLUpdateModelAction(mlFeatureEnabledSetting); + RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(mlFeatureEnabledSetting); RestMLCreateConnectorAction restMLCreateConnectorAction = new RestMLCreateConnectorAction(mlFeatureEnabledSetting); RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(clusterService, settings, mlFeatureEnabledSetting); RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(mlFeatureEnabledSetting); @@ -873,6 +879,16 @@ public List> getExecutorBuilders(Settings settings) { ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL, false ); + + FixedExecutorBuilder sdkClientThreadPool = new FixedExecutorBuilder( + settings, + SDK_CLIENT_THREAD_POOL, + OpenSearchExecutors.allocatedProcessors(settings) * 4, + 10000, + ML_THREAD_POOL_PREFIX + SDK_CLIENT_THREAD_POOL, + false + ); + FixedExecutorBuilder registerModelThreadPool = new FixedExecutorBuilder( settings, REGISTER_THREAD_POOL, @@ -939,7 +955,8 @@ public List> getExecutorBuilders(Settings settings) { trainThreadPool, predictThreadPool, remotePredictThreadPool, - batchIngestThreadPool + batchIngestThreadPool, + sdkClientThreadPool ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelAction.java index d8e0b9f3b6..191c086927 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -15,6 +16,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -27,7 +29,11 @@ public class RestMLDeleteModelAction extends BaseRestHandler { private static final String ML_DELETE_MODEL_ACTION = "ml_delete_model_action"; - public void RestMLDeleteModelAction() {} + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + public RestMLDeleteModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -43,8 +49,8 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String modelId = request.param(PARAMETER_MODEL_ID); - - MLModelDeleteRequest mlModelDeleteRequest = new MLModelDeleteRequest(modelId); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); + MLModelDeleteRequest mlModelDeleteRequest = new MLModelDeleteRequest(modelId, tenantId); return channel -> client.execute(MLModelDeleteAction.INSTANCE, mlModelDeleteRequest, new RestToXContentListener<>(channel)); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelGroupAction.java index c72fb7959a..08399e7f2e 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelGroupAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -15,6 +16,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -26,8 +28,11 @@ */ public class RestMLDeleteModelGroupAction extends BaseRestHandler { private static final String ML_DELETE_MODEL_GROUP_ACTION = "ml_delete_model_group_action"; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; - public void RestMLDeleteModelGroupAction() {} + public RestMLDeleteModelGroupAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -48,8 +53,8 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String modelGroupId = request.param(PARAMETER_MODEL_GROUP_ID); - - MLModelGroupDeleteRequest mlModelGroupDeleteRequest = new MLModelGroupDeleteRequest(modelGroupId); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); + MLModelGroupDeleteRequest mlModelGroupDeleteRequest = new MLModelGroupDeleteRequest(modelGroupId, tenantId); return channel -> client .execute(MLModelGroupDeleteAction.INSTANCE, mlModelGroupDeleteRequest, new RestToXContentListener<>(channel)); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java index 097bc6fb77..75ec55acd0 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; import static org.opensearch.ml.utils.RestActionUtils.returnContent; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -17,6 +18,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelGetRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -27,10 +29,14 @@ public class RestMLGetModelAction extends BaseRestHandler { private static final String ML_GET_MODEL_ACTION = "ml_get_model_action"; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + /** * Constructor */ - public RestMLGetModelAction() {} + public RestMLGetModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -59,7 +65,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client MLModelGetRequest getRequest(RestRequest request) throws IOException { String modelId = getParameterId(request, PARAMETER_MODEL_ID); boolean returnContent = returnContent(request); - - return new MLModelGetRequest(modelId, returnContent, true); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); + return new MLModelGetRequest(modelId, returnContent, true, tenantId); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelGroupAction.java index 2869198ef0..27cadb4d8e 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelGroupAction.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -16,6 +17,7 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -25,11 +27,14 @@ public class RestMLGetModelGroupAction extends BaseRestHandler { private static final String ML_GET_MODEL_GROUP_ACTION = "ml_get_model_group_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** * Constructor */ - public RestMLGetModelGroupAction() {} + public RestMLGetModelGroupAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -59,7 +64,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client @VisibleForTesting MLModelGroupGetRequest getRequest(RestRequest request) throws IOException { String modelGroupId = getParameterId(request, PARAMETER_MODEL_GROUP_ID); - - return new MLModelGroupGetRequest(modelGroupId); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); + return new MLModelGroupGetRequest(modelGroupId, tenantId); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java index 68fd73b20a..a20fb5bd0f 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelAction.java @@ -12,6 +12,7 @@ import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_DEPLOY_MODEL; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_VERSION; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -93,10 +94,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client */ @VisibleForTesting MLRegisterModelRequest getRequest(RestRequest request) throws IOException { + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); boolean loadModel = request.paramAsBoolean(PARAMETER_DEPLOY_MODEL, false); XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLRegisterModelInput mlInput = MLRegisterModelInput.parse(parser, loadModel); + mlInput.setTenantId(tenantId); if (mlInput.getFunctionName() == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } else if (FunctionName.isDLModel(mlInput.getFunctionName()) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelGroupAction.java index 4c765732a2..95e8261f76 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLRegisterModelGroupAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -18,6 +19,7 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -27,11 +29,14 @@ public class RestMLRegisterModelGroupAction extends BaseRestHandler { private static final String ML_REGISTER_MODEL_GROUP_ACTION = "ml_register_model_group_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; /** * Constructor */ - public RestMLRegisterModelGroupAction() {} + public RestMLRegisterModelGroupAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -59,12 +64,14 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client @VisibleForTesting MLRegisterModelGroupRequest getRequest(RestRequest request) throws IOException { boolean hasContent = request.hasContent(); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); if (!hasContent) { throw new OpenSearchParseException("Model group request has empty body"); } XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLRegisterModelGroupInput input = MLRegisterModelGroupInput.parse(parser); + input.setTenantId(tenantId); return new MLRegisterModelGroupRequest(input); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java index b8e55f9152..d6123c57d7 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchModelGroupAction.java @@ -10,6 +10,7 @@ import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import com.google.common.collect.ImmutableList; @@ -19,9 +20,12 @@ public class RestMLSearchModelGroupAction extends AbstractMLSearchAction { private static final String ML_SEARCH_MODEL_GROUP_ACTION = "ml_search_model_group_action"; private static final String SEARCH_MODEL_GROUP_PATH = ML_BASE_URI + "/model_groups/_search"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; - public RestMLSearchModelGroupAction() { + public RestMLSearchModelGroupAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { super(ImmutableList.of(SEARCH_MODEL_GROUP_PATH), ML_MODEL_GROUP_INDEX, MLModelGroup.class, MLModelGroupSearchAction.INSTANCE); + + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java index b6e3822318..3398e7b24f 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -65,11 +66,12 @@ private MLUpdateConnectorRequest getRequest(RestRequest request) throws IOExcept } String connectorId = getParameterId(request, PARAMETER_CONNECTOR_ID); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); try { XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - return MLUpdateConnectorRequest.parse(parser, connectorId); + return MLUpdateConnectorRequest.parse(parser, connectorId, tenantId); } catch (IllegalStateException illegalStateException) { throw new OpenSearchParseException(illegalStateException.getMessage()); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java index 5a40ae8c47..5f71c66633 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -22,6 +23,7 @@ import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -31,6 +33,11 @@ public class RestMLUpdateModelAction extends BaseRestHandler { private static final String ML_UPDATE_MODEL_ACTION = "ml_update_model_action"; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + public RestMLUpdateModelAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } @Override public String getName() { @@ -61,6 +68,7 @@ private MLUpdateModelRequest getRequest(RestRequest request) throws IOException } String modelId = getParameterId(request, PARAMETER_MODEL_ID); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -76,6 +84,7 @@ private MLUpdateModelRequest getRequest(RestRequest request) throws IOException input.setModelId(modelId); input.setVersion(null); input.setUpdatedConnector(null); + input.setTenantId(tenantId); return new MLUpdateModelRequest(input); } catch (IllegalStateException e) { throw new OpenSearchParseException(e.getMessage()); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java index f7757ab652..901f05f22a 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -19,6 +20,7 @@ import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -29,6 +31,15 @@ public class RestMLUpdateModelGroupAction extends BaseRestHandler { private static final String ML_UPDATE_MODEL_GROUP_ACTION = "ml_update_model_group_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLUpdateModelGroupAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + @Override public String getName() { return ML_UPDATE_MODEL_GROUP_ACTION; @@ -50,6 +61,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli private MLUpdateModelGroupRequest getRequest(RestRequest request) throws IOException { String modelGroupID = getParameterId(request, PARAMETER_MODEL_GROUP_ID); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); boolean hasContent = request.hasContent(); if (!hasContent) { throw new IOException("Model group request has empty body"); @@ -58,6 +70,7 @@ private MLUpdateModelGroupRequest getRequest(RestRequest request) throws IOExcep ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLUpdateModelGroupInput input = MLUpdateModelGroupInput.parse(parser); input.setModelGroupID(modelGroupID); + input.setTenantId(tenantId); return new MLUpdateModelGroupRequest(input); } diff --git a/plugin/src/main/plugin-metadata/plugin-security.policy b/plugin/src/main/plugin-metadata/plugin-security.policy index 1914fd5eb2..b14d46926d 100644 --- a/plugin/src/main/plugin-metadata/plugin-security.policy +++ b/plugin/src/main/plugin-metadata/plugin-security.policy @@ -30,4 +30,7 @@ grant { // aws credential file access permission java.io.FilePermission "<>", "read"; + + // AWS credentials needed for clients + permission java.io.FilePermission "${user.home}/.aws/-", "read"; }; diff --git a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java index 9e913f98d4..30060e7115 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/action/MLCommonsIntegTestCase.java @@ -393,13 +393,13 @@ public MLTask getTask(String taskId) { } public MLModel getModel(String modelId) { - MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, true); + MLModelGetRequest getRequest = new MLModelGetRequest(modelId, false, true, null); MLModelGetResponse response = client().execute(MLModelGetAction.INSTANCE, getRequest).actionGet(5000); return response.getMlModel(); } public MLModelGroup getModelGroup(String modelGroupId) { - MLModelGroupGetRequest getRequest = new MLModelGroupGetRequest(modelGroupId); + MLModelGroupGetRequest getRequest = new MLModelGroupGetRequest(modelGroupId, null); MLModelGroupGetResponse response = client().execute(MLModelGroupGetAction.INSTANCE, getRequest).actionGet(5000); return response.getMlModelGroup(); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java index 7884c4cbaf..4166e050ee 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java @@ -5,25 +5,19 @@ package org.opensearch.ml.action.connector; +import static org.mockito.ArgumentCaptor.forClass; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.action.DocWriteResponse.Result.DELETED; -import static org.opensearch.action.DocWriteResponse.Result.NOT_FOUND; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; -import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; import java.util.Collections; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import org.apache.lucene.search.TotalHits; -import org.junit.AfterClass; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; @@ -32,20 +26,14 @@ import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; -import org.opensearch.action.LatchedActionListener; -import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; @@ -68,25 +56,13 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ScalingExecutorBuilder; -import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; public class DeleteConnectorTransportActionTests extends OpenSearchTestCase { private static final String CONNECTOR_ID = "connector_id"; - - private static TestThreadPool testThreadPool = new TestThreadPool( - TransportCreateConnectorActionTests.class.getName(), - new ScalingExecutorBuilder( - GENERAL_THREAD_POOL, - 1, - Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), - TimeValue.timeValueMinutes(1), - ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL - ) - ); + DeleteResponse deleteResponse = new DeleteResponse(new ShardId(ML_CONNECTOR_INDEX, "_na_", 0), CONNECTOR_ID, 1, 0, 2, true); @Mock ThreadPool threadPool; @@ -105,9 +81,6 @@ public class DeleteConnectorTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; - @Mock - DeleteResponse deleteResponse; - @Mock NamedXContentRegistry xContentRegistry; @@ -154,115 +127,73 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(threadPool.executor(any())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); } - @AfterClass - public static void cleanup() { - ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); - } - - public void testDeleteConnector_Success() throws InterruptedException { - DeleteResponse deleteResponse = new DeleteResponse(new ShardId(ML_CONNECTOR_INDEX, "_na_", 0), CONNECTOR_ID, 1, 0, 2, true); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(deleteResponse); - when(client.delete(any(DeleteRequest.class))).thenReturn(future); + public void testDeleteConnector_Success() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(5); - listener.onResponse(true); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + }).when(client).delete(any(), any()); SearchResponse searchResponse = getEmptySearchResponse(); - PlainActionFuture searchFuture = PlainActionFuture.newFuture(); - searchFuture.onResponse(searchResponse); - when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); - ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteResponse.class); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); + // Capture and verify the response + ArgumentCaptor captor = forClass(DeleteResponse.class); verify(actionListener).onResponse(captor.capture()); - assertEquals(CONNECTOR_ID, captor.getValue().getId()); - assertEquals(DELETED, captor.getValue().getResult()); + + // Assert the captured response matches the expected values + DeleteResponse actualResponse = captor.getValue(); + assertEquals(deleteResponse.getId(), actualResponse.getId()); + assertEquals(deleteResponse.getIndex(), actualResponse.getIndex()); + assertEquals(deleteResponse.getVersion(), actualResponse.getVersion()); + assertEquals(deleteResponse.getResult(), actualResponse.getResult()); } public void testDeleteConnector_ModelIndexNotFoundSuccess() throws InterruptedException { - DeleteResponse deleteResponse = new DeleteResponse(new ShardId(ML_CONNECTOR_INDEX, "_na_", 0), CONNECTOR_ID, 1, 0, 2, true); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(deleteResponse); - when(client.delete(any(DeleteRequest.class))).thenReturn(future); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(5); - listener.onResponse(true); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); - - PlainActionFuture searchFuture = PlainActionFuture.newFuture(); - searchFuture.onFailure(new IndexNotFoundException("ml_model index not found!")); - when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); - - ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteResponse.class); - verify(actionListener).onResponse(captor.capture()); - assertEquals(CONNECTOR_ID, captor.getValue().getId()); - assertEquals(DELETED, captor.getValue().getResult()); - } - - public void testDeleteConnector_ConnectorNotFound() throws InterruptedException { - DeleteResponse deleteResponse = new DeleteResponse(new ShardId(ML_CONNECTOR_INDEX, "_na_", 0), CONNECTOR_ID, 1, 0, 2, false); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(deleteResponse); - when(client.delete(any(DeleteRequest.class))).thenReturn(future); + }).when(client).delete(any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(5); - listener.onResponse(true); + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IndexNotFoundException("ml_model index not found!")); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); - - SearchResponse searchResponse = getEmptySearchResponse(); - PlainActionFuture searchFuture = PlainActionFuture.newFuture(); - searchFuture.onResponse(searchResponse); - when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + }).when(client).search(any(), any()); - ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteResponse.class); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); + // Capture and verify the response + ArgumentCaptor captor = forClass(DeleteResponse.class); verify(actionListener).onResponse(captor.capture()); - assertEquals(CONNECTOR_ID, captor.getValue().getId()); - assertEquals(NOT_FOUND, captor.getValue().getResult()); + + // Assert the captured response matches the expected values + DeleteResponse actualResponse = captor.getValue(); + assertEquals(deleteResponse.getId(), actualResponse.getId()); + assertEquals(deleteResponse.getIndex(), actualResponse.getIndex()); + assertEquals(deleteResponse.getVersion(), actualResponse.getVersion()); + assertEquals(deleteResponse.getResult(), actualResponse.getResult()); } - public void testDeleteConnector_BlockedByModel() throws IOException, InterruptedException { + public void testDeleteConnector_BlockedByModel() throws IOException { + SearchResponse searchResponse = getNonEmptySearchResponse(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(5); - listener.onResponse(true); + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + }).when(client).search(any(), any()); - SearchResponse searchResponse = getNonEmptySearchResponse(); - PlainActionFuture searchFuture = PlainActionFuture.newFuture(); - searchFuture.onResponse(searchResponse); - when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + ArgumentCaptor argumentCaptor = forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "1 models are still using this connector, please delete or update the models first: [model_ID]", @@ -270,7 +201,7 @@ public void testDeleteConnector_BlockedByModel() throws IOException, Interrupted ); } - public void test_UserHasNoAccessException() throws IOException { + public void test_UserHasNoAccessException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(5); listener.onResponse(false); @@ -278,41 +209,33 @@ public void test_UserHasNoAccessException() throws IOException { }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + ArgumentCaptor argumentCaptor = forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("You are not allowed to delete this connector", argumentCaptor.getValue().getMessage()); } public void testDeleteConnector_SearchFailure() throws InterruptedException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(5); - listener.onResponse(true); + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("Search Failed!")); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); + }).when(client).search(any(), any()); - PlainActionFuture searchFuture = PlainActionFuture.newFuture(); - searchFuture.onFailure(new RuntimeException("Search Failed!")); - when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new ResourceNotFoundException("errorMessage")); + return null; + }).when(client).delete(any(), any()); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Search Failed!", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to search indices [.plugins-ml-model]", argumentCaptor.getValue().getMessage()); } public void testDeleteConnector_SearchException() { when(client.threadPool()).thenThrow(new RuntimeException("Thread Context Error!")); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(5); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -320,30 +243,32 @@ public void testDeleteConnector_SearchException() { } public void testDeleteConnector_ResourceNotFoundException() throws InterruptedException { - when(client.delete(any(DeleteRequest.class))).thenThrow(new ResourceNotFoundException("errorMessage")); - + SearchResponse searchResponse = getEmptySearchResponse(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(5); - listener.onResponse(true); + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); - - SearchResponse searchResponse = getEmptySearchResponse(); - PlainActionFuture searchFuture = PlainActionFuture.newFuture(); - searchFuture.onResponse(searchResponse); - when(client.search(any(SearchRequest.class))).thenReturn(searchFuture); + }).when(client).search(any(), any()); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new ResourceNotFoundException("errorMessage")); + return null; + }).when(client).delete(any(), any()); + deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to delete data object from index .plugins-ml-connector", argumentCaptor.getValue().getMessage()); } - public void test_ValidationFailedException() { + public void test_ValidationFailedException() throws IOException { + GetResponse getResponse = prepareMLConnector(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).search(any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(5); listener.onFailure(new Exception("Failed to validate access")); @@ -351,7 +276,7 @@ public void test_ValidationFailedException() { }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), any()); deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + ArgumentCaptor argumentCaptor = forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } @@ -363,12 +288,9 @@ public void testDeleteConnector_MultiTenancyEnabled_NoTenantId() throws Interrup // Create a request without a tenant ID MLConnectorDeleteRequest requestWithoutTenant = MLConnectorDeleteRequest.builder().connectorId(CONNECTOR_ID).build(); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - deleteConnectorTransportAction.doExecute(null, requestWithoutTenant, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + deleteConnectorTransportAction.doExecute(null, requestWithoutTenant, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + ArgumentCaptor argumentCaptor = forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("You don't have permission to access this resource", argumentCaptor.getValue().getMessage()); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java index 9626b75959..6868639757 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/GetConnectorTransportActionTests.java @@ -6,20 +6,14 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; -import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import java.io.IOException; import java.util.Collections; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -28,15 +22,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.LatchedActionListener; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; @@ -56,8 +45,6 @@ import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ScalingExecutorBuilder; -import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -66,17 +53,6 @@ public class GetConnectorTransportActionTests extends OpenSearchTestCase { private static final String TENANT_ID = "_tenant_id"; - private static final TestThreadPool testThreadPool = new TestThreadPool( - GetConnectorTransportActionTests.class.getName(), - new ScalingExecutorBuilder( - GENERAL_THREAD_POOL, - 1, - Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), - TimeValue.timeValueMinutes(1), - ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL - ) - ); - @Mock ThreadPool threadPool; @@ -135,12 +111,6 @@ public void setup() throws IOException { threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); - } - - @AfterClass - public static void cleanup() { - ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } @Test @@ -153,15 +123,7 @@ public void testGetConnector_UserHasNoAccess() throws IOException, InterruptedEx return null; }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); - GetResponse getResponse = prepareConnector(null); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(getResponse); - when(client.get(any(GetRequest.class))).thenReturn(future); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -182,10 +144,7 @@ public void testGetConnector_NullResponse() throws InterruptedException { return null; }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -207,10 +166,7 @@ public void testGetConnector_MultiTenancyEnabled_Success() throws IOException, I }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), getDataObjectRequestArgumentCaptor.capture(), any(), any()); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); Assert.assertEquals(tenantId, getDataObjectRequestArgumentCaptor.getValue().tenantId()); Assert.assertEquals(CONNECTOR_ID, getDataObjectRequestArgumentCaptor.getValue().id()); @@ -234,10 +190,7 @@ public void testGetConnector_MultiTenancyEnabled_ForbiddenAccess() throws IOExce return null; }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + getConnectorTransportAction.doExecute(null, mlConnectorGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index 9709cbb0ce..030f91fece 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -5,15 +5,12 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; -import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; @@ -23,27 +20,20 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import org.junit.AfterClass; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.LatchedActionListener; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; @@ -66,8 +56,6 @@ import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ScalingExecutorBuilder; -import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -79,17 +67,6 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { private static final String CONNECTOR_ID = "connector_id"; private static final String TENANT_ID = "_tenant_id"; - private static TestThreadPool testThreadPool = new TestThreadPool( - TransportCreateConnectorActionTests.class.getName(), - new ScalingExecutorBuilder( - GENERAL_THREAD_POOL, - 1, - Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), - TimeValue.timeValueMinutes(1), - ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL - ) - ); - private TransportCreateConnectorAction action; @Mock @@ -182,7 +159,6 @@ public void setup() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); List actions = new ArrayList<>(); actions @@ -209,11 +185,6 @@ public void setup() { when(request.getMlCreateConnectorInput()).thenReturn(input); } - @AfterClass - public static void cleanup() { - ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); - } - public void test_execute_connectorAccessControl_notEnabled_success() throws InterruptedException { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(true); input.setAddAllBackendRoles(null); @@ -225,16 +196,22 @@ public void test_execute_connectorAccessControl_notEnabled_success() throws Inte return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(indexResponse); - when(client.index(any(IndexRequest.class))).thenReturn(future); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - action.doExecute(task, request, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + action.doExecute(task, request, actionListener); - verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); + // Capture and verify the response + ArgumentCaptor captor = ArgumentCaptor.forClass(MLCreateConnectorResponse.class); + verify(actionListener).onResponse(captor.capture()); + + // Validate the captured response + MLCreateConnectorResponse actualResponse = captor.getValue(); + assertNotNull(actualResponse); + assertEquals(CONNECTOR_ID, actualResponse.getConnectorId()); } public void test_execute_connector_registration_multi_tenancy_fail() throws InterruptedException { @@ -249,14 +226,7 @@ public void test_execute_connector_registration_multi_tenancy_fail() throws Inte return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(indexResponse); - when(client.index(any(IndexRequest.class))).thenReturn(future); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - action.doExecute(task, request, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + action.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -277,14 +247,7 @@ public void test_execute_connectorAccessControl_notEnabled_withPermissionInfo_ex return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(mock(IndexResponse.class)); - when(client.index(any(IndexRequest.class))).thenReturn(future); - - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - action.doExecute(task, request, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + action.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -305,17 +268,22 @@ public void test_execute_connectorAccessControlEnabled_success() throws Interrup return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(indexResponse); - when(client.index(any(IndexRequest.class))).thenReturn(future); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); + + action.doExecute(task, request, actionListener); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - request.getMlCreateConnectorInput().setTenantId(TENANT_ID); - action.doExecute(task, request, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + // Capture and verify the response + ArgumentCaptor captor = ArgumentCaptor.forClass(MLCreateConnectorResponse.class); + verify(actionListener).onResponse(captor.capture()); - verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); + // Validate the captured response + MLCreateConnectorResponse actualResponse = captor.getValue(); + assertNotNull(actualResponse); + assertEquals(CONNECTOR_ID, actualResponse.getConnectorId()); } public void test_execute_connectorAccessControlEnabled_missingPermissionInfo_defaultToPrivate() throws InterruptedException { @@ -329,16 +297,22 @@ public void test_execute_connectorAccessControlEnabled_missingPermissionInfo_def return null; }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(indexResponse); - when(client.index(any(IndexRequest.class))).thenReturn(future); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); - action.doExecute(task, request, latchedActionListener); - latch.await(500, TimeUnit.MILLISECONDS); + action.doExecute(task, request, actionListener); - verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); + // Capture and verify the response + ArgumentCaptor captor = ArgumentCaptor.forClass(MLCreateConnectorResponse.class); + verify(actionListener).onResponse(captor.capture()); + + // Validate the captured response + MLCreateConnectorResponse actualResponse = captor.getValue(); + assertNotNull(actualResponse); + assertEquals(CONNECTOR_ID, actualResponse.getConnectorId()); } public void test_execute_connectorAccessControlEnabled_adminSpecifyAllBackendRoles_exception() { diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java index b10aff0d5b..2136b9ec56 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/UpdateConnectorTransportActionTests.java @@ -6,7 +6,6 @@ package org.opensearch.ml.action.connector; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.*; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; @@ -16,6 +15,7 @@ import java.nio.file.Path; import java.time.Instant; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.UUID; @@ -27,6 +27,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.DocWriteResponse.Result; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; @@ -42,6 +43,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; @@ -53,7 +55,10 @@ import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -76,6 +81,10 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { @Mock private Client client; + private SdkClient sdkClient; + + @Mock + private NamedXContentRegistry xContentRegistry; @Mock private ThreadPool threadPool; @@ -86,13 +95,15 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { @Mock private TransportService transportService; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock private ActionFilters actionFilters; @Mock private MLUpdateConnectorRequest updateRequest; - @Mock private UpdateResponse updateResponse; @Mock @@ -111,35 +122,40 @@ public class UpdateConnectorTransportActionTests extends OpenSearchTestCase { private MLEngine mlEngine; + private static final String TEST_CONNECTOR_ID = "test_connector_id"; private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); + settings = Settings .builder() .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) .build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED ); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); threadContext = new ThreadContext(settings); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - String connector_id = "test_connector_id"; MLCreateConnectorInput updateContent = MLCreateConnectorInput .builder() .updateConnector(true) .version("2") .description("updated description") .build(); - when(updateRequest.getConnectorId()).thenReturn(connector_id); + when(updateRequest.getConnectorId()).thenReturn(TEST_CONNECTOR_ID); when(updateRequest.getUpdateContent()).thenReturn(updateContent); SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); @@ -162,11 +178,13 @@ public void setup() throws IOException { transportService, actionFilters, client, + sdkClient, connectorAccessControlHelper, mlModelManager, settings, clusterService, - mlEngine + mlEngine, + mlFeatureEnabledSetting ); when(mlModelManager.getAllModelIds()).thenReturn(new String[] {}); @@ -174,7 +192,7 @@ public void setup() throws IOException { updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); Connector connector = HttpConnector .builder() .name("test") @@ -200,60 +218,7 @@ public void setup() throws IOException { // doNothing().when(connector).update(any(), any()); listener.onResponse(connector); return null; - }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); - } - - @Test - public void testUpdateConnectorDoesNotUpdateHttpConnectorTimeFields() { - HttpConnector connector = HttpConnector - .builder() - .name("test") - .protocol("http") - .version("1") - .credential(Map.of("api_key", "credential_value")) - .parameters(Map.of("param1", "value1")) - .actions( - Arrays - .asList( - ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("https://api.openai.com/v1/chat/completions") - .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) - .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") - .build() - ) - ) - .build(); - - assertNull(connector.getCreatedTime()); - assertNull(connector.getLastUpdateTime()); - - doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(connector); - return null; - }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(searchResponse); - return null; - }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(updateResponse); - return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); - - updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); - - assertNull(connector.getCreatedTime()); - assertNotNull(connector.getLastUpdateTime()); + }).when(connectorAccessControlHelper).getConnector(any(), any(), any(), any(), any(), any()); } @Test @@ -315,7 +280,7 @@ public void testUpdateConnectorUpdatesHttpConnectorTimeFields() { } @Test - public void testExecuteConnectorAccessControlSuccess() { + public void testExecuteConnectorAccessControlSuccess() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -331,7 +296,7 @@ public void testExecuteConnectorAccessControlSuccess() { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + verify(actionListener).onResponse(any(UpdateResponse.class)); } @Test @@ -372,7 +337,7 @@ public void testExecuteConnectorAccessControlException() { } @Test - public void testExecuteUpdateWrongStatus() { + public void testExecuteUpdateWrongStatus() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -389,11 +354,15 @@ public void testExecuteUpdateWrongStatus() { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testExecuteUpdateException() { + public void testExecuteUpdateException() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -411,11 +380,11 @@ public void testExecuteUpdateException() { updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("update document failure", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to update data object in index .plugins-ml-connector", argumentCaptor.getValue().getMessage()); } @Test - public void testExecuteSearchResponseNotEmpty() { + public void testExecuteSearchResponseNotEmpty() throws IOException, InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -433,7 +402,7 @@ public void testExecuteSearchResponseNotEmpty() { } @Test - public void testExecuteSearchResponseError() { + public void testExecuteSearchResponseError() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -445,11 +414,11 @@ public void testExecuteSearchResponseError() { updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Error in Search Request", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to search indices [.plugins-ml-model]", argumentCaptor.getValue().getMessage()); } @Test - public void testExecuteSearchIndexNotFoundError() { + public void testExecuteSearchIndexNotFoundError() throws InterruptedException { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { @@ -494,7 +463,10 @@ public void testExecuteSearchIndexNotFoundError() { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(Result.UPDATED, argumentCaptor.getValue().getResult()); } private SearchResponse noneEmptySearchResponse() throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java index 048fa38878..9dc9b6337d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java @@ -5,48 +5,67 @@ package org.opensearch.ml.action.model_group; +import static org.mockito.ArgumentCaptor.forClass; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import java.io.IOException; +import java.util.Collections; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchException; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; public class DeleteModelGroupTransportActionTests extends OpenSearchTestCase { + + private static final String MODEL_GROUP_ID = "model_group_id"; + DeleteResponse deleteResponse = new DeleteResponse(new ShardId(ML_MODEL_GROUP_INDEX, "_na_", 0), MODEL_GROUP_ID, 1, 0, 2, true); + @Mock ThreadPool threadPool; @Mock Client client; + SdkClient sdkClient; + @Mock TransportService transportService; @@ -56,15 +75,15 @@ public class DeleteModelGroupTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; - @Mock - DeleteResponse deleteResponse; - @Mock ClusterService clusterService; @Mock NamedXContentRegistry xContentRegistry; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -79,38 +98,44 @@ public class DeleteModelGroupTransportActionTests extends OpenSearchTestCase { public void setup() throws IOException { MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); mlModelGroupDeleteRequest = MLModelGroupDeleteRequest.builder().modelGroupId("test_id").build(); deleteModelGroupTransportAction = spy( new DeleteModelGroupTransportAction( transportService, actionFilters, client, + sdkClient, xContentRegistry, clusterService, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ) ); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); - Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); } - public void testDeleteModelGroup_Success() throws IOException { + @Test + public void testDeleteModelGroup_Success() throws InterruptedException { + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); return null; }).when(client).delete(any(), any()); - SearchResponse searchResponse = createModelGroupSearchResponse(0); + SearchResponse searchResponse = getEmptySearchResponse(); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); @@ -118,12 +143,23 @@ public void testDeleteModelGroup_Success() throws IOException { }).when(client).search(any(), isA(ActionListener.class)); deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); + + // Capture and verify the response + ArgumentCaptor captor = forClass(DeleteResponse.class); + verify(actionListener).onResponse(captor.capture()); + + // Assert the captured response matches the expected values + DeleteResponse actualResponse = captor.getValue(); + assertEquals(deleteResponse.getId(), actualResponse.getId()); + assertEquals(deleteResponse.getIndex(), actualResponse.getIndex()); + assertEquals(deleteResponse.getVersion(), actualResponse.getVersion()); + assertEquals(deleteResponse.getResult(), actualResponse.getResult()); } - public void test_AssociatedModelsExistException() throws IOException { + @Test + public void test_AssociatedModelsExistException() throws IOException, InterruptedException { - SearchResponse searchResponse = createModelGroupSearchResponse(1); + SearchResponse searchResponse = getNonEmptySearchResponse(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); @@ -134,15 +170,66 @@ public void test_AssociatedModelsExistException() throws IOException { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Cannot delete the model group when it has associated model versions", argumentCaptor.getValue().getMessage()); + } + + @Test + public void test_IndexNotFoundExceptionDuringSearch() throws IOException, InterruptedException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("index_not_found")); + return null; + }).when(client).search(any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); + + // Capture and verify the response + ArgumentCaptor captor = forClass(DeleteResponse.class); + verify(actionListener).onResponse(captor.capture()); + + // Assert the captured response matches the expected values + DeleteResponse actualResponse = captor.getValue(); + assertEquals(deleteResponse.getId(), actualResponse.getId()); + assertEquals(deleteResponse.getIndex(), actualResponse.getIndex()); + assertEquals(deleteResponse.getVersion(), actualResponse.getVersion()); + assertEquals(deleteResponse.getResult(), actualResponse.getResult()); } + @Test + public void test_DeleteRequestInternalServerError() { + SearchResponse searchResponse = getEmptySearchResponse(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(client).search(any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new OpenSearchException("Internal Server Error", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(client).delete(any(), any()); + + deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to delete data object from index .plugins-ml-model-group", argumentCaptor.getValue().getMessage()); + } + + @Test public void test_UserHasNoAccessException() throws IOException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -150,12 +237,13 @@ public void test_UserHasNoAccessException() throws IOException { assertEquals("User doesn't have privilege to delete this model group", argumentCaptor.getValue().getMessage()); } + @Test public void test_ValidationFailedException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -163,39 +251,74 @@ public void test_ValidationFailedException() { assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } - public void testDeleteModelGroup_Failure() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new Exception("errorMessage")); - return null; - }).when(client).delete(any(), any()); - - SearchResponse searchResponse = createModelGroupSearchResponse(0); + public void testDeleteModelGroup_Failure() { + SearchResponse searchResponse = getEmptySearchResponse(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("errorMessage")); + return null; + }).when(client).delete(any(), any()); + deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to delete data object from index .plugins-ml-model-group", argumentCaptor.getValue().getMessage()); + } + + private SearchResponse getEmptySearchResponse() { + SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, true, false, null, 1); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + return searchResponse; } - private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { - SearchResponse searchResponse = mock(SearchResponse.class); + private SearchResponse getNonEmptySearchResponse() throws IOException { + SearchHit[] hits = new SearchHit[1]; String modelContent = "{\n" + " \"created_time\": 1684981986069,\n" - + " \"access\": \"public\",\n" - + " \"latest_version\": 0,\n" + " \"last_updated_time\": 1684981986069,\n" - + " \"name\": \"model_group_IT\",\n" + + " \"_id\": \"model_ID\",\n" + + " \"name\": \"test_model\",\n" + " \"description\": \"This is an example description\"\n" + " }"; - SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); - SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); - when(searchResponse.getHits()).thenReturn(hits); + SearchHit model = SearchHit.fromXContent(TestHelper.parser(modelContent)); + hits[0] = model; + SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchResponseSections searchSections = new SearchResponseSections( + searchHits, + InternalAggregations.EMPTY, + null, + true, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); return searchResponse; } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java index f120f10374..2e80adcd5f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java @@ -12,6 +12,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Collections; import org.junit.Before; import org.junit.Rule; @@ -37,6 +38,9 @@ import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -48,6 +52,8 @@ public class GetModelGroupTransportActionTests extends OpenSearchTestCase { @Mock Client client; + SdkClient sdkClient; + @Mock NamedXContentRegistry xContentRegistry; @@ -73,20 +79,26 @@ public class GetModelGroupTransportActionTests extends OpenSearchTestCase { @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - mlModelGroupGetRequest = MLModelGroupGetRequest.builder().modelGroupId("test_id").build(); Settings settings = Settings.builder().build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + mlModelGroupGetRequest = MLModelGroupGetRequest.builder().modelGroupId("test_id").build(); getModelGroupTransportAction = spy( new GetModelGroupTransportAction( transportService, actionFilters, client, + sdkClient, xContentRegistry, clusterService, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ) ); @@ -176,7 +188,7 @@ public void testGetModel_IndexNotFoundException() { getModelGroupTransportAction.doExecute(null, mlModelGroupGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Fail to find model group index", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to find model group index", argumentCaptor.getValue().getMessage()); } public void testGetModel_RuntimeException() { @@ -188,7 +200,7 @@ public void testGetModel_RuntimeException() { getModelGroupTransportAction.doExecute(null, mlModelGroupGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to get data object from index .plugins-ml-model-group", argumentCaptor.getValue().getMessage()); } public GetResponse prepareMLModelGroup() throws IOException { @@ -203,7 +215,6 @@ public GetResponse prepareMLModelGroup() throws IOException { XContentBuilder content = mlModelGroup.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); - return getResponse; + return new GetResponse(getResult); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/RegisterModelGroupITTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/RegisterModelGroupITTests.java index 040e688101..79d446f516 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/RegisterModelGroupITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/RegisterModelGroupITTests.java @@ -34,7 +34,8 @@ public void test_register_public_model_group() { "mock_model_group_desc", null, AccessMode.PUBLIC, - false + false, + null ); MLRegisterModelGroupRequest createModelGroupRequest = new MLRegisterModelGroupRequest(input); client().execute(MLRegisterModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); @@ -47,14 +48,22 @@ public void test_register_private_model_group() { "mock_model_group_desc", null, AccessMode.PRIVATE, - false + false, + null ); MLRegisterModelGroupRequest createModelGroupRequest = new MLRegisterModelGroupRequest(input); client().execute(MLRegisterModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); } public void test_register_model_group_without_access_fields() { - MLRegisterModelGroupInput input = new MLRegisterModelGroupInput("mock_model_group_name", "mock_model_group_desc", null, null, null); + MLRegisterModelGroupInput input = new MLRegisterModelGroupInput( + "mock_model_group_name", + "mock_model_group_desc", + null, + null, + null, + null + ); MLRegisterModelGroupRequest createModelGroupRequest = new MLRegisterModelGroupRequest(input); client().execute(MLRegisterModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); } @@ -66,7 +75,8 @@ public void test_register_protected_model_group_with_addAllBackendRoles_true() { "mock_model_group_desc", null, AccessMode.RESTRICTED, - true + true, + null ); MLRegisterModelGroupRequest createModelGroupRequest = new MLRegisterModelGroupRequest(input); client().execute(MLRegisterModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); @@ -79,6 +89,7 @@ public void test_register_protected_model_group_with_backendRoles_notEmpty() { "mock_model_group_desc", ImmutableList.of("role-1"), AccessMode.RESTRICTED, + null, null ); MLRegisterModelGroupRequest createModelGroupRequest = new MLRegisterModelGroupRequest(input); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupITTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupITTests.java index b187cc4f8d..2e3a7a84b1 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupITTests.java @@ -34,7 +34,14 @@ public void setUp() throws Exception { } private void registerModelGroup() { - MLRegisterModelGroupInput input = new MLRegisterModelGroupInput("mock_model_group_name", "mock model group desc", null, null, null); + MLRegisterModelGroupInput input = new MLRegisterModelGroupInput( + "mock_model_group_name", + "mock model group desc", + null, + null, + null, + null + ); MLRegisterModelGroupRequest createModelGroupRequest = new MLRegisterModelGroupRequest(input); MLRegisterModelGroupResponse response = client().execute(MLRegisterModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); this.modelGroupId = response.getModelGroupId(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java index f54f034b64..f9ba48cee1 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportRegisterModelGroupActionTests.java @@ -8,12 +8,15 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.junit.Before; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -25,6 +28,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; @@ -32,6 +36,9 @@ import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -58,6 +65,9 @@ public class TransportRegisterModelGroupActionTests extends OpenSearchTestCase { @Mock private Client client; + + SdkClient sdkClient; + @Mock private ActionFilters actionFilters; @@ -76,27 +86,37 @@ public class TransportRegisterModelGroupActionTests extends OpenSearchTestCase { @Mock private MLModelGroupManager mlModelGroupManager; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + private final List backendRoles = Arrays.asList("IT", "HR"); @Before public void setup() { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); transportRegisterModelGroupAction = new TransportRegisterModelGroupAction( transportService, actionFilters, mlIndicesHandler, threadPool, client, + sdkClient, clusterService, modelAccessControlHelper, - mlModelGroupManager + mlModelGroupManager, + mlFeatureEnabledSetting ); assertNotNull(transportRegisterModelGroupAction); } + @Test public void test_Success() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -110,6 +130,7 @@ public void test_Success() { verify(actionListener).onResponse(argumentCaptor.capture()); } + @Test public void test_Failure() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -125,6 +146,18 @@ public void test_Failure() { assertEquals("Failed to init model group index", argumentCaptor.getValue().getMessage()); } + @Test + public void test_TenantIdValidationFailure() { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + MLRegisterModelGroupRequest actionRequest = prepareRequest(null, AccessMode.PUBLIC, null); + transportRegisterModelGroupAction.doExecute(task, actionRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("You don't have permission to access this resource", argumentCaptor.getValue().getMessage()); + } + private MLRegisterModelGroupRequest prepareRequest( List backendRoles, AccessMode modelAccessMode, diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index a99488d0e9..1ca2c8fa67 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -13,6 +13,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.apache.lucene.search.TotalHits; @@ -49,7 +50,10 @@ import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; @@ -76,6 +80,9 @@ public class TransportUpdateModelGroupActionTests extends OpenSearchTestCase { @Mock private Client client; + + SdkClient sdkClient; + @Mock private ActionFilters actionFilters; @@ -97,6 +104,9 @@ public class TransportUpdateModelGroupActionTests extends OpenSearchTestCase { @Mock private MLModelGroupManager mlModelGroupManager; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + private String ownerString = "bob|IT,HR|myTenant"; private List backendRoles = Arrays.asList("IT"); @@ -104,15 +114,18 @@ public class TransportUpdateModelGroupActionTests extends OpenSearchTestCase { public void setup() throws IOException { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); threadContext = new ThreadContext(settings); transportUpdateModelGroupAction = new TransportUpdateModelGroupAction( transportService, actionFilters, client, + sdkClient, xContentRegistry, clusterService, modelAccessControlHelper, - mlModelGroupManager + mlModelGroupManager, + mlFeatureEnabledSetting ); assertNotNull(transportUpdateModelGroupAction); @@ -143,10 +156,10 @@ public void setup() throws IOException { SearchResponse searchResponse = createModelGroupSearchResponse(0); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(searchResponse); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); when(client.threadPool()).thenReturn(threadPool); @@ -351,7 +364,7 @@ public void test_FailedToFindModelGroupException() { transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to find model group", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to get data object from index .plugins-ml-model-group", argumentCaptor.getValue().getMessage()); } public void test_FailedToGetModelGroupException() { @@ -365,7 +378,7 @@ public void test_FailedToGetModelGroupException() { transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to get model group", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to get data object from index .plugins-ml-model-group", argumentCaptor.getValue().getMessage()); } public void test_ModelGroupIndexNotFoundException() { @@ -379,7 +392,7 @@ public void test_ModelGroupIndexNotFoundException() { transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Fail to find model group", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to find model group", argumentCaptor.getValue().getMessage()); } public void test_FailedToUpdatetModelGroupException() { @@ -413,10 +426,10 @@ public void test_ModelGroupNameNotUnique() throws IOException { SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(searchResponse); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); MLUpdateModelGroupRequest actionRequest = prepareRequest(null, null, null); transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/UpdateModelGroupITTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/UpdateModelGroupITTests.java index 19cf5b4bd5..f4ea8f79fd 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/UpdateModelGroupITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/UpdateModelGroupITTests.java @@ -35,7 +35,14 @@ public void setUp() throws Exception { } private void registerModelGroup() { - MLRegisterModelGroupInput input = new MLRegisterModelGroupInput("mock_model_group_name", "mock_model_group_desc", null, null, null); + MLRegisterModelGroupInput input = new MLRegisterModelGroupInput( + "mock_model_group_name", + "mock_model_group_desc", + null, + null, + null, + null + ); MLRegisterModelGroupRequest createModelGroupRequest = new MLRegisterModelGroupRequest(input); MLRegisterModelGroupResponse response = client().execute(MLRegisterModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); this.modelGroupId = response.getModelGroupId(); @@ -49,7 +56,8 @@ public void test_update_public_model_group() { "mock_model_group_desc", null, AccessMode.PUBLIC, - false + false, + null ); MLUpdateModelGroupRequest createModelGroupRequest = new MLUpdateModelGroupRequest(input); client().execute(MLUpdateModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); @@ -63,7 +71,8 @@ public void test_update_private_model_group() { "mock_model_group_desc", null, AccessMode.PRIVATE, - false + false, + null ); MLUpdateModelGroupRequest createModelGroupRequest = new MLUpdateModelGroupRequest(input); client().execute(MLUpdateModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); @@ -76,6 +85,7 @@ public void test_update_model_group_without_access_fields() { "mock_model_group_desc", null, null, + null, null ); MLUpdateModelGroupRequest createModelGroupRequest = new MLUpdateModelGroupRequest(input); @@ -90,7 +100,8 @@ public void test_update_protected_model_group_with_addAllBackendRoles_true() { "mock_model_group_desc", null, AccessMode.RESTRICTED, - true + true, + null ); MLUpdateModelGroupRequest createModelGroupRequest = new MLUpdateModelGroupRequest(input); client().execute(MLUpdateModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); @@ -104,6 +115,7 @@ public void test_update_protected_model_group_with_backendRoles_notEmpty() { "mock_model_group_desc", ImmutableList.of("role-1"), AccessMode.RESTRICTED, + null, null ); MLUpdateModelGroupRequest createModelGroupRequest = new MLUpdateModelGroupRequest(input); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index cdd3c80184..11d00d47b7 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.models; +import static org.mockito.ArgumentCaptor.forClass; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; @@ -23,10 +24,12 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -44,6 +47,8 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -57,17 +62,23 @@ import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; public class DeleteModelTransportActionTests extends OpenSearchTestCase { + @Mock ThreadPool threadPool; @Mock Client client; + SdkClient sdkClient; + @Mock TransportService transportService; @@ -77,7 +88,6 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; - @Mock DeleteResponse deleteResponse; @Mock @@ -95,6 +105,9 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase { @Mock ClusterService clusterService; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + DeleteModelTransportAction deleteModelTransportAction; MLModelDeleteRequest mlModelDeleteRequest; ThreadContext threadContext; @@ -107,18 +120,23 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase { public void setup() throws IOException { MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId("test_id").build(); - Settings settings = Settings.builder().build(); + deleteResponse = new DeleteResponse(new ShardId(new Index(ML_MODEL_INDEX, "_na_"), 1), "taskId", 1, 1, 1, true); + deleteModelTransportAction = spy( new DeleteModelTransportAction( transportService, actionFilters, client, + sdkClient, settings, xContentRegistry, clusterService, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ) ); @@ -134,7 +152,16 @@ public void setup() throws IOException { when(threadPool.getThreadContext()).thenReturn(threadContext); } + @Test public void testDeleteModel_Success() throws IOException { + + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); @@ -148,31 +175,27 @@ public void testDeleteModel_Success() throws IOException { return null; }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); + // Capture and verify the response + ArgumentCaptor captor = forClass(DeleteResponse.class); + verify(actionListener).onResponse(captor.capture()); + + // Assert the captured response matches the expected values + DeleteResponse actualResponse = captor.getValue(); + assertEquals(deleteResponse.getId(), actualResponse.getId()); + assertEquals(deleteResponse.getIndex(), actualResponse.getIndex()); + assertEquals(deleteResponse.getVersion(), actualResponse.getVersion()); + assertEquals(deleteResponse.getResult(), actualResponse.getResult()); } - public void testDeleteRemoteModel_Success() throws IOException { + @Test + public void testDeleteRemoteModel_Success() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); return null; }).when(client).delete(any(), any()); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); - listener.onResponse(response); - return null; - }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -181,10 +204,15 @@ public void testDeleteRemoteModel_Success() throws IOException { }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); + + ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals(deleteResponse.getId(), captor.getValue().getId()); + assertEquals(deleteResponse.getResult(), captor.getValue().getResult()); } - public void testDeleteRemoteModel_deleteModelController_failed() throws IOException { + @Test + public void testDeleteRemoteModel_deleteModelController_failed() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); @@ -195,13 +223,6 @@ public void testDeleteRemoteModel_deleteModelController_failed() throws IOExcept return null; }).when(client).delete(any(), any()); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); - listener.onResponse(response); - return null; - }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -210,12 +231,14 @@ public void testDeleteRemoteModel_deleteModelController_failed() throws IOExcept }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Model is not all cleaned up, please try again. Model ID: test_id", argumentCaptor.getValue().getMessage()); } - public void testDeleteLocalModel_deleteModelController_failed() throws IOException { + @Test + public void testDeleteLocalModel_deleteModelController_failed() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); @@ -241,12 +264,14 @@ public void testDeleteLocalModel_deleteModelController_failed() throws IOExcepti }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Model is not all cleaned up, please try again. Model ID: test_id", argumentCaptor.getValue().getMessage()); } - public void testDeleteRemoteModel_deleteModelChunks_failed() throws IOException { + @Test + public void testDeleteRemoteModel_deleteModelChunks_failed() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); @@ -272,7 +297,8 @@ public void testDeleteRemoteModel_deleteModelChunks_failed() throws IOException assertEquals("Model is not all cleaned up, please try again. Model ID: test_id", argumentCaptor.getValue().getMessage()); } - public void testDeleteHiddenModel_Success() throws IOException { + @Test + public void testDeleteHiddenModel_Success() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); @@ -295,23 +321,21 @@ public void testDeleteHiddenModel_Success() throws IOException { doReturn(true).when(deleteModelTransportAction).isSuperAdminUserWrapper(clusterService, client); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); + + ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals(deleteResponse.getId(), captor.getValue().getId()); + assertEquals(deleteResponse.getResult(), captor.getValue().getResult()); } - public void testDeleteHiddenModel_NoSuperAdminPermission() throws IOException { + @Test + public void testDeleteHiddenModel_NoSuperAdminPermission() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); return null; }).when(client).delete(any(), any()); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); - listener.onResponse(response); - return null; - }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, true); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -321,12 +345,14 @@ public void testDeleteHiddenModel_NoSuperAdminPermission() throws IOException { doReturn(false).when(deleteModelTransportAction).isSuperAdminUserWrapper(clusterService, client); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("User doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); } - public void testDeleteModel_Success_AlgorithmNotNull() throws IOException { + @Test + public void testDeleteModel_Success_AlgorithmNotNull() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); @@ -348,10 +374,15 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException { }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); + + ArgumentCaptor captor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals(deleteResponse.getId(), captor.getValue().getId()); + assertEquals(deleteResponse.getResult(), captor.getValue().getResult()); } - public void test_UserHasNoAccessException() throws IOException { + @Test + public void test_UserHasNoAccessException() throws IOException, InterruptedException { GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID", false); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -366,12 +397,14 @@ public void test_UserHasNoAccessException() throws IOException { }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("User doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); } - public void testDeleteModel_CheckModelState() throws IOException { + @Test + public void testDeleteModel_CheckModelState() throws IOException, InterruptedException { GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING, null, false); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -380,6 +413,7 @@ public void testDeleteModel_CheckModelState() throws IOException { }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -388,7 +422,8 @@ public void testDeleteModel_CheckModelState() throws IOException { ); } - public void testDeleteModel_ModelNotFoundException() throws IOException { + @Test + public void testDeleteModel_ModelNotFoundException() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onFailure(new Exception("Fail to find model")); @@ -396,12 +431,14 @@ public void testDeleteModel_ModelNotFoundException() throws IOException { }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to find model", argumentCaptor.getValue().getMessage()); } - public void testDeleteModel_deleteModelController_ResourceNotFoundException() throws IOException { + @Test + public void testDeleteModel_deleteModelController_ResourceNotFoundException() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(deleteResponse); @@ -427,11 +464,13 @@ public void testDeleteModel_deleteModelController_ResourceNotFoundException() th }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); verify(actionListener, times(1)).onResponse(argumentCaptor.capture()); } - public void test_ValidationFailedException() throws IOException { + @Test + public void test_ValidationFailedException() throws IOException, InterruptedException { GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -446,12 +485,14 @@ public void test_ValidationFailedException() throws IOException { }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } - public void testDeleteRemoteModel_modelNotFound_ResourceNotFoundException() throws IOException { + @Test + public void testDeleteRemoteModel_modelNotFound_ResourceNotFoundException() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new ResourceNotFoundException("resource not found")); @@ -462,13 +503,6 @@ public void testDeleteRemoteModel_modelNotFound_ResourceNotFoundException() thro return null; }).when(client).delete(any(), any()); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); - listener.onResponse(response); - return null; - }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -477,12 +511,14 @@ public void testDeleteRemoteModel_modelNotFound_ResourceNotFoundException() thro }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assert argumentCaptor.getValue().getMessage().equals("Failed to find model"); } - public void testDeleteRemoteModel_modelNotFound_RuntimeException() throws IOException { + @Test + public void testDeleteRemoteModel_modelNotFound_RuntimeException() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("runtime exception")); @@ -508,12 +544,14 @@ public void testDeleteRemoteModel_modelNotFound_RuntimeException() throws IOExce }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assert argumentCaptor.getValue().getMessage().equals("runtime exception"); + assert argumentCaptor.getValue().getMessage().equals("Failed to delete data object from index .plugins-ml-model"); } - public void testModelNotFound_modelChunks_modelController_delete_success() throws IOException { + @Test + public void testModelNotFound_modelChunks_modelController_delete_success() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(null); @@ -533,11 +571,13 @@ public void testModelNotFound_modelChunks_modelController_delete_success() throw return null; }).when(client).execute(any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to find model", argumentCaptor.getValue().getMessage()); } + @Test public void testDeleteModelChunks_Success() { when(bulkByScrollResponse.getBulkFailures()).thenReturn(null); doAnswer(invocation -> { @@ -559,6 +599,7 @@ public void testDeleteModel_ThreadContextError() { assertEquals("thread context error", argumentCaptor.getValue().getMessage()); } + @Test public void test_FailToDeleteModel() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -572,6 +613,7 @@ public void test_FailToDeleteModel() { assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); } + @Test public void test_FailToDeleteAllModelChunks() { BulkItemResponse.Failure failure = new BulkItemResponse.Failure(ML_MODEL_INDEX, "test_id", new RuntimeException("Error!")); when(bulkByScrollResponse.getBulkFailures()).thenReturn(Arrays.asList(failure)); @@ -587,6 +629,7 @@ public void test_FailToDeleteAllModelChunks() { assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + BULK_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage()); } + @Test public void test_FailToDeleteAllModelChunks_TimeOut() { BulkItemResponse.Failure failure = new BulkItemResponse.Failure(ML_MODEL_INDEX, "test_id", new RuntimeException("Error!")); when(bulkByScrollResponse.getBulkFailures()).thenReturn(Arrays.asList(failure)); @@ -603,6 +646,7 @@ public void test_FailToDeleteAllModelChunks_TimeOut() { assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + TIMEOUT_MSG + "test_id", argumentCaptor.getValue().getMessage()); } + @Test public void test_FailToDeleteAllModelChunks_SearchFailure() { ScrollableHitSource.SearchFailure searchFailure = new ScrollableHitSource.SearchFailure( new RuntimeException("error"), @@ -632,6 +676,7 @@ public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID .modelState(mlModelState) .modelGroupId(modelGroupID) .isHidden(isHidden) + .algorithm(FunctionName.TEXT_EMBEDDING) .build(); return buildResponse(mlModel); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java index 62ed498e4f..160e93fed4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelITTests.java @@ -5,10 +5,10 @@ package org.opensearch.ml.action.models; +import org.opensearch.OpenSearchStatusException; import org.opensearch.OpenSearchTimeoutException; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.ml.action.MLCommonsIntegTestCase; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.test.OpenSearchIntegTestCase; @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 2) @@ -17,7 +17,7 @@ public class GetModelITTests extends MLCommonsIntegTestCase { private static final int MAX_RETRIES = 3; public void testGetModel_IndexNotFound() { - testGetModelExceptionsWithRetry(MLResourceNotFoundException.class, "test_id"); + testGetModelExceptionsWithRetry(OpenSearchStatusException.class, "test_id"); } public void testGetModel_NullModelIdException() { diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java index 8a95805aa8..53da052ba3 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java @@ -13,6 +13,7 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.util.Collections; import org.junit.Before; import org.junit.Rule; @@ -40,11 +41,15 @@ import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; public class GetModelTransportActionTests extends OpenSearchTestCase { + @Mock ThreadPool threadPool; @@ -60,6 +65,8 @@ public class GetModelTransportActionTests extends OpenSearchTestCase { @Mock ActionFilters actionFilters; + SdkClient sdkClient; + @Mock ActionListener actionListener; @@ -78,21 +85,28 @@ public class GetModelTransportActionTests extends OpenSearchTestCase { @Mock private ModelAccessControlHelper modelAccessControlHelper; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlModelGetRequest = MLModelGetRequest.builder().modelId("test_id").build(); settings = Settings.builder().build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); getModelTransportAction = spy( new GetModelTransportAction( transportService, actionFilters, client, + sdkClient, settings, xContentRegistry, clusterService, - modelAccessControlHelper + modelAccessControlHelper, + mlFeatureEnabledSetting ) ); @@ -108,13 +122,7 @@ public void setup() throws IOException { when(threadPool.getThreadContext()).thenReturn(threadContext); } - public void testGetModel_UserHasNodeAccess() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(false); - return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - + public void testGetModel_UserHasNodeAccess() throws IOException, InterruptedException { GetResponse getResponse = prepareMLModel(false); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -122,13 +130,20 @@ public void testGetModel_UserHasNodeAccess() throws IOException { return null; }).when(client).get(any(), any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("User doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); } - public void testGetModel_Success() throws IOException { + public void testGetModel_Success() throws IOException, InterruptedException { GetResponse getResponse = prepareMLModel(false); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -137,23 +152,28 @@ public void testGetModel_Success() throws IOException { }).when(client).get(any(), any()); getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); - verify(actionListener).onResponse(any(MLModelGetResponse.class)); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelGetResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); } - public void testGetModelHidden_Success() throws IOException { + public void testGetModelHidden_Success() throws IOException, InterruptedException { GetResponse getResponse = prepareMLModel(true); - mlModelGetRequest = MLModelGetRequest.builder().modelId("test_id").isUserInitiatedGetRequest(true).build(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(getResponse); return null; }).when(client).get(any(), any()); + doReturn(true).when(getModelTransportAction).isSuperAdminUserWrapper(clusterService, client); + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); - verify(actionListener).onResponse(any(MLModelGetResponse.class)); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelGetResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); } - public void testGetModelHidden_SuperUserPermissionError() throws IOException { + public void testGetModelHidden_SuperUserPermissionError() throws IOException, InterruptedException { GetResponse getResponse = prepareMLModel(true); mlModelGetRequest = MLModelGetRequest.builder().modelId("test_id").isUserInitiatedGetRequest(true).build(); doAnswer(invocation -> { @@ -161,14 +181,17 @@ public void testGetModelHidden_SuperUserPermissionError() throws IOException { listener.onResponse(getResponse); return null; }).when(client).get(any(), any()); + doReturn(false).when(getModelTransportAction).isSuperAdminUserWrapper(clusterService, client); + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("User doesn't have privilege to perform this operation on this model", argumentCaptor.getValue().getMessage()); } - public void testGetModel_ValidateAccessFailed() throws IOException { + public void testGetModel_ValidateAccessFailed() throws IOException, InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); @@ -183,6 +206,7 @@ public void testGetModel_ValidateAccessFailed() throws IOException { }).when(client).get(any(), any()); getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); @@ -194,19 +218,23 @@ public void testGetModel_NullResponse() { listener.onResponse(null); return null; }).when(client).get(any(), any()); + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to find model with the provided model id: test_id", argumentCaptor.getValue().getMessage()); } - public void testGetModel_IndexNotFoundException() { + public void testGetModel_IndexNotFoundException() throws InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new IndexNotFoundException("Fail to find model")); return null; }).when(client).get(any(), any()); + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Fail to find model", argumentCaptor.getValue().getMessage()); @@ -218,10 +246,12 @@ public void testGetModel_RuntimeException() { listener.onFailure(new RuntimeException("errorMessage")); return null; }).when(client).get(any(), any()); + getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to get data object from index .plugins-ml-model", argumentCaptor.getValue().getMessage()); } public GetResponse prepareMLModel(boolean isHidden) throws IOException { @@ -235,7 +265,6 @@ public GetResponse prepareMLModel(boolean isHidden) throws IOException { XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); - return getResponse; + return new GetResponse(getResult); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java index 5b2eae3f4a..9bbf98e139 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java @@ -49,7 +49,14 @@ public void setUp() throws Exception { } private void registerModelGroup() throws InterruptedException { - MLRegisterModelGroupInput input = new MLRegisterModelGroupInput("mock_model_group_name", "mock model group desc", null, null, null); + MLRegisterModelGroupInput input = new MLRegisterModelGroupInput( + "mock_model_group_name", + "mock model group desc", + null, + null, + null, + null + ); MLRegisterModelGroupRequest createModelGroupRequest = new MLRegisterModelGroupRequest(input); MLRegisterModelGroupResponse response = client().execute(MLRegisterModelGroupAction.INSTANCE, createModelGroupRequest).actionGet(); this.modelGroupId = response.getModelGroupId(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 942a968cf0..fa9cad2dd3 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.action.models; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; @@ -15,6 +16,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; @@ -25,6 +27,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import org.junit.Before; @@ -39,8 +42,10 @@ import org.opensearch.Version; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.FailedNodeException; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; @@ -58,6 +63,7 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.get.GetResult; @@ -80,6 +86,9 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -88,11 +97,13 @@ import com.google.common.collect.ImmutableList; public class UpdateModelTransportActionTests extends OpenSearchTestCase { + @Mock ThreadPool threadPool; @Mock Client client; + private SdkClient sdkClient; @Mock Task task; @@ -122,6 +133,8 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { MLModelGroupManager mlModelGroupManager; @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + private ModelAccessControlHelper modelAccessControlHelper; @Mock @@ -162,11 +175,20 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { @Mock MLEngine mlEngine; + @Mock + NamedXContentRegistry xContentRegistry; + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList.of("^https://api\\.test\\.com/.*$"); @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); + settings = Settings + .builder() + .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) + .put(ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED.getKey(), true) + .build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); updateLocalModelInput = MLUpdateModelInput .builder() .modelId("test_model_id") @@ -185,12 +207,11 @@ public void setup() throws IOException { .modelState(MLModelState.REGISTERED) .build(); - settings = Settings - .builder() - .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) - .build(); - - ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX); + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, + ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED + ); InetAddress inetAddress1 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 1 }); InetAddress inetAddress2 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 2 }); @@ -229,18 +250,22 @@ public void setup() throws IOException { shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + modelAccessControlHelper = spy(new ModelAccessControlHelper(clusterService, settings)); + transportUpdateModelAction = spy( new UpdateModelTransportAction( transportService, actionFilters, client, + sdkClient, connectorAccessControlHelper, modelAccessControlHelper, mlModelManager, mlModelGroupManager, settings, clusterService, - mlEngine + mlEngine, + mlFeatureEnabledSetting ) ); @@ -266,12 +291,30 @@ public void setup() throws IOException { ) .build(); + // TODO eventually remove if migrated to sdkClient doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("test_model_group_id"), any(), isA(ActionListener.class)); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(6); + listener.onResponse(true); + return null; + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess( + any(), + any(), + any(), + eq("test_model_group_id"), + any(), + any(SdkClient.class), + isA(ActionListener.class) + ); + + // TODO eventually remove if migrated to sdkClient doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(true); @@ -281,12 +324,26 @@ public void setup() throws IOException { .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; }) - .when(connectorAccessControlHelper) - .validateConnectorAccess(any(Client.class), eq("updated_test_connector_id"), isA(ActionListener.class)); + .when(modelAccessControlHelper) + .validateModelGroupAccess( + any(), + any(), + any(), + eq("updated_test_model_group_id"), + any(), + any(SdkClient.class), + isA(ActionListener.class) + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(5); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -294,11 +351,22 @@ public void setup() throws IOException { return null; }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(updateResponse); + when(client.update(any(UpdateRequest.class))).thenReturn(future); + + // TODO eventually remove if migrated to sdkClient doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(localModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onResponse(localModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); MLModelGroup modelGroup = MLModelGroup .builder() @@ -314,92 +382,142 @@ public void setup() throws IOException { GetResponse getResponse = prepareGetResponse(modelGroup); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(getResponse); return null; - }).when(mlModelGroupManager).getModelGroupResponse(eq("updated_test_model_group_id"), isA(ActionListener.class)); + }).when(mlModelGroupManager).getModelGroupResponse(eq(sdkClient), eq("updated_test_model_group_id"), isA(ActionListener.class)); } @Test - public void testUpdateLocalModelSuccess() { - transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + public void testUpdateLocalModelSuccess() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelWithoutRegisterToNewModelGroupSuccess() { + public void testUpdateModelWithoutRegisterToNewModelGroupSuccess() throws InterruptedException { updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelWithRegisterToSameModelGroupSuccess() { + public void testUpdateModelWithRegisterToSameModelGroupSuccess() throws InterruptedException { updateLocalModelRequest.getUpdateModelInput().setModelGroupId("test_model_group_id"); - transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateRemoteModelWithLocalInformationSuccess() { + public void testUpdateRemoteModelWithLocalInformationSuccess() throws InterruptedException { MLModel remoteModel = prepareMLModel("REMOTE_EXTERNAL"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); - transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateExternalRemoteModelWithExternalRemoteInformationSuccess() { + public void testUpdateExternalRemoteModelWithExternalRemoteInformationSuccess() throws InterruptedException { MLModel remoteModel = prepareMLModel("REMOTE_EXTERNAL"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); - transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateInternalRemoteModelWithInternalRemoteInformationSuccess() { + public void testUpdateInternalRemoteModelWithInternalRemoteInformationSuccess() throws InterruptedException { MLModel remoteModel = prepareMLModel("REMOTE_INTERNAL"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); - transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), actionListener); - verify(actionListener).onResponse(updateResponse); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateHiddenRemoteModelWithRemoteInformationSuccess() { + public void testUpdateHiddenRemoteModelWithRemoteInformationSuccess() throws InterruptedException { MLModel remoteModel = prepareMLModel("REMOTE_INTERNAL", true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doReturn(true).when(transportUpdateModelAction).isSuperAdminUserWrapper(clusterService, client); - transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test public void testUpdateHiddenRemoteModelPermissionError() { MLModel remoteModel = prepareMLModel("REMOTE_INTERNAL", true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doReturn(false).when(transportUpdateModelAction).isSuperAdminUserWrapper(clusterService, client); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_INTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -411,10 +529,10 @@ public void testUpdateHiddenRemoteModelPermissionError() { public void testUpdateRemoteModelWithNoExternalConnectorFound() { MLModel remoteModelWithInternalConnector = prepareUnsupportedMLModel(FunctionName.REMOTE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModelWithInternalConnector); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -429,16 +547,16 @@ public void testUpdateRemoteModelWithNoExternalConnectorFound() { public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlNoPermission() { MLModel remoteModel = prepareMLModel("REMOTE_EXTERNAL"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(false); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -453,19 +571,19 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControlOtherException() { MLModel remoteModel = prepareMLModel("REMOTE_EXTERNAL"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(remoteModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener .onFailure( new RuntimeException("Any other connector access control Exception occurred. Please check log for more details.") ); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -477,14 +595,20 @@ public void testUpdateRemoteModelWithRemoteInformationWithConnectorAccessControl } @Test - public void testUpdateModelWithModelAccessControlNoPermission() { + public void testUpdateModelWithModelAccessControlNoPermission() throws InterruptedException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), any(), any(), any(), any(), any(SdkClient.class), isA(ActionListener.class)); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -561,10 +685,10 @@ public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlOtherExc @Test public void testUpdateModelWithRegisterToNewModelGroupNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new MLResourceNotFoundException("Model group not found with MODEL_GROUP_ID: updated_test_model_group_id")); return null; - }).when(mlModelGroupManager).getModelGroupResponse(eq("updated_test_model_group_id"), isA(ActionListener.class)); + }).when(mlModelGroupManager).getModelGroupResponse(eq(sdkClient), eq("updated_test_model_group_id"), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -578,10 +702,10 @@ public void testUpdateModelWithRegisterToNewModelGroupNotFound() { @Test public void testUpdateModelWithModelNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(null); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -592,10 +716,10 @@ public void testUpdateModelWithModelNotFound() { @Test public void testUpdateModelWithFunctionNameFieldNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mlModelWithNullFunctionName); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -622,10 +746,10 @@ public void testUpdateLocalModelWithInternalRemoteInformation() { public void testUpdateLocalModelWithUnsupportedFunction() { MLModel localModelWithUnsupportedFunction = prepareUnsupportedMLModel(FunctionName.KMEANS); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(localModelWithUnsupportedFunction); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, prepareRemoteRequest("REMOTE_EXTERNAL"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -634,37 +758,42 @@ public void testUpdateLocalModelWithUnsupportedFunction() { } @Test - public void testUpdateRequestDocIOException() throws IOException { + public void testUpdateRequestDocIOException() throws IOException, InterruptedException { doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); doReturn("mockId").when(mockUpdateModelInput).getModelId(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); doReturn(MLModelState.REGISTERED).when(mockModel).getModelState(); doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); - transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Exception occurred during building update request.", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to parse data object to update in index .plugins-ml-model", argumentCaptor.getValue().getMessage()); } @Test - public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IOException { + public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IOException, InterruptedException { doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); doReturn("mockId").when(mockUpdateModelInput).getModelId(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), any(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); @@ -673,10 +802,12 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO doReturn("mockUpdateModelGroupId").when(mockUpdateModelInput).getModelGroupId(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(46); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("mockUpdateModelGroupId"), any(), isA(ActionListener.class)); + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), any(), any(), eq("mockUpdateModelGroupId"), any(), eq(sdkClient), isA(ActionListener.class)); MLModelGroup modelGroup = MLModelGroup .builder() @@ -692,20 +823,25 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO GetResponse getResponse = prepareGetResponse(modelGroup); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(getResponse); return null; - }).when(mlModelGroupManager).getModelGroupResponse(eq("mockUpdateModelGroupId"), isA(ActionListener.class)); + }).when(mlModelGroupManager).getModelGroupResponse(eq(sdkClient), eq("mockUpdateModelGroupId"), isA(ActionListener.class)); doThrow(new IOException("Exception occurred during building update request.")).when(mockUpdateModelInput).toXContent(any(), any()); - transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IOException.class); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Exception occurred during building update request.", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to parse data object to update in index .plugins-ml-model", argumentCaptor.getValue().getMessage()); } @Test - public void testGetUpdateResponseListenerWithVersionBumpWrongStatus() { + public void testGetUpdateResponseListenerWithVersionBumpWrongStatus() throws InterruptedException { UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -714,11 +850,15 @@ public void testGetUpdateResponseListenerWithVersionBumpWrongStatus() { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); - verify(actionListener).onResponse(updateWrongResponse); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateWrongResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateWrongResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testGetUpdateResponseListenerWithVersionBumpOtherException() { + public void testGetUpdateResponseListenerWithVersionBumpOtherException() throws InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener @@ -731,16 +871,14 @@ public void testGetUpdateResponseListenerWithVersionBumpOtherException() { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", - argumentCaptor.getValue().getMessage() - ); + assertEquals("Failed to update data object in index .plugins-ml-model-group", argumentCaptor.getValue().getMessage()); } @Test - public void testGetUpdateResponseListenerWithNullUpdateResponse() { + public void testGetUpdateResponseListenerWithNullUpdateResponse() throws InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(null); @@ -748,13 +886,14 @@ public void testGetUpdateResponseListenerWithNullUpdateResponse() { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to update ML model: test_model_id", argumentCaptor.getValue().getMessage()); } @Test - public void testGetUpdateResponseListenerWrongStatus() { + public void testGetUpdateResponseListenerWrongStatus() throws InterruptedException { UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -763,11 +902,15 @@ public void testGetUpdateResponseListenerWrongStatus() { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); - verify(actionListener).onResponse(updateWrongResponse); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateWrongResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateWrongResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testGetUpdateResponseListenerOtherException() { + public void testGetUpdateResponseListenerOtherException() throws InterruptedException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener @@ -780,22 +923,20 @@ public void testGetUpdateResponseListenerOtherException() { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); updateLocalModelRequest.getUpdateModelInput().setModelGroupId(null); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", - argumentCaptor.getValue().getMessage() - ); + assertEquals("Failed to update data object in index .plugins-ml-model", argumentCaptor.getValue().getMessage()); } @Test public void testUpdateModelStateDeployingException() { MLModel testDeployingModel = prepareMLModel("TEXT_EMBEDDING", MLModelState.DEPLOYING); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testDeployingModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); @@ -810,10 +951,10 @@ public void testUpdateModelStateDeployingException() { public void testUpdateModelStateLoadingException() { MLModel testDeployingModel = prepareMLModel("TEXT_EMBEDDING", MLModelState.LOADING); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testDeployingModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); @@ -825,13 +966,13 @@ public void testUpdateModelStateLoadingException() { } @Test - public void testUpdateModelCacheModelStateDeployedSuccess() { + public void testUpdateModelCacheModelStateDeployedSuccess() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -841,18 +982,26 @@ public void testUpdateModelCacheModelStateDeployedSuccess() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelCacheModelWithIsModelEnabledSuccess() { + public void testUpdateModelCacheModelWithIsModelEnabledSuccess() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -864,18 +1013,26 @@ public void testUpdateModelCacheModelWithIsModelEnabledSuccess() { testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); testUpdateModelCacheRequest.getUpdateModelInput().setConnector(null); testUpdateModelCacheRequest.getUpdateModelInput().setIsEnabled(true); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelCacheModelWithoutUpdateConnectorWithRateLimiterSuccess() { + public void testUpdateModelCacheModelWithoutUpdateConnectorWithRateLimiterSuccess() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -888,18 +1045,26 @@ public void testUpdateModelCacheModelWithoutUpdateConnectorWithRateLimiterSucces testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); testUpdateModelCacheRequest.getUpdateModelInput().setRateLimiter(rateLimiter); testUpdateModelCacheRequest.getUpdateModelInput().setConnector(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelCacheModelWithRateLimiterSuccess() { + public void testUpdateModelCacheModelWithRateLimiterSuccess() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -911,36 +1076,52 @@ public void testUpdateModelCacheModelWithRateLimiterSuccess() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); testUpdateModelCacheRequest.getUpdateModelInput().setRateLimiter(rateLimiter); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelWithPartialRateLimiterSuccess() { + public void testUpdateModelWithPartialRateLimiterSuccess() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").build(); MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); testUpdateModelCacheRequest.getUpdateModelInput().setRateLimiter(rateLimiter); testUpdateModelCacheRequest.getUpdateModelInput().setConnector(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelCacheModelWithPartialRateLimiterSuccess() { + public void testUpdateModelCacheModelWithPartialRateLimiterSuccess() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -952,8 +1133,16 @@ public void testUpdateModelCacheModelWithPartialRateLimiterSuccess() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); testUpdateModelCacheRequest.getUpdateModelInput().setRateLimiter(rateLimiter); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test @@ -984,17 +1173,17 @@ public void testUpdateModelCacheUpdateResponseListenerWithNullUpdateResponse() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to update ML model: test_model_id", argumentCaptor.getValue().getMessage()); + assertEquals("Trying to update the connector or connector_id field on a local model.", argumentCaptor.getValue().getMessage()); } @Test - public void testUpdateModelCacheModelWithUndeploySuccessEmptyFailures() { + public void testUpdateModelCacheModelWithUndeploySuccessEmptyFailures() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1005,21 +1194,29 @@ public void testUpdateModelCacheModelWithUndeploySuccessEmptyFailures() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateControllerWithUndeploySuccessPartiallyFailures() { + public void testUpdateControllerWithUndeploySuccessPartiallyFailures() throws InterruptedException { List failures = List .of(new FailedNodeException("foo1", "Undeploy failed.", new RuntimeException("Exception occurred."))); MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1030,7 +1227,11 @@ public void testUpdateControllerWithUndeploySuccessPartiallyFailures() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -1041,13 +1242,13 @@ public void testUpdateControllerWithUndeploySuccessPartiallyFailures() { } @Test - public void testUpdateControllerWithUndeployNullResponse() { + public void testUpdateControllerWithUndeployNullResponse() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1057,7 +1258,12 @@ public void testUpdateControllerWithUndeployNullResponse() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -1067,13 +1273,13 @@ public void testUpdateControllerWithUndeployNullResponse() { } @Test - public void testUpdateControllerWithUndeployOtherException() { + public void testUpdateControllerWithUndeployOtherException() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -1083,7 +1289,11 @@ public void testUpdateControllerWithUndeployOtherException() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -1091,8 +1301,11 @@ public void testUpdateControllerWithUndeployOtherException() { } @Test - public void testUpdateModelCacheModelStateDeployedWrongStatus() { + public void testUpdateModelCacheModelStateDeployedWrongStatus() throws InterruptedException { + // Prepare a deployed MLModel with REMOTE_INTERNAL configuration MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + + // Mock the update response with a CREATED status (simulating the wrong status) UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -1100,32 +1313,45 @@ public void testUpdateModelCacheModelStateDeployedWrongStatus() { return null; }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + // Mock fetching the model successfully doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); + // Mock cache update response doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); listener.onResponse(updateModelCacheNodesResponse); return null; }).when(client).execute(any(), any(), isA(ActionListener.class)); + // Prepare the update request MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + + // Execute the transport action transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateWrongResponse); + + // Verify that onFailure was NOT invoked and onResponse was called instead + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + + // Validate the captured response + UpdateResponse capturedResponse = argumentCaptor.getValue(); + assertEquals(updateWrongResponse.getId(), capturedResponse.getId()); + assertEquals(updateWrongResponse.getResult(), capturedResponse.getResult()); } @Test - public void testUpdateModelCacheModelStateDeployedUpdateModelCacheException() { + public void testUpdateModelCacheModelStateDeployedUpdateModelCacheException() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1140,44 +1366,12 @@ public void testUpdateModelCacheModelStateDeployedUpdateModelCacheException() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", - argumentCaptor.getValue().getMessage() - ); - } - - @Test - public void testUpdateModelCacheModelStateDeployedUpdateException() { - MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener - .onFailure( - new RuntimeException( - "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details." - ) - ); - return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(testUpdateModelCacheModel); - return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(updateModelCacheNodesResponse); - return null; - }).when(client).execute(any(), any(), isA(ActionListener.class)); - - MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); - testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -1187,13 +1381,13 @@ public void testUpdateModelCacheModelStateDeployedUpdateException() { } @Test - public void testUpdateModelCacheModelRegisterToNewModelGroupSuccess() { + public void testUpdateModelCacheModelRegisterToNewModelGroupSuccess() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1202,12 +1396,20 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupSuccess() { }).when(client).execute(any(), any(), isA(ActionListener.class)); MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelCacheModelRegisterToNewModelGroupWrongStatus() { + public void testUpdateModelCacheModelRegisterToNewModelGroupWrongStatus() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); doAnswer(invocation -> { @@ -1217,10 +1419,10 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupWrongStatus() { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1229,18 +1431,23 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupWrongStatus() { }).when(client).execute(any(), any(), isA(ActionListener.class)); MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateWrongResponse); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateWrongResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateWrongResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateModelCacheException() { + public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateModelCacheException() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1254,7 +1461,12 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateModelCacheExce }).when(client).execute(any(), any(), isA(ActionListener.class)); MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -1264,7 +1476,7 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateModelCacheExce } @Test - public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateException() { + public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateException() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -1291,22 +1503,20 @@ public void testUpdateModelCacheModelRegisterToNewModelGroupUpdateException() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "Any other Exception occurred during running getUpdateResponseListener. Please check log for more details.", - argumentCaptor.getValue().getMessage() - ); + assertEquals("Trying to update the connector or connector_id field on a local model.", argumentCaptor.getValue().getMessage()); } @Test - public void testUpdateModelCacheModelStateLoadedSuccess() { + public void testUpdateModelCacheModelStateLoadedSuccess() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.LOADED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1316,18 +1526,26 @@ public void testUpdateModelCacheModelStateLoadedSuccess() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelCacheModelStatePartiallyDeployedSuccess() { + public void testUpdateModelCacheModelStatePartiallyDeployedSuccess() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.PARTIALLY_DEPLOYED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1337,18 +1555,26 @@ public void testUpdateModelCacheModelStatePartiallyDeployedSuccess() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } @Test - public void testUpdateModelCacheModelStatePartiallyLoadedSuccess() { + public void testUpdateModelCacheModelStatePartiallyLoadedSuccess() throws InterruptedException { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.PARTIALLY_LOADED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(testUpdateModelCacheModel); return null; - }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -1358,8 +1584,16 @@ public void testUpdateModelCacheModelStatePartiallyLoadedSuccess() { MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); - transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertEquals(updateResponse.getId(), argumentCaptor.getValue().getId()); + assertEquals(updateResponse.getResult(), argumentCaptor.getValue().getResult()); } // TODO: Add UT to make sure that version incremented successfully. @@ -1382,6 +1616,7 @@ private MLModel prepareMLModel(String functionName, MLModelState modelState, boo mlModel = MLModel .builder() .name("test_name") + .tenantId("_tenant_id") .modelId("test_model_id") .modelGroupId("test_model_group_id") .description("test_description") @@ -1499,10 +1734,10 @@ public void testUpdateModelStatePartiallyLoadedException() { doReturn("mockId").when(mockUpdateModelInput).getModelId(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), anyString(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); @@ -1524,10 +1759,10 @@ public void testUpdateModelStatePartiallyDeployedException() { doReturn("mockId").when(mockUpdateModelInput).getModelId(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(4); listener.onResponse(mockModel); return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); + }).when(mlModelManager).getModel(eq("mockId"), anyString(), any(), any(), isA(ActionListener.class)); doReturn("test_model_group_id").when(mockModel).getModelGroupId(); doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index d936f199a2..6a9856ae44 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -21,12 +21,14 @@ import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -43,6 +45,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTask; @@ -68,6 +71,8 @@ import org.opensearch.ml.task.MLTaskDispatcher; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.tasks.Task; @@ -78,6 +83,7 @@ import com.google.common.collect.ImmutableList; public class TransportRegisterModelActionTests extends OpenSearchTestCase { + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -110,6 +116,8 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { @Mock private Client client; + private SdkClient sdkClient; + @Mock private DiscoveryNodeHelper nodeFilter; @@ -158,6 +166,7 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); settings = Settings .builder() .put(ML_COMMONS_TRUSTED_URL_REGEX.getKey(), trustedUrlRegex) @@ -184,6 +193,7 @@ public void setup() throws IOException { settings, threadPool, client, + sdkClient, nodeFilter, mlTaskDispatcher, mlStats, @@ -195,10 +205,10 @@ public void setup() throws IOException { assertNotNull(transportRegisterModelAction); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); MLStat mlStat = mock(MLStat.class); when(mlStats.getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT))).thenReturn(mlStat); @@ -217,10 +227,10 @@ public void setup() throws IOException { SearchResponse searchResponse = createModelGroupSearchResponse(0); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(searchResponse); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(true); @@ -228,12 +238,13 @@ public void setup() throws IOException { when(node2.getId()).thenReturn("node2Id"); doAnswer(invocation -> { return null; }).when(mlModelManager).registerMLModel(any(), any()); - doAnswer(invocation -> { return null; }).when(mlModelManager).registerMLRemoteModel(any(), any(), any()); + doAnswer(invocation -> { return null; }).when(mlModelManager).registerMLRemoteModel(any(), any(), any(), any()); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); } + @Test public void testDoExecute_LocalModelDisabledException() { when(mlFeatureEnabledSetting.isLocalModelEnabled()).thenReturn(false); @@ -270,12 +281,13 @@ public void testDoExecute_LocalModelDisabledException() { ); } + @Test public void testDoExecute_userHasNoAccessException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -283,6 +295,7 @@ public void testDoExecute_userHasNoAccessException() { assertEquals("You don't have permissions to perform this operation on this model.", argumentCaptor.getValue().getMessage()); } + @Test public void testDoExecute_successWithLocalNodeEqualToClusterNode() { when(node1.getId()).thenReturn("NodeId1"); when(node2.getId()).thenReturn("NodeId1"); @@ -298,6 +311,7 @@ public void testDoExecute_successWithLocalNodeEqualToClusterNode() { verify(actionListener).onResponse(argumentCaptor.capture()); } + @Test public void testDoExecute_successWithCreateModelGroup() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -320,6 +334,7 @@ public void testDoExecute_successWithCreateModelGroup() { verify(actionListener).onResponse(argumentCaptor.capture()); } + @Test public void testDoExecute_failureWithCreateModelGroup() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -334,6 +349,7 @@ public void testDoExecute_failureWithCreateModelGroup() { assertEquals("Failed to create Model Group", argumentCaptor.getValue().getMessage()); } + @Test public void testDoExecute_invalidURL() { transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -341,6 +357,7 @@ public void testDoExecute_invalidURL() { assertEquals("URL can't match trusted url regex", argumentCaptor.getValue().getMessage()); } + @Test public void testRegisterModelUrlNotAllowed() throws Exception { Settings settings = Settings .builder() @@ -367,6 +384,7 @@ public void testRegisterModelUrlNotAllowed() throws Exception { settings, threadPool, client, + sdkClient, nodeFilter, mlTaskDispatcher, mlStats, @@ -386,6 +404,7 @@ public void testRegisterModelUrlNotAllowed() throws Exception { ); } + @Test public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() { when(node1.getId()).thenReturn("NodeId1"); when(node2.getId()).thenReturn("NodeId2"); @@ -401,6 +420,7 @@ public void testDoExecute_successWithLocalNodeNotEqualToClusterNode() { verify(actionListener).onResponse(argumentCaptor.capture()); } + @Test public void testDoExecute_FailToSendForwardRequest() { when(node1.getId()).thenReturn("NodeId1"); when(node2.getId()).thenReturn("NodeId2"); @@ -411,6 +431,7 @@ public void testDoExecute_FailToSendForwardRequest() { verify(actionListener).onResponse(argumentCaptor.capture()); } + @Test public void testTransportRegisterModelActionDoExecuteWithDispatchException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -424,12 +445,13 @@ public void testTransportRegisterModelActionDoExecuteWithDispatchException() { verify(actionListener).onFailure(argumentCaptor.capture()); } + @Test public void test_ValidationFailedException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", "modelGroupID"), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -437,6 +459,7 @@ public void test_ValidationFailedException() { assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } + @Test public void testTransportRegisterModelActionDoExecuteWithCreateTaskException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -450,6 +473,7 @@ public void testTransportRegisterModelActionDoExecuteWithCreateTaskException() { verify(actionListener).onFailure(argumentCaptor.capture()); } + @Test public void test_execute_registerRemoteModel_withConnectorId_success() { MLRegisterModelRequest request = mock(MLRegisterModelRequest.class); MLRegisterModelInput input = mock(MLRegisterModelInput.class); @@ -459,16 +483,15 @@ public void test_execute_registerRemoteModel_withConnectorId_success() { when(input.getConnectorId()).thenReturn("mockConnectorId"); when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(true); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), anyString(), isA(ActionListener.class)); - MLRegisterModelResponse response = mock(MLRegisterModelResponse.class); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); transportRegisterModelAction.doExecute(task, request, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); - verify(mlModelManager).registerMLRemoteModel(eq(input), isA(MLTask.class), eq(actionListener)); + verify(mlModelManager).registerMLRemoteModel(eq(sdkClient), eq(input), isA(MLTask.class), eq(actionListener)); } + @Test public void test_execute_registerRemoteModel_withConnectorId_noPermissionToConnectorId() { MLRegisterModelRequest request = mock(MLRegisterModelRequest.class); MLRegisterModelInput input = mock(MLRegisterModelInput.class); @@ -476,10 +499,10 @@ public void test_execute_registerRemoteModel_withConnectorId_noPermissionToConne when(input.getConnectorId()).thenReturn("mockConnectorId"); when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onResponse(false); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), anyString(), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); MLRegisterModelResponse response = mock(MLRegisterModelResponse.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); @@ -495,6 +518,7 @@ public void test_execute_registerRemoteModel_withConnectorId_noPermissionToConne ); } + @Test public void test_execute_registerRemoteModel_withConnectorId_connectorValidationException() { MLRegisterModelRequest request = mock(MLRegisterModelRequest.class); MLRegisterModelInput input = mock(MLRegisterModelInput.class); @@ -502,16 +526,17 @@ public void test_execute_registerRemoteModel_withConnectorId_connectorValidation when(input.getConnectorId()).thenReturn("mockConnectorId"); when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(5); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(), anyString(), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any(), any(), any(), isA(ActionListener.class)); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } + @Test public void test_execute_registerRemoteModel_withInternalConnector_success() { MLRegisterModelRequest request = mock(MLRegisterModelRequest.class); MLRegisterModelInput input = mock(MLRegisterModelInput.class); @@ -532,9 +557,10 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() { MLRegisterModelResponse response = mock(MLRegisterModelResponse.class); transportRegisterModelAction.doExecute(task, request, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); - verify(mlModelManager).registerMLRemoteModel(eq(input), isA(MLTask.class), eq(actionListener)); + verify(mlModelManager).registerMLRemoteModel(eq(sdkClient), eq(input), isA(MLTask.class), eq(actionListener)); } + @Test public void test_execute_registerRemoteModel_withInternalConnector_connectorIsNull() { MLRegisterModelRequest request = mock(MLRegisterModelRequest.class); MLRegisterModelInput input = mock(MLRegisterModelInput.class); @@ -550,6 +576,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_connectorIsNu ); } + @Test public void test_execute_registerRemoteModel_withInternalConnector_predictEndpointIsNull() { MLRegisterModelRequest request = mock(MLRegisterModelRequest.class); MLRegisterModelInput input = mock(MLRegisterModelInput.class); @@ -567,6 +594,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_predictEndpoi ); } + @Test public void test_ModelNameAlreadyExists() throws IOException { when(node1.getId()).thenReturn("NodeId1"); when(node2.getId()).thenReturn("NodeId2"); @@ -578,29 +606,29 @@ public void test_ModelNameAlreadyExists() throws IOException { }).when(transportService).sendRequest(any(), any(), any(), any()); SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(searchResponse); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", null), actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); verify(actionListener).onResponse(argumentCaptor.capture()); } + @Test public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException { SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(searchResponse); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); - + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); MLRegisterModelInput registerModelInput = MLRegisterModelInput .builder() @@ -621,10 +649,10 @@ public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException public void test_FailureWhenSearchingModelGroupName() throws IOException { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Runtime exception")); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); @@ -637,16 +665,16 @@ public void test_NoAccessWhenModelNameAlreadyExists() throws IOException { SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(searchResponse); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); + ActionListener listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java index f7eb64c8eb..92e32a45ac 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java @@ -15,6 +15,7 @@ import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -103,19 +104,21 @@ public void setup() throws IOException { SearchResponse searchResponse = createModelGroupSearchResponse(0); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(searchResponse); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); } + @Test public void testTransportRegisterModelMetaActionConstructor() { assertNotNull(action); } + @Test public void testTransportRegisterModelMetaActionDoExecute() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); @@ -125,6 +128,7 @@ public void testTransportRegisterModelMetaActionDoExecute() { verify(actionListener).onResponse(argumentCaptor.capture()); } + @Test public void testDoExecute_successWithCreateModelGroup() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -138,6 +142,7 @@ public void testDoExecute_successWithCreateModelGroup() { verify(actionListener).onResponse(argumentCaptor.capture()); } + @Test public void testDoExecute_failureWithCreateModelGroup() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -152,6 +157,7 @@ public void testDoExecute_failureWithCreateModelGroup() { assertEquals("Failed to create Model Group", argumentCaptor.getValue().getMessage()); } + @Test public void testDoExecute_userHasNoAccessException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -168,6 +174,7 @@ public void testDoExecute_userHasNoAccessException() { assertEquals("You don't have permissions to perform this operation on this model.", argumentCaptor.getValue().getMessage()); } + @Test public void test_ValidationFailedException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -184,14 +191,15 @@ public void test_ValidationFailedException() { assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } + @Test public void testDoExecute_ModelNameAlreadyExists() throws IOException { SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(searchResponse); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); MLRegisterModelMetaRequest actionRequest = prepareRequest(null); action.doExecute(task, actionRequest, actionListener); @@ -199,6 +207,7 @@ public void testDoExecute_ModelNameAlreadyExists() throws IOException { verify(actionListener).onResponse(argumentCaptor.capture()); } + @Test public void testDoExecute_NoAccessWhenModelNameAlreadyExists() throws IOException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -208,10 +217,10 @@ public void testDoExecute_NoAccessWhenModelNameAlreadyExists() throws IOExceptio SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(searchResponse); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); MLRegisterModelMetaRequest actionRequest = prepareRequest(null); action.doExecute(task, actionRequest, actionListener); @@ -223,12 +232,13 @@ public void testDoExecute_NoAccessWhenModelNameAlreadyExists() throws IOExceptio ); } - public void test_FailureWhenSearchingModelGroupName() throws IOException { + @Test + public void test_FailureWhenSearchingModelGroupName() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Runtime exception")); return null; - }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any(), any()); MLRegisterModelMetaRequest actionRequest = prepareRequest(null); action.doExecute(task, actionRequest, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java index 06f86dadc9..00c50eafd8 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java @@ -12,8 +12,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; -import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; import static org.opensearch.ml.utils.TestHelper.clusterSetting; @@ -22,26 +20,18 @@ import java.util.Collections; import java.util.List; import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import org.junit.AfterClass; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.LatchedActionListener; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.ConfigConstants; @@ -67,8 +57,6 @@ import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.threadpool.ScalingExecutorBuilder; -import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableList; @@ -106,17 +94,6 @@ public class ConnectorAccessControlHelperTests extends OpenSearchTestCase { @Mock MLFeatureEnabledSetting mlFeatureEnabledSetting; - private static TestThreadPool testThreadPool = new TestThreadPool( - ConnectorAccessControlHelperTests.class.getName(), - new ScalingExecutorBuilder( - GENERAL_THREAD_POOL, - 1, - Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), - TimeValue.timeValueMinutes(1), - ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL - ) - ); - @Before public void setup() { MockitoAnnotations.openMocks(this); @@ -138,12 +115,6 @@ public void setup() { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(threadPool.executor(any())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); - } - - @AfterClass - public static void cleanup() { - ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } @Test @@ -442,12 +413,12 @@ public void testGetConnectorHappyCase() throws IOException, InterruptedException GetDataObjectRequest getRequest = GetDataObjectRequest.builder().index(CommonValue.ML_CONNECTOR_INDEX).id("connectorId").build(); GetResponse getResponse = prepareConnector(); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onResponse(getResponse); - when(client.get(any(GetRequest.class))).thenReturn(future); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); connectorAccessControlHelper .getConnector( sdkClient, @@ -455,25 +426,27 @@ public void testGetConnectorHappyCase() throws IOException, InterruptedException client.threadPool().getThreadContext().newStoredContext(true), getRequest, "connectorId", - latchedActionListener + getConnectorActionListener ); - latch.await(500, TimeUnit.MILLISECONDS); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Connector.class); + verify(getConnectorActionListener).onResponse(argumentCaptor.capture()); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(GetRequest.class); - verify(client, times(1)).get(requestCaptor.capture()); - assertEquals(CommonValue.ML_CONNECTOR_INDEX, requestCaptor.getValue().index()); + // Verify the captured connector + Connector capturedConnector = argumentCaptor.getValue(); + assertNotNull(capturedConnector); + assertEquals("test_connector", capturedConnector.getName()); } @Test public void testGetConnectorException() throws IOException, InterruptedException { GetDataObjectRequest getRequest = GetDataObjectRequest.builder().index(CommonValue.ML_CONNECTOR_INDEX).id("connectorId").build(); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onFailure(new RuntimeException("Failed to get connector")); - when(client.get(any(GetRequest.class))).thenReturn(future); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Failed to get connector")); + return null; + }).when(client).get(any(), any()); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); connectorAccessControlHelper .getConnector( sdkClient, @@ -481,25 +454,24 @@ public void testGetConnectorException() throws IOException, InterruptedException client.threadPool().getThreadContext().newStoredContext(true), getRequest, "connectorId", - latchedActionListener + getConnectorActionListener ); - latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(getConnectorActionListener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to get connector", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to get data object from index .plugins-ml-connector", argumentCaptor.getValue().getMessage()); } @Test public void testGetConnectorIndexNotFound() throws IOException, InterruptedException { GetDataObjectRequest getRequest = GetDataObjectRequest.builder().index(CommonValue.ML_CONNECTOR_INDEX).id("connectorId").build(); - PlainActionFuture future = PlainActionFuture.newFuture(); - future.onFailure(new IndexNotFoundException("Index not found")); - when(client.get(any(GetRequest.class))).thenReturn(future); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new IndexNotFoundException("Index not found")); + return null; + }).when(client).get(any(), any()); - CountDownLatch latch = new CountDownLatch(1); - LatchedActionListener latchedActionListener = new LatchedActionListener<>(getConnectorActionListener, latch); connectorAccessControlHelper .getConnector( sdkClient, @@ -507,9 +479,8 @@ public void testGetConnectorIndexNotFound() throws IOException, InterruptedExcep client.threadPool().getThreadContext().newStoredContext(true), getRequest, "connectorId", - latchedActionListener + getConnectorActionListener ); - latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(getConnectorActionListener).onFailure(argumentCaptor.capture()); diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java index a1cbebfba4..ffdc21269e 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java @@ -6,30 +6,42 @@ package org.opensearch.ml.helper; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_THREAD_POOL_PREFIX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.LatchedActionListener; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.get.GetResult; @@ -39,18 +51,42 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.MLModelGroup.MLModelGroupBuilder; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ScalingExecutorBuilder; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; public class ModelAccessControlHelperTests extends OpenSearchTestCase { + private static final TestThreadPool testThreadPool = new TestThreadPool( + ModelAccessControlHelperTests.class.getName(), + new ScalingExecutorBuilder( + GENERAL_THREAD_POOL, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(Settings.EMPTY) - 1), + TimeValue.timeValueMinutes(1), + ML_THREAD_POOL_PREFIX + GENERAL_THREAD_POOL + ) + ); + @Mock ClusterService clusterService; + @Mock + MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock Client client; + SdkClient sdkClient; + + @Mock + NamedXContentRegistry xContentRegistry; + @Mock private ActionListener actionListener; @@ -66,13 +102,16 @@ public class ModelAccessControlHelperTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().put(ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED.getKey(), true).build(); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); threadContext = new ThreadContext(settings); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings); assertNotNull(modelAccessControlHelper); + // TODO Remove when all calls are migrated to SdkClient version doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(getResponse); @@ -81,6 +120,12 @@ public void setup() { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(threadPool.executor(anyString())).thenReturn(testThreadPool.executor(GENERAL_THREAD_POOL)); + } + + @AfterClass + public static void cleanup() { + ThreadPool.terminate(testThreadPool, 500, TimeUnit.MILLISECONDS); } public void setupModelGroup(String owner, String access, List backendRoles) throws IOException { @@ -92,14 +137,23 @@ public void setupModelGroup(String owner, String access, List backendRol }).when(client).get(any(), any()); } - public void test_UndefinedModelGroupID() { + // TODO Remove when all calls are migrated to SdkClient version + public void test_UndefinedModelGroupID_NoSdkClient() { modelAccessControlHelper.validateModelGroupAccess(null, null, client, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); } - public void test_UndefinedOwner() throws IOException { + public void test_UndefinedModelGroupID() { + modelAccessControlHelper.validateModelGroupAccess(null, mlFeatureEnabledSetting, null, null, client, sdkClient, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertTrue(argumentCaptor.getValue()); + } + + // TODO Remove when all calls are migrated to SdkClient version + public void test_UndefinedOwner_NoSdkClient() throws IOException { getResponse = modelGroupBuilder(null, null, null); modelAccessControlHelper.validateModelGroupAccess(null, "testGroupID", client, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -107,7 +161,17 @@ public void test_UndefinedOwner() throws IOException { assertTrue(argumentCaptor.getValue()); } - public void test_ExceptionEmptyBackendRoles() throws IOException { + public void test_UndefinedOwner() throws IOException { + getResponse = modelGroupBuilder(null, null, null); + modelAccessControlHelper + .validateModelGroupAccess(null, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertTrue(argumentCaptor.getValue()); + } + + // TODO Remove when all calls are migrated to SdkClient version + public void test_ExceptionEmptyBackendRoles_NoSdkClient() throws IOException { String owner = "owner|IT,HR|myTenant"; User user = User.parse("owner|IT,HR|myTenant"); getResponse = modelGroupBuilder(null, AccessMode.RESTRICTED.getValue(), owner); @@ -117,7 +181,28 @@ public void test_ExceptionEmptyBackendRoles() throws IOException { assertEquals("Backend roles shouldn't be null", argumentCaptor.getValue().getMessage()); } - public void test_MatchingBackendRoles() throws IOException { + public void test_ExceptionEmptyBackendRoles() throws IOException, InterruptedException { + String owner = "owner|IT,HR|myTenant"; + User user = User.parse("owner|IT,HR|myTenant"); + getResponse = modelGroupBuilder(null, AccessMode.RESTRICTED.getValue(), owner); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any())).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + modelAccessControlHelper + .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Backend roles shouldn't be null", argumentCaptor.getValue().getMessage()); + } + + // TODO Remove when all calls are migrated to SdkClient version + public void test_MatchingBackendRoles_NoSdkClient() throws IOException { String owner = "owner|IT,HR|myTenant"; List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.RESTRICTED.getValue(), backendRoles); @@ -128,7 +213,29 @@ public void test_MatchingBackendRoles() throws IOException { assertTrue(argumentCaptor.getValue()); } - public void test_PublicModelGroup() throws IOException { + public void test_MatchingBackendRoles() throws IOException, InterruptedException { + String owner = "owner|IT,HR|myTenant"; + List backendRoles = Arrays.asList("IT", "HR"); + setupModelGroup(owner, AccessMode.RESTRICTED.getValue(), backendRoles); + User user = User.parse("owner|IT,HR|myTenant"); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any())).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + modelAccessControlHelper + .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertTrue(argumentCaptor.getValue()); + } + + // TODO Remove when all calls are migrated to SdkClient version + public void test_PublicModelGroup_NoSdkClient() throws IOException { String owner = "owner|IT,HR|myTenant"; List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PUBLIC.getValue(), backendRoles); @@ -139,7 +246,29 @@ public void test_PublicModelGroup() throws IOException { assertTrue(argumentCaptor.getValue()); } - public void test_PrivateModelGroupWithSameOwner() throws IOException { + public void test_PublicModelGroup() throws IOException, InterruptedException { + String owner = "owner|IT,HR|myTenant"; + List backendRoles = Arrays.asList("IT", "HR"); + setupModelGroup(owner, AccessMode.PUBLIC.getValue(), backendRoles); + User user = User.parse("owner|IT,HR|myTenant"); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any())).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + modelAccessControlHelper + .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertTrue(argumentCaptor.getValue()); + } + + // TODO Remove when all calls are migrated to SdkClient version + public void test_PrivateModelGroupWithSameOwner_NoSdkClient() throws IOException { String owner = "owner|IT,HR|myTenant"; List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PRIVATE.getValue(), backendRoles); @@ -150,7 +279,29 @@ public void test_PrivateModelGroupWithSameOwner() throws IOException { assertTrue(argumentCaptor.getValue()); } - public void test_PrivateModelGroupWithDifferentOwner() throws IOException { + public void test_PrivateModelGroupWithSameOwner() throws IOException, InterruptedException { + String owner = "owner|IT,HR|myTenant"; + List backendRoles = Arrays.asList("IT", "HR"); + setupModelGroup(owner, AccessMode.PRIVATE.getValue(), backendRoles); + User user = User.parse("owner|IT,HR|myTenant"); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any())).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + modelAccessControlHelper + .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertTrue(argumentCaptor.getValue()); + } + + // TODO Remove when all calls are migrated to SdkClient version + public void test_PrivateModelGroupWithDifferentOwner_NoSdkClient() throws IOException { String owner = "owner|IT,HR|myTenant"; List backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PRIVATE.getValue(), backendRoles); @@ -161,6 +312,27 @@ public void test_PrivateModelGroupWithDifferentOwner() throws IOException { assertFalse(argumentCaptor.getValue()); } + public void test_PrivateModelGroupWithDifferentOwner() throws IOException, InterruptedException { + String owner = "owner|IT,HR|myTenant"; + List backendRoles = Arrays.asList("IT", "HR"); + setupModelGroup(owner, AccessMode.PRIVATE.getValue(), backendRoles); + User user = User.parse("user|IT,HR|myTenant"); + + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(getResponse); + when(client.get(any())).thenReturn(future); + + CountDownLatch latch = new CountDownLatch(1); + LatchedActionListener latchedActionListener = new LatchedActionListener<>(actionListener, latch); + modelAccessControlHelper + .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + latch.await(500, TimeUnit.MILLISECONDS); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + assertFalse(argumentCaptor.getValue()); + } + public void test_SkipModelAccessControl() { User admin = User.parse("owner|IT,HR|all_access"); User user = User.parse("owner|IT,HR|myTenant"); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index f284f54a5e..8ebf6456a3 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -8,12 +8,13 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.apache.lucene.search.TotalHits; @@ -22,11 +23,14 @@ import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -36,6 +40,8 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.get.GetResult; @@ -44,13 +50,18 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; public class MLModelGroupManagerTests extends OpenSearchTestCase { + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -66,6 +77,8 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase { @Mock private Client client; + SdkClient sdkClient; + @Mock private ActionListener actionListener; @@ -83,17 +96,28 @@ public class MLModelGroupManagerTests extends OpenSearchTestCase { @Mock private MLModelGroupManager mlModelGroupManager; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + private final List backendRoles = Arrays.asList("IT", "HR"); @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().build(); + sdkClient = Mockito.spy(SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap())); threadContext = new ThreadContext(settings); - mlModelGroupManager = new MLModelGroupManager(mlIndicesHandler, client, clusterService, modelAccessControlHelper); + mlModelGroupManager = new MLModelGroupManager( + mlIndicesHandler, + client, + sdkClient, + clusterService, + modelAccessControlHelper, + mlFeatureEnabledSetting + ); assertNotNull(mlModelGroupManager); - - when(indexResponse.getId()).thenReturn("modelGroupID"); + indexResponse = new IndexResponse(new ShardId(ML_MODEL_GROUP_INDEX, "_na_", 0), "model_group_ID", 1, 0, 2, true); + // when(indexResponse.getId()).thenReturn("modelGroupID"); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -107,7 +131,7 @@ public void setup() throws IOException { return null; }).when(mlIndicesHandler).initModelGroupIndexIfAbsent(any()); - SearchResponse searchResponse = createModelGroupSearchResponse(0); + SearchResponse searchResponse = getEmptySearchResponse(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); @@ -129,7 +153,7 @@ public void test_SuccessAddAllBackendRolesTrue() { } public void test_ModelGroupNameNotUnique() throws IOException {// - SearchResponse searchResponse = createModelGroupSearchResponse(1); + SearchResponse searchResponse = getNonEmptySearchResponse(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); @@ -316,7 +340,7 @@ public void test_ExceptionFailedToIndexModelGroup() { mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Index Not Found", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to put data object in index .plugins-ml-model-group", argumentCaptor.getValue().getMessage()); } public void test_ExceptionInitModelGroupIndexIfAbsent() { @@ -353,7 +377,7 @@ public void test_SuccessGetModelGroup() throws IOException { return null; }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); - mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); + mlModelGroupManager.getModelGroupResponse(sdkClient, "testModelGroupID", modelGroupListener); verify(modelGroupListener).onResponse(getResponse); } @@ -367,13 +391,10 @@ public void test_OtherExceptionGetModelGroup() throws IOException { return null; }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); - mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); + mlModelGroupManager.getModelGroupResponse(sdkClient, "testModelGroupID", modelGroupListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(modelGroupListener).onFailure(argumentCaptor.capture()); - assertEquals( - "Any other Exception occurred during getting the model group. Please check log for more details.", - argumentCaptor.getValue().getMessage() - ); + assertEquals("Failed to get data object from index .plugins-ml-model-group", argumentCaptor.getValue().getMessage()); } public void test_NotFoundGetModelGroup() throws IOException { @@ -383,29 +404,12 @@ public void test_NotFoundGetModelGroup() throws IOException { return null; }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); - mlModelGroupManager.getModelGroupResponse("testModelGroupID", modelGroupListener); + mlModelGroupManager.getModelGroupResponse(sdkClient, "testModelGroupID", modelGroupListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(modelGroupListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to find model group with ID: testModelGroupID", argumentCaptor.getValue().getMessage()); } - public void test_NoResponseoInitModelGroup() throws IOException { - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(0); - actionListener.onResponse(false); - return null; - }).when(mlIndicesHandler).initModelGroupIndexIfAbsent(any()); - - when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); - - MLRegisterModelGroupInput mlRegisterModelGroupInput = prepareRequest(null, null, null); - mlModelGroupManager.createModelGroup(mlRegisterModelGroupInput, actionListener); - - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("No response to create ML Model Group index", argumentCaptor.getValue().getMessage()); - } - private MLRegisterModelGroupInput prepareRequest(List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { return MLRegisterModelGroupInput .builder() @@ -417,8 +421,24 @@ private MLRegisterModelGroupInput prepareRequest(List backendRoles, Acce .build(); } - private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { - SearchResponse searchResponse = mock(SearchResponse.class); + private SearchResponse getEmptySearchResponse() { + SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, true, false, null, 1); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + return searchResponse; + } + + private SearchResponse getNonEmptySearchResponse() throws IOException { + SearchHit[] hits = new SearchHit[1]; String modelContent = "{\n" + " \"created_time\": 1684981986069,\n" + " \"access\": \"public\",\n" @@ -428,9 +448,28 @@ private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOE + " \"name\": \"model_group_IT\",\n" + " \"description\": \"This is an example description\"\n" + " }"; - SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); - SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); - when(searchResponse.getHits()).thenReturn(hits); + SearchHit model = SearchHit.fromXContent(TestHelper.parser(modelContent)); + hits[0] = model; + SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); + SearchResponseSections searchSections = new SearchResponseSections( + searchHits, + InternalAggregations.EMPTY, + null, + true, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); return searchResponse; } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 8542891c22..182b2cf39c 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -70,6 +70,7 @@ import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; +import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -85,15 +86,22 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.CircuitBreakingException; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.get.GetResult; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.breaker.ThresholdCircuitBreaker; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLModelGroup; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; @@ -121,6 +129,8 @@ import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.stats.suppliers.CounterSupplier; import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -137,6 +147,7 @@ public class MLModelManagerTests extends OpenSearchTestCase { private ClusterService clusterService; @Mock private Client client; + private SdkClient sdkClient; @Mock private ThreadPool threadPool; private NamedXContentRegistry xContentRegistry; @@ -212,7 +223,7 @@ public void setup() throws URISyntaxException { ); clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService)); xContentRegistry = NamedXContentRegistry.EMPTY; - + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); modelName = "model_name1"; modelId = randomAlphaOfLength(10); modelContentHashValue = "c446f747520bcc6af053813cb1e8d34944a7c4686bbb405aeaa23883b5a806c8"; @@ -273,6 +284,7 @@ public void setup() throws URISyntaxException { clusterService, scriptService, client, + sdkClient, threadPool, xContentRegistry, modelHelper, @@ -326,6 +338,7 @@ public void setup() throws URISyntaxException { }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); } + @Test public void testRegisterMLModel_ExceedMaxRunningTask() { String error = "exceed max running task limit"; doThrow(new MLLimitExceededException(error)).when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); @@ -448,7 +461,7 @@ public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionE ); } - public void testRegisterMLRemoteModel() throws PrivilegedActionException { + public void testRegisterMLRemoteModel() throws PrivilegedActionException, IOException { ActionListener listener = mock(ActionListener.class); doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); @@ -458,18 +471,27 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException { MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); + + GetResponse getResponse = prepareMLModelGroup(); doAnswer(invocation -> { - ActionListener indexResponseActionListener = (ActionListener) invocation.getArguments()[1]; + ActionListener getModelGrouplistener = invocation.getArgument(1); + getModelGrouplistener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + ActionListener indexResponseActionListener = invocation.getArgument(1); indexResponseActionListener.onResponse(indexResponse); return null; }).when(client).index(any(), any()); + when(indexResponse.getId()).thenReturn("mockIndexId"); - modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener); + modelManager.registerMLRemoteModel(sdkClient, pretrainedInput, pretrainedTask, listener); assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE); verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); } - public void testRegisterMLRemoteModel_SkipMemoryCBOpen() { + public void testRegisterMLRemoteModel_SkipMemoryCBOpen() throws IOException { ActionListener listener = mock(ActionListener.class); doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); when(mlCircuitBreakerService.checkOpenCB()) @@ -484,18 +506,26 @@ public void testRegisterMLRemoteModel_SkipMemoryCBOpen() { MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); + + GetResponse getResponse = prepareMLModelGroup(); + doAnswer(invocation -> { + ActionListener getModelGrouplistener = invocation.getArgument(1); + getModelGrouplistener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { ActionListener indexResponseActionListener = (ActionListener) invocation.getArguments()[1]; indexResponseActionListener.onResponse(indexResponse); return null; }).when(client).index(any(), any()); when(indexResponse.getId()).thenReturn("mockIndexId"); - modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener); + modelManager.registerMLRemoteModel(sdkClient, pretrainedInput, pretrainedTask, listener); assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE); verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); } - public void testIndexRemoteModel() throws PrivilegedActionException { + public void testIndexRemoteModel() throws PrivilegedActionException, IOException { ActionListener listener = mock(ActionListener.class); doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null); @@ -505,12 +535,21 @@ public void testIndexRemoteModel() throws PrivilegedActionException { MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); + + GetResponse getResponse = prepareMLModelGroup(); + + IndexResponse indexResponse = new IndexResponse(new ShardId("test", "test", 1), "mockIndexId", 1l, 1l, 1l, true); + doAnswer(invocation -> { + ActionListener getModelGrouplistener = invocation.getArgument(1); + getModelGrouplistener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { ActionListener indexResponseActionListener = (ActionListener) invocation.getArguments()[1]; indexResponseActionListener.onResponse(indexResponse); return null; }).when(client).index(any(), any()); - when(indexResponse.getId()).thenReturn("mockIndexId"); modelManager.indexRemoteModel(pretrainedInput, pretrainedTask, "1.0.0"); assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE); verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); @@ -621,12 +660,19 @@ public void testDeployModel_FailedToGetModel() { when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_threadpool(threadPool, taskExecutorService); mock_client_get_failure(client); + + doAnswer(invocation -> { + ActionListener listener1 = invocation.getArgument(1); + listener1.onFailure(new RuntimeException("get doc failure")); + return null; + }).when(client).get(any(), any()); + mock_client_ThreadContext(client, threadPool, threadContext); modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); - assertEquals("get doc failure", exception.getValue().getMessage()); + assertEquals("Failed to get data object from index .plugins-ml-model", exception.getValue().getMessage()); verify(mlStats) .createCounterStatIfAbsent( eq(FunctionName.TEXT_EMBEDDING), @@ -1281,4 +1327,19 @@ private MLRegisterModelInput mockRemoteModelInput(boolean isHidden) { .deployModel(true) .build(); } + + public GetResponse prepareMLModelGroup() throws IOException { + MLModelGroup mlModelGroup = MLModelGroup + .builder() + .modelGroupId("test_id") + .name("modelGroup") + .description("this is an example description") + .latestVersion(1) + .access("private") + .build(); + XContentBuilder content = mlModelGroup.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + BytesReference bytesReference = BytesReference.bytes(content); + GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); + return new GetResponse(getResult); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java index e3ef6422d0..edcb48261f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelActionTests.java @@ -9,6 +9,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; import static org.mockito.Mockito.times; +import static org.opensearch.ml.common.input.Constants.TENANT_ID_HEADER; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import java.util.HashMap; @@ -16,8 +17,10 @@ import java.util.Map; import org.junit.Before; +import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.client.node.NodeClient; import org.opensearch.common.settings.Settings; @@ -26,6 +29,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -44,9 +48,14 @@ public class RestMLDeleteModelActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() { - restMLDeleteModelAction = new RestMLDeleteModelAction(); + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); + restMLDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -65,17 +74,20 @@ public void tearDown() throws Exception { client.close(); } + @Test public void testConstructor() { - RestMLDeleteModelAction mlDeleteModelAction = new RestMLDeleteModelAction(); + RestMLDeleteModelAction mlDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting); assertNotNull(mlDeleteModelAction); } + @Test public void testGetName() { String actionName = restMLDeleteModelAction.getName(); assertFalse(Strings.isNullOrEmpty(actionName)); assertEquals("ml_delete_model_action", actionName); } + @Test public void testRoutes() { List routes = restMLDeleteModelAction.routes(); assertNotNull(routes); @@ -85,6 +97,7 @@ public void testRoutes() { assertEquals("/_plugins/_ml/models/{model_id}", route.getPath()); } + @Test public void test_PrepareRequest() throws Exception { RestRequest request = getRestRequest(); restMLDeleteModelAction.handleRequest(request, channel, client); @@ -95,10 +108,59 @@ public void test_PrepareRequest() throws Exception { assertEquals(taskId, "test_id"); } + @Test + public void testPrepareRequest_WithTenantIdInHeader() throws Exception { + // Mock multi-tenancy enabled + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + // Create a RestRequest with tenantId in the header + RestRequest request = getRestRequestWithTenantId("test_tenant"); + restMLDeleteModelAction.handleRequest(request, channel, client); + + // Capture the request sent to the client + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelDeleteRequest.class); + verify(client, times(1)).execute(eq(MLModelDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + + // Verify modelId and tenantId + MLModelDeleteRequest capturedRequest = argumentCaptor.getValue(); + assertEquals("test_id", capturedRequest.getModelId()); + assertEquals("test_tenant", capturedRequest.getTenantId()); + } + + public void testPrepareRequest_WithoutTenantId() throws Exception { + // Mock multi-tenancy disabled + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + + // Create a RestRequest without tenantId + RestRequest request = getRestRequest(); + restMLDeleteModelAction.handleRequest(request, channel, client); + + // Capture the request sent to the client + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelDeleteRequest.class); + verify(client, times(1)).execute(eq(MLModelDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + + // Verify modelId and ensure tenantId is null + MLModelDeleteRequest capturedRequest = argumentCaptor.getValue(); + assertEquals("test_id", capturedRequest.getModelId()); + assertNull(capturedRequest.getTenantId()); + } + + private RestRequest getRestRequestWithTenantId(String tenantId) { + Map params = new HashMap<>(); + params.put(PARAMETER_MODEL_ID, "test_id"); // Model ID as a parameter + + Map> headers = new HashMap<>(); + headers.put(TENANT_ID_HEADER, List.of(tenantId)); // Tenant ID as a header + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withHeaders(headers) // Set headers + .build(); + } + private RestRequest getRestRequest() { Map params = new HashMap<>(); params.put(PARAMETER_MODEL_ID, "test_id"); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); - return request; + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelGroupActionTests.java index 48bf4f29dc..16a05f2d55 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelGroupActionTests.java @@ -11,23 +11,31 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.input.Constants.TENANT_ID_HEADER; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -46,9 +54,22 @@ public class RestMLDeleteModelGroupActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Mock + private ClusterService clusterService; + + Settings settings; + @Before public void setup() { - restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); + MockitoAnnotations.openMocks(this); + MockitoAnnotations.openMocks(this); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MULTI_TENANCY_ENABLED))); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -68,7 +89,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLDeleteModelGroupAction mlDeleteModelGroupAction = new RestMLDeleteModelGroupAction(); + RestMLDeleteModelGroupAction mlDeleteModelGroupAction = new RestMLDeleteModelGroupAction(mlFeatureEnabledSetting); assertNotNull(mlDeleteModelGroupAction); } @@ -97,10 +118,37 @@ public void test_PrepareRequest() throws Exception { assertEquals(taskId, "test_id"); } + public void test_PrepareRequest_WithTenantId() throws Exception { + // Enable multi-tenancy + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + // Create RestRequest with tenantId in header + RestRequest request = getRestRequestWithTenantId("test_tenant"); + restMLDeleteModelGroupAction.handleRequest(request, channel, client); + + // Capture request sent to client + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelGroupDeleteRequest.class); + verify(client, times(1)).execute(eq(MLModelGroupDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + + // Verify modelGroupId and tenantId + MLModelGroupDeleteRequest capturedRequest = argumentCaptor.getValue(); + assertEquals("test_id", capturedRequest.getModelGroupId()); + assertEquals("test_tenant", capturedRequest.getTenantId()); + } + private RestRequest getRestRequest() { Map params = new HashMap<>(); params.put(PARAMETER_MODEL_GROUP_ID, "test_id"); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); - return request; + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } + + private RestRequest getRestRequestWithTenantId(String tenantId) { + Map params = new HashMap<>(); + params.put(PARAMETER_MODEL_GROUP_ID, "test_id"); + + Map> headers = new HashMap<>(); + headers.put(TENANT_ID_HEADER, List.of(tenantId)); // Add tenant ID to headers + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).withHeaders(headers).build(); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java index a17358e213..fffc2614c7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java @@ -9,18 +9,24 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; import static org.mockito.Mockito.times; +import static org.opensearch.ml.common.input.Constants.TENANT_ID_HEADER; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; @@ -28,6 +34,7 @@ import org.opensearch.ml.common.transport.model.MLModelGetAction; import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.common.transport.model.MLModelGetResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -48,9 +55,22 @@ public class RestMLGetModelActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + Settings settings; + + @Mock + private ClusterService clusterService; + @Before public void setup() { - restMLGetModelAction = new RestMLGetModelAction(); + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MULTI_TENANCY_ENABLED))); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + restMLGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -70,7 +90,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction(); + RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); assertNotNull(mlGetModelAction); } @@ -89,6 +109,24 @@ public void testRoutes() { assertEquals("/_plugins/_ml/models/{model_id}", route.getPath()); } + public void test_PrepareRequest_WithTenantId() throws Exception { + // Enable multi-tenancy + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + // Create RestRequest with tenantId in the header + RestRequest request = getRestRequestWithTenantId("test_tenant"); + restMLGetModelAction.handleRequest(request, channel, client); + + // Capture request sent to client + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelGetRequest.class); + verify(client, times(1)).execute(eq(MLModelGetAction.INSTANCE), argumentCaptor.capture(), any()); + + // Verify modelId and tenantId + MLModelGetRequest capturedRequest = argumentCaptor.getValue(); + assertEquals("test_id", capturedRequest.getModelId()); + assertEquals("test_tenant", capturedRequest.getTenantId()); + } + public void test_PrepareRequest() throws Exception { RestRequest request = getRestRequest(); restMLGetModelAction.handleRequest(request, channel, client); @@ -99,6 +137,16 @@ public void test_PrepareRequest() throws Exception { assertEquals(taskId, "test_id"); } + private RestRequest getRestRequestWithTenantId(String tenantId) { + Map params = new HashMap<>(); + params.put(PARAMETER_MODEL_ID, "test_id"); + + Map> headers = new HashMap<>(); + headers.put(TENANT_ID_HEADER, List.of(tenantId)); // Add tenant ID to headers + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).withHeaders(headers).build(); + } + private RestRequest getRestRequest() { Map params = new HashMap<>(); params.put(PARAMETER_MODEL_ID, "test_id"); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionIT.java index 821f9e5973..35e7f8d199 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionIT.java @@ -23,7 +23,7 @@ public class RestMLGetModelGroupActionIT extends MLCommonsRestTestCase { public void testGetModelAPI_EmptyResources() throws IOException { exceptionRule.expect(ResponseException.class); - exceptionRule.expectMessage("Fail to find model group index"); + exceptionRule.expectMessage("Failed to find model group index"); TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/model_groups/111222333", null, "", null); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionTests.java index 0f99b406df..8eb1ed0ace 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelGroupActionTests.java @@ -11,18 +11,25 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.input.Constants.TENANT_ID_HEADER; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; @@ -30,6 +37,7 @@ import org.opensearch.ml.common.transport.model_group.MLModelGroupGetAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetRequest; import org.opensearch.ml.common.transport.model_group.MLModelGroupGetResponse; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -44,15 +52,27 @@ public class RestMLGetModelGroupActionTests extends OpenSearchTestCase { private RestMLGetModelGroupAction restMLGetModelGroupAction; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Mock + private ClusterService clusterService; + NodeClient client; private ThreadPool threadPool; + Settings settings; @Mock RestChannel channel; @Before public void setup() { - restMLGetModelGroupAction = new RestMLGetModelGroupAction(); + MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MULTI_TENANCY_ENABLED))); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + restMLGetModelGroupAction = new RestMLGetModelGroupAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -72,7 +92,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction(); + RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); assertNotNull(mlGetModelAction); } @@ -91,6 +111,25 @@ public void testRoutes() { assertEquals("/_plugins/_ml/model_groups/{model_group_id}", route.getPath()); } + public void test_PrepareRequest_WithTenantId() throws Exception { + // Mock multi-tenancy enabled + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + + // Create a RestRequest with tenant ID in the headers + RestRequest request = getRestRequestWithTenantId("test_tenant"); + restMLGetModelGroupAction.handleRequest(request, channel, client); + + // Capture the request sent to the client + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelGroupGetRequest.class); + verify(client, times(1)).execute(eq(MLModelGroupGetAction.INSTANCE), argumentCaptor.capture(), any()); + + // Verify modelGroupId and tenantId + MLModelGroupGetRequest capturedRequest = argumentCaptor.getValue(); + assertEquals("test_id", capturedRequest.getModelGroupId()); + assertEquals("test_tenant", capturedRequest.getTenantId()); + + } + public void test_PrepareRequest() throws Exception { RestRequest request = getRestRequest(); restMLGetModelGroupAction.handleRequest(request, channel, client); @@ -101,6 +140,14 @@ public void test_PrepareRequest() throws Exception { assertEquals(taskId, "test_id"); } + private RestRequest getRestRequestWithTenantId(String tenantId) { + Map params = new HashMap<>(); + params.put(PARAMETER_MODEL_GROUP_ID, "test_id"); + Map> headers = new HashMap<>(); + headers.put(TENANT_ID_HEADER, List.of(tenantId)); + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).withHeaders(headers).build(); + } + private RestRequest getRestRequest() { Map params = new HashMap<>(); params.put(PARAMETER_MODEL_GROUP_ID, "test_id"); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelGroupActionTests.java index 221848c597..4c024920cb 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelGroupActionTests.java @@ -11,9 +11,12 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.junit.Rule; @@ -24,6 +27,8 @@ import org.opensearch.OpenSearchParseException; import org.opensearch.action.get.GetResponse; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; @@ -33,6 +38,7 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -51,6 +57,13 @@ public class RestMLRegisterModelGroupActionTests extends OpenSearchTestCase { private NodeClient client; private ThreadPool threadPool; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + Settings settings; + + @Mock + private ClusterService clusterService; + @Mock RestChannel channel; @@ -59,7 +72,11 @@ public void setup() { MockitoAnnotations.openMocks(this); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restMLRegisterModelGroupAction = new RestMLRegisterModelGroupAction(); + settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MULTI_TENANCY_ENABLED))); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + restMLRegisterModelGroupAction = new RestMLRegisterModelGroupAction(mlFeatureEnabledSetting); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); return null; @@ -74,7 +91,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLRegisterModelGroupAction registerModelGroupAction = new RestMLRegisterModelGroupAction(); + RestMLRegisterModelGroupAction registerModelGroupAction = new RestMLRegisterModelGroupAction(mlFeatureEnabledSetting); assertNotNull(registerModelGroupAction); } @@ -113,22 +130,20 @@ public void testRegisterModelGroupRequestWithEmptyContent() throws Exception { private RestRequest getRestRequest() { RestRequest.Method method = RestRequest.Method.POST; final Map modelGroup = Map.of("name", "testModelGroupName", "description", "This is test description"); - String requestContent = new Gson().toJson(modelGroup).toString(); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + String requestContent = new Gson().toJson(modelGroup); + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) .withPath("/_plugins/_ml/model_groups/_register") .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); - return request; } private RestRequest getRestRequestWithEmptyContent() { RestRequest.Method method = RestRequest.Method.POST; - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) .withPath("/_plugins/_ml/model_groups/_register") .withContent(new BytesArray(""), XContentType.JSON) .build(); - return request; } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java index c215a4259f..0aa45c4ea3 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java @@ -36,6 +36,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; @@ -57,10 +58,13 @@ public class RestMLSearchModelGroupActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(); + restMLSearchModelGroupAction = new RestMLSearchModelGroupAction(mlFeatureEnabledSetting); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -106,7 +110,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLSearchModelGroupAction mlSearchModelGroupAction = new RestMLSearchModelGroupAction(); + RestMLSearchModelGroupAction mlSearchModelGroupAction = new RestMLSearchModelGroupAction(mlFeatureEnabledSetting); assertNotNull(mlSearchModelGroupAction); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java index c11c7e3fb8..755a04e62c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.opensearch.ml.utils.TestHelper.toJsonString; import java.util.HashMap; @@ -38,6 +39,7 @@ import org.opensearch.ml.common.transport.model.MLUpdateModelAction; import org.opensearch.ml.common.transport.model.MLUpdateModelInput; import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -56,6 +58,9 @@ public class RestMLUpdateModelActionTests extends OpenSearchTestCase { private NodeClient client; private ThreadPool threadPool; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock RestChannel channel; @@ -64,7 +69,8 @@ public void setup() { MockitoAnnotations.openMocks(this); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restMLUpdateModelAction = new RestMLUpdateModelAction(); + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + restMLUpdateModelAction = new RestMLUpdateModelAction(mlFeatureEnabledSetting); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); return null; @@ -80,7 +86,7 @@ public void tearDown() throws Exception { @Test public void testConstructor() { - RestMLUpdateModelAction UpdateModelAction = new RestMLUpdateModelAction(); + RestMLUpdateModelAction UpdateModelAction = new RestMLUpdateModelAction(mlFeatureEnabledSetting); assertNotNull(UpdateModelAction); } @@ -176,26 +182,24 @@ private RestRequest getRestRequest() { String requestContent = new Gson().toJson(modelContent); Map params = new HashMap<>(); params.put("model_id", "test_modelId"); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) .withPath("/_plugins/_ml/models/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); - return request; } private RestRequest getRestRequestWithEmptyContent() { RestRequest.Method method = RestRequest.Method.PUT; Map params = new HashMap<>(); params.put("model_id", "test_modelId"); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) .withPath("/_plugins/_ml/models/{model_id}") .withParams(params) .withContent(new BytesArray(""), XContentType.JSON) .build(); - return request; } private RestRequest getRestRequestWithNullModelId() { @@ -203,13 +207,12 @@ private RestRequest getRestRequestWithNullModelId() { final Map modelContent = Map.of("name", "testModelName", "description", "This is test description"); String requestContent = new Gson().toJson(modelContent); Map params = new HashMap<>(); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) .withPath("/_plugins/_ml/models/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); - return request; } private RestRequest getRestRequestWithNullField() { @@ -217,13 +220,12 @@ private RestRequest getRestRequestWithNullField() { String requestContent = "{\"name\":\"testModelName\",\"description\":null}"; Map params = new HashMap<>(); params.put("model_id", "test_modelId"); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) .withPath("/_plugins/_ml/models/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); - return request; } private RestRequest getRestRequestWithConnectorIDAndConnectorUpdateContent() { @@ -248,13 +250,12 @@ private RestRequest getRestRequestWithConnectorIDAndConnectorUpdateContent() { String requestContent = new Gson().toJson(modelContent); Map params = new HashMap<>(); params.put("model_id", "test_modelId"); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) .withPath("/_plugins/_ml/models/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); - return request; } private RestRequest getRestRequestWithConnectorID() { @@ -264,13 +265,12 @@ private RestRequest getRestRequestWithConnectorID() { String requestContent = new Gson().toJson(modelContent); Map params = new HashMap<>(); params.put("model_id", "test_modelId"); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) .withPath("/_plugins/_ml/models/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); - return request; } private RestRequest getRestRequestWithConnectorUpdateContent() { @@ -286,12 +286,11 @@ private RestRequest getRestRequestWithConnectorUpdateContent() { String requestContent = new Gson().toJson(modelContent); Map params = new HashMap<>(); params.put("model_id", "test_modelId"); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) .withPath("/_plugins/_ml/models/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); - return request; } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java index 54412454ab..1653812fe5 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java @@ -11,11 +11,14 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED; import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.junit.Rule; @@ -25,6 +28,8 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.get.GetResponse; import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; @@ -34,6 +39,7 @@ import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -52,6 +58,13 @@ public class RestMLUpdateModelGroupActionTests extends OpenSearchTestCase { private NodeClient client; private ThreadPool threadPool; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Mock + private ClusterService clusterService; + Settings settings; + @Mock RestChannel channel; @@ -60,7 +73,11 @@ public void setup() { MockitoAnnotations.openMocks(this); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(); + settings = Settings.builder().put(ML_COMMONS_MULTI_TENANCY_ENABLED.getKey(), false).build(); + when(clusterService.getSettings()).thenReturn(settings); + when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MULTI_TENANCY_ENABLED))); + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); + restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction(mlFeatureEnabledSetting); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); return null; @@ -75,7 +92,7 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLUpdateModelGroupAction UpdateModelGroupAction = new RestMLUpdateModelGroupAction(); + RestMLUpdateModelGroupAction UpdateModelGroupAction = new RestMLUpdateModelGroupAction(mlFeatureEnabledSetting); assertNotNull(UpdateModelGroupAction); }