Skip to content

Commit

Permalink
feat: DIM access token cache (#1250)
Browse files Browse the repository at this point in the history
  • Loading branch information
wolf4ood authored Apr 29, 2024
1 parent 57affd6 commit 8c53419
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
import org.eclipse.edc.runtime.metamodel.annotation.Extension;
import org.eclipse.edc.runtime.metamodel.annotation.Inject;
import org.eclipse.edc.runtime.metamodel.annotation.Provider;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.security.Vault;
import org.eclipse.edc.spi.system.ServiceExtension;
import org.eclipse.tractusx.edc.iam.iatp.sts.dim.oauth.DimOauth2Client;
import org.eclipse.tractusx.edc.iam.iatp.sts.dim.oauth.DimOauthClientImpl;

import java.time.Clock;

/**
* Extension that provides an implementation if {@link DimOauth2Client} using {@link Oauth2Client}
* and the {@link StsRemoteClientConfiguration} configuration for fetching an OAuth token
Expand All @@ -46,14 +49,20 @@ public class DimOauthClientExtension implements ServiceExtension {
@Inject
private Vault vault;

@Inject
private Monitor monitor;

@Inject
private Clock clock;

@Override
public String name() {
return NAME;
}

@Provider
public DimOauth2Client oauth2Client() {
return new DimOauthClientImpl(oauth2Client, vault, clientConfiguration);
return new DimOauthClientImpl(oauth2Client, vault, clientConfiguration, clock, monitor);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,66 @@
import org.eclipse.edc.iam.oauth2.spi.client.Oauth2CredentialsRequest;
import org.eclipse.edc.iam.oauth2.spi.client.SharedSecretOauth2CredentialsRequest;
import org.eclipse.edc.spi.iam.TokenRepresentation;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.result.Result;
import org.eclipse.edc.spi.security.Vault;
import org.eclipse.tractusx.edc.iam.iatp.sts.dim.StsRemoteClientConfiguration;
import org.jetbrains.annotations.NotNull;

import java.time.Clock;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Optional;

public class DimOauthClientImpl implements DimOauth2Client {

private static final String GRANT_TYPE = "client_credentials";
private final StsRemoteClientConfiguration configuration;
private final Oauth2Client oauth2Client;

private final Vault vault;
private final Clock clock;
private final Monitor monitor;

private volatile TimestampedToken authToken;

public DimOauthClientImpl(Oauth2Client oauth2Client, Vault vault, StsRemoteClientConfiguration configuration) {
public DimOauthClientImpl(Oauth2Client oauth2Client, Vault vault, StsRemoteClientConfiguration configuration, Clock clock, Monitor monitor) {
this.configuration = configuration;
this.oauth2Client = oauth2Client;
this.vault = vault;
this.clock = clock;
this.monitor = monitor;
}

@Override
public Result<TokenRepresentation> obtainRequestToken() {
if (isExpired()) {
synchronized (this) {
if (isExpired()) {
monitor.debug("DIM Token expired, need to refresh.");
// expiresIn should always be present, but if not we don't cache it
return requestToken().onSuccess(tokenRepresentation -> Optional.ofNullable(tokenRepresentation.getExpiresIn())
.ifPresent(expiresIn -> this.authToken = new TimestampedToken(tokenRepresentation, Instant.now(clock), expiresIn)));
} else {
return Result.success(authToken.value);
}
}
} else {
return Result.success(authToken.value);
}
}

private Result<TokenRepresentation> requestToken() {
return createRequest().compose(oauth2Client::requestToken);
}

private boolean isExpired() {
if (authToken == null) {
return true;
}
return authToken.isExpired(clock);
}

@NotNull
private Result<Oauth2CredentialsRequest> createRequest() {
var secret = vault.resolveSecret(configuration.clientSecretAlias());
Expand All @@ -63,4 +98,11 @@ private Result<Oauth2CredentialsRequest> createRequest() {
return Result.failure("Failed to fetch client secret from the vault with alias: %s".formatted(configuration.clientSecretAlias()));
}
}

record TimestampedToken(TokenRepresentation value, Instant lastUpdatedAt, long validitySeconds) {

public boolean isExpired(Clock clock) {
return lastUpdatedAt.plus(validitySeconds, ChronoUnit.SECONDS).isBefore(Instant.now(clock));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,37 @@
import org.eclipse.edc.iam.oauth2.spi.client.Oauth2Client;
import org.eclipse.edc.iam.oauth2.spi.client.SharedSecretOauth2CredentialsRequest;
import org.eclipse.edc.spi.iam.TokenRepresentation;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.result.Result;
import org.eclipse.edc.spi.security.Vault;
import org.eclipse.tractusx.edc.iam.iatp.sts.dim.StsRemoteClientConfiguration;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;

import java.time.Clock;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class DimOauthClientImplTest {

private final Oauth2Client oauth2Client = mock(Oauth2Client.class);
private final Oauth2Client oauth2Client = mock();

private final Vault vault = mock();

private final Vault vault = mock(Vault.class);
private final Monitor monitor = mock();

@Test
void obtainRequestToken() {
void obtainRequestToken_withNoExpiration() {
var config = new StsRemoteClientConfiguration("http://localhost:8081/token", "clientId", "client_secret_alias");
var tokenRepresentation = TokenRepresentation.Builder.newInstance().token("token").build();
when(vault.resolveSecret("client_secret_alias")).thenReturn("client_secret");
when(oauth2Client.requestToken(any())).thenReturn(Result.success(tokenRepresentation));
var client = new DimOauthClientImpl(oauth2Client, vault, config);
var client = new DimOauthClientImpl(oauth2Client, vault, config, Clock.systemUTC(), monitor);

var response = client.obtainRequestToken();
assertThat(response).isNotNull().extracting(Result::getContent).isEqualTo(tokenRepresentation);
Expand All @@ -60,14 +66,75 @@ void obtainRequestToken() {
assertThat(request.getClientSecret()).isEqualTo("client_secret");
assertThat(request.getUrl()).isEqualTo(config.tokenUrl());

response = client.obtainRequestToken();
assertThat(response).isNotNull().extracting(Result::getContent).isEqualTo(tokenRepresentation);

verify(oauth2Client, times(2)).requestToken(any());

}

@Test
void obtainRequestToken_withExpiration_whenNotExpired() {
var config = new StsRemoteClientConfiguration("http://localhost:8081/token", "clientId", "client_secret_alias");
var tokenRepresentation = TokenRepresentation.Builder.newInstance().token("token").expiresIn(10L).build();
when(vault.resolveSecret("client_secret_alias")).thenReturn("client_secret");
when(oauth2Client.requestToken(any())).thenReturn(Result.success(tokenRepresentation));
var client = new DimOauthClientImpl(oauth2Client, vault, config, Clock.systemUTC(), monitor);

var response = client.obtainRequestToken();
assertThat(response).isNotNull().extracting(Result::getContent).isEqualTo(tokenRepresentation);

var captor = ArgumentCaptor.forClass(SharedSecretOauth2CredentialsRequest.class);
verify(oauth2Client).requestToken(captor.capture());

var request = captor.getValue();

assertThat(request.getClientId()).isEqualTo(config.clientId());
assertThat(request.getClientSecret()).isEqualTo("client_secret");
assertThat(request.getUrl()).isEqualTo(config.tokenUrl());

response = client.obtainRequestToken();
assertThat(response).isNotNull().extracting(Result::getContent).isEqualTo(tokenRepresentation);

verify(oauth2Client, times(1)).requestToken(any());

}

@Test
void obtainRequestToken_withExpiration_whenExpired() throws InterruptedException {
var config = new StsRemoteClientConfiguration("http://localhost:8081/token", "clientId", "client_secret_alias");
var tokenRepresentation = TokenRepresentation.Builder.newInstance().token("token").expiresIn(2L).build();
when(vault.resolveSecret("client_secret_alias")).thenReturn("client_secret");
when(oauth2Client.requestToken(any())).thenReturn(Result.success(tokenRepresentation));
var client = new DimOauthClientImpl(oauth2Client, vault, config, Clock.systemUTC(), monitor);

var response = client.obtainRequestToken();
assertThat(response).isNotNull().extracting(Result::getContent).isEqualTo(tokenRepresentation);

var captor = ArgumentCaptor.forClass(SharedSecretOauth2CredentialsRequest.class);
verify(oauth2Client).requestToken(captor.capture());

var request = captor.getValue();

assertThat(request.getClientId()).isEqualTo(config.clientId());
assertThat(request.getClientSecret()).isEqualTo("client_secret");
assertThat(request.getUrl()).isEqualTo(config.tokenUrl());

Thread.sleep(2100);

response = client.obtainRequestToken();
assertThat(response).isNotNull().extracting(Result::getContent).isEqualTo(tokenRepresentation);

verify(oauth2Client, times(2)).requestToken(any());

}

@Test
void obtainRequestToken_failed() {
var config = new StsRemoteClientConfiguration("http://localhost:8081/token", "clientId", "client_secret");

when(oauth2Client.requestToken(any())).thenReturn(Result.failure("failure"));
var client = new DimOauthClientImpl(oauth2Client, vault, config);
var client = new DimOauthClientImpl(oauth2Client, vault, config, Clock.systemUTC(), monitor);

var response = client.obtainRequestToken();
assertThat(response).isNotNull().matches(Result::failed);
Expand Down

0 comments on commit 8c53419

Please sign in to comment.