Skip to content

Commit

Permalink
Add a way to capture message payloads in tests. (#21774)
Browse files Browse the repository at this point in the history
One use in TestReadInteraction shows how this works.  More uses in
that file will need to be added to address the other TODOs.
  • Loading branch information
bzbarsky-apple authored and pull[bot] committed Aug 23, 2022
1 parent f6e170a commit 1095108
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 3 deletions.
21 changes: 19 additions & 2 deletions src/app/tests/TestReadInteraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2696,6 +2696,20 @@ void TestReadInteraction::TestPostSubscribeRoundtripChunkReportTimeout(nlTestSui
ctx.CreateSessionBobToAlice();
}

namespace {

void CheckForInvalidAction(nlTestSuite * apSuite, Test::MessageCapturer & messageLog)
{
NL_TEST_ASSERT(apSuite, messageLog.MessageCount() == 1);
NL_TEST_ASSERT(apSuite, messageLog.IsMessageType(0, Protocols::InteractionModel::MsgType::StatusResponse));
CHIP_ERROR status;
NL_TEST_ASSERT(apSuite,
StatusResponse::ProcessStatusResponse(std::move(messageLog.MessagePayload(0)), status) == CHIP_NO_ERROR);
NL_TEST_ASSERT(apSuite, status == CHIP_IM_GLOBAL_STATUS(InvalidAction));
}

} // anonymous namespace

// Read Client sends the read request, Read Handler drops the response, then test injects unknown status reponse message for Read
// Client.
void TestReadInteraction::TestReadClientReceiveInvalidMessage(nlTestSuite * apSuite, void * apContext)
Expand Down Expand Up @@ -2750,6 +2764,9 @@ void TestReadInteraction::TestReadClientReceiveInvalidMessage(nlTestSuite * apSu
payloadHeader.SetExchangeID(0);
payloadHeader.SetMessageType(chip::Protocols::InteractionModel::MsgType::StatusResponse);

Test::MessageCapturer messageLog(ctx);
messageLog.mCaptureStandaloneAcks = false;

rm->ClearRetransTable(readClient.mExchange.Get());
NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mDroppedMessageCount == 1);
Expand All @@ -2760,12 +2777,12 @@ void TestReadInteraction::TestReadClientReceiveInvalidMessage(nlTestSuite * apSu
readClient.OnMessageReceived(readClient.mExchange.Get(), payloadHeader, std::move(msgBuf));
ctx.DrainAndServiceIO();

// TODO: Need to validate what status is being sent to the ReadHandler
// The ReadHandler closed its exchange when it sent the Report Data (which we dropped).
// Since we synthesized the StatusResponse to the ReadClient, instead of sending it from the ReadHandler,
// the only messages here are the ReadClient's StatusResponse to the unexpected message and an MRP ack.
NL_TEST_ASSERT(apSuite, ctx.GetLoopback().mSentMessageCount == 2);
NL_TEST_ASSERT(apSuite, delegate.mError == CHIP_IM_GLOBAL_STATUS(Busy));

CheckForInvalidAction(apSuite, messageLog);
}

engine->Shutdown();
Expand Down
12 changes: 12 additions & 0 deletions src/messaging/tests/MessagingContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <credentials/tests/CHIPCert_unit_test_vectors.h>
#include <lib/support/CodeUtils.h>
#include <lib/support/ErrorStr.h>
#include <protocols/secure_channel/Constants.h>

namespace chip {
namespace Test {
Expand Down Expand Up @@ -204,5 +205,16 @@ Messaging::ExchangeContext * MessagingContext::NewExchangeToBob(Messaging::Excha
return mExchangeManager.NewContext(GetSessionAliceToBob(), delegate);
}

void MessageCapturer::OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader,
const SessionHandle & session, DuplicateMessage isDuplicate,
System::PacketBufferHandle && msgBuf)
{
if (mCaptureStandaloneAcks || !payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::StandaloneAck))
{
mCapturedMessages.emplace_back(Message{ packetHeader, payloadHeader, isDuplicate, msgBuf.CloneData() });
}
mOriginalDelegate.OnMessageReceived(packetHeader, payloadHeader, session, isDuplicate, std::move(msgBuf));
}

} // namespace Test
} // namespace chip
52 changes: 51 additions & 1 deletion src/messaging/tests/MessagingContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

#include <nlunit-test.h>

#include <vector>

namespace chip {
namespace Test {

Expand Down Expand Up @@ -65,7 +67,7 @@ class PlatformMemoryUser
};

/**
* @brief The context of test cases for messaging layer. It wil initialize network layer and system layer, and create
* @brief The context of test cases for messaging layer. It will initialize network layer and system layer, and create
* two secure sessions, connected with each other. Exchanges can be created for each secure session.
*/
class MessagingContext : public PlatformMemoryUser
Expand Down Expand Up @@ -213,5 +215,53 @@ class LoopbackMessagingContext : public LoopbackTransportManager, public Messagi
using LoopbackTransportManager::GetSystemLayer;
};

// Class that can be used to capture decrypted message traffic in tests using
// MessagingContext.
class MessageCapturer : public SessionMessageDelegate
{
public:
MessageCapturer(MessagingContext & aContext) :
mSessionManager(aContext.GetSecureSessionManager()), mOriginalDelegate(aContext.GetExchangeManager())
{
// Interpose ourselves into the message flow.
mSessionManager.SetMessageDelegate(this);
}

~MessageCapturer()
{
// Restore the normal message flow.
mSessionManager.SetMessageDelegate(&mOriginalDelegate);
}

struct Message
{
PacketHeader mPacketHeader;
PayloadHeader mPayloadHeader;
DuplicateMessage mIsDuplicate;
System::PacketBufferHandle mPayload;
};

size_t MessageCount() const { return mCapturedMessages.size(); }

template <typename MessageType, typename = std::enable_if_t<std::is_enum<MessageType>::value>>
bool IsMessageType(size_t index, MessageType type)
{
return mCapturedMessages[index].mPayloadHeader.HasMessageType(type);
}

System::PacketBufferHandle & MessagePayload(size_t index) { return mCapturedMessages[index].mPayload; }

bool mCaptureStandaloneAcks = true;

private:
// SessionMessageDelegate implementation.
void OnMessageReceived(const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, const SessionHandle & session,
DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override;

SessionManager & mSessionManager;
SessionMessageDelegate & mOriginalDelegate;
std::vector<Message> mCapturedMessages;
};

} // namespace Test
} // namespace chip
1 change: 1 addition & 0 deletions src/transport/raw/MessageHeader.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ class PayloadHeader
{
public:
constexpr PayloadHeader() { SetProtocol(Protocols::NotSpecified); }
constexpr PayloadHeader(const PayloadHeader &) = default;
PayloadHeader & operator=(const PayloadHeader &) = default;

/** Get the Session ID from this header. */
Expand Down

0 comments on commit 1095108

Please sign in to comment.