Skip to content

Commit

Permalink
add UT for acess denied cases
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt committed Jan 26, 2024
1 parent cbe2ca5 commit fc476ac
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ public UpdateConversationTransportAction(
super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new);
this.client = client;
this.cmHandler = cmHandler;
System.out.println(clusterService.getSettings());
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,4 +701,19 @@ public void testUpdateConversation_ClientFails() {
verify(getListener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().equals("Client Failure"));
}

public void testUpdateConversation_NoAccess_ThenFail() {
doReturn(true).when(metadata).hasIndex(anyString());
doAnswer(invocation -> {
ActionListener<Boolean> al = invocation.getArgument(1);
al.onResponse(false);
return null;
}).when(conversationMetaIndex).checkAccess(anyString(), any());

ActionListener<UpdateResponse> updateListener = mock(ActionListener.class);
conversationMetaIndex.updateConversation("conversationId", new UpdateRequest(), updateListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(updateListener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().equals("User [BAD_USER] does not have access to conversation conversationId"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.opensearch.OpenSearchWrapperException;
import org.opensearch.ResourceAlreadyExistsException;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.admin.indices.refresh.RefreshResponse;
import org.opensearch.action.bulk.BulkResponse;
Expand All @@ -47,6 +49,8 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.client.IndicesAdminClient;
Expand All @@ -59,6 +63,8 @@
import org.opensearch.commons.ConfigConstants;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.MatchAllQueryBuilder;
Expand Down Expand Up @@ -441,46 +447,34 @@ public void testGetTraces_NoIndex_ThenEmpty() {
assert (argCaptor.getValue().size() == 0);
}

public void testGetTraces() {
doAnswer(invocation -> {
XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent());
content.startObject();
content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now());
content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs");
content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id");
content.endObject();
public void testInnerGetTraces_success() {
setUpSearchTraceResponse();
doReturn(true).when(metadata).hasIndex(anyString());
@SuppressWarnings("unchecked")
ActionListener<List<Interaction>> getTracesListener = mock(ActionListener.class);
interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener);
@SuppressWarnings("unchecked")
ArgumentCaptor<List<Interaction>> argCaptor = ArgumentCaptor.forClass(List.class);
verify(getTracesListener, times(1)).onResponse(argCaptor.capture());
assert (argCaptor.getValue().size() == 1);
}

SearchHit[] hits = new SearchHit[1];
hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content));
SearchHits searchHits = new SearchHits(hits, null, Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(
searchHits,
InternalAggregations.EMPTY,
null,
false,
false,
null,
1
);
SearchResponse searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);
ActionListener<SearchResponse> al = invocation.getArgument(1);
al.onResponse(searchResponse);
return null;
}).when(client).search(any(), any());
public void testGetTraces_success() {
setupGrantAccess();
doReturn(true).when(metadata).hasIndex(anyString());
setupRefreshSuccess();

GetResponse response = setUpInteractionResponse("iid");
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(response);
return null;
}).when(client).get(any(), any());
setUpSearchTraceResponse();
doReturn(true).when(metadata).hasIndex(anyString());
@SuppressWarnings("unchecked")
ActionListener<List<Interaction>> getTracesListener = mock(ActionListener.class);
interactionsIndex.innerGetTraces("cid", 0, 10, getTracesListener);
interactionsIndex.getTraces("iid", 0, 10, getTracesListener);
@SuppressWarnings("unchecked")
ArgumentCaptor<List<Interaction>> argCaptor = ArgumentCaptor.forClass(List.class);
verify(getTracesListener, times(1)).onResponse(argCaptor.capture());
Expand Down Expand Up @@ -800,4 +794,156 @@ public void testGetSg_ClientFails_ThenFail() {
verify(getListener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().equals("Client Failure in Sg Get"));
}

public void testGetSg_NoAccess_ThenFail() {
doReturn(true).when(metadata).hasIndex(anyString());
setupDenyAccess("Henry");
setupRefreshSuccess();
GetResponse response = setUpInteractionResponse("iid");
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(response);
return null;
}).when(client).get(any(), any());
ActionListener<Interaction> getListener = mock(ActionListener.class);
interactionsIndex.getInteraction("iid", getListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(getListener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().equals("User [Henry] does not have access to interaction iid"));
}

public void testGetSg_GrantAccess_Success() {
setupGrantAccess();
doReturn(true).when(metadata).hasIndex(anyString());
setupRefreshSuccess();
GetResponse response = setUpInteractionResponse("iid");
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(response);
return null;
}).when(client).get(any(), any());
ActionListener<Interaction> getListener = mock(ActionListener.class);
interactionsIndex.getInteraction("iid", getListener);
ArgumentCaptor<Interaction> argCaptor = ArgumentCaptor.forClass(Interaction.class);
verify(getListener, times(1)).onResponse(argCaptor.capture());
assert (argCaptor.getValue().getId().equals("iid"));
assert (argCaptor.getValue().getConversationId().equals("conversation test 1"));
}

public void testGetTraces_NoAccess_ThenFail() {
doReturn(true).when(metadata).hasIndex(anyString());
setupRefreshSuccess();
setupDenyAccess("Xun");
GetResponse response = setUpInteractionResponse("iid");
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(response);
return null;
}).when(client).get(any(), any());

ActionListener<List<Interaction>> getListener = mock(ActionListener.class);
interactionsIndex.getTraces("iid", 0, 10, getListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(getListener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().equals("User [Xun] does not have access to interaction iid"));
}

public void testUpdateInteraction_NoAccess_ThenFail() {
doReturn(true).when(metadata).hasIndex(anyString());
setupRefreshSuccess();
setupDenyAccess("Xun");
GetResponse response = setUpInteractionResponse("iid");
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(response);
return null;
}).when(client).get(any(), any());

ActionListener<UpdateResponse> updateListener = mock(ActionListener.class);
interactionsIndex.updateInteraction("iid", new UpdateRequest(), updateListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(updateListener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().equals("User [Xun] does not have access to interaction iid"));
}

public void testUpdateInteraction_Success() {
doReturn(true).when(metadata).hasIndex(anyString());
setupRefreshSuccess();
setupGrantAccess();
GetResponse response = setUpInteractionResponse("iid");
doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(response);
return null;
}).when(client).get(any(), any());

doAnswer(invocation -> {
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED);
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(), any());

ActionListener<UpdateResponse> updateListener = mock(ActionListener.class);
interactionsIndex.updateInteraction("iid", new UpdateRequest(), updateListener);
ArgumentCaptor<UpdateResponse> argCaptor = ArgumentCaptor.forClass(UpdateResponse.class);
verify(updateListener, times(1)).onResponse(argCaptor.capture());
}

private GetResponse setUpInteractionResponse(String interactionId) {
@SuppressWarnings("unchecked")
GetResponse response = mock(GetResponse.class);
doReturn(true).when(response).isExists();
doReturn(interactionId).when(response).getId();
doReturn(
Map
.of(
ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD,
Instant.now().toString(),
ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD,
"conversation test 1",
ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD,
"answer1"
)
).when(response).getSourceAsMap();
return response;
}

private void setUpSearchTraceResponse() {
doAnswer(invocation -> {
XContentBuilder content = XContentBuilder.builder(XContentType.JSON.xContent());
content.startObject();
content.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now());
content.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "sample inputs");
content.field(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "conversation-id");
content.endObject();

SearchHit[] hits = new SearchHit[1];
hits[0] = new SearchHit(0, "iId", null, null).sourceRef(BytesReference.bytes(content));
SearchHits searchHits = new SearchHits(hits, null, Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(
searchHits,
InternalAggregations.EMPTY,
null,
false,
false,
null,
1
);
SearchResponse searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);
ActionListener<SearchResponse> al = invocation.getArgument(1);
al.onResponse(searchResponse);
return null;
}).when(client).search(any(), any());
}
}

0 comments on commit fc476ac

Please sign in to comment.