diff --git a/src/messaging/ReliableMessageMgr.cpp b/src/messaging/ReliableMessageMgr.cpp index 97457252632ef1..e5a8c2025839a1 100644 --- a/src/messaging/ReliableMessageMgr.cpp +++ b/src/messaging/ReliableMessageMgr.cpp @@ -133,6 +133,17 @@ void ReliableMessageMgr::ExecuteActions() if (!rc || entry.nextRetransTimeTick != 0) continue; + if (entry.retainedBuf.IsNull()) + { + // We generally try to prevent entries with a null buffer being in a table, but it could happen + // if the message dispatch (which is supposed to fill in the buffer) fails to do so _and_ returns + // success (so its caller doesn't clear out the bogus table entry). + // + // If that were to happen, we would crash in the code below. Guard against it, just in case. + ClearRetransTable(entry); + continue; + } + uint8_t sendCount = entry.sendCount; uint32_t msgId = entry.retainedBuf.GetMsgId(); diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 818335e3c46fa3..6b7d371ebf9a44 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -102,6 +102,49 @@ class MockAppDelegate : public ExchangeDelegate bool IsOnMessageReceivedCalled = false; }; +class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDispatch +{ +public: + CHIP_ERROR SendMessageImpl(SecureSessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, + EncryptedPacketBufferHandle * retainedMessage) override + { + PacketHeader packetHeader; + + ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); + ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); + + if (retainedMessage != nullptr && mRetainMessageOnSend) + { + *retainedMessage = EncryptedPacketBufferHandle::MarkEncrypted(message.Retain()); + } + return gTransportMgr.SendMessage(Transport::PeerAddress(), std::move(message)); + } + + bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } + + bool mRetainMessageOnSend = true; +}; + +class MockSessionEstablishmentDelegate : public ExchangeDelegate +{ +public: + void OnMessageReceived(ExchangeContext * ec, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, + System::PacketBufferHandle && buffer) override + { + IsOnMessageReceivedCalled = true; + } + + void OnResponseTimeout(ExchangeContext * ec) override {} + + virtual ExchangeMessageDispatch * GetMessageDispatch(ReliableMessageMgr * rmMgr, SecureSessionMgr * sessionMgr) override + { + return &mMessageDispatch; + } + + bool IsOnMessageReceivedCalled = false; + MockSessionEstablishmentExchangeDispatch mMessageDispatch; +}; + void test_os_sleep_ms(uint64_t millisecs) { struct timespec sleep_time; @@ -291,6 +334,61 @@ void CheckCloseExchangeAndResendApplicationMessage(nlTestSuite * inSuite, void * rm->ClearRetransTable(rc); } +void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext) +{ + TestContext & ctx = *reinterpret_cast(inContext); + + ctx.GetInetLayer().SystemLayer()->Init(nullptr); + + chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); + NL_TEST_ASSERT(inSuite, !buffer.IsNull()); + + CHIP_ERROR err = CHIP_NO_ERROR; + + MockSessionEstablishmentDelegate mockSender; + ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender); + NL_TEST_ASSERT(inSuite, exchange != nullptr); + + ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + ReliableMessageContext * rc = exchange->GetReliableMessageContext(); + NL_TEST_ASSERT(inSuite, rm != nullptr); + NL_TEST_ASSERT(inSuite, rc != nullptr); + + rc->SetConfig({ + 1, // CHIP_CONFIG_RMP_DEFAULT_INITIAL_RETRY_INTERVAL + 1, // CHIP_CONFIG_RMP_DEFAULT_ACTIVE_RETRY_INTERVAL + }); + + err = mockSender.mMessageDispatch.Init(rm); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + mockSender.mMessageDispatch.mRetainMessageOnSend = false; + + // Let's drop the initial message + gLoopback.mSendMessageCount = 0; + gLoopback.mNumMessagesToDrop = 1; + gLoopback.mDroppedMessageCount = 0; + + // Ensure the retransmit table is empty right now + NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0); + err = exchange->SendMessage(Echo::MsgType::EchoRequest, std::move(buffer)); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + // Ensure the message was dropped + NL_TEST_ASSERT(inSuite, gLoopback.mDroppedMessageCount == 1); + + // 1 tick is 64 ms, sleep 65 ms to trigger first re-transmit + test_os_sleep_ms(65); + ReliableMessageMgr::Timeout(&ctx.GetSystemLayer(), rm, CHIP_SYSTEM_NO_ERROR); + + // Ensure the retransmit table is empty, as we did not provide a message to retain + NL_TEST_ASSERT(inSuite, rm->TestGetCountRetransTable() == 0); + + exchange->Close(); + + rm->ClearRetransTable(rc); +} + void CheckSendStandaloneAckMessage(nlTestSuite * inSuite, void * inContext) { TestContext & ctx = *reinterpret_cast(inContext); @@ -323,6 +421,7 @@ const nlTest sTests[] = NL_TEST_DEF("Test ReliableMessageMgr::CheckFailRetrans", CheckFailRetrans), NL_TEST_DEF("Test ReliableMessageMgr::CheckResendApplicationMessage", CheckResendApplicationMessage), NL_TEST_DEF("Test ReliableMessageMgr::CheckCloseExchangeAndResendApplicationMessage", CheckCloseExchangeAndResendApplicationMessage), + NL_TEST_DEF("Test ReliableMessageMgr::CheckFailedMessageRetainOnSend", CheckFailedMessageRetainOnSend), NL_TEST_DEF("Test ReliableMessageMgr::CheckSendStandaloneAckMessage", CheckSendStandaloneAckMessage), NL_TEST_SENTINEL() diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 788296affa3425..85cc40c9ddd8c2 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -187,7 +187,7 @@ CHIP_ERROR SecureSessionMgr::SendMessage(SecureSessionHandle session, PayloadHea // Retain the packet buffer in case it's needed for retransmissions. if (bufferRetainSlot != nullptr) { - (*bufferRetainSlot) = msgBuf.Retain(); + *bufferRetainSlot = EncryptedPacketBufferHandle::MarkEncrypted(msgBuf.Retain()); } ChipLogProgress(Inet, "Sending msg from 0x" ChipLogFormatX64 " to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec", diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index a716a408c6936a..f1a4ae78abaffd 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -53,7 +53,7 @@ namespace chip { * EncryptedPacketBufferHandle is a kind of PacketBufferHandle class and used to hold a packet buffer * object whose payload has already been encrypted. */ -class EncryptedPacketBufferHandle final : private System::PacketBufferHandle +class EncryptedPacketBufferHandle final : public System::PacketBufferHandle { public: EncryptedPacketBufferHandle() {} @@ -92,13 +92,13 @@ class EncryptedPacketBufferHandle final : private System::PacketBufferHandle CHIP_ERROR InsertPacketHeader(const PacketHeader & aPacketHeader) { return aPacketHeader.EncodeBeforeData(*this); } #endif // CHIP_ENABLE_TEST_ENCRYPTED_BUFFER_API -private: - // Allow SecureSessionMgr to assign or construct us from a PacketBufferHandle - friend class SecureSessionMgr; + static EncryptedPacketBufferHandle MarkEncrypted(PacketBufferHandle && aBuffer) + { + return EncryptedPacketBufferHandle(std::move(aBuffer)); + } +private: EncryptedPacketBufferHandle(PacketBufferHandle && aBuffer) : PacketBufferHandle(std::move(aBuffer)) {} - - void operator=(PacketBufferHandle && aBuffer) { PacketBufferHandle::operator=(std::move(aBuffer)); } }; /**