Skip to content

Commit

Permalink
add more user based permission check in Memory
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt committed Jan 26, 2024
1 parent 849fecf commit cbe2ca5
Show file tree
Hide file tree
Showing 11 changed files with 364 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,20 @@ public void createInteraction(

/**
* Update a conversation
* @param conversationId the conversation id to update
* @param updateContent update content for the conversations index
* @param listener receives the update response
*/
public void updateConversation(String conversationId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);

/**
* Update an interaction
* @param interactionId the interaction id to update
* @param updateContent update content for the interaction index
* @param listener receives the update response
*/
public void updateInteraction(String interactionId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);

/**
* Get a single ConversationMeta object
* @param conversationId id of the conversation to get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@
import java.time.Instant;
import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -28,29 +30,51 @@
@Log4j2
public class UpdateConversationTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

@Inject
public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public UpdateConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
OpenSearchConversationalMemoryHandler cmHandler,
ClusterService clusterService
) {
super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new);
this.client = client;
this.cmHandler = cmHandler;
System.out.println(clusterService.getSettings());
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
String conversationId = updateConversationRequest.getConversationId();
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId);
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());
if (!featureIsEnabled) {
listener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
} else {
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
String conversationId = updateConversationRequest.getConversationId();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());

updateRequest.doc(updateContent);
updateRequest.docAsUpsert(true);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context));
} catch (Exception e) {
log.error("Failed to update Conversation for conversation id" + conversationId, e);
listener.onFailure(e);
cmHandler.updateConversation(conversationId, updateContent, getUpdateResponseListener(conversationId, listener, context));
} catch (Exception e) {
log.error("Failed to update Conversation " + conversationId, e);
listener.onFailure(e);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@

package org.opensearch.ml.memory.action.conversation;

import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -25,26 +29,47 @@
@Log4j2
public class UpdateInteractionTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

@Inject
public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public UpdateInteractionTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
OpenSearchConversationalMemoryHandler cmHandler,
ClusterService clusterService
) {
super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
if (!featureIsEnabled) {
listener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
}
UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request);
String interactionId = updateInteractionRequest.getInteractionId();
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId);
updateRequest.doc(updateInteractionRequest.getUpdateContent());
updateRequest.docAsUpsert(true);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Map<String, Object> updateContent = updateInteractionRequest.getUpdateContent();

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context));
cmHandler.updateInteraction(interactionId, updateContent, getUpdateResponseListener(interactionId, listener, context));
} catch (Exception e) {
log.error("Failed to update Interaction for interaction id " + interactionId, e);
log.error("Failed to update Interaction " + interactionId, e);
listener.onFailure(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,18 +360,35 @@ public void searchConversations(SearchRequest request, ActionListener<SearchResp
}

/**
* Update conversations in the index
* Update conversation in the index
* @param conversationId the conversation id that needs update
* @param updateRequest original update request
* @param listener receives the update response for the wrapped query
*/
public void updateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
public void updateConversation(String conversationId, UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener
.onFailure(
new IndexNotFoundException("cannot update conversation since the conversation index does not exist", META_INDEX_NAME)
);
return;
}

this.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
innerUpdateConversation(updateRequest, listener);
} else {
String userstr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
}
}, e -> { listener.onFailure(e); }));
}

private void innerUpdateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<UpdateResponse> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
client.update(updateRequest, internalListener);
Expand Down
Loading

0 comments on commit cbe2ca5

Please sign in to comment.