Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more user based permission check in Memory #1927

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,20 @@ public void createInteraction(

/**
* Update a conversation
* @param conversationId the conversation id to update
* @param updateContent update content for the conversations index
* @param listener receives the update response
*/
public void updateConversation(String conversationId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);

/**
* Update an interaction
* @param interactionId the interaction id to update
* @param updateContent update content for the interaction index
* @param listener receives the update response
*/
public void updateInteraction(String interactionId, Map<String, Object> updateContent, ActionListener<UpdateResponse> listener);

/**
* Get a single ConversationMeta object
* @param conversationId id of the conversation to get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@
import java.time.Instant;
import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -28,29 +30,50 @@
@Log4j2
public class UpdateConversationTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

@Inject
public UpdateConversationTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public UpdateConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
OpenSearchConversationalMemoryHandler cmHandler,
ClusterService clusterService
) {
super(UpdateConversationAction.NAME, transportService, actionFilters, UpdateConversationRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
String conversationId = updateConversationRequest.getConversationId();
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.META_INDEX_NAME, conversationId);
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());
if (!featureIsEnabled) {
listener
.onFailure(

Check warning on line 58 in memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java

View check run for this annotation

Codecov / codecov/patch

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java#L57-L58

Added lines #L57 - L58 were not covered by tests
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()

Check warning on line 61 in memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java

View check run for this annotation

Codecov / codecov/patch

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java#L61

Added line #L61 was not covered by tests
)
);
return;

Check warning on line 64 in memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java

View check run for this annotation

Codecov / codecov/patch

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateConversationTransportAction.java#L64

Added line #L64 was not covered by tests
} else {
UpdateConversationRequest updateConversationRequest = UpdateConversationRequest.fromActionRequest(request);
String conversationId = updateConversationRequest.getConversationId();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Map<String, Object> updateContent = updateConversationRequest.getUpdateContent();
updateContent.putIfAbsent(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, Instant.now());

updateRequest.doc(updateContent);
updateRequest.docAsUpsert(true);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, getUpdateResponseListener(conversationId, listener, context));
} catch (Exception e) {
log.error("Failed to update Conversation for conversation id" + conversationId, e);
listener.onFailure(e);
cmHandler.updateConversation(conversationId, updateContent, getUpdateResponseListener(conversationId, listener, context));
} catch (Exception e) {
log.error("Failed to update Conversation " + conversationId, e);
listener.onFailure(e);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@

package org.opensearch.ml.memory.action.conversation;

import java.util.Map;

import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -25,26 +29,47 @@
@Log4j2
public class UpdateInteractionTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> {
Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

@Inject
public UpdateInteractionTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public UpdateInteractionTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
OpenSearchConversationalMemoryHandler cmHandler,
ClusterService clusterService
) {
super(UpdateInteractionAction.NAME, transportService, actionFilters, UpdateInteractionRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> listener) {
if (!featureIsEnabled) {
listener
.onFailure(

Check warning on line 57 in memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java

View check run for this annotation

Codecov / codecov/patch

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java#L56-L57

Added lines #L56 - L57 were not covered by tests
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()

Check warning on line 60 in memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java

View check run for this annotation

Codecov / codecov/patch

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java#L60

Added line #L60 was not covered by tests
)
);
Zhangxunmt marked this conversation as resolved.
Show resolved Hide resolved
return;

Check warning on line 63 in memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java

View check run for this annotation

Codecov / codecov/patch

memory/src/main/java/org/opensearch/ml/memory/action/conversation/UpdateInteractionTransportAction.java#L63

Added line #L63 was not covered by tests
}
UpdateInteractionRequest updateInteractionRequest = UpdateInteractionRequest.fromActionRequest(request);
String interactionId = updateInteractionRequest.getInteractionId();
UpdateRequest updateRequest = new UpdateRequest(ConversationalIndexConstants.INTERACTIONS_INDEX_NAME, interactionId);
updateRequest.doc(updateInteractionRequest.getUpdateContent());
updateRequest.docAsUpsert(true);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
Map<String, Object> updateContent = updateInteractionRequest.getUpdateContent();

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, getUpdateResponseListener(interactionId, listener, context));
cmHandler.updateInteraction(interactionId, updateContent, getUpdateResponseListener(interactionId, listener, context));
} catch (Exception e) {
log.error("Failed to update Interaction for interaction id " + interactionId, e);
log.error("Failed to update Interaction " + interactionId, e);
listener.onFailure(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,18 +360,35 @@ public void searchConversations(SearchRequest request, ActionListener<SearchResp
}

/**
* Update conversations in the index
* Update conversation in the index
* @param conversationId the conversation id that needs update
* @param updateRequest original update request
* @param listener receives the update response for the wrapped query
*/
public void updateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
public void updateConversation(String conversationId, UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener
.onFailure(
new IndexNotFoundException("cannot update conversation since the conversation index does not exist", META_INDEX_NAME)
);
return;
}

this.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
innerUpdateConversation(updateRequest, listener);
} else {
String userstr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
}
}, e -> { listener.onFailure(e); }));
}

private void innerUpdateConversation(UpdateRequest updateRequest, ActionListener<UpdateResponse> listener) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<UpdateResponse> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
client.update(updateRequest, internalListener);
Expand Down
Loading
Loading