From 568e05fdc3d78882e925e8e799aca6fb86c88295 Mon Sep 17 00:00:00 2001 From: Alena Khineika Date: Thu, 18 Jan 2024 22:05:01 +0100 Subject: [PATCH] fix(NODE-5127): implement reject kmsRequest on server close (#3964) --- src/client-side-encryption/state_machine.ts | 139 ++++++++++-------- .../state_machine.test.ts | 90 ++++++++++++ 2 files changed, 167 insertions(+), 62 deletions(-) diff --git a/src/client-side-encryption/state_machine.ts b/src/client-side-encryption/state_machine.ts index 7d5dc23bf8..a4b2379fb5 100644 --- a/src/client-side-encryption/state_machine.ts +++ b/src/client-side-encryption/state_machine.ts @@ -13,7 +13,7 @@ import { import { type ProxyOptions } from '../cmap/connection'; import { getSocks, type SocksLib } from '../deps'; import { type MongoClient, type MongoClientOptions } from '../mongo_client'; -import { BufferPool, MongoDBCollectionNamespace } from '../utils'; +import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils'; import { type DataKey } from './client_encryption'; import { MongoCryptError } from './errors'; import { type MongocryptdManager } from './mongocryptd_manager'; @@ -282,7 +282,7 @@ export class StateMachine { * @param kmsContext - A C++ KMS context returned from the bindings * @returns A promise that resolves when the KMS reply has be fully parsed */ - kmsRequest(request: MongoCryptKMSRequest): Promise { + async kmsRequest(request: MongoCryptKMSRequest): Promise { const parsedUrl = request.endpoint.split(':'); const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT; const options: tls.ConnectionOptions & { host: string; port: number } = { @@ -291,52 +291,73 @@ export class StateMachine { port }; const message = request.message; + const buffer = new BufferPool(); - // TODO(NODE-3959): We can adopt `for-await on(socket, 'data')` with logic to control abort - // eslint-disable-next-line @typescript-eslint/no-misused-promises, no-async-promise-executor - return new Promise(async (resolve, reject) => { - const buffer = new BufferPool(); + const netSocket: net.Socket = new net.Socket(); + let socket: tls.TLSSocket; - // eslint-disable-next-line prefer-const - let socket: net.Socket; - let rawSocket: net.Socket; - - function destroySockets() { - for (const sock of [socket, rawSocket]) { - if (sock) { - sock.removeAllListeners(); - sock.destroy(); - } + function destroySockets() { + for (const sock of [socket, netSocket]) { + if (sock) { + sock.removeAllListeners(); + sock.destroy(); } } + } - function ontimeout() { - destroySockets(); - reject(new MongoCryptError('KMS request timed out')); - } + function ontimeout() { + return new MongoCryptError('KMS request timed out'); + } + + function onerror(cause: Error) { + return new MongoCryptError('KMS request failed', { cause }); + } - function onerror(err: Error) { - destroySockets(); - const mcError = new MongoCryptError('KMS request failed', { cause: err }); - reject(mcError); + function onclose() { + return new MongoCryptError('KMS request closed'); + } + + const tlsOptions = this.options.tlsOptions; + if (tlsOptions) { + const kmsProvider = request.kmsProvider as ClientEncryptionDataKeyProvider; + const providerTlsOptions = tlsOptions[kmsProvider]; + if (providerTlsOptions) { + const error = this.validateTlsOptions(kmsProvider, providerTlsOptions); + if (error) { + throw error; + } + try { + await this.setTlsOptions(providerTlsOptions, options); + } catch (err) { + throw onerror(err); + } } + } + const { + promise: willConnect, + reject: rejectOnNetSocketError, + resolve: resolveOnNetSocketConnect + } = promiseWithResolvers(); + netSocket + .once('timeout', () => rejectOnNetSocketError(ontimeout())) + .once('error', err => rejectOnNetSocketError(onerror(err))) + .once('close', () => rejectOnNetSocketError(onclose())) + .once('connect', () => resolveOnNetSocketConnect()); + + try { if (this.options.proxyOptions && this.options.proxyOptions.proxyHost) { - rawSocket = net.connect({ + netSocket.connect({ host: this.options.proxyOptions.proxyHost, port: this.options.proxyOptions.proxyPort || 1080 }); + await willConnect; - rawSocket.on('timeout', ontimeout); - rawSocket.on('error', onerror); try { - // eslint-disable-next-line @typescript-eslint/no-var-requires - const events = require('events') as typeof import('events'); - await events.once(rawSocket, 'connect'); socks ??= loadSocks(); options.socket = ( await socks.SocksClient.createConnection({ - existing_socket: rawSocket, + existing_socket: netSocket, command: 'connect', destination: { host: options.host, port: options.port }, proxy: { @@ -350,45 +371,39 @@ export class StateMachine { }) ).socket; } catch (err) { - return onerror(err); + throw onerror(err); } } - const tlsOptions = this.options.tlsOptions; - if (tlsOptions) { - const kmsProvider = request.kmsProvider as ClientEncryptionDataKeyProvider; - const providerTlsOptions = tlsOptions[kmsProvider]; - if (providerTlsOptions) { - const error = this.validateTlsOptions(kmsProvider, providerTlsOptions); - if (error) reject(error); - try { - await this.setTlsOptions(providerTlsOptions, options); - } catch (error) { - return onerror(error); - } - } - } socket = tls.connect(options, () => { socket.write(message); }); - socket.once('timeout', ontimeout); - socket.once('error', onerror); - - socket.on('data', data => { - buffer.append(data); - while (request.bytesNeeded > 0 && buffer.length) { - const bytesNeeded = Math.min(request.bytesNeeded, buffer.length); - request.addResponse(buffer.read(bytesNeeded)); - } + const { + promise: willResolveKmsRequest, + reject: rejectOnTlsSocketError, + resolve + } = promiseWithResolvers(); + socket + .once('timeout', () => rejectOnTlsSocketError(ontimeout())) + .once('error', err => rejectOnTlsSocketError(onerror(err))) + .once('close', () => rejectOnTlsSocketError(onclose())) + .on('data', data => { + buffer.append(data); + while (request.bytesNeeded > 0 && buffer.length) { + const bytesNeeded = Math.min(request.bytesNeeded, buffer.length); + request.addResponse(buffer.read(bytesNeeded)); + } - if (request.bytesNeeded <= 0) { - // There's no need for any more activity on this socket at this point. - destroySockets(); - resolve(); - } - }); - }); + if (request.bytesNeeded <= 0) { + resolve(); + } + }); + await willResolveKmsRequest; + } finally { + // There's no need for any more activity on this socket at this point. + destroySockets(); + } } *requests(context: MongoCryptContext) { diff --git a/test/unit/client-side-encryption/state_machine.test.ts b/test/unit/client-side-encryption/state_machine.test.ts index e84b7c1f18..eda8669b3e 100644 --- a/test/unit/client-side-encryption/state_machine.test.ts +++ b/test/unit/client-side-encryption/state_machine.test.ts @@ -251,6 +251,96 @@ describe('StateMachine', function () { }); }); + context('when server closed the socket', function () { + context('Socks5', function () { + let server; + + beforeEach(async function () { + server = net.createServer(async socket => { + socket.end(); + }); + server.listen(0); + await once(server, 'listening'); + }); + + afterEach(function () { + server.close(); + }); + + it('throws a MongoCryptError with SocksClientError cause', async function () { + const stateMachine = new StateMachine({ + proxyOptions: { + proxyHost: 'localhost', + proxyPort: server.address().port + } + } as any); + const request = new MockRequest(Buffer.from('foobar'), 500); + + try { + await stateMachine.kmsRequest(request); + } catch (err) { + expect(err.name).to.equal('MongoCryptError'); + expect(err.message).to.equal('KMS request failed'); + expect(err.cause.constructor.name).to.equal('SocksClientError'); + return; + } + expect.fail('missed exception'); + }); + }); + + context('endpoint with host and port', function () { + let server; + let serverSocket; + + beforeEach(async function () { + server = net.createServer(async socket => { + serverSocket = socket; + }); + server.listen(0); + await once(server, 'listening'); + }); + + afterEach(function () { + server.close(); + }); + + beforeEach(async function () { + const netSocket = net.connect({ + port: server.address().port + }); + await once(netSocket, 'connect'); + this.sinon.stub(tls, 'connect').returns(netSocket); + }); + + afterEach(function () { + server.close(); + this.sinon.restore(); + }); + + it('throws a MongoCryptError error', async function () { + const stateMachine = new StateMachine({ + host: 'localhost', + port: server.address().port + } as any); + const request = new MockRequest(Buffer.from('foobar'), 500); + + try { + const kmsRequestPromise = stateMachine.kmsRequest(request); + + await promisify(setTimeout)(0); + serverSocket.end(); + + await kmsRequestPromise; + } catch (err) { + expect(err.name).to.equal('MongoCryptError'); + expect(err.message).to.equal('KMS request closed'); + return; + } + expect.fail('missed exception'); + }); + }); + }); + afterEach(function () { this.sinon.restore(); });