diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java index a48cc6ed17..b553a222a9 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -248,11 +248,20 @@ public void createInteraction( /** * Update a conversation + * @param conversationId the conversation id to update * @param updateContent update content for the conversations index * @param listener receives the update response */ public void updateConversation(String conversationId, Map updateContent, ActionListener listener); + /** + * Update an interaction + * @param interactionId the interaction id to update + * @param updateContent update content for the interaction index + * @param listener receives the update response + */ + public void updateInteraction(String interactionId, Map updateContent, ActionListener listener); + /** * Get a single ConversationMeta object * @param conversationId id of the conversation to get diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java index 995edb8b16..9f256fd075 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java @@ -8,18 +8,20 @@ import java.time.Instant; import java.util.Map; +import org.opensearch.OpenSearchException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -28,29 +30,51 @@ @Log4j2 public class UpdateConversationTransportAction extends HandledTransportAction { Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; @Inject - public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + public UpdateConversationTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + OpenSearchConversationalMemoryHandler cmHandler, + ClusterService clusterService + ) { super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new); this.client = client; + this.cmHandler = cmHandler; + System.out.println(clusterService.getSettings()); + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); } @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request); - String conversationId = updateConversationRequest.getConversationId(); - UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId); - Map updateContent = updateConversationRequest.getUpdateContent(); - updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now()); + if (!featureIsEnabled) { + listener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } else { + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request); + String conversationId = updateConversationRequest.getConversationId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + Map updateContent = updateConversationRequest.getUpdateContent(); + updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now()); - updateRequest.doc(updateContent); - updateRequest.docAsUpsert(true); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context)); - } catch (Exception e) { - log.error("Failed to update Conversation for conversation id" + conversationId, e); - listener.onFailure(e); + cmHandler.updateConversation(conversationId, updateContent, getUpdateResponseListener(conversationId, listener, context)); + } catch (Exception e) { + log.error("Failed to update Conversation " + conversationId, e); + listener.onFailure(e); + } } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java index 9abf8571c4..785e4e3fdb 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java @@ -5,18 +5,22 @@ package org.opensearch.ml.memory.action.conversation; +import java.util.Map; + +import org.opensearch.OpenSearchException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -25,26 +29,47 @@ @Log4j2 public class UpdateInteractionTransportAction extends HandledTransportAction { Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; @Inject - public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + public UpdateInteractionTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + OpenSearchConversationalMemoryHandler cmHandler, + ClusterService clusterService + ) { super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new); this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); } @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + if (!featureIsEnabled) { + listener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request); String interactionId = updateInteractionRequest.getInteractionId(); - UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId); - updateRequest.doc(updateInteractionRequest.getUpdateContent()); - updateRequest.docAsUpsert(true); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + Map updateContent = updateInteractionRequest.getUpdateContent(); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context)); + cmHandler.updateInteraction(interactionId, updateContent, getUpdateResponseListener(interactionId, listener, context)); } catch (Exception e) { - log.error("Failed to update Interaction for interaction id " + interactionId, e); + log.error("Failed to update Interaction " + interactionId, e); listener.onFailure(e); } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index f7e2a63138..86a045a35a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -360,11 +360,12 @@ public void searchConversations(SearchRequest request, ActionListener listener) { + public void updateConversation(String conversationId, UpdateRequest updateRequest, ActionListener listener) { if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener .onFailure( @@ -372,6 +373,22 @@ public void updateConversation(UpdateRequest updateRequest, ActionListener { + if (access) { + innerUpdateConversation(updateRequest, listener); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { listener.onFailure(e); })); + } + + private void innerUpdateConversation(UpdateRequest updateRequest, ActionListener listener) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); client.update(updateRequest, internalListener); diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index efd4b562d4..13c00753f7 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -40,6 +40,8 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.client.Requests; import org.opensearch.cluster.service.ClusterService; @@ -330,6 +332,47 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList listener.onResponse(List.of()); return; } + + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); + ActionListener al = ActionListener.wrap(getResponse -> { + // If the interaction doesn't exist, fail + if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { + throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found"); + } + Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap()); + // checks if the user has permission to access the conversation that the interaction belongs to + String conversationId = interaction.getConversationId(); + ActionListener accessListener = ActionListener.wrap(access -> { + if (access) { + innerGetTraces(interactionId, from, maxResults, listener); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null + ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS + : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to interaction " + interactionId); + } + }, e -> { listener.onFailure(e); }); + conversationMetaIndex.checkAccess(conversationId, accessListener); + }, e -> { internalListener.onFailure(e); }); + client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, ActionListener.runBefore(al, () -> threadContext.restore())); + }, e -> { + log.error("Failed to refresh interactions index during get interaction ", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + @VisibleForTesting + void innerGetTraces(String interactionId, int from, int maxResults, ActionListener> listener) { SearchRequest request = Requests.searchRequest(INTERACTIONS_INDEX_NAME); // Build the query BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); @@ -364,7 +407,7 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList } /** - * Gets all of the interactions in a conversation, regardless of conversation size + * Gets all interactions in a conversation, regardless of conversation size * @param conversationId conversation to get all interactions of * @param maxResults how many interactions to get per search query * @param listener receives the list of all interactions in the conversation @@ -509,15 +552,16 @@ public void getInteraction(String interactionId, ActionListener lis ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); ActionListener al = ActionListener.wrap(getResponse -> { - // If the conversation doesn't exist, fail + // If the interaction doesn't exist, fail if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found"); } Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap()); - internalListener.onResponse(interaction); + // checks if the user has permission to access the conversation that the interaction belongs to + checkInteractionPermission(interactionId, interaction, internalListener); }, e -> { internalListener.onFailure(e); }); client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { - client.get(request, al); + client.get(request, ActionListener.runBefore(al, () -> threadContext.restore())); }, e -> { log.error("Failed to refresh interactions index during get interaction ", e); internalListener.onFailure(e); @@ -526,4 +570,87 @@ public void getInteraction(String interactionId, ActionListener lis listener.onFailure(e); } } + + /** + * Update interaction in the index + * @param interactionId the interaction id that needs update + * @param updateRequest original update request + * @param listener receives the update response for the wrapped query + */ + public void updateInteraction(String interactionId, UpdateRequest updateRequest, ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { + listener + .onFailure( + new IndexNotFoundException( + "cannot update interaction since the interaction index does not exist", + INTERACTIONS_INDEX_NAME + ) + ); + return; + } + + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); + ActionListener al = ActionListener.wrap(getResponse -> { + // If the interaction doesn't exist, fail + if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { + throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found"); + } + Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap()); + // checks if the user has permission to access the conversation that the interaction belongs to + String conversationId = interaction.getConversationId(); + ActionListener accessListener = ActionListener.wrap(access -> { + if (access) { + innerUpdateInteraction(updateRequest, internalListener); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null + ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS + : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to interaction " + interactionId); + } + }, e -> { listener.onFailure(e); }); + conversationMetaIndex.checkAccess(conversationId, accessListener); + }, e -> { internalListener.onFailure(e); }); + client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, ActionListener.runBefore(al, () -> threadContext.restore())); + }, e -> { + log.error("Failed to refresh interactions index during get interaction ", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private void innerUpdateInteraction(UpdateRequest updateRequest, ActionListener listener) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + client.update(updateRequest, internalListener); + } catch (Exception e) { + log.error("Failed to update Conversation. Details {}:", e); + listener.onFailure(e); + } + } + + private void checkInteractionPermission(String interactionId, Interaction interaction, ActionListener internalListener) { + String conversationId = interaction.getConversationId(); + ActionListener accessListener = ActionListener.wrap(access -> { + if (access) { + internalListener.onResponse(interaction); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to interaction " + interactionId); + } + }, e -> { internalListener.onFailure(e); }); + conversationMetaIndex.checkAccess(conversationId, accessListener); + } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index 64d39991df..1b4aa0c495 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -384,10 +384,23 @@ public ActionFuture searchInteractions(String conversationId, Se return fut; } + /** + * List all traces of an interaction + * @param interactionId id of the parent interaction + * @param from where to start listing from + * @maxResults how many traces to list + * @listener process the response + */ public void getTraces(String interactionId, int from, int maxResults, ActionListener> listener) { interactionsIndex.getTraces(interactionId, from, maxResults, listener); } + /** + * Update conversation in the index + * @param conversationId the conversation id that needs update + * @param updateContent original update content + * @param listener receives the update response for the wrapped query + */ public void updateConversation(String conversationId, Map updateContent, ActionListener listener) { UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId); updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now()); @@ -396,7 +409,23 @@ public void updateConversation(String conversationId, Map update updateRequest.docAsUpsert(true); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - conversationMetaIndex.updateConversation(updateRequest, listener); + conversationMetaIndex.updateConversation(conversationId, updateRequest, listener); + } + + /** + * Update interaction in the index + * @param interactionId the interaction id that needs update + * @param updateContent original update content + * @param listener receives the update response for the wrapped query + */ + public void updateInteraction(String interactionId, Map updateContent, ActionListener listener) { + UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId); + + updateRequest.doc(updateContent); + updateRequest.docAsUpsert(true); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + interactionsIndex.updateInteraction(interactionId, updateRequest, listener); } /** diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java index f8476f9ccb..5fb4067acb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java @@ -6,9 +6,9 @@ package org.opensearch.ml.memory.action.conversation; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; @@ -18,21 +18,26 @@ import java.time.Instant; import java.util.HashMap; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -65,20 +70,30 @@ public class UpdateConversationTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; - ThreadContext threadContext; + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + @Mock + ClusterService clusterService; - private Settings settings; + ThreadContext threadContext; private ShardId shardId; @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - settings = Settings.builder().build(); + this.clusterService = Mockito.mock(ClusterService.class); + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); threadContext = new ThreadContext(settings); + this.threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + String conversationId = "test_conversation_id"; Map updateContent = new HashMap<>(); updateContent.put(META_NAME_FIELD, "new name"); @@ -87,27 +102,36 @@ public void setup() throws IOException { when(updateRequest.getUpdateContent()).thenReturn(updateContent); shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); - - transportUpdateConversationAction = new UpdateConversationTransportAction(transportService, actionFilters, client); + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + transportUpdateConversationAction = new UpdateConversationTransportAction( + transportService, + actionFilters, + client, + cmHandler, + clusterService + ); } public void test_execute_Success() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(updateResponse); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateConversation(any(), any(), any()); transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getResult().equals(DocWriteResponse.Result.UPDATED)); } public void test_execute_UpdateFailure() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Error in Update Request")); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateConversation(any(), any(), any()); transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); @@ -118,17 +142,17 @@ public void test_execute_UpdateFailure() { public void test_execute_UpdateWrongStatus() { UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(updateResponse); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateConversation(any(), any(), any()); transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); verify(actionListener).onResponse(updateResponse); } public void test_execute_ThrowException() { - doThrow(new RuntimeException("Error in Update Request")).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + doThrow(new RuntimeException("Error in Update Request")).when(cmHandler).updateConversation(any(), any(), any()); transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java index 3dbd16ca64..f81dbe0e94 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java @@ -16,21 +16,26 @@ import java.io.IOException; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -62,20 +67,29 @@ public class UpdateInteractionTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; - ThreadContext threadContext; + @Mock + OpenSearchConversationalMemoryHandler cmHandler; - private Settings settings; + @Mock + ClusterService clusterService; + + ThreadContext threadContext; private ShardId shardId; @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - settings = Settings.builder().build(); + this.clusterService = Mockito.mock(ClusterService.class); + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); threadContext = new ThreadContext(settings); + this.threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); String interactionId = "test_interaction_id"; Map updateContent = Map .of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!"), INTERACTIONS_RESPONSE_FIELD, "response"); @@ -84,15 +98,21 @@ public void setup() throws IOException { shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); - updateInteractionTransportAction = new UpdateInteractionTransportAction(transportService, actionFilters, client); + updateInteractionTransportAction = new UpdateInteractionTransportAction( + transportService, + actionFilters, + client, + cmHandler, + clusterService + ); } public void test_execute_Success() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(updateResponse); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateInteraction(any(String.class), any(Map.class), isA(ActionListener.class)); updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); verify(actionListener).onResponse(updateResponse); @@ -100,10 +120,10 @@ public void test_execute_Success() { public void test_execute_UpdateFailure() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Error in Update Request")); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateInteraction(any(String.class), any(Map.class), isA(ActionListener.class)); updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); @@ -114,17 +134,19 @@ public void test_execute_UpdateFailure() { public void test_execute_UpdateWrongStatus() { UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(updateResponse); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateInteraction(any(String.class), any(Map.class), isA(ActionListener.class)); updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); verify(actionListener).onResponse(updateResponse); } public void test_execute_ThrowException() { - doThrow(new RuntimeException("Error in Update Request")).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + doThrow(new RuntimeException("Error in Update Request")) + .when(cmHandler) + .updateInteraction(any(String.class), any(Map.class), isA(ActionListener.class)); updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index e3841209c5..650ef8eab8 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -644,7 +644,7 @@ public void testUpdateConversation_NoIndex_ThenFail() { doReturn(false).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); - conversationMetaIndex.updateConversation(new UpdateRequest(), getListener); + conversationMetaIndex.updateConversation("tester_id", new UpdateRequest(), getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor @@ -654,6 +654,16 @@ public void testUpdateConversation_NoIndex_ThenFail() { } public void testUpdateConversation_Success() { + setupRefreshSuccess(); + final String id = "test_id"; + GetResponse dummyGetResponse = mock(GetResponse.class); + doReturn(true).when(dummyGetResponse).isExists(); + doReturn(id).when(dummyGetResponse).getId(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(dummyGetResponse); + return null; + }).when(client).get(any(), any()); doReturn(true).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); @@ -665,18 +675,28 @@ public void testUpdateConversation_Success() { listener.onResponse(updateResponse); return null; }).when(client).update(any(), any()); - conversationMetaIndex.updateConversation(new UpdateRequest(), getListener); + conversationMetaIndex.updateConversation("test_id", new UpdateRequest(), getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(UpdateResponse.class); verify(getListener, times(1)).onResponse(argCaptor.capture()); } public void testUpdateConversation_ClientFails() { + setupRefreshSuccess(); + final String id = "test_id"; + GetResponse dummyGetResponse = mock(GetResponse.class); + doReturn(true).when(dummyGetResponse).isExists(); + doReturn(id).when(dummyGetResponse).getId(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(dummyGetResponse); + return null; + }).when(client).get(any(), any()); doReturn(true).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); doThrow(new RuntimeException("Client Failure")).when(client).update(any(), any()); - conversationMetaIndex.updateConversation(new UpdateRequest(), getListener); + conversationMetaIndex.updateConversation("test_id", new UpdateRequest(), getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Client Failure")); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 41fdb1af41..3e523bc07c 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -480,7 +480,7 @@ public void testGetTraces() { doReturn(true).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener> getTracesListener = mock(ActionListener.class); - interactionsIndex.getTraces("cid", 0, 10, getTracesListener); + interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener); @SuppressWarnings("unchecked") ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); verify(getTracesListener, times(1)).onResponse(argCaptor.capture()); @@ -491,7 +491,7 @@ public void testGetTraces_clientFail() { doReturn(true).when(metadata).hasIndex(anyString()); doThrow(new RuntimeException("Client Failure")).when(client).search(any(), any()); ActionListener> getTracesListener = mock(ActionListener.class); - interactionsIndex.getTraces("cid", 0, 10, getTracesListener); + interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getTracesListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Client Failure")); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index 2c1c28c529..216435c9ec 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -179,12 +179,12 @@ public void testGetTraces() { public void testUpdateConversation() { doAnswer(invocation -> { - ActionListener al = invocation.getArgument(1); + ActionListener al = invocation.getArgument(2); ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); al.onResponse(updateResponse); return null; - }).when(conversationMetaIndex).updateConversation(any(), any()); + }).when(conversationMetaIndex).updateConversation(any(), any(), any()); ActionListener updateConversationListener = mock(ActionListener.class); cmHandler.updateConversation("cId", new HashMap<>(), updateConversationListener);