diff --git a/lib/_tls_wrap.js b/lib/_tls_wrap.js index 5950c52c1dc2ec..84599102be4b74 100644 --- a/lib/_tls_wrap.js +++ b/lib/_tls_wrap.js @@ -502,17 +502,28 @@ function TLSSocket(socket, opts) { this[kPendingSession] = null; let wrap; - if ((socket instanceof net.Socket && socket._handle) || !socket) { - // 1. connected socket - // 2. no socket, one will be created with net.Socket().connect - wrap = socket; + let handle; + let wrapHasActiveWriteFromPrevOwner; + + if (socket) { + if (socket instanceof net.Socket && socket._handle) { + // 1. connected socket + wrap = socket; + } else { + // 2. socket has no handle so it is js not c++ + // 3. unconnected sockets are wrapped + // TLS expects to interact from C++ with a net.Socket that has a C++ stream + // handle, but a JS stream doesn't have one. Wrap it up to make it look like + // a socket. + wrap = new JSStreamSocket(socket); + } + + handle = wrap._handle; + wrapHasActiveWriteFromPrevOwner = wrap.writableLength > 0; } else { - // 3. socket has no handle so it is js not c++ - // 4. unconnected sockets are wrapped - // TLS expects to interact from C++ with a net.Socket that has a C++ stream - // handle, but a JS stream doesn't have one. Wrap it up to make it look like - // a socket. - wrap = new JSStreamSocket(socket); + // 4. no socket, one will be created with net.Socket().connect + wrap = null; + wrapHasActiveWriteFromPrevOwner = false; } // Just a documented property to make secure sockets @@ -520,7 +531,7 @@ function TLSSocket(socket, opts) { this.encrypted = true; ReflectApply(net.Socket, this, [{ - handle: this._wrapHandle(wrap), + handle: this._wrapHandle(wrap, handle, wrapHasActiveWriteFromPrevOwner), allowHalfOpen: socket ? socket.allowHalfOpen : tlsOptions.allowHalfOpen, pauseOnCreate: tlsOptions.pauseOnConnect, manualStart: true, @@ -539,6 +550,21 @@ function TLSSocket(socket, opts) { if (enableTrace && this._handle) this._handle.enableTrace(); + if (wrapHasActiveWriteFromPrevOwner) { + // `wrap` is a streams.Writable in JS. This empty write will be queued + // and hence finish after all existing writes, which is the timing + // we want to start to send any tls data to `wrap`. + wrap.write('', (err) => { + if (err) { + debug('error got before writing any tls data to the underlying stream'); + this.destroy(err); + return; + } + + this._handle.writesIssuedByPrevListenerDone(); + }); + } + // Read on next tick so the caller has a chance to setup listeners process.nextTick(initRead, this, socket); } @@ -599,11 +625,14 @@ TLSSocket.prototype.disableRenegotiation = function disableRenegotiation() { this[kDisableRenegotiation] = true; }; -TLSSocket.prototype._wrapHandle = function(wrap, handle) { - if (!handle && wrap) { - handle = wrap._handle; - } - +/** + * + * @param {null|net.Socket} wrap + * @param {null|object} handle + * @param {boolean} wrapHasActiveWriteFromPrevOwner + * @returns {object} + */ +TLSSocket.prototype._wrapHandle = function(wrap, handle, wrapHasActiveWriteFromPrevOwner) { const options = this._tlsOptions; if (!handle) { handle = options.pipe ? @@ -620,7 +649,10 @@ TLSSocket.prototype._wrapHandle = function(wrap, handle) { if (!(context.context instanceof NativeSecureContext)) { throw new ERR_TLS_INVALID_CONTEXT('context'); } - const res = tls_wrap.wrap(handle, context.context, !!options.isServer); + + const res = tls_wrap.wrap(handle, context.context, + !!options.isServer, + wrapHasActiveWriteFromPrevOwner); res._parent = handle; // C++ "wrap" object: TCPWrap, JSStream, ... res._parentWrap = wrap; // JS object: net.Socket, JSStreamSocket, ... res._secureContext = context; @@ -637,7 +669,7 @@ TLSSocket.prototype[kReinitializeHandle] = function reinitializeHandle(handle) { const originalServername = this.ssl ? this._handle.getServername() : null; const originalSession = this.ssl ? this._handle.getSession() : null; - this.handle = this._wrapHandle(null, handle); + this.handle = this._wrapHandle(null, handle, false); this.ssl = this._handle; net.Socket.prototype[kReinitializeHandle].call(this, this.handle); diff --git a/src/crypto/crypto_tls.cc b/src/crypto/crypto_tls.cc index 71d59b511e1736..13937262997365 100644 --- a/src/crypto/crypto_tls.cc +++ b/src/crypto/crypto_tls.cc @@ -319,12 +319,15 @@ TLSWrap::TLSWrap(Environment* env, Local obj, Kind kind, StreamBase* stream, - SecureContext* sc) + SecureContext* sc, + UnderlyingStreamWriteStatus under_stream_ws) : AsyncWrap(env, obj, AsyncWrap::PROVIDER_TLSWRAP), StreamBase(env), env_(env), kind_(kind), - sc_(sc) { + sc_(sc), + has_active_write_issued_by_prev_listener_( + under_stream_ws == UnderlyingStreamWriteStatus::kHasActive) { MakeWeak(); CHECK(sc_); ssl_ = sc_->CreateSSL(); @@ -434,14 +437,19 @@ void TLSWrap::InitSSL() { void TLSWrap::Wrap(const FunctionCallbackInfo& args) { Environment* env = Environment::GetCurrent(args); - CHECK_EQ(args.Length(), 3); + CHECK_EQ(args.Length(), 4); CHECK(args[0]->IsObject()); CHECK(args[1]->IsObject()); CHECK(args[2]->IsBoolean()); + CHECK(args[3]->IsBoolean()); Local sc = args[1].As(); Kind kind = args[2]->IsTrue() ? Kind::kServer : Kind::kClient; + UnderlyingStreamWriteStatus under_stream_ws = + args[3]->IsTrue() ? UnderlyingStreamWriteStatus::kHasActive + : UnderlyingStreamWriteStatus::kVacancy; + StreamBase* stream = StreamBase::FromObject(args[0].As()); CHECK_NOT_NULL(stream); @@ -452,7 +460,8 @@ void TLSWrap::Wrap(const FunctionCallbackInfo& args) { return; } - TLSWrap* res = new TLSWrap(env, obj, kind, stream, Unwrap(sc)); + TLSWrap* res = new TLSWrap( + env, obj, kind, stream, Unwrap(sc), under_stream_ws); args.GetReturnValue().Set(res->object()); } @@ -558,6 +567,13 @@ void TLSWrap::EncOut() { return; } + if (UNLIKELY(has_active_write_issued_by_prev_listener_)) { + Debug(this, + "Returning from EncOut(), " + "has_active_write_issued_by_prev_listener_ is true"); + return; + } + // Split-off queue if (established_ && current_write_) { Debug(this, "EncOut() write is scheduled"); @@ -628,6 +644,15 @@ void TLSWrap::EncOut() { void TLSWrap::OnStreamAfterWrite(WriteWrap* req_wrap, int status) { Debug(this, "OnStreamAfterWrite(status = %d)", status); + + if (UNLIKELY(has_active_write_issued_by_prev_listener_)) { + Debug(this, "Notify write finish to the previous_listener_"); + CHECK_EQ(write_size_, 0); // we must have restrained writes + + previous_listener_->OnStreamAfterWrite(req_wrap, status); + return; + } + if (current_empty_write_) { Debug(this, "Had empty write"); BaseObjectPtr current_empty_write = @@ -1974,6 +1999,16 @@ void TLSWrap::GetALPNNegotiatedProto(const FunctionCallbackInfo& args) { args.GetReturnValue().Set(result); } +void TLSWrap::WritesIssuedByPrevListenerDone( + const FunctionCallbackInfo& args) { + TLSWrap* w; + ASSIGN_OR_RETURN_UNWRAP(&w, args.Holder()); + + Debug(w, "WritesIssuedByPrevListenerDone is called"); + w->has_active_write_issued_by_prev_listener_ = false; + w->EncOut(); // resume all of our restrained writes +} + void TLSWrap::Cycle() { // Prevent recursion if (++cycle_depth_ > 1) @@ -2050,6 +2085,10 @@ void TLSWrap::Initialize( SetProtoMethod(isolate, t, "setSession", SetSession); SetProtoMethod(isolate, t, "setVerifyMode", SetVerifyMode); SetProtoMethod(isolate, t, "start", Start); + SetProtoMethod(isolate, + t, + "writesIssuedByPrevListenerDone", + WritesIssuedByPrevListenerDone); SetProtoMethodNoSideEffect( isolate, t, "exportKeyingMaterial", ExportKeyingMaterial); @@ -2131,6 +2170,7 @@ void TLSWrap::RegisterExternalReferences(ExternalReferenceRegistry* registry) { registry->Register(GetSharedSigalgs); registry->Register(GetTLSTicket); registry->Register(VerifyError); + registry->Register(WritesIssuedByPrevListenerDone); #ifdef SSL_set_max_send_fragment registry->Register(SetMaxSendFragment); diff --git a/src/crypto/crypto_tls.h b/src/crypto/crypto_tls.h index 5d1b24ff772328..01414efc6b0258 100644 --- a/src/crypto/crypto_tls.h +++ b/src/crypto/crypto_tls.h @@ -48,6 +48,8 @@ class TLSWrap : public AsyncWrap, kServer }; + enum class UnderlyingStreamWriteStatus { kHasActive, kVacancy }; + static void Initialize(v8::Local target, v8::Local unused, v8::Local context, @@ -136,7 +138,8 @@ class TLSWrap : public AsyncWrap, v8::Local obj, Kind kind, StreamBase* stream, - SecureContext* sc); + SecureContext* sc, + UnderlyingStreamWriteStatus under_stream_ws); static void SSLInfoCallback(const SSL* ssl_, int where, int ret); void InitSSL(); @@ -216,6 +219,8 @@ class TLSWrap : public AsyncWrap, static void Start(const v8::FunctionCallbackInfo& args); static void VerifyError(const v8::FunctionCallbackInfo& args); static void Wrap(const v8::FunctionCallbackInfo& args); + static void WritesIssuedByPrevListenerDone( + const v8::FunctionCallbackInfo& args); #ifdef SSL_set_max_send_fragment static void SetMaxSendFragment( @@ -283,6 +288,8 @@ class TLSWrap : public AsyncWrap, BIOPointer bio_trace_; + bool has_active_write_issued_by_prev_listener_ = false; + public: std::vector alpn_protos_; // Accessed by SelectALPNCallback. }; diff --git a/test/parallel/test-double-tls-client.js b/test/parallel/test-double-tls-client.js new file mode 100644 index 00000000000000..8309a1b9559c7a --- /dev/null +++ b/test/parallel/test-double-tls-client.js @@ -0,0 +1,58 @@ +'use strict'; +const common = require('../common'); +const assert = require('assert'); +if (!common.hasCrypto) common.skip('missing crypto'); +const fixtures = require('../common/fixtures'); +const tls = require('tls'); + +// In reality, this can be a HTTP CONNECT message, signaling the incoming +// data is TLS encrypted +const HEAD = 'XXXX'; + +const subserver = tls.createServer({ + key: fixtures.readKey('agent1-key.pem'), + cert: fixtures.readKey('agent1-cert.pem'), +}) + .on('secureConnection', common.mustCall(() => { + process.exit(0); + })); + +const server = tls.createServer({ + key: fixtures.readKey('agent1-key.pem'), + cert: fixtures.readKey('agent1-cert.pem'), +}) + .listen(client) + .on('secureConnection', (serverTlsSock) => { + serverTlsSock.on('data', (chunk) => { + assert.strictEqual(chunk.toString(), HEAD); + subserver.emit('connection', serverTlsSock); + }); + }); + +function client() { + const down = tls.connect({ + host: '127.0.0.1', + port: server.address().port, + rejectUnauthorized: false + }).on('secureConnect', () => { + down.write(HEAD, common.mustSucceed()); + + // Sending tls data on a client TLSSocket with an active write led to a crash: + // + // node[16862]: ../src/crypto/crypto_tls.cc:963:virtual int node::crypto::TLSWrap::DoWrite(node::WriteWrap*, + // uv_buf_t*, size_t, uv_stream_t*): Assertion `!current_write_' failed. + // 1: 0xb090e0 node::Abort() [node] + // 2: 0xb0915e [node] + // 3: 0xca8413 node::crypto::TLSWrap::DoWrite(node::WriteWrap*, uv_buf_t*, unsigned long, uv_stream_s*) [node] + // 4: 0xcaa549 node::StreamBase::Write(uv_buf_t*, unsigned long, uv_stream_s*, v8::Local) [node] + // 5: 0xca88d7 node::crypto::TLSWrap::EncOut() [node] + // 6: 0xd3df3e [node] + // 7: 0xd3f35f v8::internal::Builtin_HandleApiCall(int, unsigned long*, v8::internal::Isolate*) [node] + // 8: 0x15d9ef9 [node] + // Aborted + tls.connect({ + socket: down, + rejectUnauthorized: false + }); + }); +} diff --git a/test/parallel/test-double-tls-server.js b/test/parallel/test-double-tls-server.js new file mode 100644 index 00000000000000..2964748a121159 --- /dev/null +++ b/test/parallel/test-double-tls-server.js @@ -0,0 +1,94 @@ +'use strict'; +const common = require('../common'); +const assert = require('assert'); +if (!common.hasCrypto) common.skip('missing crypto'); +const fixtures = require('../common/fixtures'); +const tls = require('tls'); +const net = require('net'); + +// Sending tls data on a server TLSSocket with an active write led to a crash: +// +// node[1296]: ../src/crypto/crypto_tls.cc:963:virtual int node::crypto::TLSWrap::DoWrite(node::WriteWrap*, +// uv_buf_t*, size_t, uv_stream_t*): Assertion `!current_write_' failed. +// 1: 0xb090e0 node::Abort() [node] +// 2: 0xb0915e [node] +// 3: 0xca8413 node::crypto::TLSWrap::DoWrite(node::WriteWrap*, uv_buf_t*, unsigned long, uv_stream_s*) [node] +// 4: 0xcaa549 node::StreamBase::Write(uv_buf_t*, unsigned long, uv_stream_s*, v8::Local) [node] +// 5: 0xca88d7 node::crypto::TLSWrap::EncOut() [node] +// 6: 0xca9ba8 node::crypto::TLSWrap::OnStreamRead(long, uv_buf_t const&) [node] +// 7: 0xca8eb0 node::crypto::TLSWrap::ClearOut() [node] +// 8: 0xca9ba0 node::crypto::TLSWrap::OnStreamRead(long, uv_buf_t const&) [node] +// 9: 0xbe50dd node::LibuvStreamWrap::OnUvRead(long, uv_buf_t const*) [node] +// 10: 0xbe54c4 [node] +// 11: 0x15583d7 [node] +// 12: 0x1558c00 [node] +// 13: 0x155ede4 [node] +// 14: 0x154d008 uv_run [node] + +const serverReplaySize = 2 * 1024 * 1024; + +(async function() { + const tlsClientHello = await getClientHello(); + + const subserver = tls.createServer({ + key: fixtures.readKey('agent1-key.pem'), + cert: fixtures.readKey('agent1-cert.pem'), + }); + + const server = tls.createServer({ + key: fixtures.readKey('agent1-key.pem'), + cert: fixtures.readKey('agent1-cert.pem'), + }) + .listen(startClient) + .on('secureConnection', (serverTlsSock) => { + // Craft writes that are large enough to stuck in sending + // In reality this can be a 200 response to the incoming HTTP CONNECT + const half = Buffer.alloc(serverReplaySize / 2, 0); + serverTlsSock.write(half, common.mustSucceed()); + serverTlsSock.write(half, common.mustSucceed()); + + subserver.emit('connection', serverTlsSock); + }); + + + function startClient() { + const clientTlsSock = tls.connect({ + host: '127.0.0.1', + port: server.address().port, + rejectUnauthorized: false, + }); + + const recv = []; + let revcLen = 0; + clientTlsSock.on('data', (chunk) => { + revcLen += chunk.length; + recv.push(chunk); + if (revcLen > serverReplaySize) { + // Check the server's replay is followed by the subserver's TLS ServerHello + const serverHelloFstByte = Buffer.concat(recv).subarray(serverReplaySize, serverReplaySize + 1); + assert.strictEqual(serverHelloFstByte.toString('hex'), '16'); + process.exit(0); + } + }); + + // In reality, one may want to send a HTTP CONNECT before starting this double TLS + clientTlsSock.write(tlsClientHello); + } +})().then(common.mustCall()); + +function getClientHello() { + return new Promise((resolve) => { + const server = net.createServer((sock) => { + sock.on('data', (chunk) => { + resolve(chunk); + }); + }) + .listen(() => { + tls.connect({ + port: server.address().port, + host: '127.0.0.1', + ALPNProtocols: ['h2'], + }).on('error', () => {}); + }); + }); +} diff --git a/test/parallel/test-socket-writes-before-passed-to-tls-socket.js b/test/parallel/test-socket-writes-before-passed-to-tls-socket.js new file mode 100644 index 00000000000000..22c5b87111579c --- /dev/null +++ b/test/parallel/test-socket-writes-before-passed-to-tls-socket.js @@ -0,0 +1,42 @@ +'use strict'; +const common = require('../common'); +const assert = require('assert'); +if (!common.hasCrypto) common.skip('missing crypto'); +const tls = require('tls'); +const net = require('net'); + +const HEAD = Buffer.alloc(1024 * 1024, 0); + +const server = net.createServer((serverSock) => { + let recvLen = 0; + const recv = []; + serverSock.on('data', common.mustCallAtLeast((chunk) => { + recv.push(chunk); + recvLen += chunk.length; + + // Check that HEAD is followed by a client hello + if (recvLen > HEAD.length) { + const clientHelloFstByte = Buffer.concat(recv).subarray(HEAD.length, HEAD.length + 1); + assert.strictEqual(clientHelloFstByte.toString('hex'), '16'); + process.exit(0); + } + }, 1)); +}) + .listen(client); + +function client() { + const socket = net.createConnection({ + host: '127.0.0.1', + port: server.address().port, + }); + socket.write(HEAD.subarray(0, HEAD.length / 2), common.mustSucceed()); + + // This write will be queued by streams.Writable, the super class of net.Socket, + // which will dequeue this write when it gets notified about the finish of the first write. + // We had a bug that it wouldn't get notified. This test verifies the bug is fixed. + socket.write(HEAD.subarray(HEAD.length / 2), common.mustSucceed()); + + tls.connect({ + socket, + }); +} diff --git a/test/sequential/test-async-wrap-getasyncid.js b/test/sequential/test-async-wrap-getasyncid.js index de7af137c47489..2f1cc19b49905b 100644 --- a/test/sequential/test-async-wrap-getasyncid.js +++ b/test/sequential/test-async-wrap-getasyncid.js @@ -292,7 +292,7 @@ if (common.hasCrypto) { // eslint-disable-line node-core/crypto-check // TLSWrap is exposed, but needs to be instantiated via tls_wrap.wrap(). const tls_wrap = internalBinding('tls_wrap'); - testInitialized(tls_wrap.wrap(tcp, credentials.context, true), 'TLSWrap'); + testInitialized(tls_wrap.wrap(tcp, credentials.context, true, false), 'TLSWrap'); } {