Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(remote STS): lazy fetch the secret from the vault before request #4413

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
import org.eclipse.edc.spi.system.ServiceExtension;
import org.eclipse.edc.spi.system.ServiceExtensionContext;

import java.util.Objects;

import static java.lang.String.format;

/**
* Configuration Extension for the STS OAuth2 client
*/
Expand Down Expand Up @@ -58,10 +54,8 @@ public StsRemoteClientConfiguration clientConfiguration(ServiceExtensionContext
var tokenUrl = context.getConfig().getString(TOKEN_URL);
var clientId = context.getConfig().getString(CLIENT_ID);
var clientSecretAlias = context.getConfig().getString(CLIENT_SECRET_ALIAS);
var clientSecret = vault.resolveSecret(clientSecretAlias);
Objects.requireNonNull(clientSecret, format("Client secret could not be retrieved from the vault with alias %s", clientSecretAlias));

return new StsRemoteClientConfiguration(tokenUrl, clientId, clientSecret);
return new StsRemoteClientConfiguration(tokenUrl, clientId, clientSecretAlias);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.eclipse.edc.runtime.metamodel.annotation.Extension;
import org.eclipse.edc.runtime.metamodel.annotation.Inject;
import org.eclipse.edc.runtime.metamodel.annotation.Provider;
import org.eclipse.edc.spi.security.Vault;
import org.eclipse.edc.spi.system.ServiceExtension;

/**
Expand All @@ -37,13 +38,16 @@ public class StsRemoteClientExtension implements ServiceExtension {
@Inject
private Oauth2Client oauth2Client;

@Inject
private Vault vault;

@Override
public String name() {
return NAME;
}

@Provider
public SecureTokenService secureTokenService() {
return new RemoteSecureTokenService(oauth2Client, clientConfiguration);
return new RemoteSecureTokenService(oauth2Client, clientConfiguration, vault);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.eclipse.edc.iam.identitytrust.sts.remote.client.StsRemoteClientConfigurationExtension.CLIENT_ID;
import static org.eclipse.edc.iam.identitytrust.sts.remote.client.StsRemoteClientConfigurationExtension.CLIENT_SECRET_ALIAS;
import static org.eclipse.edc.iam.identitytrust.sts.remote.client.StsRemoteClientConfigurationExtension.TOKEN_URL;
Expand All @@ -41,14 +40,12 @@ void setup(ServiceExtensionContext context) {
}

@Test
void initialize(StsRemoteClientConfigurationExtension extension, ServiceExtensionContext context, Vault vault) {
void initialize(StsRemoteClientConfigurationExtension extension, ServiceExtensionContext context) {

var tokenUrl = "http://tokenUrl";
var clientId = "clientId";
var secretAlias = "secretAlias";

when(vault.resolveSecret(secretAlias)).thenReturn(secretAlias);

var configMap = Map.of(TOKEN_URL, tokenUrl, CLIENT_ID, clientId, CLIENT_SECRET_ALIAS, secretAlias);
var config = ConfigFactory.fromMap(configMap);

Expand All @@ -59,22 +56,8 @@ void initialize(StsRemoteClientConfigurationExtension extension, ServiceExtensio
.satisfies(configuration -> {
assertThat(configuration.tokenUrl()).isEqualTo(tokenUrl);
assertThat(configuration.clientId()).isEqualTo(clientId);
assertThat(configuration.clientSecret()).isEqualTo(secretAlias);
assertThat(configuration.clientSecretAlias()).isEqualTo(secretAlias);
});
}

@Test
void initialize_fail_withVaultSecretResolutionError(StsRemoteClientConfigurationExtension extension, ServiceExtensionContext context, Vault vault) {

var tokenUrl = "http://tokenUrl";
var clientId = "clientId";
var secretAlias = "secretAlias";

var configMap = Map.of(TOKEN_URL, tokenUrl, CLIENT_ID, clientId, CLIENT_SECRET_ALIAS, secretAlias);
var config = ConfigFactory.fromMap(configMap);

when(context.getConfig()).thenReturn(config);

assertThatThrownBy(() -> extension.clientConfiguration(context)).isInstanceOf(NullPointerException.class);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.eclipse.edc.iam.oauth2.spi.client.SharedSecretOauth2CredentialsRequest;
import org.eclipse.edc.spi.iam.TokenRepresentation;
import org.eclipse.edc.spi.result.Result;
import org.eclipse.edc.spi.security.Vault;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

Expand All @@ -41,35 +42,46 @@ public class RemoteSecureTokenService implements SecureTokenService {

private final Oauth2Client oauth2Client;
private final StsRemoteClientConfiguration configuration;
private final Vault vault;

public RemoteSecureTokenService(Oauth2Client oauth2Client, StsRemoteClientConfiguration configuration) {
public RemoteSecureTokenService(Oauth2Client oauth2Client, StsRemoteClientConfiguration configuration, Vault vault) {
this.oauth2Client = oauth2Client;
this.configuration = configuration;
this.vault = vault;
}

@Override
public Result<TokenRepresentation> createToken(Map<String, String> claims, @Nullable String bearerAccessScope) {
return oauth2Client.requestToken(createRequest(claims, bearerAccessScope));
return createRequest(claims, bearerAccessScope)
.compose(oauth2Client::requestToken);
}

@NotNull
private Oauth2CredentialsRequest createRequest(Map<String, String> claims, @Nullable String bearerAccessScope) {
var builder = SharedSecretOauth2CredentialsRequest.Builder.newInstance()
.url(configuration.tokenUrl())
.clientId(configuration.clientId())
.clientSecret(configuration.clientSecret())
.grantType(GRANT_TYPE);

var additionalParams = claims.entrySet().stream()
.filter(entry -> CLAIM_MAPPING.containsKey(entry.getKey()))
.map(entry -> Map.entry(CLAIM_MAPPING.get(entry.getKey()), entry.getValue()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

if (bearerAccessScope != null) {
additionalParams.put(BEARER_ACCESS_SCOPE, bearerAccessScope);
private Result<Oauth2CredentialsRequest> createRequest(Map<String, String> claims, @Nullable String bearerAccessScope) {

var secret = vault.resolveSecret(configuration.clientSecretAlias());
if (secret != null) {
var builder = SharedSecretOauth2CredentialsRequest.Builder.newInstance()
.url(configuration.tokenUrl())
.clientId(configuration.clientId())
.clientSecret(secret)
.grantType(GRANT_TYPE);

var additionalParams = claims.entrySet().stream()
.filter(entry -> CLAIM_MAPPING.containsKey(entry.getKey()))
.map(entry -> Map.entry(CLAIM_MAPPING.get(entry.getKey()), entry.getValue()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

if (bearerAccessScope != null) {
additionalParams.put(BEARER_ACCESS_SCOPE, bearerAccessScope);
}

builder.params(additionalParams);
return Result.success(builder.build());
} else {
return Result.failure("Failed to fetch client secret from the vault with alias: %s".formatted(configuration.clientSecretAlias()));
}

builder.params(additionalParams);
return builder.build();

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
/**
* Configuration of the OAuth2 client
*/
public record StsRemoteClientConfiguration(String tokenUrl, String clientId, String clientSecret) {
public record StsRemoteClientConfiguration(String tokenUrl, String clientId, String clientSecretAlias) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import org.eclipse.edc.iam.oauth2.spi.client.SharedSecretOauth2CredentialsRequest;
import org.eclipse.edc.spi.iam.TokenRepresentation;
import org.eclipse.edc.spi.result.Result;
import org.eclipse.edc.spi.security.Vault;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;

import java.util.Map;

import static java.lang.String.format;
import static org.assertj.core.api.Assertions.assertThat;
import static org.eclipse.edc.iam.identitytrust.spi.SelfIssuedTokenConstants.BEARER_ACCESS_SCOPE;
import static org.eclipse.edc.iam.identitytrust.spi.SelfIssuedTokenConstants.PRESENTATION_TOKEN_CLAIM;
Expand All @@ -38,19 +40,23 @@

public class RemoteSecureTokenServiceTest {

private final StsRemoteClientConfiguration configuration = new StsRemoteClientConfiguration("id", "secret", "url");
private final StsRemoteClientConfiguration configuration = new StsRemoteClientConfiguration("url", "id", "secretAlias");
private final Oauth2Client oauth2Client = mock();
private final Vault vault = mock();
private RemoteSecureTokenService secureTokenService;

@BeforeEach
void setup() {
secureTokenService = new RemoteSecureTokenService(oauth2Client, configuration);
secureTokenService = new RemoteSecureTokenService(oauth2Client, configuration, vault);
}

@Test
void createToken() {
var audience = "aud";
var secret = "secret";
when(oauth2Client.requestToken(any())).thenReturn(Result.success(TokenRepresentation.Builder.newInstance().build()));
when(vault.resolveSecret(configuration.clientSecretAlias())).thenReturn(secret);

assertThat(secureTokenService.createToken(Map.of(AUDIENCE, audience), null)).isSucceeded();

var captor = ArgumentCaptor.forClass(SharedSecretOauth2CredentialsRequest.class);
Expand All @@ -60,7 +66,7 @@ void createToken() {
assertThat(request.getUrl()).isEqualTo(configuration.tokenUrl());
assertThat(request.getClientId()).isEqualTo(configuration.clientId());
assertThat(request.getGrantType()).isEqualTo(GRANT_TYPE);
assertThat(request.getClientSecret()).isEqualTo(configuration.clientSecret());
assertThat(request.getClientSecret()).isEqualTo(secret);
assertThat(request.getParams())
.containsEntry(AUDIENCE_PARAM, audience);
});
Expand All @@ -70,7 +76,11 @@ void createToken() {
void createToken_withAccessScope() {
var audience = "aud";
var bearerAccessScope = "scope";
var secret = "secret";

when(oauth2Client.requestToken(any())).thenReturn(Result.success(TokenRepresentation.Builder.newInstance().build()));
when(vault.resolveSecret(configuration.clientSecretAlias())).thenReturn(secret);

assertThat(secureTokenService.createToken(Map.of(AUDIENCE, audience), bearerAccessScope)).isSucceeded();

var captor = ArgumentCaptor.forClass(SharedSecretOauth2CredentialsRequest.class);
Expand All @@ -80,7 +90,7 @@ void createToken_withAccessScope() {
assertThat(request.getUrl()).isEqualTo(configuration.tokenUrl());
assertThat(request.getClientId()).isEqualTo(configuration.clientId());
assertThat(request.getGrantType()).isEqualTo(GRANT_TYPE);
assertThat(request.getClientSecret()).isEqualTo(configuration.clientSecret());
assertThat(request.getClientSecret()).isEqualTo(secret);
assertThat(request.getParams())
.containsEntry(AUDIENCE_PARAM, audience)
.containsEntry(BEARER_ACCESS_SCOPE, bearerAccessScope);
Expand All @@ -91,7 +101,11 @@ void createToken_withAccessScope() {
void createToken_withAccessToken() {
var audience = "aud";
var accessToken = "accessToken";
var secret = "secret";

when(oauth2Client.requestToken(any())).thenReturn(Result.success(TokenRepresentation.Builder.newInstance().build()));
when(vault.resolveSecret(configuration.clientSecretAlias())).thenReturn(secret);

assertThat(secureTokenService.createToken(Map.of(AUDIENCE, audience, PRESENTATION_TOKEN_CLAIM, accessToken), null)).isSucceeded();

var captor = ArgumentCaptor.forClass(SharedSecretOauth2CredentialsRequest.class);
Expand All @@ -101,11 +115,25 @@ void createToken_withAccessToken() {
assertThat(request.getUrl()).isEqualTo(configuration.tokenUrl());
assertThat(request.getClientId()).isEqualTo(configuration.clientId());
assertThat(request.getGrantType()).isEqualTo(GRANT_TYPE);
assertThat(request.getClientSecret()).isEqualTo(configuration.clientSecret());
assertThat(request.getClientSecret()).isEqualTo(secret);
assertThat(request.getParams())
.containsEntry(AUDIENCE_PARAM, audience)
.containsEntry(PRESENTATION_TOKEN_CLAIM, accessToken);
});
}

@Test
void createToken_shouldFail_whenSecretIsNotPresent() {
var audience = "aud";
var accessToken = "accessToken";

when(oauth2Client.requestToken(any())).thenReturn(Result.success(TokenRepresentation.Builder.newInstance().build()));
when(vault.resolveSecret(configuration.clientSecretAlias())).thenReturn(null);

assertThat(secureTokenService.createToken(Map.of(AUDIENCE, audience, PRESENTATION_TOKEN_CLAIM, accessToken), null))
.isFailed()
.detail().isEqualTo(format("Failed to fetch client secret from the vault with alias: %s", configuration.clientSecretAlias()));

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import org.eclipse.edc.iam.oauth2.client.Oauth2ClientImpl;
import org.eclipse.edc.json.JacksonTypeManager;
import org.eclipse.edc.junit.annotations.EndToEndTest;
import org.eclipse.edc.junit.extensions.EdcRuntimeExtension;
import org.eclipse.edc.junit.extensions.EmbeddedRuntime;
import org.eclipse.edc.junit.extensions.RuntimePerClassExtension;
import org.eclipse.edc.spi.iam.TokenRepresentation;
import org.eclipse.edc.spi.result.Failure;
import org.eclipse.edc.spi.security.Vault;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
Expand All @@ -48,10 +50,10 @@ public class RemoteStsEndToEndTest extends StsEndToEndTestBase {

public static final int PORT = getFreePort();
public static final String STS_TOKEN_PATH = "http://localhost:" + PORT + "/sts/token";
private static final String SECRET = "secret";

@RegisterExtension
static EdcRuntimeExtension sts = new EdcRuntimeExtension(
":system-tests:sts-api:sts-api-test-runtime",
static RuntimePerClassExtension sts = new RuntimePerClassExtension(new EmbeddedRuntime(
"sts",
new HashMap<>() {
{
Expand All @@ -60,24 +62,26 @@ public class RemoteStsEndToEndTest extends StsEndToEndTestBase {
put("web.http.sts.path", "/sts");
put("web.http.sts.port", String.valueOf(PORT));
}
}
);
private final StsRemoteClientConfiguration config = new StsRemoteClientConfiguration(STS_TOKEN_PATH, "client_id", "client_secret");

},
":system-tests:sts-api:sts-api-test-runtime"
));
private final StsRemoteClientConfiguration config = new StsRemoteClientConfiguration(STS_TOKEN_PATH, "client_id", "client_secret_alias");
private RemoteSecureTokenService remoteSecureTokenService;

@BeforeEach
void setup() {
var oauth2Client = new Oauth2ClientImpl(testHttpClient(), new JacksonTypeManager());
remoteSecureTokenService = new RemoteSecureTokenService(oauth2Client, config);
var vault = sts.getService(Vault.class);
vault.storeSecret(config.clientSecretAlias(), SECRET);
remoteSecureTokenService = new RemoteSecureTokenService(oauth2Client, config, vault);
}

@Test
void requestToken() {
var audience = "audience";
var params = Map.of(AUDIENCE, audience);

var client = initClient(config.clientId(), config.clientSecret());
var client = initClient(config.clientId(), SECRET);

assertThat(remoteSecureTokenService.createToken(params, null))
.isSucceeded()
Expand All @@ -100,7 +104,7 @@ void requestToken_withBearerScope() {
var bearerAccessScope = "org.test.Member:read org.test.GoldMember:read";
var params = Map.of(AUDIENCE, audience);

var client = initClient(config.clientId(), config.clientSecret());
var client = initClient(config.clientId(), SECRET);

assertThat(remoteSecureTokenService.createToken(params, bearerAccessScope))
.isSucceeded()
Expand Down Expand Up @@ -133,7 +137,7 @@ void requestToken_withAttachedAccessToken() {
AUDIENCE, audience,
PRESENTATION_TOKEN_CLAIM, accessToken);

var client = initClient(config.clientId(), config.clientSecret());
var client = initClient(config.clientId(), SECRET);

assertThat(remoteSecureTokenService.createToken(params, null))
.isSucceeded()
Expand Down Expand Up @@ -161,7 +165,7 @@ void requestToken_shouldReturnError_whenClientNotFound() {
}

@Override
protected EdcRuntimeExtension getRuntime() {
protected RuntimePerClassExtension getRuntime() {
return sts;
}

Expand Down
Loading
Loading