Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Event Hubs] Support SAS token in connnection string #14912

Merged
merged 7 commits into from
Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,31 @@ public class ConnectionStringProperties {
private static final String ENDPOINT = "Endpoint";
private static final String SHARED_ACCESS_KEY_NAME = "SharedAccessKeyName";
private static final String SHARED_ACCESS_KEY = "SharedAccessKey";
private static final String SHARED_ACCESS_SIGNATURE = "SharedAccessSignature";
private static final String SAS_VALUE_PREFIX = "sharedaccesssignature ";
private static final String ENTITY_PATH = "EntityPath";
private static final String CONNECTION_STRING_WITH_ACCESS_KEY = "Endpoint={endpoint};"
+ "SharedAccessKeyName={sharedAccessKeyName};SharedAccessKey={sharedAccessKey};EntityPath={entityPath}";
private static final String CONNECTION_STRING_WITH_SAS = "Endpoint={endpoint};SharedAccessSignature="
+ "SharedAccessSignature {sharedAccessSignature};EntityPath={entityPath}";
private static final String ERROR_MESSAGE_FORMAT = "Could not parse 'connectionString'. Expected format: "
+ "'Endpoint={endpoint};SharedAccessKeyName={sharedAccessKeyName};"
+ "SharedAccessKey={sharedAccessKey};EntityPath={entityPath}'. Actual: %s";
+ CONNECTION_STRING_WITH_ACCESS_KEY + " or " + CONNECTION_STRING_WITH_SAS + ". Actual: %s";
private static final String ERROR_MESSAGE_ENDPOINT_FORMAT = "'Endpoint' must be provided in 'connectionString'."
+ " Actual: %s";

private final URI endpoint;
private final String entityPath;
private final String sharedAccessKeyName;
private final String sharedAccessKey;
private final String sharedAccessSignature;

/**
* Creates a new instance by parsing the {@code connectionString} into its components.
*
* @param connectionString The connection string to the Event Hub instance.
*
* @throws NullPointerException if {@code connectionString} is null.
* @throws IllegalArgumentException if {@code connectionString} is an empty string or the connection string has
* an invalid format.
* an invalid format.
*/
public ConnectionStringProperties(String connectionString) {
Objects.requireNonNull(connectionString, "'connectionString' cannot be null.");
Expand All @@ -56,6 +62,7 @@ public ConnectionStringProperties(String connectionString) {
String entityPath = null;
String sharedAccessKeyName = null;
String sharedAccessKeyValue = null;
String sharedAccessSignature = null;

for (String tokenValuePair : tokenValuePairs) {
final String[] pair = tokenValuePair.split(TOKEN_VALUE_SEPARATOR, 2);
Expand Down Expand Up @@ -83,25 +90,34 @@ public ConnectionStringProperties(String connectionString) {
sharedAccessKeyValue = value;
} else if (key.equalsIgnoreCase(ENTITY_PATH)) {
entityPath = value;
} else if (key.equalsIgnoreCase(SHARED_ACCESS_SIGNATURE)
&& value.toLowerCase(Locale.ROOT).startsWith(SAS_VALUE_PREFIX)) {
sharedAccessSignature = value;
} else {
throw new IllegalArgumentException(
String.format(Locale.US, "Illegal connection string parameter name: %s", key));
}
}

if (endpoint == null || sharedAccessKeyName == null || sharedAccessKeyValue == null) {
// connection string should have an endpoint and either shared access signature or shared access key and value
boolean includesSharedKey = sharedAccessKeyName != null || sharedAccessKeyValue != null;
boolean hasSharedKeyAndValue = sharedAccessKeyName != null && sharedAccessKeyValue != null;
boolean includesSharedAccessSignature = sharedAccessSignature != null;
if (endpoint == null
|| (includesSharedKey && includesSharedAccessSignature) // includes both SAS and key or value
|| (!hasSharedKeyAndValue && !includesSharedAccessSignature)) { // invalid key, value and SAS
throw new IllegalArgumentException(String.format(Locale.US, ERROR_MESSAGE_FORMAT, connectionString));
}

this.endpoint = endpoint;
this.entityPath = entityPath;
this.sharedAccessKeyName = sharedAccessKeyName;
this.sharedAccessKey = sharedAccessKeyValue;
this.sharedAccessSignature = sharedAccessSignature;
}

/**
* Gets the endpoint to be used for connecting to the AMQP message broker.
*
* @return The endpoint address, including protocol, from the connection string.
*/
public URI getEndpoint() {
Expand All @@ -110,7 +126,6 @@ public URI getEndpoint() {

/**
* Gets the entity path to connect to in the message broker.
*
* @return The entity path to connect to in the message broker.
*/
public String getEntityPath() {
Expand All @@ -119,7 +134,6 @@ public String getEntityPath() {

/**
* Gets the name of the shared access key, either for the Event Hubs namespace or the Event Hub instance.
*
* @return The name of the shared access key.
*/
public String getSharedAccessKeyName() {
Expand All @@ -128,13 +142,21 @@ public String getSharedAccessKeyName() {

/**
* The value of the shared access key, either for the Event Hubs namespace or the Event Hub.
*
* @return The value of the shared access key.
*/
public String getSharedAccessKey() {
return sharedAccessKey;
}

/**
* The value of the shared access signature, if the connection string used to create this instance included the
* shared access signature component.
* @return The shared access signature value, if included in the connection string.
*/
public String getSharedAccessSignature() {
return sharedAccessSignature;
}

/*
* The function checks for pre existing scheme of "sb://" , "http://" or "https://". If the scheme is not provided
* in endpoint, it will set the default scheme to "sb://".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.Locale;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertThrows;

Expand All @@ -16,6 +19,11 @@ public class ConnectionStringPropertiesTest {
private static final String EVENT_HUB = "event-hub-instance";
private static final String SAS_KEY = "test-sas-key";
private static final String SAS_VALUE = "some-secret-value";
private static final String SHARED_ACCESS_SIGNATURE = "SharedAccessSignature "
+ "sr=https%3A%2F%2Fentity-name.servicebus.windows.net%2F"
+ "&sig=encodedsignature%3D"
+ "&se=100000"
+ "&skn=test-sas-key";

@Test
public void nullConnectionString() {
Expand Down Expand Up @@ -130,7 +138,54 @@ public void parseConnectionString() {
Assertions.assertEquals(EVENT_HUB, properties.getEntityPath());
}

private static String getConnectionString(String hostname, String eventHubName, String sasKeyName, String sasKeyValue) {
@ParameterizedTest
@MethodSource("getInvalidConnectionString")
public void testConnectionStringWithSas(String invalidConnectionString) {
assertThrows(IllegalArgumentException.class, () -> new ConnectionStringProperties(invalidConnectionString));
}

@ParameterizedTest
@MethodSource("getSharedAccessSignature")
public void testInvalidSharedAccessSignature(String sas) {
assertThrows(IllegalArgumentException.class, () ->
new ConnectionStringProperties(getConnectionString(HOSTNAME_URI, null, null, null, sas)));
}

private static Stream<String> getInvalidConnectionString() {
String keyNameWithSas = getConnectionString(HOSTNAME_URI, EVENT_HUB, SAS_KEY, null, SHARED_ACCESS_SIGNATURE);
String keyValueWithSas = getConnectionString(HOSTNAME_URI, EVENT_HUB, null, SAS_VALUE, SHARED_ACCESS_SIGNATURE);
String keyNameAndValueWithSas = getConnectionString(HOSTNAME_URI, EVENT_HUB, SAS_KEY, SAS_VALUE,
SHARED_ACCESS_SIGNATURE);
String nullHostName = getConnectionString(null, EVENT_HUB, SAS_KEY, SAS_VALUE, SHARED_ACCESS_SIGNATURE);
String nullHostNameValidSas = getConnectionString(null, EVENT_HUB, null, null, SHARED_ACCESS_SIGNATURE);
String nullHostNameValidKey = getConnectionString(null, EVENT_HUB, SAS_KEY, SAS_VALUE, null);
return Stream.of(keyNameWithSas, keyValueWithSas, keyNameAndValueWithSas, nullHostName, nullHostNameValidSas,
nullHostNameValidKey);
}

private static Stream<String> getSharedAccessSignature() {
String nullSas = null;
String sasInvalidPrefix = "AccessSignature " // invalid prefix
+ "sr=https%3A%2F%2Fentity-name.servicebus.windows.net%2F"
+ "&sig=encodedsignature%3D"
+ "&se=100000"
+ "&skn=test-sas-key";
String sasWithoutSpace = "SharedAccessSignature" // no space after prefix
+ "sr=https%3A%2F%2Fentity-name.servicebus.windows.net%2F"
+ "&sig=encodedsignature%3D"
+ "&se=100000"
+ "&skn=test-sas-key";

return Stream.of(nullSas, sasInvalidPrefix, sasWithoutSpace);
}

private static String getConnectionString(String hostname, String eventHubName, String sasKeyName,
String sasKeyValue) {
return getConnectionString(hostname, eventHubName, sasKeyName, sasKeyValue, null);
}

private static String getConnectionString(String hostname, String eventHubName, String sasKeyName,
String sasKeyValue, String sharedAccessSignature) {
final StringBuilder builder = new StringBuilder();
if (hostname != null) {
builder.append(String.format(Locale.US, "Endpoint=%s;", hostname));
Expand All @@ -144,6 +199,9 @@ private static String getConnectionString(String hostname, String eventHubName,
if (sasKeyValue != null) {
builder.append(String.format(Locale.US, "SharedAccessKey=%s;", sasKeyValue));
}
if (sharedAccessSignature != null) {
builder.append(String.format(Locale.US, "SharedAccessSignature=%s;", sharedAccessSignature));
}

return builder.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,22 @@ public EventHubClientBuilder() {
* connection string.
*/
public EventHubClientBuilder connectionString(String connectionString) {
final ConnectionStringProperties properties = new ConnectionStringProperties(connectionString);
final TokenCredential tokenCredential = new EventHubSharedKeyCredential(properties.getSharedAccessKeyName(),
properties.getSharedAccessKey(), ClientConstants.TOKEN_VALIDITY);

ConnectionStringProperties properties = new ConnectionStringProperties(connectionString);
TokenCredential tokenCredential = getTokenCredential(properties);
return credential(properties.getEndpoint().getHost(), properties.getEntityPath(), tokenCredential);
}

private TokenCredential getTokenCredential(ConnectionStringProperties properties) {
TokenCredential tokenCredential;
if (properties.getSharedAccessSignature() == null) {
tokenCredential = new EventHubSharedKeyCredential(properties.getSharedAccessKeyName(),
properties.getSharedAccessKey(), ClientConstants.TOKEN_VALIDITY);
} else {
tokenCredential = new EventHubSharedKeyCredential(properties.getSharedAccessSignature());
}
return tokenCredential;
}

/**
* Sets the credential information given a connection string to the Event Hubs namespace and name to a specific
* Event Hub instance.
Expand Down Expand Up @@ -213,8 +222,7 @@ public EventHubClientBuilder connectionString(String connectionString, String ev
}

final ConnectionStringProperties properties = new ConnectionStringProperties(connectionString);
final TokenCredential tokenCredential = new EventHubSharedKeyCredential(properties.getSharedAccessKeyName(),
properties.getSharedAccessKey(), ClientConstants.TOKEN_VALIDITY);
TokenCredential tokenCredential = getTokenCredential(properties);

if (!CoreUtils.isNullOrEmpty(properties.getEntityPath())
&& !eventHubName.equals(properties.getEntityPath())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.Base64;
import java.util.Locale;
import java.util.Objects;
Expand Down Expand Up @@ -51,6 +53,7 @@ public class EventHubSharedKeyCredential implements TokenCredential {
private final String policyName;
private final Duration tokenValidity;
private final SecretKeySpec secretKeySpec;
private final String sharedAccessSignature;

/**
* Creates an instance that authorizes using the {@code policyName} and {@code sharedAccessKey}.
Expand Down Expand Up @@ -98,6 +101,26 @@ public EventHubSharedKeyCredential(String policyName, String sharedAccessKey, Du

final byte[] sasKeyBytes = sharedAccessKey.getBytes(UTF_8);
secretKeySpec = new SecretKeySpec(sasKeyBytes, HASH_ALGORITHM);
sharedAccessSignature = null;
}

/**
* Creates an instance using the provided Shared Access Signature (SAS) string. The credential created using this
* constructor will not be refreshed. The expiration time is set to the time defined in "se={
* tokenValidationSeconds}`. If the SAS string does not contain this or is in invalid format, then the token
* expiration will be set to {@link OffsetDateTime#MAX max duration}.
* <p><a href="https://docs.microsoft.com/rest/api/eventhub/generate-sas-token">See how to generate SAS
* programmatically.</a></p>
*
* @param sharedAccessSignature The base64 encoded shared access signature string.
* @throws NullPointerException if {@code sharedAccessSignature} is null.
*/
public EventHubSharedKeyCredential(String sharedAccessSignature) {
this.sharedAccessSignature = Objects.requireNonNull(sharedAccessSignature,
"'sharedAccessSignature' cannot be null");
this.policyName = null;
this.secretKeySpec = null;
this.tokenValidity = null;
}

/**
Expand All @@ -124,6 +147,10 @@ private AccessToken generateSharedAccessSignature(final String resource) throws
throw logger.logExceptionAsError(new IllegalArgumentException("resource cannot be empty"));
}

if (sharedAccessSignature != null) {
return new AccessToken(sharedAccessSignature, getExpirationTime(sharedAccessSignature));
}

final Mac hmac;
try {
hmac = Mac.getInstance(HASH_ALGORITHM);
Expand Down Expand Up @@ -153,4 +180,24 @@ private AccessToken generateSharedAccessSignature(final String resource) throws

return new AccessToken(token, expiresOn);
}

private OffsetDateTime getExpirationTime(String sharedAccessSignature) {
String[] parts = sharedAccessSignature.split("&");
return Arrays.stream(parts)
.map(part -> part.split("="))
.filter(pair -> pair.length == 2 && pair[0].equalsIgnoreCase("se"))
.findFirst()
.map(pair -> pair[1])
.map(expirationTimeStr -> {
try {
long epochSeconds = Long.parseLong(expirationTimeStr);
return Instant.ofEpochSecond(epochSeconds).atOffset(ZoneOffset.UTC);
} catch (NumberFormatException exception) {
logger.verbose("Invalid expiration time format in the SAS token: {}. Falling back to max "
+ "expiration time.", expirationTimeStr);
return OffsetDateTime.MAX;
}
})
.orElse(OffsetDateTime.MAX);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,23 @@ public void throwsWithProxyWhenTransportTypeNotChanged() {
assertNotNull(builder.buildAsyncClient());
});
}
@Test
public void testConnectionStringWithSas() {

String connectionStringWithNoEntityPath = "Endpoint=sb://eh-name.servicebus.windows.net/;"
+ "SharedAccessSignature=SharedAccessSignature test-value";
String connectionStringWithEntityPath = "Endpoint=sb://eh-name.servicebus.windows.net/;"
+ "SharedAccessSignature=SharedAccessSignature test-value;EntityPath=eh-name";

assertNotNull(new EventHubClientBuilder()
.connectionString(connectionStringWithNoEntityPath, "eh-name"));
assertNotNull(new EventHubClientBuilder()
.connectionString(connectionStringWithEntityPath));
assertThrows(NullPointerException.class, () -> new EventHubClientBuilder()
.connectionString(connectionStringWithNoEntityPath));
assertThrows(IllegalArgumentException.class, () -> new EventHubClientBuilder()
.connectionString(connectionStringWithEntityPath, "eh-name-mismatch"));
}

@MethodSource("getProxyConfigurations")
@ParameterizedTest
Expand Down
Loading