From c604e74e2ca0bcf827441558e40cecbaf9f4fe94 Mon Sep 17 00:00:00 2001 From: Bailey Pearson Date: Thu, 4 Apr 2024 09:23:17 -0600 Subject: [PATCH] fix(NODE-6051): only provide expected allowed keys to libmongocrypt after fetching aws kms credentials (#4057) --- src/client-side-encryption/providers/aws.ts | 27 ++++--- test/integration/auth/mongodb_aws.test.ts | 73 +++++++++++++++---- ...records_for_mongos_discovery.prose.test.ts | 10 +-- .../providers/credentialsProvider.test.ts | 38 +++++----- test/unit/connection_string.spec.test.ts | 7 +- test/unit/sdam/monitor.test.ts | 5 +- test/unit/sdam/topology.test.ts | 10 +-- 7 files changed, 110 insertions(+), 60 deletions(-) diff --git a/src/client-side-encryption/providers/aws.ts b/src/client-side-encryption/providers/aws.ts index 64aa9f0adc..240e560bd9 100644 --- a/src/client-side-encryption/providers/aws.ts +++ b/src/client-side-encryption/providers/aws.ts @@ -1,20 +1,27 @@ -import { getAwsCredentialProvider } from '../../deps'; +import { AWSSDKCredentialProvider } from '../../cmap/auth/aws_temporary_credentials'; import { type KMSProviders } from '.'; /** * @internal */ export async function loadAWSCredentials(kmsProviders: KMSProviders): Promise { - const credentialProvider = getAwsCredentialProvider(); + const credentialProvider = new AWSSDKCredentialProvider(); - if ('kModuleError' in credentialProvider) { - return kmsProviders; - } + // We shouldn't ever receive a response from the AWS SDK that doesn't have a `SecretAccessKey` + // or `AccessKeyId`. However, TS says these fields are optional. We provide empty strings + // and let libmongocrypt error if we're unable to fetch the required keys. + const { + SecretAccessKey = '', + AccessKeyId = '', + Token + } = await credentialProvider.getCredentials(); + const aws: NonNullable = { + secretAccessKey: SecretAccessKey, + accessKeyId: AccessKeyId + }; + // the AWS session token is only required for temporary credentials so only attach it to the + // result if it's present in the response from the aws sdk + Token != null && (aws.sessionToken = Token); - const { fromNodeProviderChain } = credentialProvider; - const provider = fromNodeProviderChain(); - // The state machine is the only place calling this so it will - // catch if there is a rejection here. - const aws = await provider(); return { ...kmsProviders, aws }; } diff --git a/test/integration/auth/mongodb_aws.test.ts b/test/integration/auth/mongodb_aws.test.ts index a96ed91d53..5fceca50a4 100644 --- a/test/integration/auth/mongodb_aws.test.ts +++ b/test/integration/auth/mongodb_aws.test.ts @@ -5,30 +5,26 @@ import * as http from 'http'; import { performance } from 'perf_hooks'; import * as sinon from 'sinon'; +// eslint-disable-next-line @typescript-eslint/no-restricted-imports +import { refreshKMSCredentials } from '../../../src/client-side-encryption/providers'; import { AWSTemporaryCredentialProvider, MongoAWSError, type MongoClient, MongoDBAWS, MongoMissingCredentialsError, - MongoServerError + MongoServerError, + setDifference } from '../../mongodb'; -function awsSdk() { - try { - return require('@aws-sdk/credential-providers'); - } catch { - return null; - } -} +const isMongoDBAWSAuthEnvironment = (process.env.MONGODB_URI ?? '').includes('MONGODB-AWS'); describe('MONGODB-AWS', function () { let awsSdkPresent; let client: MongoClient; beforeEach(function () { - const MONGODB_URI = process.env.MONGODB_URI; - if (!MONGODB_URI || MONGODB_URI.indexOf('MONGODB-AWS') === -1) { + if (!isMongoDBAWSAuthEnvironment) { this.currentTest.skipReason = 'requires MONGODB_URI to contain MONGODB-AWS auth mechanism'; return this.skip(); } @@ -39,7 +35,7 @@ describe('MONGODB-AWS', function () { `Always inform the AWS tests if they run with or without the SDK (MONGODB_AWS_SDK=${MONGODB_AWS_SDK})` ).to.include(MONGODB_AWS_SDK); - awsSdkPresent = !!awsSdk(); + awsSdkPresent = AWSTemporaryCredentialProvider.isAWSSDKInstalled; expect( awsSdkPresent, MONGODB_AWS_SDK === 'true' @@ -244,8 +240,10 @@ describe('MONGODB-AWS', function () { const envCheck = () => { const { AWS_WEB_IDENTITY_TOKEN_FILE = '' } = process.env; - credentialProvider = awsSdk(); - return AWS_WEB_IDENTITY_TOKEN_FILE.length === 0 || credentialProvider == null; + return ( + AWS_WEB_IDENTITY_TOKEN_FILE.length === 0 || + !AWSTemporaryCredentialProvider.isAWSSDKInstalled + ); }; beforeEach(function () { @@ -255,6 +253,9 @@ describe('MONGODB-AWS', function () { return this.skip(); } + // @ts-expect-error We intentionally access a protected variable. + credentialProvider = AWSTemporaryCredentialProvider.awsSDK; + storedEnv = process.env; if (test.env.AWS_STS_REGIONAL_ENDPOINTS === undefined) { delete process.env.AWS_STS_REGIONAL_ENDPOINTS; @@ -324,3 +325,49 @@ describe('MONGODB-AWS', function () { } }); }); + +describe('AWS KMS Credential Fetching', function () { + context('when the AWS SDK is not installed', function () { + beforeEach(function () { + this.currentTest.skipReason = !isMongoDBAWSAuthEnvironment + ? 'Test must run in an AWS auth testing environment' + : AWSTemporaryCredentialProvider.isAWSSDKInstalled + ? 'This test must run in an environment where the AWS SDK is not installed.' + : undefined; + this.currentTest?.skipReason && this.skip(); + }); + it('fetching AWS KMS credentials throws an error', async function () { + const error = await refreshKMSCredentials({ aws: {} }).catch(e => e); + expect(error).to.be.instanceOf(MongoAWSError); + }); + }); + + context('when the AWS SDK is installed', function () { + beforeEach(function () { + this.currentTest.skipReason = !isMongoDBAWSAuthEnvironment + ? 'Test must run in an AWS auth testing environment' + : !AWSTemporaryCredentialProvider.isAWSSDKInstalled + ? 'This test must run in an environment where the AWS SDK is installed.' + : undefined; + this.currentTest?.skipReason && this.skip(); + }); + it('KMS credentials are successfully fetched.', async function () { + const { aws } = await refreshKMSCredentials({ aws: {} }); + + expect(aws).to.have.property('accessKeyId'); + expect(aws).to.have.property('secretAccessKey'); + }); + + it('does not return any extra keys for the `aws` credential provider', async function () { + const { aws } = await refreshKMSCredentials({ aws: {} }); + + const keys = new Set(Object.keys(aws ?? {})); + const allowedKeys = ['accessKeyId', 'secretAccessKey', 'sessionToken']; + + expect( + Array.from(setDifference(keys, allowedKeys)), + 'received an unexpected key in the response refreshing KMS credentials' + ).to.deep.equal([]); + }); + }); +}); diff --git a/test/unit/assorted/polling_srv_records_for_mongos_discovery.prose.test.ts b/test/unit/assorted/polling_srv_records_for_mongos_discovery.prose.test.ts index d60a81d034..8ec5cd8e29 100644 --- a/test/unit/assorted/polling_srv_records_for_mongos_discovery.prose.test.ts +++ b/test/unit/assorted/polling_srv_records_for_mongos_discovery.prose.test.ts @@ -1,7 +1,7 @@ import { expect } from 'chai'; import * as dns from 'dns'; import { once } from 'events'; -import { coerce } from 'semver'; +import { satisfies } from 'semver'; import * as sinon from 'sinon'; import { @@ -51,11 +51,9 @@ describe('Polling Srv Records for Mongos Discovery', () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const test = this.currentTest!; - const { major } = coerce(process.version); - test.skipReason = - major === 18 || major === 20 - ? 'TODO(NODE-5666): fix failing unit tests on Node18' - : undefined; + test.skipReason = satisfies(process.version, '>=18.0.0') + ? `TODO(NODE-5666): fix failing unit tests on Node18 (Running with Nodejs ${process.version})` + : undefined; if (test.skipReason) this.skip(); }); diff --git a/test/unit/client-side-encryption/providers/credentialsProvider.test.ts b/test/unit/client-side-encryption/providers/credentialsProvider.test.ts index ddd9c87b45..486fb41c60 100644 --- a/test/unit/client-side-encryption/providers/credentialsProvider.test.ts +++ b/test/unit/client-side-encryption/providers/credentialsProvider.test.ts @@ -20,6 +20,8 @@ import { } from '../../../../src/client-side-encryption/providers/azure'; // eslint-disable-next-line @typescript-eslint/no-restricted-imports import * as utils from '../../../../src/client-side-encryption/providers/utils'; +// eslint-disable-next-line @typescript-eslint/no-restricted-imports +import { AWSSDKCredentialProvider } from '../../../../src/cmap/auth/aws_temporary_credentials'; import * as requirements from '../requirements.helper'; const originalAccessKeyId = process.env.AWS_ACCESS_KEY_ID; @@ -154,25 +156,25 @@ describe('#refreshKMSCredentials', function () { }); }); - context('when the sdk is not installed', function () { - const kmsProviders = { - local: { - key: Buffer.alloc(96) - }, - aws: {} - }; - - before(function () { - if (requirements.credentialProvidersInstalled.aws && this.currentTest) { - this.currentTest.skipReason = 'Credentials will be loaded when sdk present'; - this.currentTest.skip(); - return; - } + context('when the AWS SDK returns unknown fields', function () { + beforeEach(() => { + sinon.stub(AWSSDKCredentialProvider.prototype, 'getCredentials').resolves({ + Token: 'example', + SecretAccessKey: 'example', + AccessKeyId: 'example', + Expiration: new Date() + }); }); - - it('does not refresh credentials', async function () { - const providers = await refreshKMSCredentials(kmsProviders); - expect(providers).to.deep.equal(kmsProviders); + afterEach(() => sinon.restore()); + it('only returns fields libmongocrypt expects', async function () { + const credentials = await refreshKMSCredentials({ aws: {} }); + expect(credentials).to.deep.equal({ + aws: { + accessKeyId: accessKey, + secretAccessKey: secretKey, + sessionToken: sessionToken + } + }); }); }); }); diff --git a/test/unit/connection_string.spec.test.ts b/test/unit/connection_string.spec.test.ts index cc4fbb7c69..213b2101c9 100644 --- a/test/unit/connection_string.spec.test.ts +++ b/test/unit/connection_string.spec.test.ts @@ -1,4 +1,4 @@ -import { coerce } from 'semver'; +import { satisfies } from 'semver'; import { loadSpecTests } from '../spec'; import { executeUriValidationTest } from '../tools/uri_spec_runner'; @@ -15,14 +15,13 @@ describe('Connection String spec tests', function () { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const test = this.currentTest!; - const { major } = coerce(process.version); const skippedTests = [ 'Invalid port (zero) with IP literal', 'Invalid port (zero) with hostname' ]; test.skipReason = - major === 20 && skippedTests.includes(test.title) - ? 'TODO(NODE-5666): fix failing unit tests on Node18' + satisfies(process.version, '>=20.0.0') && skippedTests.includes(test.title) + ? 'TODO(NODE-5666): fix failing unit tests on Node20+' : undefined; if (test.skipReason) this.skip(); diff --git a/test/unit/sdam/monitor.test.ts b/test/unit/sdam/monitor.test.ts index 408a2b2665..939c1c2de3 100644 --- a/test/unit/sdam/monitor.test.ts +++ b/test/unit/sdam/monitor.test.ts @@ -2,7 +2,7 @@ import { once } from 'node:events'; import * as net from 'node:net'; import { expect } from 'chai'; -import { coerce } from 'semver'; +import { satisfies } from 'semver'; import * as sinon from 'sinon'; import { setTimeout } from 'timers'; import { setTimeout as setTimeoutPromise } from 'timers/promises'; @@ -57,7 +57,6 @@ describe('monitoring', function () { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const test = this.currentTest!; - const { major } = coerce(process.version); const failingTests = [ 'should connect and issue an initial server check', 'should ignore attempts to connect when not already closed', @@ -67,7 +66,7 @@ describe('monitoring', function () { 'correctly returns the mean of the heartbeat durations' ]; test.skipReason = - (major === 18 || major === 20) && failingTests.includes(test.title) + satisfies(process.version, '>=18.0.0') && failingTests.includes(test.title) ? 'TODO(NODE-5666): fix failing unit tests on Node18' : undefined; diff --git a/test/unit/sdam/topology.test.ts b/test/unit/sdam/topology.test.ts index c6511d27c4..e4a34417d5 100644 --- a/test/unit/sdam/topology.test.ts +++ b/test/unit/sdam/topology.test.ts @@ -2,7 +2,7 @@ import { expect } from 'chai'; import { once } from 'events'; import * as net from 'net'; import { type AddressInfo } from 'net'; -import { coerce, type SemVer } from 'semver'; +import { satisfies } from 'semver'; import * as sinon from 'sinon'; import { clearTimeout } from 'timers'; @@ -284,11 +284,9 @@ describe('Topology (unit)', function () { it('should encounter a server selection timeout on garbled server responses', function () { const test = this.test; - const { major } = coerce(process.version) as SemVer; - test.skipReason = - major === 18 || major === 20 - ? 'TODO(NODE-5666): fix failing unit tests on Node18' - : undefined; + test.skipReason = satisfies(process.version, '>=18.0.0') + ? 'TODO(NODE-5666): fix failing unit tests on Node18' + : undefined; if (test.skipReason) this.skip();