diff --git a/pom.xml b/pom.xml index f001b13..cc4a5d7 100644 --- a/pom.xml +++ b/pom.xml @@ -379,28 +379,23 @@ software.amazon.awssdk - http-auth-aws - 2.25.28 - - - software.amazon.awssdk - auth - 2.25.28 + dsql + 2.29.27 software.amazon.awssdk regions - 2.25.28 + 2.25.60 software.amazon.awssdk sdk-core - 2.25.28 + 2.25.60 software.amazon.awssdk - http-client-spi - 2.25.28 + auth + 2.25.60 diff --git a/src/main/java/com/oltpbenchmark/util/IAMUtil.java b/src/main/java/com/oltpbenchmark/util/IAMUtil.java index 24df45a..a47a8d1 100644 --- a/src/main/java/com/oltpbenchmark/util/IAMUtil.java +++ b/src/main/java/com/oltpbenchmark/util/IAMUtil.java @@ -1,19 +1,12 @@ package com.oltpbenchmark.util; -import java.net.URI; -import java.net.URISyntaxException; -import java.time.Clock; import java.time.Duration; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.auth.signer.Aws4Signer; -import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams; import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.http.SdkHttpFullRequest; -import software.amazon.awssdk.http.SdkHttpMethod; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.regions.providers.AwsRegionProviderChain; import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; +import software.amazon.awssdk.services.dsql.DsqlUtilities; public class IAMUtil { // Default token validity is one hour @@ -21,12 +14,6 @@ public class IAMUtil { private static final String ADMIN_USERNAME = "admin"; - private static final String SIGNING_NAME = "dsql"; - - private static final String DB_CONNECT_ADMIN = "DbConnectAdmin"; - - private static final String DB_CONNECT = "DbConnect"; - public static String generateAuroraDsqlPasswordToken(String url, String username) { return generateAuroraDsqlPasswordToken( url, @@ -40,33 +27,29 @@ public static String generateAuroraDsqlPasswordToken( String username, AwsCredentialsProvider credentialsProvider, AwsRegionProviderChain regionProvider) { + DsqlUtilities utilities = + DsqlUtilities.builder() + .region(regionProvider.getRegion()) + .credentialsProvider(credentialsProvider) + .build(); + try { IAMUtil.validateUrl(url); String host = url.split("//")[1].split(":")[0]; - - Clock now = Clock.systemUTC(); - Region region = regionProvider.getRegion(); - if (region == null) region = Region.US_EAST_1; - - Aws4Signer signer = Aws4Signer.create(); - Aws4PresignerParams presignerParams = - Aws4PresignerParams.builder() - .signingName(SIGNING_NAME) - .signingRegion(region) - .awsCredentials(credentialsProvider.resolveCredentials()) - .signingClockOverride(now) - .expirationTime(now.instant().plus(DEFAULT_VALIDITY)) - .build(); - SdkHttpFullRequest request = - SdkHttpFullRequest.builder() - .uri(new URI("https", host, "/", null)) - .appendRawQueryParameter( - "Action", (username.equals(ADMIN_USERNAME)) ? DB_CONNECT_ADMIN : DB_CONNECT) - .method(SdkHttpMethod.GET) - .build(); - - return signer.presign(request, presignerParams).getUri().toString().replace("https://", ""); - } catch (URISyntaxException | SdkClientException e) { + return username.equals(ADMIN_USERNAME) + ? utilities.generateDbConnectAdminAuthToken( + builder -> + builder + .hostname(host) + .region(regionProvider.getRegion()) + .expiresIn(DEFAULT_VALIDITY)) + : utilities.generateDbConnectAuthToken( + builder -> + builder + .hostname(host) + .region(regionProvider.getRegion()) + .expiresIn(DEFAULT_VALIDITY)); + } catch (SdkClientException e) { throw new RuntimeException(e); } } diff --git a/src/test/java/com/oltpbenchmark/util/TestIAMUtil.java b/src/test/java/com/oltpbenchmark/util/TestIAMUtil.java index e6b647f..b9f58a1 100644 --- a/src/test/java/com/oltpbenchmark/util/TestIAMUtil.java +++ b/src/test/java/com/oltpbenchmark/util/TestIAMUtil.java @@ -1,8 +1,6 @@ package com.oltpbenchmark.util; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; -import static org.junit.Assert.assertTrue; import org.junit.Before; import org.junit.Test; @@ -42,29 +40,6 @@ public String secretAccessKey() { Mockito.when(regionProvider.getRegion()).thenReturn(Region.US_EAST_2); } - @Test - public void testGenerateAuroraDsqlPasswordToken() { - String token = - IAMUtil.generateAuroraDsqlPasswordToken( - VALID_URL, VALID_ADMIN_USERNAME, credentialsProvider, regionProvider); - assertNotNull(token); - assertTrue(token.contains("localhost/?")); - assertTrue(token.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256")); - assertTrue(token.contains("X-Amz-Expires=3600")); - assertTrue(token.contains("Action=DbConnectAdmin")); - assertTrue(token.contains("X-Amz-Credential=ACCESS_KEY")); - assertTrue(token.contains("X-Amz-Signature")); - } - - @Test - public void testGenerateAuroraDsqlPasswordTokenNonAdminUser() { - String token = - IAMUtil.generateAuroraDsqlPasswordToken( - VALID_URL, "other", credentialsProvider, regionProvider); - assertNotNull(token); - assertTrue(token.contains("Action=DbConnect")); - } - @Test public void testGenerateAuroraDsqlPasswordTokenInvalidUrl() { assertThrows(