Skip to content

Commit

Permalink
Use transport action
Browse files Browse the repository at this point in the history
Signed-off-by: Derek Ho <dxho@amazon.com>
  • Loading branch information
derek-ho committed Jan 6, 2025
1 parent 6418226 commit bc8aacf
Show file tree
Hide file tree
Showing 11 changed files with 427 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.security.action.apitokens.ApiTokenAction;
import org.opensearch.security.action.apitokens.ApiTokenIndexListenerCache;
import org.opensearch.security.action.apitokens.ApiTokenUpdateAction;
import org.opensearch.security.action.apitokens.TransportApiTokenUpdateAction;
import org.opensearch.security.action.configupdate.ConfigUpdateAction;
import org.opensearch.security.action.configupdate.TransportConfigUpdateAction;
import org.opensearch.security.action.onbehalf.CreateOnBehalfOfTokenAction;
Expand Down Expand Up @@ -686,6 +688,7 @@ public UnaryOperator<RestHandler> getRestHandlerWrapper(final ThreadContext thre
List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> actions = new ArrayList<>(1);
if (!disabled && !SSLConfig.isSslOnlyMode()) {
actions.add(new ActionHandler<>(ConfigUpdateAction.INSTANCE, TransportConfigUpdateAction.class));
actions.add(new ActionHandler<>(ApiTokenUpdateAction.INSTANCE, TransportApiTokenUpdateAction.class));
// external storage does not support reload and does not provide SSL certs info
if (!ExternalSecurityKeyStore.hasExternalSslContext(settings)) {
actions.add(new ActionHandler<>(CertificatesActionType.INSTANCE, TransportCertificatesInfoNodesAction.class));
Expand Down Expand Up @@ -719,14 +722,6 @@ public void onIndexModule(IndexModule indexModule) {
)
);

// TODO: Is there a higher level approach that makes more sense here? Does this cover unsuccessful index ops?
if (ConfigConstants.OPENSEARCH_API_TOKENS_INDEX.equals(indexModule.getIndex().getName())) {
ApiTokenIndexListenerCache apiTokenIndexListenerCacher = ApiTokenIndexListenerCache.getInstance();
apiTokenIndexListenerCacher.initialize();
indexModule.addIndexOperationListener(apiTokenIndexListenerCacher);
log.warn("Security plugin started listening to operations on index {}", ConfigConstants.OPENSEARCH_API_TOKENS_INDEX);
}

indexModule.forceQueryCacheProvider((indexSettings, nodeCache) -> new QueryCache() {

@Override
Expand Down Expand Up @@ -1105,6 +1100,7 @@ public Collection<Object> createComponents(
adminDns = new AdminDNs(settings);

cr = ConfigurationRepository.create(settings, this.configPath, threadPool, localClient, clusterService, auditLog);
ApiTokenIndexListenerCache.getInstance().initialize(clusterService, localClient);

this.passwordHasher = PasswordHasherFactory.createPasswordHasher(settings);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@
import java.util.stream.Collectors;

import com.google.common.collect.ImmutableList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.security.identity.SecurityTokenManager;
Expand All @@ -48,6 +52,7 @@

public class ApiTokenAction extends BaseRestHandler {
private final ApiTokenRepository apiTokenRepository;
public Logger log = LogManager.getLogger(this.getClass());

private static final List<RestHandler.Route> ROUTES = addRoutesPrefix(
ImmutableList.of(
Expand Down Expand Up @@ -133,20 +138,32 @@ private RestChannelConsumer handlePost(RestRequest request, NodeClient client) {
(Long) requestBody.getOrDefault(EXPIRATION_FIELD, Instant.now().toEpochMilli() + TimeUnit.DAYS.toMillis(30))
);

builder.startObject();
builder.field("Api Token: ", token);
builder.endObject();

response = new BytesRestResponse(RestStatus.OK, builder);
// Then trigger the update action
ApiTokenUpdateRequest updateRequest = new ApiTokenUpdateRequest();
client.execute(ApiTokenUpdateAction.INSTANCE, updateRequest, new ActionListener<ApiTokenUpdateResponse>() {
@Override
public void onResponse(ApiTokenUpdateResponse updateResponse) {
try {
XContentBuilder builder = channel.newBuilder();
builder.startObject();
builder.field("Api Token: ", token);
builder.endObject();

BytesRestResponse response = new BytesRestResponse(RestStatus.OK, builder);
channel.sendResponse(response);
} catch (IOException e) {
sendErrorResponse(channel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to send response after token creation");
}
}

@Override
public void onFailure(Exception e) {
sendErrorResponse(channel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to propagate token creation");
}
});
} catch (final Exception exception) {
builder.startObject()
.field("error", "An unexpected error occurred. Please check the input and try again.")
.field("message", exception.getMessage())
.endObject();
response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, builder);
sendErrorResponse(channel, RestStatus.INTERNAL_SERVER_ERROR, exception.getMessage());
}
builder.close();
channel.sendResponse(response);
};
}

Expand Down Expand Up @@ -239,22 +256,46 @@ private RestChannelConsumer handleDelete(RestRequest request, NodeClient client)
validateRequestParameters(requestBody);
apiTokenRepository.deleteApiToken((String) requestBody.get(NAME_FIELD));

builder.startObject();
builder.field("message", "token " + requestBody.get(NAME_FIELD) + " deleted successfully.");
builder.endObject();

response = new BytesRestResponse(RestStatus.OK, builder);
ApiTokenUpdateRequest updateRequest = new ApiTokenUpdateRequest();
client.execute(ApiTokenUpdateAction.INSTANCE, updateRequest, new ActionListener<ApiTokenUpdateResponse>() {
@Override
public void onResponse(ApiTokenUpdateResponse updateResponse) {
try {
XContentBuilder builder = channel.newBuilder();
builder.startObject();
builder.field("message", "token " + requestBody.get(NAME_FIELD) + " deleted successfully.");
builder.endObject();

BytesRestResponse response = new BytesRestResponse(RestStatus.OK, builder);
channel.sendResponse(response);
} catch (Exception e) {
sendErrorResponse(channel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to send response after token update");
}
}

@Override
public void onFailure(Exception e) {
sendErrorResponse(channel, RestStatus.INTERNAL_SERVER_ERROR, "Failed to propagate token deletion");
}
});
} catch (final ApiTokenException exception) {
builder.startObject().field("error", exception.getMessage()).endObject();
response = new BytesRestResponse(RestStatus.NOT_FOUND, builder);
sendErrorResponse(channel, RestStatus.NOT_FOUND, exception.getMessage());
} catch (final Exception exception) {
builder.startObject().field("error", exception.getMessage()).endObject();
response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, builder);
sendErrorResponse(channel, RestStatus.INTERNAL_SERVER_ERROR, exception.getMessage());
}
builder.close();
channel.sendResponse(response);
};

}

private void sendErrorResponse(RestChannel channel, RestStatus status, String errorMessage) {
try {
XContentBuilder builder = channel.newBuilder();
builder.startObject().field("error", errorMessage).endObject();
BytesRestResponse response = new BytesRestResponse(status, builder);
channel.sendResponse(response);
} catch (Exception e) {
log.error("Failed to send error response", e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,105 +8,137 @@

package org.opensearch.security.action.apitokens;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.shard.IndexingOperationListener;

/**
* This class implements an index operation listener for operations performed on api tokens
* These indices are defined on bootstrap and configured to listen in OpenSearchSecurityPlugin.java
*/
public class ApiTokenIndexListenerCache implements IndexingOperationListener {
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterChangedEvent;
import org.opensearch.cluster.ClusterStateListener;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.security.support.ConfigConstants;

private final static Logger log = LogManager.getLogger(ApiTokenIndexListenerCache.class);
public class ApiTokenIndexListenerCache implements ClusterStateListener {

private static final Logger log = LogManager.getLogger(ApiTokenIndexListenerCache.class);
private static final ApiTokenIndexListenerCache INSTANCE = new ApiTokenIndexListenerCache();
private final ConcurrentHashMap<String, String> idToJtiMap = new ConcurrentHashMap<>();

private Map<String, Permissions> jtis = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, String> idToJtiMap = new ConcurrentHashMap<>();
private final Map<String, Permissions> jtis = new ConcurrentHashMap<>();

private boolean initialized;
private final AtomicBoolean initialized = new AtomicBoolean(false);
private ClusterService clusterService;
private Client client;

private ApiTokenIndexListenerCache() {}

public static ApiTokenIndexListenerCache getInstance() {
return ApiTokenIndexListenerCache.INSTANCE;
return INSTANCE;
}

public void initialize(ClusterService clusterService, Client client) {
if (initialized.compareAndSet(false, true)) {
this.clusterService = clusterService;
this.client = client;

// Register as cluster state listener
this.clusterService.addListener(this);
}
}

/**
* Initializes the ApiTokenIndexListenerCache.
* This method is called during the plugin's initialization process.
*
*/
public void initialize() {
@Override
public void clusterChanged(ClusterChangedEvent event) {
// Reload cache if the security index has changed
IndexMetadata securityIndex = event.state().metadata().index(getSecurityIndexName());
if (securityIndex != null) {
reloadApiTokensFromIndex();
}
}

if (initialized) {
void reloadApiTokensFromIndex() {
if (!initialized.get()) {
log.debug("Cache not yet initialized or client is null, skipping reload");
return;
}

initialized = true;
if (clusterService.state() != null && clusterService.state().blocks().hasGlobalBlockWithStatus(RestStatus.SERVICE_UNAVAILABLE)) {
log.debug("Cluster not yet ready, skipping API tokens cache reload");
return;
}

try {
// Clear existing caches
log.info("Reloading API tokens cache from index: {}", jtis.entrySet().toString());

idToJtiMap.clear();
jtis.clear();

// Search request to get all API tokens from the security index
client.prepareSearch(getSecurityIndexName())
.setQuery(QueryBuilders.matchAllQuery())
.execute()
.actionGet()
.getHits()
.forEach(hit -> {
// Parse the document and update the cache
Map<String, Object> source = hit.getSourceAsMap();
String id = hit.getId();
String jti = (String) source.get("jti");
Permissions permissions = parsePermissions(source);

idToJtiMap.put(id, jti);
jtis.put(jti, permissions);
});

log.debug("Successfully reloaded API tokens cache");
} catch (Exception e) {
log.error("Failed to reload API tokens cache", e);
}
}

public boolean isInitialized() {
return initialized;
private String getSecurityIndexName() {
// Return the name of your security index
return ConfigConstants.OPENSEARCH_API_TOKENS_INDEX;
}

/**
* This method is called after an index operation is performed.
* It adds the JTI of the indexed document to the cache and maps the document ID to the JTI (for deletion handling).
* @param shardId The shard ID of the index where the operation was performed.
* @param index The index where the operation was performed.
* @param result The result of the index operation.
*/
@Override
public void postIndex(ShardId shardId, Engine.Index index, Engine.IndexResult result) {
BytesReference sourceRef = index.source();

try {
XContentParser parser = XContentType.JSON.xContent()
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, sourceRef.streamInput());
@SuppressWarnings("unchecked")
private Permissions parsePermissions(Map<String, Object> source) {
// Implement parsing logic for permissions from the document
return new Permissions(
(List<String>) source.get(ApiToken.CLUSTER_PERMISSIONS_FIELD),
(List<ApiToken.IndexPermission>) source.get(ApiToken.INDEX_PERMISSIONS_FIELD)
);
}

ApiToken token = ApiToken.fromXContent(parser);
jtis.put(token.getJti(), new Permissions(token.getClusterPermissions(), token.getIndexPermissions()));
idToJtiMap.put(index.id(), token.getJti());
// Getter methods for cached data
public String getJtiForId(String id) {
return idToJtiMap.get(id);
}

} catch (IOException e) {
log.error("Failed to parse indexed document", e);
}
public Permissions getPermissionsForJti(String jti) {
return jtis.get(jti);
}

/**
* This method is called after a delete operation is performed.
* It deletes the corresponding document id in the map and the corresponding jti from the cache.
* @param shardId The shard ID of the index where the delete operation was performed.
* @param delete The delete operation that was performed.
* @param result The result of the delete operation.
*/
@Override
public void postDelete(ShardId shardId, Engine.Delete delete, Engine.DeleteResult result) {
String docId = delete.id();
String jti = idToJtiMap.remove(docId);
if (jti != null) {
jtis.remove(jti);
log.debug("Removed token with ID {} and JTI {} from cache", docId, jti);
}
// Method to check if a token is valid
public boolean isValidToken(String jti) {
return jtis.containsKey(jti);
}

public Map<String, Permissions> getJtis() {
return jtis;
}

// Cleanup method
public void close() {
if (clusterService != null) {
clusterService.removeListener(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.security.action.apitokens;

import org.opensearch.action.ActionType;

public class ApiTokenUpdateAction extends ActionType<ApiTokenUpdateResponse> {

public static final ApiTokenUpdateAction INSTANCE = new ApiTokenUpdateAction();
public static final String NAME = "cluster:admin/opendistro_security/apitoken/update";

protected ApiTokenUpdateAction() {
super(NAME, ApiTokenUpdateResponse::new);
}
}
Loading

0 comments on commit bc8aacf

Please sign in to comment.