From fc476acc0d3335e4eb833b85ecadd954c9ec3dc7 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Thu, 25 Jan 2024 21:37:45 -0800 Subject: [PATCH] add UT for acess denied cases Signed-off-by: Xun Zhang --- .../UpdateConversationTransportAction.java | 1 - .../index/ConversationMetaIndexTests.java | 15 ++ .../memory/index/InteractionsIndexTests.java | 216 +++++++++++++++--- 3 files changed, 196 insertions(+), 36 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java index 9f256fd075..4a5431e02c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java @@ -45,7 +45,6 @@ public UpdateConversationTransportAction( super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new); this.client = client; this.cmHandler = cmHandler; - System.out.println(clusterService.getSettings()); this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); clusterService .getClusterSettings() diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index 650ef8eab8..cf0d73fedb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -701,4 +701,19 @@ public void testUpdateConversation_ClientFails() { verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Client Failure")); } + + public void testUpdateConversation_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(false); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + + ActionListener updateListener = mock(ActionListener.class); + conversationMetaIndex.updateConversation("conversationId", new UpdateRequest(), updateListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(updateListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [BAD_USER] does not have access to conversation conversationId")); + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 3e523bc07c..2ca66a1abe 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -32,12 +32,14 @@ import java.time.Instant; import java.util.Collections; import java.util.List; +import java.util.Map; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.opensearch.OpenSearchWrapperException; import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.refresh.RefreshResponse; import org.opensearch.action.bulk.BulkResponse; @@ -47,6 +49,8 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; @@ -59,6 +63,8 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -441,46 +447,34 @@ public void testGetTraces_NoIndex_ThenEmpty() { assert (argCaptor.getValue().size() == 0); } - public void testGetTraces() { - doAnswer(invocation -> { - XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); - content.startObject(); - content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now()); - content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs"); - content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id"); - content.endObject(); + public void testInnerGetTraces_success() { + setUpSearchTraceResponse(); + doReturn(true).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener> getTracesListener = mock(ActionListener.class); + interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(getTracesListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 1); + } - SearchHit[] hits = new SearchHit[1]; - hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content)); - SearchHits searchHits = new SearchHits(hits, null, Float.NaN); - SearchResponseSections searchSections = new SearchResponseSections( - searchHits, - InternalAggregations.EMPTY, - null, - false, - false, - null, - 1 - ); - SearchResponse searchResponse = new SearchResponse( - searchSections, - null, - 1, - 1, - 0, - 11, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - ActionListener al = invocation.getArgument(1); - al.onResponse(searchResponse); - return null; - }).when(client).search(any(), any()); + public void testGetTraces_success() { + setupGrantAccess(); + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + setUpSearchTraceResponse(); doReturn(true).when(metadata).hasIndex(anyString()); @SuppressWarnings("unchecked") ActionListener> getTracesListener = mock(ActionListener.class); - interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener); + interactionsIndex.getTraces("iid", 0, 10, getTracesListener); @SuppressWarnings("unchecked") ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); verify(getTracesListener, times(1)).onResponse(argCaptor.capture()); @@ -800,4 +794,156 @@ public void testGetSg_ClientFails_ThenFail() { verify(getListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Client Failure in Sg Get")); } + + public void testGetSg_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess("Henry"); + setupRefreshSuccess(); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [Henry] does not have access to interaction iid")); + } + + public void testGetSg_GrantAccess_Success() { + setupGrantAccess(); + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Interaction.class); + verify(getListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getId().equals("iid")); + assert (argCaptor.getValue().getConversationId().equals("conversation test 1")); + } + + public void testGetTraces_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + setupDenyAccess("Xun"); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + ActionListener> getListener = mock(ActionListener.class); + interactionsIndex.getTraces("iid", 0, 10, getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [Xun] does not have access to interaction iid")); + } + + public void testUpdateInteraction_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + setupDenyAccess("Xun"); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + ActionListener updateListener = mock(ActionListener.class); + interactionsIndex.updateInteraction("iid", new UpdateRequest(), updateListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(updateListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [Xun] does not have access to interaction iid")); + } + + public void testUpdateInteraction_Success() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + setupGrantAccess(); + GetResponse response = setUpInteractionResponse("iid"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(), any()); + + ActionListener updateListener = mock(ActionListener.class); + interactionsIndex.updateInteraction("iid", new UpdateRequest(), updateListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(UpdateResponse.class); + verify(updateListener, times(1)).onResponse(argCaptor.capture()); + } + + private GetResponse setUpInteractionResponse(String interactionId) { + @SuppressWarnings("unchecked") + GetResponse response = mock(GetResponse.class); + doReturn(true).when(response).isExists(); + doReturn(interactionId).when(response).getId(); + doReturn( + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, + "conversation test 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer1" + ) + ).when(response).getSourceAsMap(); + return response; + } + + private void setUpSearchTraceResponse() { + doAnswer(invocation -> { + XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent()); + content.startObject(); + content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now()); + content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs"); + content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id"); + content.endObject(); + + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content)); + SearchHits searchHits = new SearchHits(hits, null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections( + searchHits, + InternalAggregations.EMPTY, + null, + false, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + ActionListener al = invocation.getArgument(1); + al.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + } }