Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(credential-providers): pass caller client options to fromTemporaryCredentials inner STSClient #6838

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .prettierignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ CHANGELOG.md
**/*.hbs
**/*/report.md
clients/*/src/endpoint/ruleset.ts
packages/nested-clients/src/submodules/*/endpoint/ruleset.ts
**/*.java
10 changes: 6 additions & 4 deletions clients/client-sts/src/defaultStsRoleAssumers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ const resolveRegion = async (
*/
export const getDefaultRoleAssumer = (
stsOptions: STSRoleAssumerOptions,
stsClientCtor: new (options: STSClientConfig) => STSClient
STSClient: new (options: STSClientConfig) => STSClient
): RoleAssumer => {
let stsClient: STSClient;
let closureSourceCreds: AwsCredentialIdentity;
Expand All @@ -104,7 +104,8 @@ export const getDefaultRoleAssumer = (
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
stsClient = new STSClient({
profile: stsOptions?.parentClientConfig?.profile,
// A hack to make sts client uses the credential in current closure.
credentialDefaultProvider: () => async () => closureSourceCreds,
region: resolvedRegion,
Expand Down Expand Up @@ -146,7 +147,7 @@ export type RoleAssumerWithWebIdentity = (
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: STSRoleAssumerOptions,
stsClientCtor: new (options: STSClientConfig) => STSClient
STSClient: new (options: STSClientConfig) => STSClient
): RoleAssumerWithWebIdentity => {
let stsClient: STSClient;
return async (params) => {
Expand All @@ -164,7 +165,8 @@ export const getDefaultRoleAssumerWithWebIdentity = (
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
stsClient = new STSClient({
profile: stsOptions?.parentClientConfig?.profile,
region: resolvedRegion,
requestHandler: isCompatibleRequestHandler ? (requestHandler as any) : undefined,
logger: logger as any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ const resolveRegion = async (
*/
export const getDefaultRoleAssumer = (
stsOptions: STSRoleAssumerOptions,
stsClientCtor: new (options: STSClientConfig) => STSClient
STSClient: new (options: STSClientConfig) => STSClient
): RoleAssumer => {
let stsClient: STSClient;
let closureSourceCreds: AwsCredentialIdentity;
Expand All @@ -101,7 +101,8 @@ export const getDefaultRoleAssumer = (
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
stsClient = new STSClient({
profile: stsOptions?.parentClientConfig?.profile,
// A hack to make sts client uses the credential in current closure.
credentialDefaultProvider: () => async () => closureSourceCreds,
region: resolvedRegion,
Expand Down Expand Up @@ -143,7 +144,7 @@ export type RoleAssumerWithWebIdentity = (
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: STSRoleAssumerOptions,
stsClientCtor: new (options: STSClientConfig) => STSClient
STSClient: new (options: STSClientConfig) => STSClient
): RoleAssumerWithWebIdentity => {
let stsClient: STSClient;
return async (params) => {
Expand All @@ -161,7 +162,8 @@ export const getDefaultRoleAssumerWithWebIdentity = (
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
stsClient = new STSClient({
profile: stsOptions?.parentClientConfig?.profile,
region: resolvedRegion,
requestHandler: isCompatibleRequestHandler ? (requestHandler as any) : undefined,
logger: logger as any,
Expand Down
1 change: 1 addition & 0 deletions packages/credential-providers/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"@aws-sdk/credential-provider-web-identity": "*",
"@aws-sdk/nested-clients": "*",
"@aws-sdk/types": "*",
"@smithy/core": "^3.0.0",
"@smithy/credential-provider-imds": "^4.0.0",
"@smithy/property-provider": "^4.0.0",
"@smithy/types": "^4.0.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,39 @@ import type {
CredentialProviderOptions,
RuntimeConfigAwsCredentialIdentityProvider,
} from "@aws-sdk/types";
import { normalizeProvider } from "@smithy/core";
import { CredentialsProviderError } from "@smithy/property-provider";
import { AwsCredentialIdentity, AwsCredentialIdentityProvider, Pluggable } from "@smithy/types";
import { AwsCredentialIdentity, AwsCredentialIdentityProvider, Logger, Pluggable, RequestHandler } from "@smithy/types";

export interface FromTemporaryCredentialsOptions extends CredentialProviderOptions {
params: Omit<AssumeRoleCommandInput, "RoleSessionName"> & { RoleSessionName?: string };
masterCredentials?: AwsCredentialIdentity | AwsCredentialIdentityProvider;
clientConfig?: STSClientConfig;
logger?: Logger;
clientPlugins?: Pluggable<any, any>[];
mfaCodeProvider?: (mfaSerial: string) => Promise<string>;
}

const ASSUME_ROLE_DEFAULT_REGION = "us-east-1";

export const fromTemporaryCredentials = (
options: FromTemporaryCredentialsOptions,
credentialDefaultProvider?: () => AwsCredentialIdentityProvider
): RuntimeConfigAwsCredentialIdentityProvider => {
let stsClient: STSClient;
return async (awsIdentityProperties: AwsIdentityProperties = {}): Promise<AwsCredentialIdentity> => {
options.logger?.debug("@aws-sdk/credential-providers - fromTemporaryCredentials (STS)");
const { callerClientConfig } = awsIdentityProperties;
const logger = options.logger ?? callerClientConfig?.logger;
logger?.debug("@aws-sdk/credential-providers - fromTemporaryCredentials (STS)");

const params = { ...options.params, RoleSessionName: options.params.RoleSessionName ?? "aws-sdk-js-" + Date.now() };
if (params?.SerialNumber) {
if (!options.mfaCodeProvider) {
throw new CredentialsProviderError(
`Temporary credential requires multi-factor authentication, but no MFA code callback was provided.`,
{
tryNextLink: false,
logger: options.logger,
logger,
}
);
}
Expand All @@ -42,14 +49,68 @@ export const fromTemporaryCredentials = (
const defaultCredentialsOrError =
typeof credentialDefaultProvider === "function" ? credentialDefaultProvider() : undefined;

const { callerClientConfig } = awsIdentityProperties;
const credentialSources = [
options.masterCredentials,
options.clientConfig?.credentials,
/**
* Important (!): callerClientConfig?.credentials is not a valid
* credential source for this provider, because this function
* is the caller client's credential provider function.
*/
void callerClientConfig?.credentials,
callerClientConfig?.credentialDefaultProvider?.(),
defaultCredentialsOrError,
];
let credentialSource = "STS client default credentials";
if (credentialSources[0]) {
credentialSource = "options.masterCredentials";
} else if (credentialSources[1]) {
credentialSource = "options.clientConfig.credentials";
} else if (credentialSources[2]) {
// This branch is not possible, see above void note.
// This code is here to prevent accidental attempts to utilize
// the invalid credential source.
credentialSource = "caller client's credentials";
throw new Error("fromTemporaryCredentials recursion in callerClientConfig.credentials");
} else if (credentialSources[3]) {
credentialSource = "caller client's credentialDefaultProvider";
} else if (credentialSources[4]) {
credentialSource = "AWS SDK default credentials";
}

const regionSources = [options.clientConfig?.region, callerClientConfig?.region, ASSUME_ROLE_DEFAULT_REGION];
let regionSource = "default partition's default region";
if (regionSources[0]) {
regionSource = "options.clientConfig.region";
} else if (regionSources[1]) {
regionSource = "caller client's region";
}

const requestHandlerSources = [
filterRequestHandler(options.clientConfig?.requestHandler),
filterRequestHandler(callerClientConfig?.requestHandler),
];
let requestHandlerSource = "STS default requestHandler";
if (requestHandlerSources[0]) {
requestHandlerSource = "options.clientConfig.requestHandler";
} else if (requestHandlerSources[1]) {
requestHandlerSource = "caller client's requestHandler";
}

logger?.debug?.(
`@aws-sdk/credential-providers - fromTemporaryCredentials STS client init with ` +
`${regionSource}=${await normalizeProvider(
coalesce(regionSources)
)()}, ${credentialSource}, ${requestHandlerSource}.`
);

stsClient = new STSClient({
...options.clientConfig,
credentials:
options.masterCredentials ??
options.clientConfig?.credentials ??
callerClientConfig?.credentialDefaultProvider?.() ??
defaultCredentialsOrError,
credentials: coalesce(credentialSources),
logger,
profile: options.clientConfig?.profile ?? callerClientConfig?.profile,
region: coalesce(regionSources),
requestHandler: coalesce(requestHandlerSources),
});
}
if (options.clientPlugins) {
Expand All @@ -60,7 +121,7 @@ export const fromTemporaryCredentials = (
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new CredentialsProviderError(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`, {
logger: options.logger,
logger,
});
}
return {
Expand All @@ -73,3 +134,21 @@ export const fromTemporaryCredentials = (
};
};
};

/**
* @internal
*/
const filterRequestHandler = (requestHandler: STSClientConfig["requestHandler"]): undefined | typeof requestHandler => {
return (requestHandler as RequestHandler<any, any>)?.metadata?.handlerProtocol === "h2" ? undefined : requestHandler;
};

/**
* @internal
*/
const coalesce = (args: any) => {
for (const item of args) {
if (item !== undefined) {
return item;
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ describe("fromTemporaryCredentials", () => {
await provider();
expect(vi.mocked(STSClient as any)).toHaveBeenCalledWith({
credentials: masterCredentials,
logger: void 0,
profile: void 0,
region: "us-east-1",
requestHandler: void 0,
});
expect(mockUsePlugin).toHaveBeenCalledTimes(1);
expect(mockUsePlugin).toHaveBeenNthCalledWith(1, plugin);
Expand Down Expand Up @@ -193,6 +197,44 @@ describe("fromTemporaryCredentials", () => {
});
});

it("uses caller client options if not overridden with provider client options", async () => {
const provider = fromTemporaryCredentialsNode({
params: {
RoleArn,
RoleSessionName,
},
});
const logger = {
debug() {},
info() {},
warn() {},
error() {},
};
const credentials = {
accessKeyId: "",
secretAccessKey: "",
};
const credentialProvider = async () => credentials;
const regionProvider = async () => "B";
await provider({
callerClientConfig: {
profile: "A",
region: regionProvider,
logger,
requestHandler: Symbol.for("requestHandler") as any,
credentialDefaultProvider: () => credentialProvider,
},
});
expect(vi.mocked(STSClient as any).mock.calls[0][0]).toEqual({
profile: "A",
region: regionProvider,
logger,
requestHandler: Symbol.for("requestHandler") as any,
// mockImpl resolved the credentials.
credentials,
});
});

it("should allow assume roles assuming roles assuming roles ad infinitum", async () => {
const roleArnOf = (id: string) => `arn:aws:iam::123456789:role/${id}`;
const idOf = (roleArn: string) => roleArn.split("/")?.[1] ?? "UNKNOWN";
Expand Down
Loading
Loading