Skip to content

Commit

Permalink
Enhance Ec2ImdsHttpHandler (elastic#119334)
Browse files Browse the repository at this point in the history
- Require IMDSv1 if using alternative endpoints (i.e. ECS)
- Forbid profile name lookup with alternative endpoints
- Add token TTL header for IMDSv2
- Add support for instance-identity docs
  • Loading branch information
DaveCTurner authored Dec 31, 2024
1 parent 4012208 commit 34ec706
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContent;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -45,21 +48,38 @@ public class Ec2ImdsHttpHandler implements HttpHandler {

private final BiConsumer<String, String> newCredentialsConsumer;
private final Map<String, String> instanceAddresses;
private final Set<String> validCredentialsEndpoints = ConcurrentCollections.newConcurrentSet();
private final Set<String> validCredentialsEndpoints;
private final boolean dynamicProfileNames;
private final Supplier<String> availabilityZoneSupplier;
@Nullable // if instance identity document not available
private final ToXContent instanceIdentityDocument;

public Ec2ImdsHttpHandler(
Ec2ImdsVersion ec2ImdsVersion,
BiConsumer<String, String> newCredentialsConsumer,
Collection<String> alternativeCredentialsEndpoints,
Supplier<String> availabilityZoneSupplier,
@Nullable ToXContent instanceIdentityDocument,
Map<String, String> instanceAddresses
) {
this.ec2ImdsVersion = Objects.requireNonNull(ec2ImdsVersion);
this.newCredentialsConsumer = Objects.requireNonNull(newCredentialsConsumer);
this.instanceAddresses = instanceAddresses;
this.validCredentialsEndpoints.addAll(alternativeCredentialsEndpoints);

if (alternativeCredentialsEndpoints.isEmpty()) {
dynamicProfileNames = true;
validCredentialsEndpoints = ConcurrentCollections.newConcurrentSet();
} else if (ec2ImdsVersion == Ec2ImdsVersion.V2) {
throw new IllegalArgumentException(
Strings.format("alternative credentials endpoints %s requires IMDSv1", alternativeCredentialsEndpoints)
);
} else {
dynamicProfileNames = false;
validCredentialsEndpoints = Set.copyOf(alternativeCredentialsEndpoints);
}

this.availabilityZoneSupplier = availabilityZoneSupplier;
this.instanceIdentityDocument = instanceIdentityDocument;
}

@Override
Expand All @@ -78,6 +98,8 @@ public void handle(final HttpExchange exchange) throws IOException {
validImdsTokens.add(token);
final var responseBody = token.getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().add("Content-Type", "text/plain");
exchange.getResponseHeaders()
.add("x-aws-ec2-metadata-token-ttl-seconds", Long.toString(TimeValue.timeValueDays(1).seconds()));
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), responseBody.length);
exchange.getResponseBody().write(responseBody);
}
Expand All @@ -98,7 +120,7 @@ public void handle(final HttpExchange exchange) throws IOException {
}

if ("GET".equals(requestMethod)) {
if (path.equals(IMDS_SECURITY_CREDENTIALS_PATH)) {
if (path.equals(IMDS_SECURITY_CREDENTIALS_PATH) && dynamicProfileNames) {
final var profileName = randomIdentifier();
validCredentialsEndpoints.add(IMDS_SECURITY_CREDENTIALS_PATH + profileName);
sendStringResponse(exchange, profileName);
Expand All @@ -107,6 +129,9 @@ public void handle(final HttpExchange exchange) throws IOException {
final var availabilityZone = availabilityZoneSupplier.get();
sendStringResponse(exchange, availabilityZone);
return;
} else if (instanceIdentityDocument != null && path.equals("/latest/dynamic/instance-identity/document")) {
sendStringResponse(exchange, Strings.toString(instanceIdentityDocument));
return;
} else if (validCredentialsEndpoints.contains(path)) {
final String accessKey = randomIdentifier();
final String sessionToken = randomIdentifier();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package fixture.aws.imds;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.ToXContent;

import java.util.Collection;
import java.util.HashMap;
Expand All @@ -24,6 +25,7 @@ public class Ec2ImdsServiceBuilder {
private BiConsumer<String, String> newCredentialsConsumer = Ec2ImdsServiceBuilder::rejectNewCredentials;
private Collection<String> alternativeCredentialsEndpoints = Set.of();
private Supplier<String> availabilityZoneSupplier = Ec2ImdsServiceBuilder::rejectAvailabilityZone;
private ToXContent instanceIdentityDocument = null;
private final Map<String, String> instanceAddresses = new HashMap<>();

public Ec2ImdsServiceBuilder(Ec2ImdsVersion ec2ImdsVersion) {
Expand Down Expand Up @@ -64,8 +66,13 @@ public Ec2ImdsHttpHandler buildHandler() {
newCredentialsConsumer,
alternativeCredentialsEndpoints,
availabilityZoneSupplier,
instanceIdentityDocument,
Map.copyOf(instanceAddresses)
);
}

public Ec2ImdsServiceBuilder instanceIdentityDocument(ToXContent instanceIdentityDocument) {
this.instanceIdentityDocument = instanceIdentityDocument;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,13 @@ public void testImdsV1() throws IOException {
assertTrue(Strings.hasText(profileName));

final var credentialsResponse = handleRequest(handler, "GET", SECURITY_CREDENTIALS_URI + profileName);
assertEquals(RestStatus.OK, credentialsResponse.status());

assertThat(generatedCredentials, aMapWithSize(1));
final var accessKey = generatedCredentials.keySet().iterator().next();
final var sessionToken = generatedCredentials.values().iterator().next();

final var responseMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), credentialsResponse.body().streamInput(), false);
assertEquals(Set.of("AccessKeyId", "Expiration", "RoleArn", "SecretAccessKey", "Token"), responseMap.keySet());
assertEquals(accessKey, responseMap.get("AccessKeyId"));
assertEquals(sessionToken, responseMap.get("Token"));
assertValidCredentialsResponse(
credentialsResponse,
generatedCredentials.keySet().iterator().next(),
generatedCredentials.values().iterator().next()
);
}

public void testImdsV2Disabled() {
Expand All @@ -78,6 +75,7 @@ public void testImdsV2() throws IOException {

final var tokenResponse = handleRequest(handler, "PUT", "/latest/api/token");
assertEquals(RestStatus.OK, tokenResponse.status());
assertEquals(List.of("86400" /* seconds in a day */), tokenResponse.responseHeaders().get("x-aws-ec2-metadata-token-ttl-seconds"));
final var token = tokenResponse.body().utf8ToString();

final var roleResponse = checkImdsV2GetRequest(handler, SECURITY_CREDENTIALS_URI, token);
Expand All @@ -86,16 +84,13 @@ public void testImdsV2() throws IOException {
assertTrue(Strings.hasText(profileName));

final var credentialsResponse = checkImdsV2GetRequest(handler, SECURITY_CREDENTIALS_URI + profileName, token);
assertEquals(RestStatus.OK, credentialsResponse.status());

assertThat(generatedCredentials, aMapWithSize(1));
final var accessKey = generatedCredentials.keySet().iterator().next();
final var sessionToken = generatedCredentials.values().iterator().next();

final var responseMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), credentialsResponse.body().streamInput(), false);
assertEquals(Set.of("AccessKeyId", "Expiration", "RoleArn", "SecretAccessKey", "Token"), responseMap.keySet());
assertEquals(accessKey, responseMap.get("AccessKeyId"));
assertEquals(sessionToken, responseMap.get("Token"));
assertValidCredentialsResponse(
credentialsResponse,
generatedCredentials.keySet().iterator().next(),
generatedCredentials.values().iterator().next()
);
}

public void testAvailabilityZone() {
Expand All @@ -113,7 +108,54 @@ public void testAvailabilityZone() {
assertEquals(generatedAvailabilityZones, Set.of(availabilityZone));
}

private record TestHttpResponse(RestStatus status, BytesReference body) {}
public void testAlternativeCredentialsEndpoint() throws IOException {
expectThrows(
IllegalArgumentException.class,
new Ec2ImdsServiceBuilder(Ec2ImdsVersion.V2).alternativeCredentialsEndpoints(Set.of("/should-not-work"))::buildHandler
);

final var alternativePaths = randomList(1, 5, () -> "/" + randomIdentifier());
final Map<String, String> generatedCredentials = new HashMap<>();

final var handler = new Ec2ImdsServiceBuilder(Ec2ImdsVersion.V1).alternativeCredentialsEndpoints(alternativePaths)
.newCredentialsConsumer(generatedCredentials::put)
.buildHandler();

final var credentialsResponse = handleRequest(handler, "GET", randomFrom(alternativePaths));

assertThat(generatedCredentials, aMapWithSize(1));
assertValidCredentialsResponse(
credentialsResponse,
generatedCredentials.keySet().iterator().next(),
generatedCredentials.values().iterator().next()
);
}

private static void assertValidCredentialsResponse(TestHttpResponse credentialsResponse, String accessKey, String sessionToken)
throws IOException {
assertEquals(RestStatus.OK, credentialsResponse.status());
final var responseMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), credentialsResponse.body().streamInput(), false);
assertEquals(Set.of("AccessKeyId", "Expiration", "RoleArn", "SecretAccessKey", "Token"), responseMap.keySet());
assertEquals(accessKey, responseMap.get("AccessKeyId"));
assertEquals(sessionToken, responseMap.get("Token"));
}

public void testInstanceIdentityDocument() {
final Set<String> generatedRegions = new HashSet<>();
final var handler = new Ec2ImdsServiceBuilder(Ec2ImdsVersion.V1).instanceIdentityDocument((builder, params) -> {
final var newRegion = randomIdentifier();
generatedRegions.add(newRegion);
return builder.field("region", newRegion);
}).buildHandler();

final var instanceIdentityResponse = handleRequest(handler, "GET", "/latest/dynamic/instance-identity/document");
assertEquals(RestStatus.OK, instanceIdentityResponse.status());
final var instanceIdentityString = instanceIdentityResponse.body().utf8ToString();

assertEquals(Strings.format("{\"region\":\"%s\"}", generatedRegions.iterator().next()), instanceIdentityString);
}

private record TestHttpResponse(RestStatus status, Headers responseHeaders, BytesReference body) {}

private static TestHttpResponse checkImdsV2GetRequest(Ec2ImdsHttpHandler handler, String uri, String token) {
final var unauthorizedResponse = handleRequest(handler, "GET", uri, null);
Expand Down Expand Up @@ -145,7 +187,11 @@ private static TestHttpResponse handleRequest(Ec2ImdsHttpHandler handler, String
fail(e);
}
assertNotEquals(0, httpExchange.getResponseCode());
return new TestHttpResponse(RestStatus.fromCode(httpExchange.getResponseCode()), httpExchange.getResponseBodyContents());
return new TestHttpResponse(
RestStatus.fromCode(httpExchange.getResponseCode()),
httpExchange.getResponseHeaders(),
httpExchange.getResponseBodyContents()
);
}

private static class TestHttpExchange extends HttpExchange {
Expand Down

0 comments on commit 34ec706

Please sign in to comment.