From 2455761560031131f590931b83343e5e951f1311 Mon Sep 17 00:00:00 2001 From: Michael Sandstedt Date: Fri, 7 Jan 2022 17:49:40 -0600 Subject: [PATCH] Cleanup CASESession and PASESession lifetime (#13357) The fix in #12794 means that the CASESession and PASESession objects do not need to persist for ExchangeMessageDispatch::SendMessage to succeed at the final ACK of session establishment. Instead, SendMessage uses the SessionEstablishmentExchangeDispatch global singleton. This means we can address 13146 such that CASESession and PASESession may actually be freed or reused when completion callbacks fire. This will only work, however, if these objects clear themselves as delegates for their exchange contexts when discarding references to these. This commit does so. This commit also reorders all calls to mDelegate->OnSessionEstablished and mDelegate->OnSessionEstablishmentError to occur last in any given method in case mDelegate frees or reuses the CASESession or PASESession objects on execution of these completion callbacks. With this, we can remove the code in the OperationalDeviceProxy that defers release of the CASESession object until after an iteration of the event loop. Now when OnSessionEstablished fires, the CASESession can be reused or discarded immediately. --- src/app/OperationalDeviceProxy.cpp | 19 +++----- src/app/OperationalDeviceProxy.h | 2 +- src/protocols/secure_channel/CASESession.cpp | 49 ++++++++++++++------ src/protocols/secure_channel/CASESession.h | 6 +++ src/protocols/secure_channel/PASESession.cpp | 37 +++++++++++---- src/protocols/secure_channel/PASESession.h | 6 +++ 6 files changed, 82 insertions(+), 37 deletions(-) diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index 0164e5c08247d6..5d3aca2620cf91 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -217,7 +217,7 @@ void OperationalDeviceProxy::HandleCASEConnectionFailure(void * context, CASECli device->DequeueConnectionSuccessCallbacks(/* executeCallback */ false); device->DequeueConnectionFailureCallbacks(error, /* executeCallback */ true); - device->DeferCloseCASESession(); + device->CloseCASESession(); } void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * client) @@ -238,7 +238,7 @@ void OperationalDeviceProxy::HandleCASEConnected(void * context, CASEClient * cl device->DequeueConnectionFailureCallbacks(CHIP_NO_ERROR, /* executeCallback */ false); device->DequeueConnectionSuccessCallbacks(/* executeCallback */ true); - device->DeferCloseCASESession(); + device->CloseCASESession(); } } @@ -276,22 +276,15 @@ void OperationalDeviceProxy::Clear() mInitParams = DeviceProxyInitParams(); } -void OperationalDeviceProxy::CloseCASESessionTask(System::Layer * layer, void * context) +void OperationalDeviceProxy::CloseCASESession() { - OperationalDeviceProxy * device = static_cast(context); - if (device->mCASEClient) + if (mCASEClient) { - device->mInitParams.clientPool->Release(device->mCASEClient); - device->mCASEClient = nullptr; + mInitParams.clientPool->Release(mCASEClient); + mCASEClient = nullptr; } } -void OperationalDeviceProxy::DeferCloseCASESession() -{ - // Defer the release for the pending Ack to be sent - mSystemLayer->ScheduleWork(CloseCASESessionTask, this); -} - void OperationalDeviceProxy::OnSessionReleased() { mState = State::Initialized; diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index 91dd127706ebce..292c58c07728ea 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -235,7 +235,7 @@ class DLL_EXPORT OperationalDeviceProxy : public DeviceProxy, SessionReleaseDele static void CloseCASESessionTask(System::Layer * layer, void * context); - void DeferCloseCASESession(); + void CloseCASESession(); void EnqueueConnectionCallbacks(Callback::Callback * onConnection, Callback::Callback * onFailure); diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index ff9b25b9ca2750..a069263f371b50 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -122,6 +122,19 @@ void CASESession::CloseExchange() } } +void CASESession::DiscardExchange() +{ + if (mExchangeCtxt != nullptr) + { + // Make sure the exchange doesn't try to notify us when it closes, + // since we might be dead by then. + mExchangeCtxt->SetDelegate(nullptr); + // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // exchange will handle that. + mExchangeCtxt = nullptr; + } +} + CHIP_ERROR CASESession::ToCachable(CASESessionCachable & cachableSession) { const NodeId peerNodeId = GetPeerNodeId(); @@ -252,11 +265,12 @@ void CASESession::OnResponseTimeout(ExchangeContext * ec) VerifyOrReturn(mExchangeCtxt == ec, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout exchange doesn't match")); ChipLogError(SecureChannel, "CASESession timed out while waiting for a response from the peer. Current state was %" PRIu8, mState); - mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); - // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // Discard the exchange so that Clear() doesn't try closing it. The // exchange will handle that. - mExchangeCtxt = nullptr; + DiscardExchange(); Clear(); + // Do this last in case the delegate frees us. + mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); } CHIP_ERROR CASESession::DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) @@ -683,10 +697,12 @@ CHIP_ERROR CASESession::HandleSigma2Resume(System::PacketBufferHandle && msg) mCASESessionEstablished = true; - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; + // Discard the exchange so that Clear() doesn't try closing it. The + // exchange will handle that. + DiscardExchange(); // Call delegate to indicate session establishment is successful + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablished(); exit: @@ -1117,10 +1133,12 @@ CHIP_ERROR CASESession::HandleSigma3(System::PacketBufferHandle && msg) mCASESessionEstablished = true; - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; + // Discard the exchange so that Clear() doesn't try closing it. The + // exchange will handle that. + DiscardExchange(); // Call delegate to indicate session establishment is successful + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablished(); exit: @@ -1301,16 +1319,18 @@ void CASESession::OnSuccessStatusReport() ChipLogProgress(SecureChannel, "Success status report received. Session was established"); mCASESessionEstablished = true; - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; - - // Call delegate to indicate pairing completion - mDelegate->OnSessionEstablished(); + // Discard the exchange so that Clear() doesn't try closing it. The + // exchange will handle that. + DiscardExchange(); mState = kInitialized; // TODO: Set timestamp on the new session, to allow selecting a least-recently-used session for eviction // on running out of session contexts. + + // Call delegate to indicate pairing completion. + // Do this last in case the delegate frees us. + mDelegate->OnSessionEstablished(); } CHIP_ERROR CASESession::OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) @@ -1522,10 +1542,11 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea // Call delegate to indicate session establishment failure. if (err != CHIP_NO_ERROR) { - // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // Discard the exchange so that Clear() doesn't try closing it. The // exchange will handle that. - mExchangeCtxt = nullptr; + DiscardExchange(); Clear(); + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablishmentError(err); } return err; diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index b266158c011f4d..a780c298a6dd9a 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -220,6 +220,12 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin void CloseExchange(); + /** + * Clear our reference to our exchange context pointer so that it can close + * itself at some later time. + */ + void DiscardExchange(); + // TODO: Remove this and replace with system method to retrieve current time CHIP_ERROR SetEffectiveTime(void); diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 4056ee1a44da31..3f1d29cadc146e 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -116,6 +116,19 @@ void PASESession::CloseExchange() } } +void PASESession::DiscardExchange() +{ + if (mExchangeCtxt != nullptr) + { + // Make sure the exchange doesn't try to notify us when it closes, + // since we might be dead by then. + mExchangeCtxt->SetDelegate(nullptr); + // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // exchange will handle that. + mExchangeCtxt = nullptr; + } +} + CHIP_ERROR PASESession::Serialize(PASESessionSerialized & output) { PASESessionSerializable serializable; @@ -349,11 +362,12 @@ void PASESession::OnResponseTimeout(ExchangeContext * ec) ChipLogError(SecureChannel, "PASESession timed out while waiting for a response from the peer. Expected message type was %" PRIu8, to_underlying(mNextExpectedMsg)); - mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); - // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // Discard the exchange so that Clear() doesn't try closing it. The // exchange will handle that. - mExchangeCtxt = nullptr; + DiscardExchange(); Clear(); + // Do this last in case the delegate frees us. + mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); } CHIP_ERROR PASESession::DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) @@ -829,10 +843,12 @@ CHIP_ERROR PASESession::HandleMsg3(System::PacketBufferHandle && msg) mPairingComplete = true; - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; + // Discard the exchange so that Clear() doesn't try closing it. The + // exchange will handle that. + DiscardExchange(); // Call delegate to indicate pairing completion + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablished(); exit: @@ -848,10 +864,12 @@ void PASESession::OnSuccessStatusReport() { mPairingComplete = true; - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; + // Discard the exchange so that Clear() doesn't try closing it. The + // exchange will handle that. + DiscardExchange(); // Call delegate to indicate pairing completion + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablished(); } @@ -942,11 +960,12 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Payl // Call delegate to indicate pairing failure if (err != CHIP_NO_ERROR) { - // Null out mExchangeCtxt so that Clear() doesn't try closing it. The + // Discard the exchange so that Clear() doesn't try closing it. The // exchange will handle that. - mExchangeCtxt = nullptr; + DiscardExchange(); Clear(); ChipLogError(SecureChannel, "Failed during PASE session setup. %s", ErrorStr(err)); + // Do this last in case the delegate frees us. mDelegate->OnSessionEstablishmentError(err); } return err; diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index 4bb43440b0cc37..f3ff1ae62f0df0 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -263,6 +263,12 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin void CloseExchange(); + /** + * Clear our reference to our exchange context pointer so that it can close + * itself at some later time. + */ + void DiscardExchange(); + SessionEstablishmentDelegate * mDelegate = nullptr; Protocols::SecureChannel::MsgType mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_PakeError;