Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Ec2ImdsHttpHandler #119334

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContent;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -45,21 +48,38 @@ public class Ec2ImdsHttpHandler implements HttpHandler {

private final BiConsumer<String, String> newCredentialsConsumer;
private final Map<String, String> instanceAddresses;
private final Set<String> validCredentialsEndpoints = ConcurrentCollections.newConcurrentSet();
private final Set<String> validCredentialsEndpoints;
private final boolean dynamicProfileNames;
private final Supplier<String> availabilityZoneSupplier;
@Nullable // if instance identity document not available
private final ToXContent instanceIdentityDocument;

public Ec2ImdsHttpHandler(
Ec2ImdsVersion ec2ImdsVersion,
BiConsumer<String, String> newCredentialsConsumer,
Collection<String> alternativeCredentialsEndpoints,
Supplier<String> availabilityZoneSupplier,
@Nullable ToXContent instanceIdentityDocument,
Map<String, String> instanceAddresses
) {
this.ec2ImdsVersion = Objects.requireNonNull(ec2ImdsVersion);
this.newCredentialsConsumer = Objects.requireNonNull(newCredentialsConsumer);
this.instanceAddresses = instanceAddresses;
this.validCredentialsEndpoints.addAll(alternativeCredentialsEndpoints);

if (alternativeCredentialsEndpoints.isEmpty()) {
dynamicProfileNames = true;
validCredentialsEndpoints = ConcurrentCollections.newConcurrentSet();
} else if (ec2ImdsVersion == Ec2ImdsVersion.V2) {
throw new IllegalArgumentException(
Strings.format("alternative credentials endpoints %s requires IMDSv1", alternativeCredentialsEndpoints)
);
} else {
dynamicProfileNames = false;
validCredentialsEndpoints = Set.copyOf(alternativeCredentialsEndpoints);
}

this.availabilityZoneSupplier = availabilityZoneSupplier;
this.instanceIdentityDocument = instanceIdentityDocument;
}

@Override
Expand All @@ -78,6 +98,8 @@ public void handle(final HttpExchange exchange) throws IOException {
validImdsTokens.add(token);
final var responseBody = token.getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().add("Content-Type", "text/plain");
exchange.getResponseHeaders()
.add("x-aws-ec2-metadata-token-ttl-seconds", Long.toString(TimeValue.timeValueDays(1).seconds()));
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), responseBody.length);
exchange.getResponseBody().write(responseBody);
}
Expand All @@ -98,7 +120,7 @@ public void handle(final HttpExchange exchange) throws IOException {
}

if ("GET".equals(requestMethod)) {
if (path.equals(IMDS_SECURITY_CREDENTIALS_PATH)) {
if (path.equals(IMDS_SECURITY_CREDENTIALS_PATH) && dynamicProfileNames) {
final var profileName = randomIdentifier();
validCredentialsEndpoints.add(IMDS_SECURITY_CREDENTIALS_PATH + profileName);
sendStringResponse(exchange, profileName);
Expand All @@ -107,6 +129,9 @@ public void handle(final HttpExchange exchange) throws IOException {
final var availabilityZone = availabilityZoneSupplier.get();
sendStringResponse(exchange, availabilityZone);
return;
} else if (instanceIdentityDocument != null && path.equals("/latest/dynamic/instance-identity/document")) {
sendStringResponse(exchange, Strings.toString(instanceIdentityDocument));
return;
} else if (validCredentialsEndpoints.contains(path)) {
final String accessKey = randomIdentifier();
final String sessionToken = randomIdentifier();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package fixture.aws.imds;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.ToXContent;

import java.util.Collection;
import java.util.HashMap;
Expand All @@ -24,6 +25,7 @@ public class Ec2ImdsServiceBuilder {
private BiConsumer<String, String> newCredentialsConsumer = Ec2ImdsServiceBuilder::rejectNewCredentials;
private Collection<String> alternativeCredentialsEndpoints = Set.of();
private Supplier<String> availabilityZoneSupplier = Ec2ImdsServiceBuilder::rejectAvailabilityZone;
private ToXContent instanceIdentityDocument = null;
private final Map<String, String> instanceAddresses = new HashMap<>();

public Ec2ImdsServiceBuilder(Ec2ImdsVersion ec2ImdsVersion) {
Expand Down Expand Up @@ -64,8 +66,13 @@ public Ec2ImdsHttpHandler buildHandler() {
newCredentialsConsumer,
alternativeCredentialsEndpoints,
availabilityZoneSupplier,
instanceIdentityDocument,
Map.copyOf(instanceAddresses)
);
}

public Ec2ImdsServiceBuilder instanceIdentityDocument(ToXContent instanceIdentityDocument) {
this.instanceIdentityDocument = instanceIdentityDocument;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,13 @@ public void testImdsV1() throws IOException {
assertTrue(Strings.hasText(profileName));

final var credentialsResponse = handleRequest(handler, "GET", SECURITY_CREDENTIALS_URI + profileName);
assertEquals(RestStatus.OK, credentialsResponse.status());

assertThat(generatedCredentials, aMapWithSize(1));
final var accessKey = generatedCredentials.keySet().iterator().next();
final var sessionToken = generatedCredentials.values().iterator().next();

final var responseMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), credentialsResponse.body().streamInput(), false);
assertEquals(Set.of("AccessKeyId", "Expiration", "RoleArn", "SecretAccessKey", "Token"), responseMap.keySet());
assertEquals(accessKey, responseMap.get("AccessKeyId"));
assertEquals(sessionToken, responseMap.get("Token"));
assertValidCredentialsResponse(
credentialsResponse,
generatedCredentials.keySet().iterator().next(),
generatedCredentials.values().iterator().next()
);
}

public void testImdsV2Disabled() {
Expand All @@ -78,6 +75,7 @@ public void testImdsV2() throws IOException {

final var tokenResponse = handleRequest(handler, "PUT", "/latest/api/token");
assertEquals(RestStatus.OK, tokenResponse.status());
assertEquals(List.of("86400" /* seconds in a day */), tokenResponse.responseHeaders().get("x-aws-ec2-metadata-token-ttl-seconds"));
final var token = tokenResponse.body().utf8ToString();

final var roleResponse = checkImdsV2GetRequest(handler, SECURITY_CREDENTIALS_URI, token);
Expand All @@ -86,16 +84,13 @@ public void testImdsV2() throws IOException {
assertTrue(Strings.hasText(profileName));

final var credentialsResponse = checkImdsV2GetRequest(handler, SECURITY_CREDENTIALS_URI + profileName, token);
assertEquals(RestStatus.OK, credentialsResponse.status());

assertThat(generatedCredentials, aMapWithSize(1));
final var accessKey = generatedCredentials.keySet().iterator().next();
final var sessionToken = generatedCredentials.values().iterator().next();

final var responseMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), credentialsResponse.body().streamInput(), false);
assertEquals(Set.of("AccessKeyId", "Expiration", "RoleArn", "SecretAccessKey", "Token"), responseMap.keySet());
assertEquals(accessKey, responseMap.get("AccessKeyId"));
assertEquals(sessionToken, responseMap.get("Token"));
assertValidCredentialsResponse(
credentialsResponse,
generatedCredentials.keySet().iterator().next(),
generatedCredentials.values().iterator().next()
);
}

public void testAvailabilityZone() {
Expand All @@ -113,7 +108,54 @@ public void testAvailabilityZone() {
assertEquals(generatedAvailabilityZones, Set.of(availabilityZone));
}

private record TestHttpResponse(RestStatus status, BytesReference body) {}
public void testAlternativeCredentialsEndpoint() throws IOException {
expectThrows(
IllegalArgumentException.class,
new Ec2ImdsServiceBuilder(Ec2ImdsVersion.V2).alternativeCredentialsEndpoints(Set.of("/should-not-work"))::buildHandler
);

final var alternativePaths = randomList(1, 5, () -> "/" + randomIdentifier());
final Map<String, String> generatedCredentials = new HashMap<>();

final var handler = new Ec2ImdsServiceBuilder(Ec2ImdsVersion.V1).alternativeCredentialsEndpoints(alternativePaths)
.newCredentialsConsumer(generatedCredentials::put)
.buildHandler();

final var credentialsResponse = handleRequest(handler, "GET", randomFrom(alternativePaths));

assertThat(generatedCredentials, aMapWithSize(1));
assertValidCredentialsResponse(
credentialsResponse,
generatedCredentials.keySet().iterator().next(),
generatedCredentials.values().iterator().next()
);
}

private static void assertValidCredentialsResponse(TestHttpResponse credentialsResponse, String accessKey, String sessionToken)
throws IOException {
assertEquals(RestStatus.OK, credentialsResponse.status());
final var responseMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), credentialsResponse.body().streamInput(), false);
assertEquals(Set.of("AccessKeyId", "Expiration", "RoleArn", "SecretAccessKey", "Token"), responseMap.keySet());
assertEquals(accessKey, responseMap.get("AccessKeyId"));
assertEquals(sessionToken, responseMap.get("Token"));
}

public void testInstanceIdentityDocument() {
final Set<String> generatedRegions = new HashSet<>();
final var handler = new Ec2ImdsServiceBuilder(Ec2ImdsVersion.V1).instanceIdentityDocument((builder, params) -> {
final var newRegion = randomIdentifier();
generatedRegions.add(newRegion);
return builder.field("region", newRegion);
}).buildHandler();

final var instanceIdentityResponse = handleRequest(handler, "GET", "/latest/dynamic/instance-identity/document");
assertEquals(RestStatus.OK, instanceIdentityResponse.status());
final var instanceIdentityString = instanceIdentityResponse.body().utf8ToString();

assertEquals(Strings.format("{\"region\":\"%s\"}", generatedRegions.iterator().next()), instanceIdentityString);
}

private record TestHttpResponse(RestStatus status, Headers responseHeaders, BytesReference body) {}

private static TestHttpResponse checkImdsV2GetRequest(Ec2ImdsHttpHandler handler, String uri, String token) {
final var unauthorizedResponse = handleRequest(handler, "GET", uri, null);
Expand Down Expand Up @@ -145,7 +187,11 @@ private static TestHttpResponse handleRequest(Ec2ImdsHttpHandler handler, String
fail(e);
}
assertNotEquals(0, httpExchange.getResponseCode());
return new TestHttpResponse(RestStatus.fromCode(httpExchange.getResponseCode()), httpExchange.getResponseBodyContents());
return new TestHttpResponse(
RestStatus.fromCode(httpExchange.getResponseCode()),
httpExchange.getResponseHeaders(),
httpExchange.getResponseBodyContents()
);
}

private static class TestHttpExchange extends HttpExchange {
Expand Down