From cbe2ca5afa7df8f64f88956a557c4c26e10fc73c Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Wed, 24 Jan 2024 10:25:40 -0800 Subject: [PATCH 1/2] add more user based permission check in Memory Signed-off-by: Xun Zhang --- .../memory/ConversationalMemoryHandler.java | 9 ++ .../UpdateConversationTransportAction.java | 56 +++++--- .../UpdateInteractionTransportAction.java | 45 ++++-- .../memory/index/ConversationMetaIndex.java | 21 ++- .../ml/memory/index/InteractionsIndex.java | 135 +++++++++++++++++- ...OpenSearchConversationalMemoryHandler.java | 31 +++- ...pdateConversationTransportActionTests.java | 54 +++++-- ...UpdateInteractionTransportActionTests.java | 46 ++++-- .../index/ConversationMetaIndexTests.java | 26 +++- .../memory/index/InteractionsIndexTests.java | 4 +- ...earchConversationalMemoryHandlerTests.java | 4 +- 11 files changed, 364 insertions(+), 67 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java index a48cc6ed17..b553a222a9 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -248,11 +248,20 @@ public void createInteraction( /** * Update a conversation + * @param conversationId the conversation id to update * @param updateContent update content for the conversations index * @param listener receives the update response */ public void updateConversation(String conversationId, Map updateContent, ActionListener listener); + /** + * Update an interaction + * @param interactionId the interaction id to update + * @param updateContent update content for the interaction index + * @param listener receives the update response + */ + public void updateInteraction(String interactionId, Map updateContent, ActionListener listener); + /** * Get a single ConversationMeta object * @param conversationId id of the conversation to get diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java index 995edb8b16..9f256fd075 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java @@ -8,18 +8,20 @@ import java.time.Instant; import java.util.Map; +import org.opensearch.OpenSearchException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -28,29 +30,51 @@ @Log4j2 public class UpdateConversationTransportAction extends HandledTransportAction { Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; @Inject - public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + public UpdateConversationTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + OpenSearchConversationalMemoryHandler cmHandler, + ClusterService clusterService + ) { super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new); this.client = client; + this.cmHandler = cmHandler; + System.out.println(clusterService.getSettings()); + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); } @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request); - String conversationId = updateConversationRequest.getConversationId(); - UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId); - Map updateContent = updateConversationRequest.getUpdateContent(); - updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now()); + if (!featureIsEnabled) { + listener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } else { + UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request); + String conversationId = updateConversationRequest.getConversationId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + Map updateContent = updateConversationRequest.getUpdateContent(); + updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now()); - updateRequest.doc(updateContent); - updateRequest.docAsUpsert(true); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context)); - } catch (Exception e) { - log.error("Failed to update Conversation for conversation id" + conversationId, e); - listener.onFailure(e); + cmHandler.updateConversation(conversationId, updateContent, getUpdateResponseListener(conversationId, listener, context)); + } catch (Exception e) { + log.error("Failed to update Conversation " + conversationId, e); + listener.onFailure(e); + } } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java index 9abf8571c4..785e4e3fdb 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java @@ -5,18 +5,22 @@ package org.opensearch.ml.memory.action.conversation; +import java.util.Map; + +import org.opensearch.OpenSearchException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -25,26 +29,47 @@ @Log4j2 public class UpdateInteractionTransportAction extends HandledTransportAction { Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; @Inject - public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { + public UpdateInteractionTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + OpenSearchConversationalMemoryHandler cmHandler, + ClusterService clusterService + ) { super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new); this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); } @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + if (!featureIsEnabled) { + listener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request); String interactionId = updateInteractionRequest.getInteractionId(); - UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId); - updateRequest.doc(updateInteractionRequest.getUpdateContent()); - updateRequest.docAsUpsert(true); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + Map updateContent = updateInteractionRequest.getUpdateContent(); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context)); + cmHandler.updateInteraction(interactionId, updateContent, getUpdateResponseListener(interactionId, listener, context)); } catch (Exception e) { - log.error("Failed to update Interaction for interaction id " + interactionId, e); + log.error("Failed to update Interaction " + interactionId, e); listener.onFailure(e); } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index f7e2a63138..86a045a35a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -360,11 +360,12 @@ public void searchConversations(SearchRequest request, ActionListener listener) { + public void updateConversation(String conversationId, UpdateRequest updateRequest, ActionListener listener) { if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener .onFailure( @@ -372,6 +373,22 @@ public void updateConversation(UpdateRequest updateRequest, ActionListener { + if (access) { + innerUpdateConversation(updateRequest, listener); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { listener.onFailure(e); })); + } + + private void innerUpdateConversation(UpdateRequest updateRequest, ActionListener listener) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); client.update(updateRequest, internalListener); diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index efd4b562d4..13c00753f7 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -40,6 +40,8 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.client.Requests; import org.opensearch.cluster.service.ClusterService; @@ -330,6 +332,47 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList listener.onResponse(List.of()); return; } + + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); + ActionListener al = ActionListener.wrap(getResponse -> { + // If the interaction doesn't exist, fail + if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { + throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found"); + } + Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap()); + // checks if the user has permission to access the conversation that the interaction belongs to + String conversationId = interaction.getConversationId(); + ActionListener accessListener = ActionListener.wrap(access -> { + if (access) { + innerGetTraces(interactionId, from, maxResults, listener); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null + ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS + : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to interaction " + interactionId); + } + }, e -> { listener.onFailure(e); }); + conversationMetaIndex.checkAccess(conversationId, accessListener); + }, e -> { internalListener.onFailure(e); }); + client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, ActionListener.runBefore(al, () -> threadContext.restore())); + }, e -> { + log.error("Failed to refresh interactions index during get interaction ", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + @VisibleForTesting + void innerGetTraces(String interactionId, int from, int maxResults, ActionListener> listener) { SearchRequest request = Requests.searchRequest(INTERACTIONS_INDEX_NAME); // Build the query BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); @@ -364,7 +407,7 @@ public void getTraces(String interactionId, int from, int maxResults, ActionList } /** - * Gets all of the interactions in a conversation, regardless of conversation size + * Gets all interactions in a conversation, regardless of conversation size * @param conversationId conversation to get all interactions of * @param maxResults how many interactions to get per search query * @param listener receives the list of all interactions in the conversation @@ -509,15 +552,16 @@ public void getInteraction(String interactionId, ActionListener lis ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); ActionListener al = ActionListener.wrap(getResponse -> { - // If the conversation doesn't exist, fail + // If the interaction doesn't exist, fail if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found"); } Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap()); - internalListener.onResponse(interaction); + // checks if the user has permission to access the conversation that the interaction belongs to + checkInteractionPermission(interactionId, interaction, internalListener); }, e -> { internalListener.onFailure(e); }); client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { - client.get(request, al); + client.get(request, ActionListener.runBefore(al, () -> threadContext.restore())); }, e -> { log.error("Failed to refresh interactions index during get interaction ", e); internalListener.onFailure(e); @@ -526,4 +570,87 @@ public void getInteraction(String interactionId, ActionListener lis listener.onFailure(e); } } + + /** + * Update interaction in the index + * @param interactionId the interaction id that needs update + * @param updateRequest original update request + * @param listener receives the update response for the wrapped query + */ + public void updateInteraction(String interactionId, UpdateRequest updateRequest, ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { + listener + .onFailure( + new IndexNotFoundException( + "cannot update interaction since the interaction index does not exist", + INTERACTIONS_INDEX_NAME + ) + ); + return; + } + + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); + ActionListener al = ActionListener.wrap(getResponse -> { + // If the interaction doesn't exist, fail + if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { + throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found"); + } + Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap()); + // checks if the user has permission to access the conversation that the interaction belongs to + String conversationId = interaction.getConversationId(); + ActionListener accessListener = ActionListener.wrap(access -> { + if (access) { + innerUpdateInteraction(updateRequest, internalListener); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null + ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS + : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to interaction " + interactionId); + } + }, e -> { listener.onFailure(e); }); + conversationMetaIndex.checkAccess(conversationId, accessListener); + }, e -> { internalListener.onFailure(e); }); + client.admin().indices().refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, ActionListener.runBefore(al, () -> threadContext.restore())); + }, e -> { + log.error("Failed to refresh interactions index during get interaction ", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private void innerUpdateInteraction(UpdateRequest updateRequest, ActionListener listener) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + client.update(updateRequest, internalListener); + } catch (Exception e) { + log.error("Failed to update Conversation. Details {}:", e); + listener.onFailure(e); + } + } + + private void checkInteractionPermission(String interactionId, Interaction interaction, ActionListener internalListener) { + String conversationId = interaction.getConversationId(); + ActionListener accessListener = ActionListener.wrap(access -> { + if (access) { + internalListener.onResponse(interaction); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to interaction " + interactionId); + } + }, e -> { internalListener.onFailure(e); }); + conversationMetaIndex.checkAccess(conversationId, accessListener); + } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index 64d39991df..1b4aa0c495 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -384,10 +384,23 @@ public ActionFuture searchInteractions(String conversationId, Se return fut; } + /** + * List all traces of an interaction + * @param interactionId id of the parent interaction + * @param from where to start listing from + * @maxResults how many traces to list + * @listener process the response + */ public void getTraces(String interactionId, int from, int maxResults, ActionListener> listener) { interactionsIndex.getTraces(interactionId, from, maxResults, listener); } + /** + * Update conversation in the index + * @param conversationId the conversation id that needs update + * @param updateContent original update content + * @param listener receives the update response for the wrapped query + */ public void updateConversation(String conversationId, Map updateContent, ActionListener listener) { UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId); updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now()); @@ -396,7 +409,23 @@ public void updateConversation(String conversationId, Map update updateRequest.docAsUpsert(true); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - conversationMetaIndex.updateConversation(updateRequest, listener); + conversationMetaIndex.updateConversation(conversationId, updateRequest, listener); + } + + /** + * Update interaction in the index + * @param interactionId the interaction id that needs update + * @param updateContent original update content + * @param listener receives the update response for the wrapped query + */ + public void updateInteraction(String interactionId, Map updateContent, ActionListener listener) { + UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId); + + updateRequest.doc(updateContent); + updateRequest.docAsUpsert(true); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + interactionsIndex.updateInteraction(interactionId, updateRequest, listener); } /** diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java index f8476f9ccb..5fb4067acb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportActionTests.java @@ -6,9 +6,9 @@ package org.opensearch.ml.memory.action.conversation; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_NAME_FIELD; @@ -18,21 +18,26 @@ import java.time.Instant; import java.util.HashMap; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -65,20 +70,30 @@ public class UpdateConversationTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; - ThreadContext threadContext; + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + @Mock + ClusterService clusterService; - private Settings settings; + ThreadContext threadContext; private ShardId shardId; @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - settings = Settings.builder().build(); + this.clusterService = Mockito.mock(ClusterService.class); + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); threadContext = new ThreadContext(settings); + this.threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + String conversationId = "test_conversation_id"; Map updateContent = new HashMap<>(); updateContent.put(META_NAME_FIELD, "new name"); @@ -87,27 +102,36 @@ public void setup() throws IOException { when(updateRequest.getUpdateContent()).thenReturn(updateContent); shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); - - transportUpdateConversationAction = new UpdateConversationTransportAction(transportService, actionFilters, client); + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + transportUpdateConversationAction = new UpdateConversationTransportAction( + transportService, + actionFilters, + client, + cmHandler, + clusterService + ); } public void test_execute_Success() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(updateResponse); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateConversation(any(), any(), any()); transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getResult().equals(DocWriteResponse.Result.UPDATED)); } public void test_execute_UpdateFailure() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Error in Update Request")); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateConversation(any(), any(), any()); transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); @@ -118,17 +142,17 @@ public void test_execute_UpdateFailure() { public void test_execute_UpdateWrongStatus() { UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(updateResponse); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateConversation(any(), any(), any()); transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); verify(actionListener).onResponse(updateResponse); } public void test_execute_ThrowException() { - doThrow(new RuntimeException("Error in Update Request")).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + doThrow(new RuntimeException("Error in Update Request")).when(cmHandler).updateConversation(any(), any(), any()); transportUpdateConversationAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java index 3dbd16ca64..f81dbe0e94 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportActionTests.java @@ -16,21 +16,26 @@ import java.io.IOException; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -62,20 +67,29 @@ public class UpdateInteractionTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; - ThreadContext threadContext; + @Mock + OpenSearchConversationalMemoryHandler cmHandler; - private Settings settings; + @Mock + ClusterService clusterService; + + ThreadContext threadContext; private ShardId shardId; @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - settings = Settings.builder().build(); + this.clusterService = Mockito.mock(ClusterService.class); + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); threadContext = new ThreadContext(settings); + this.threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); String interactionId = "test_interaction_id"; Map updateContent = Map .of(INTERACTIONS_ADDITIONAL_INFO_FIELD, Map.of("feedback", "thumbs up!"), INTERACTIONS_RESPONSE_FIELD, "response"); @@ -84,15 +98,21 @@ public void setup() throws IOException { shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); - updateInteractionTransportAction = new UpdateInteractionTransportAction(transportService, actionFilters, client); + updateInteractionTransportAction = new UpdateInteractionTransportAction( + transportService, + actionFilters, + client, + cmHandler, + clusterService + ); } public void test_execute_Success() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(updateResponse); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateInteraction(any(String.class), any(Map.class), isA(ActionListener.class)); updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); verify(actionListener).onResponse(updateResponse); @@ -100,10 +120,10 @@ public void test_execute_Success() { public void test_execute_UpdateFailure() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Error in Update Request")); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateInteraction(any(String.class), any(Map.class), isA(ActionListener.class)); updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); @@ -114,17 +134,19 @@ public void test_execute_UpdateFailure() { public void test_execute_UpdateWrongStatus() { UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(updateResponse); return null; - }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + }).when(cmHandler).updateInteraction(any(String.class), any(Map.class), isA(ActionListener.class)); updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); verify(actionListener).onResponse(updateResponse); } public void test_execute_ThrowException() { - doThrow(new RuntimeException("Error in Update Request")).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + doThrow(new RuntimeException("Error in Update Request")) + .when(cmHandler) + .updateInteraction(any(String.class), any(Map.class), isA(ActionListener.class)); updateInteractionTransportAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index e3841209c5..650ef8eab8 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -644,7 +644,7 @@ public void testUpdateConversation_NoIndex_ThenFail() { doReturn(false).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); - conversationMetaIndex.updateConversation(new UpdateRequest(), getListener); + conversationMetaIndex.updateConversation("tester_id", new UpdateRequest(), getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor @@ -654,6 +654,16 @@ public void testUpdateConversation_NoIndex_ThenFail() { } public void testUpdateConversation_Success() { + setupRefreshSuccess(); + final String id = "test_id"; + GetResponse dummyGetResponse = mock(GetResponse.class); + doReturn(true).when(dummyGetResponse).isExists(); + doReturn(id).when(dummyGetResponse).getId(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(dummyGetResponse); + return null; + }).when(client).get(any(), any()); doReturn(true).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); @@ -665,18 +675,28 @@ public void testUpdateConversation_Success() { listener.onResponse(updateResponse); return null; }).when(client).update(any(), any()); - conversationMetaIndex.updateConversation(new UpdateRequest(), getListener); + conversationMetaIndex.updateConversation("test_id", new UpdateRequest(), getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(UpdateResponse.class); verify(getListener, times(1)).onResponse(argCaptor.capture()); } public void testUpdateConversation_ClientFails() { + setupRefreshSuccess(); + final String id = "test_id"; + GetResponse dummyGetResponse = mock(GetResponse.class); + doReturn(true).when(dummyGetResponse).isExists(); + doReturn(id).when(dummyGetResponse).getId(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(dummyGetResponse); + return null; + }).when(client).get(any(), any()); doReturn(true).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener getListener = mock(ActionListener.class); doThrow(new RuntimeException("Client Failure")).when(client).update(any(), any()); - conversationMetaIndex.updateConversation(new UpdateRequest(), getListener); + conversationMetaIndex.updateConversation("test_id", new UpdateRequest(), getListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Client Failure")); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 41fdb1af41..3e523bc07c 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -480,7 +480,7 @@ public void testGetTraces() { doReturn(true).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener> getTracesListener = mock(ActionListener.class); - interactionsIndex.getTraces("cid", 0, 10, getTracesListener); + interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener); @SuppressWarnings("unchecked") ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); verify(getTracesListener, times(1)).onResponse(argCaptor.capture()); @@ -491,7 +491,7 @@ public void testGetTraces_clientFail() { doReturn(true).when(metadata).hasIndex(anyString()); doThrow(new RuntimeException("Client Failure")).when(client).search(any(), any()); ActionListener> getTracesListener = mock(ActionListener.class); - interactionsIndex.getTraces("cid", 0, 10, getTracesListener); + interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(getTracesListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Client Failure")); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index 2c1c28c529..216435c9ec 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -179,12 +179,12 @@ public void testGetTraces() { public void testUpdateConversation() { doAnswer(invocation -> { - ActionListener al = invocation.getArgument(1); + ActionListener al = invocation.getArgument(2); ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); al.onResponse(updateResponse); return null; - }).when(conversationMetaIndex).updateConversation(any(), any()); + }).when(conversationMetaIndex).updateConversation(any(), any(), any()); ActionListener updateConversationListener = mock(ActionListener.class); cmHandler.updateConversation("cId", new HashMap<>(), updateConversationListener); From fc476acc0d3335e4eb833b85ecadd954c9ec3dc7 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Thu, 25 Jan 2024 21:37:45 -0800 Subject: [PATCH 2/2] add UT for acess denied cases Signed-off-by: Xun Zhang --- .../UpdateConversationTransportAction.java | 1 - .../index/ConversationMetaIndexTests.java | 15 ++ .../memory/index/InteractionsIndexTests.java | 216 +++++++++++++++--- 3 files changed, 196 insertions(+), 36 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java index 9f256fd075..4a5431e02c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java @@ -45,7 +45,6 @@ public UpdateConversationTransportAction( super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new); this.client = client; this.cmHandler = cmHandler; - System.out.println(clusterService.getSettings()); this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); clusterService .getClusterSettings() diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index 650ef8eab8..cf0d73fedb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -701,4 +701,19 @@ public void testUpdateConversation_ClientFails() { verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Client Failure")); } + + public void testUpdateConversation_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(false); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + + ActionListener updateListener = mock(ActionListener.class); + conversationMetaIndex.updateConversation("conversationId", new UpdateRequest(), updateListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(updateListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [BAD_USER] does not have access to conversation conversationId")); + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 3e523bc07c..2ca66a1abe 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -32,12 +32,14 @@ import java.time.Instant; import java.util.Collections; import java.util.List; +import java.util.Map; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.opensearch.OpenSearchWrapperException; import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.refresh.RefreshResponse; import org.opensearch.action.bulk.BulkResponse; @@ -47,6 +49,8 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; @@ -59,6 +63,8 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -441,46 +447,34 @@ public void testGetTraces_NoIndex_ThenEmpty() { assert (argCaptor.getValue().size() == 0); } - public void testGetTraces() { - doAnswer(invocation -> { - XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); - content.startObject(); - content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now()); - content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs"); - content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id"); - content.endObject(); + public void testInnerGetTraces_success() { + setUpSearchTraceResponse(); + doReturn(true).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener> getTracesListener = mock(ActionListener.class); + interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(getTracesListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 1); + } - SearchHit[] hits = new SearchHit[1]; - hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content)); - SearchHits searchHits = new SearchHits(hits, null, Float.NaN); - SearchResponseSections searchSections = new SearchResponseSections( - searchHits, - InternalAggregations.EMPTY, - null, - false, - false, - null, - 1 - ); - SearchResponse searchResponse = new SearchResponse( - searchSections, - null, - 1, - 1, - 0, - 11, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - ActionListener al = invocation.getArgument(1); - al.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); + public void testGetTraces_success() { + setupGrantAccess(); + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + setUpSearchTraceResponse(); doReturn(true).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener> getTracesListener = mock(ActionListener.class); - interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener); + interactionsIndex.getTraces("iid", 0, 10, getTracesListener); @SuppressWarnings("unchecked") ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); verify(getTracesListener, times(1)).onResponse(argCaptor.capture()); @@ -800,4 +794,156 @@ public void testGetSg_ClientFails_ThenFail() { verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Client Failure in Sg Get")); } + + public void testGetSg_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess("Henry"); + setupRefreshSuccess(); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [Henry] does not have access to interaction iid")); + } + + public void testGetSg_GrantAccess_Success() { + setupGrantAccess(); + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Interaction.class); + verify(getListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getId().equals("iid")); + assert (argCaptor.getValue().getConversationId().equals("conversation test 1")); + } + + public void testGetTraces_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + setupDenyAccess("Xun"); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + ActionListener> getListener = mock(ActionListener.class); + interactionsIndex.getTraces("iid", 0, 10, getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [Xun] does not have access to interaction iid")); + } + + public void testUpdateInteraction_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + setupDenyAccess("Xun"); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + ActionListener updateListener = mock(ActionListener.class); + interactionsIndex.updateInteraction("iid", new UpdateRequest(), updateListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(updateListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [Xun] does not have access to interaction iid")); + } + + public void testUpdateInteraction_Success() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + setupGrantAccess(); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(), any()); + + ActionListener updateListener = mock(ActionListener.class); + interactionsIndex.updateInteraction("iid", new UpdateRequest(), updateListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(updateListener, times(1)).onResponse(argCaptor.capture()); + } + + private GetResponse setUpInteractionResponse(String interactionId) { + @SuppressWarnings("unchecked") + GetResponse response = mock(GetResponse.class); + doReturn(true).when(response).isExists(); + doReturn(interactionId).when(response).getId(); + doReturn( + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, + "conversation test 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer1" + ) + ).when(response).getSourceAsMap(); + return response; + } + + private void setUpSearchTraceResponse() { + doAnswer(invocation -> { + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now()); + content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs"); + content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id"); + content.endObject(); + + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content)); + SearchHits searchHits = new SearchHits(hits, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections( + searchHits, + InternalAggregations.EMPTY, + null, + false, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + ActionListener al = invocation.getArgument(1); + al.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + } }