Skip to content

Commit

Permalink
unauthenticated-session
Browse files Browse the repository at this point in the history
  • Loading branch information
kghost committed Aug 26, 2021
1 parent f7cbced commit 5dd6f18
Show file tree
Hide file tree
Showing 26 changed files with 466 additions and 183 deletions.
3 changes: 1 addition & 2 deletions src/app/server/RendezvousServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ CHIP_ERROR RendezvousServer::WaitForPairing(const RendezvousParameters & params,
ReturnErrorOnFailure(mPairingSession.WaitForPairing(params.GetSetupPINCode(), pbkdf2IterCount, salt, keyID, this));
}

ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr));
mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress());
ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mSessionMgr));

return CHIP_NO_ERROR;
}
Expand Down
17 changes: 13 additions & 4 deletions src/channel/ChannelContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,22 @@ void ChannelContext::EnterCasePairingState()
auto & prepare = GetPrepareVars();
prepare.mCasePairingSession = Platform::New<CASESession>();

ExchangeContext * ctxt =
mExchangeManager->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), prepare.mCasePairingSession);
VerifyOrReturn(ctxt != nullptr);

// TODO: currently only supports IP/UDP paring
Transport::PeerAddress addr;
addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(prepare.mAddress);

auto session = mExchangeManager->GetSessionMgr()->CreateUnauthenticatedSession(addr);
if (!session.HasValue())
{
ExitCasePairingState();
ExitPreparingState();
EnterFailedState(CHIP_ERROR_NO_MEMORY);
return;
}

ExchangeContext * ctxt = mExchangeManager->NewContext(session.Value(), prepare.mCasePairingSession);
VerifyOrReturn(ctxt != nullptr);

Transport::FabricInfo * fabric = mFabricsTable->FindFabricWithIndex(mFabricIndex);
VerifyOrReturn(fabric != nullptr);
CHIP_ERROR err = prepare.mCasePairingSession->EstablishSession(addr, fabric, prepare.mBuilder.GetPeerNodeId(),
Expand Down
12 changes: 8 additions & 4 deletions src/controller/CHIPDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,12 +550,16 @@ CHIP_ERROR Device::WarmupCASESession()
VerifyOrReturnError(mDeviceOperationalCertProvisioned, CHIP_ERROR_INCORRECT_STATE);
VerifyOrReturnError(mState == ConnectionState::NotConnected, CHIP_NO_ERROR);

Messaging::ExchangeContext * exchange =
mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mCASESession);
// Create a UnauthenticatedSession for CASE pairing.
// Don't use mSecureSession here, because mSecureSession is the secure session.
Optional<SessionHandle> session = mSessionManager->CreateUnauthenticatedSession(mDeviceAddress);
if (!session.HasValue()) {
return CHIP_ERROR_NO_MEMORY;
}
Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(session.Value(), &mCASESession);
VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL);

ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager->GetTransportManager()));
mCASESession.MessageDispatch().SetPeerAddress(mDeviceAddress);
ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager));

uint16_t keyID = 0;
ReturnErrorOnFailure(mIDAllocator->Allocate(keyID));
Expand Down
9 changes: 6 additions & 3 deletions src/controller/CHIPDeviceController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam
Transport::PeerAddress peerAddress = Transport::PeerAddress::UDP(Inet::IPAddress::Any);

Messaging::ExchangeContext * exchangeCtxt = nullptr;
Optional<SessionHandle> session;

uint16_t keyID = 0;

Expand Down Expand Up @@ -863,9 +864,8 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam

mIsIPRendezvous = (params.GetPeerAddress().GetTransportType() != Transport::Type::kBle);

err = mPairingSession.MessageDispatch().Init(mTransportMgr);
err = mPairingSession.MessageDispatch().Init(mSessionMgr);
SuccessOrExit(err);
mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress());

device->Init(GetControllerDeviceInitParams(), mListenPort, remoteDeviceId, peerAddress, fabric->GetFabricIndex());

Expand All @@ -891,7 +891,10 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam
}
}
#endif
exchangeCtxt = mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mPairingSession);
session = mSessionMgr->CreateUnauthenticatedSession(params.GetPeerAddress());
VerifyOrExit(session.HasValue(), CHIP_ERROR_NO_MEMORY);

exchangeCtxt = mExchangeMgr->NewContext(session.Value(), &mPairingSession);
VerifyOrExit(exchangeCtxt != nullptr, err = CHIP_ERROR_INTERNAL);

err = mIDAllocator.Allocate(keyID);
Expand Down
17 changes: 15 additions & 2 deletions src/lib/support/Pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ template <class T, size_t N>
class BitMapObjectPool : public StaticAllocatorBitmap
{
public:
BitMapObjectPool() : StaticAllocatorBitmap(mMemory, mUsage, N, sizeof(T)) {}
BitMapObjectPool() : StaticAllocatorBitmap(mData.mMemory, mUsage, N, sizeof(T)) {}

static size_t Size() { return N; }

Expand All @@ -110,6 +110,13 @@ class BitMapObjectPool : public StaticAllocatorBitmap
Deallocate(element);
}

template <typename... Args>
void ResetObject(T * element, Args &&... args)
{
element->~T();
new (element) T(std::forward<Args>(args)...);
}

/**
* @brief
* Run a functor for each active object in the pool
Expand Down Expand Up @@ -144,7 +151,13 @@ class BitMapObjectPool : public StaticAllocatorBitmap
};

std::atomic<tBitChunkType> mUsage[(N + kBitChunkSize - 1) / kBitChunkSize];
alignas(alignof(T)) uint8_t mMemory[N * sizeof(T)];
union Data
{
Data() {}
~Data() {}
alignas(alignof(T)) uint8_t mMemory[N * sizeof(T)];
T mMemoryViewForDebug[N]; // Just for debugger
} mData;
};

} // namespace chip
15 changes: 13 additions & 2 deletions src/lib/support/ReferenceCountedHandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,25 @@ class ReferenceCountedHandle
explicit ReferenceCountedHandle(Target & target) : mTarget(target) { mTarget.Retain(); }
~ReferenceCountedHandle() { mTarget.Release(); }

ReferenceCountedHandle(const ReferenceCountedHandle & that) = delete;
ReferenceCountedHandle(const ReferenceCountedHandle & that) : mTarget(that.mTarget)
{
mTarget.Retain();
}

ReferenceCountedHandle(ReferenceCountedHandle && that) : mTarget(that.mTarget)
{
mTarget.Retain();
}

ReferenceCountedHandle & operator=(const ReferenceCountedHandle & that) = delete;
ReferenceCountedHandle(ReferenceCountedHandle && that) = delete;
ReferenceCountedHandle & operator=(ReferenceCountedHandle && that) = delete;

bool operator==(const ReferenceCountedHandle & that) const { return &mTarget == &that.mTarget; }
bool operator!=(const ReferenceCountedHandle & that) const { return !(*this == that); }

Target * operator->() { return &mTarget; }
Target & Get() const { return mTarget; }

private:
Target & mTarget;
};
Expand Down
2 changes: 1 addition & 1 deletion src/messaging/ApplicationExchangeDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ CHIP_ERROR ApplicationExchangeDispatch::PrepareMessage(SessionHandle session, Pa
System::PacketBufferHandle && message,
EncryptedPacketBufferHandle & preparedMessage)
{
return mSessionMgr->BuildEncryptedMessagePayload(session, payloadHeader, std::move(message), preparedMessage);
return mSessionMgr->PrepareMessage(session, payloadHeader, std::move(message), preparedMessage);
}

CHIP_ERROR ApplicationExchangeDispatch::SendPreparedMessage(SessionHandle session,
Expand Down
3 changes: 2 additions & 1 deletion src/messaging/ExchangeMessageDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SessionHandle session, uint16_t
std::unique_ptr<ReliableMessageMgr::RetransTableEntry, decltype(deleter)> entryOwner(entry, deleter);

ReturnErrorOnFailure(PrepareMessage(session, payloadHeader, std::move(message), entryOwner->retainedBuf));
reliableMessageMgr->StartRetransmision(entryOwner.get());
ReturnErrorOnFailure(SendPreparedMessage(session, entryOwner->retainedBuf));
reliableMessageMgr->StartRetransmision(entryOwner.release());
entryOwner.release();
}
else
{
Expand Down
3 changes: 3 additions & 0 deletions src/messaging/ExchangeMessageDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ class ExchangeMessageDispatch : public ReferenceCounted<ExchangeMessageDispatch>

protected:
virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0;

// TODO: remove IsReliableTransmissionAllowed, this function should be provided over session.
virtual bool IsReliableTransmissionAllowed() const { return true; }

virtual bool IsEncryptionRequired() const { return true; }
};

Expand Down
2 changes: 1 addition & 1 deletion src/messaging/ReliableMessageMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ void ReliableMessageMgr::ClearRetransTable(RetransTableEntry & rEntry)
// Expire any virtual ticks that have expired so all wakeup sources reflect the current time
ExpireTicks();

rEntry.rc->ReleaseContext();
rEntry.rc->SetOccupied(false);
rEntry.rc->ReleaseContext();
rEntry.rc = nullptr;

// Clear all other fields
Expand Down
14 changes: 12 additions & 2 deletions src/messaging/tests/MessagingContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ CHIP_ERROR MessagingContext::Init(nlTestSuite * suite, TransportMgrBase * transp
ReturnErrorOnFailure(mExchangeManager.Init(&mSecureSessionMgr));
ReturnErrorOnFailure(mMessageCounterManager.Init(&mExchangeManager));

ReturnErrorOnFailure(mSecureSessionMgr.NewPairing(mPeer, GetDestinationNodeId(), &mPairingLocalToPeer,
ReturnErrorOnFailure(mSecureSessionMgr.NewPairing(Optional<Transport::PeerAddress>::Value(mPeerAddress), GetDestinationNodeId(), &mPairingLocalToPeer,
SecureSession::SessionRole::kInitiator, mSrcFabricIndex));

return mSecureSessionMgr.NewPairing(mPeer, GetSourceNodeId(), &mPairingPeerToLocal, SecureSession::SessionRole::kResponder,
return mSecureSessionMgr.NewPairing(Optional<Transport::PeerAddress>::Value(mLocalAddress), GetSourceNodeId(), &mPairingPeerToLocal, SecureSession::SessionRole::kResponder,
mDestFabricIndex);
}

Expand All @@ -67,6 +67,16 @@ SessionHandle MessagingContext::GetSessionPeerToLocal()
return SessionHandle(GetSourceNodeId(), GetPeerKeyId(), GetLocalKeyId(), mDestFabricIndex);
}

Messaging::ExchangeContext * MessagingContext::NewUnauthenticatedExchangeToPeer(Messaging::ExchangeDelegate * delegate)
{
return mExchangeManager.NewContext(mSecureSessionMgr.CreateUnauthenticatedSession(mPeerAddress).Value(), delegate);
}

Messaging::ExchangeContext * MessagingContext::NewUnauthenticatedExchangeToLocal(Messaging::ExchangeDelegate * delegate)
{
return mExchangeManager.NewContext(mSecureSessionMgr.CreateUnauthenticatedSession(mLocalAddress).Value(), delegate);
}

Messaging::ExchangeContext * MessagingContext::NewExchangeToPeer(Messaging::ExchangeDelegate * delegate)
{
// TODO: temprary create a SessionHandle from node id, will be fix in PR 3602
Expand Down
9 changes: 7 additions & 2 deletions src/messaging/tests/MessagingContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class MessagingContext
{
public:
MessagingContext() :
mInitialized(false), mPeer(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)),
mInitialized(false), mLocalAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)),
mPeerAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT + 1)),
mPairingPeerToLocal(GetLocalKeyId(), GetPeerKeyId()), mPairingLocalToPeer(GetPeerKeyId(), GetLocalKeyId())
{}
~MessagingContext() { VerifyOrDie(mInitialized == false); }
Expand Down Expand Up @@ -80,6 +81,9 @@ class MessagingContext
SessionHandle GetSessionLocalToPeer();
SessionHandle GetSessionPeerToLocal();

Messaging::ExchangeContext * NewUnauthenticatedExchangeToPeer(Messaging::ExchangeDelegate * delegate);
Messaging::ExchangeContext * NewUnauthenticatedExchangeToLocal(Messaging::ExchangeDelegate * delegate);

Messaging::ExchangeContext * NewExchangeToPeer(Messaging::ExchangeDelegate * delegate);
Messaging::ExchangeContext * NewExchangeToLocal(Messaging::ExchangeDelegate * delegate);

Expand All @@ -98,7 +102,8 @@ class MessagingContext
NodeId mDestinationNodeId = 111222333;
uint16_t mLocalKeyId = 1;
uint16_t mPeerKeyId = 2;
Optional<Transport::PeerAddress> mPeer;
Transport::PeerAddress mLocalAddress;
Transport::PeerAddress mPeerAddress;
SecurePairingUsingTestSecret mPairingPeerToLocal;
SecurePairingUsingTestSecret mPairingLocalToPeer;
Transport::FabricTable mFabrics;
Expand Down
55 changes: 19 additions & 36 deletions src/messaging/tests/TestReliableMessageProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,9 @@ class MockAppDelegate : public ExchangeDelegate
nlTestSuite * mTestSuite = nullptr;
};

class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDispatch
class MockSessionEstablishmentExchangeDispatch : public Messaging::ApplicationExchangeDispatch
{
public:
CHIP_ERROR PrepareMessage(SessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message,
EncryptedPacketBufferHandle & preparedMessage) override
{
PacketHeader packetHeader;

ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message));
ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message));

preparedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(message));
return CHIP_NO_ERROR;
}

CHIP_ERROR SendPreparedMessage(SessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const override
{
return gTransportMgr.SendMessage(Transport::PeerAddress(), preparedMessage.CastToWritable());
}

bool IsReliableTransmissionAllowed() const override { return mRetainMessageOnSend; }

bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; }
Expand Down Expand Up @@ -367,6 +350,9 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext)
CHIP_ERROR err = CHIP_NO_ERROR;

MockSessionEstablishmentDelegate mockSender;
err = mockSender.mMessageDispatch.Init(&ctx.GetSecureSessionManager());
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender);
NL_TEST_ASSERT(inSuite, exchange != nullptr);

Expand All @@ -380,9 +366,6 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext)
1, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
});

err = mockSender.mMessageDispatch.Init();
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

mockSender.mMessageDispatch.mRetainMessageOnSend = false;

// Let's drop the initial message
Expand Down Expand Up @@ -414,24 +397,27 @@ void CheckUnencryptedMessageReceiveFailure(nlTestSuite * inSuite, void * inConte
NL_TEST_ASSERT(inSuite, !buffer.IsNull());

MockSessionEstablishmentDelegate mockReceiver;
CHIP_ERROR err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver);
CHIP_ERROR err = mockReceiver.mMessageDispatch.Init(&ctx.GetSecureSessionManager());
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

// Expect the received messages to be encrypted
mockReceiver.mMessageDispatch.mRequireEncryption = true;

MockSessionEstablishmentDelegate mockSender;
ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender);
err = mockSender.mMessageDispatch.Init(&ctx.GetSecureSessionManager());
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

ExchangeContext * exchange = ctx.NewUnauthenticatedExchangeToPeer(&mockSender);
NL_TEST_ASSERT(inSuite, exchange != nullptr);

ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
ReliableMessageContext * rc = exchange->GetReliableMessageContext();
NL_TEST_ASSERT(inSuite, rm != nullptr);
NL_TEST_ASSERT(inSuite, rc != nullptr);

err = mockSender.mMessageDispatch.Init();
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

gLoopback.mSentMessageCount = 0;
gLoopback.mNumMessagesToDrop = 0;
gLoopback.mDroppedMessageCount = 0;
Expand Down Expand Up @@ -584,23 +570,23 @@ void CheckResendSessionEstablishmentMessageWithPeerExchange(nlTestSuite * inSuit
CHIP_ERROR err = ctx.Init(inSuite, &gTransportMgr, &gIOContext);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

ctx.SetSourceNodeId(kPlaceholderNodeId);
ctx.SetDestinationNodeId(kPlaceholderNodeId);
ctx.SetLocalKeyId(0);
ctx.SetPeerKeyId(0);
ctx.SetFabricIndex(kUndefinedFabricIndex);

chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD));
NL_TEST_ASSERT(inSuite, !buffer.IsNull());

MockSessionEstablishmentDelegate mockReceiver;
err = mockReceiver.mMessageDispatch.Init(&ctx.GetSecureSessionManager());
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

mockReceiver.mTestSuite = inSuite;

MockSessionEstablishmentDelegate mockSender;
ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockSender);
err = mockSender.mMessageDispatch.Init(&ctx.GetSecureSessionManager());
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

ExchangeContext * exchange = ctx.NewUnauthenticatedExchangeToPeer(&mockSender);
NL_TEST_ASSERT(inSuite, exchange != nullptr);

ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr();
Expand All @@ -613,9 +599,6 @@ void CheckResendSessionEstablishmentMessageWithPeerExchange(nlTestSuite * inSuit
1, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
});

err = mockSender.mMessageDispatch.Init();
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

// Let's drop the initial message
gLoopback.mSentMessageCount = 0;
gLoopback.mNumMessagesToDrop = 1;
Expand Down
2 changes: 1 addition & 1 deletion src/protocols/secure_channel/CASEServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ CHIP_ERROR CASEServer::ListenForSessionEstablishment(Messaging::ExchangeManager

Cleanup();

ReturnErrorOnFailure(GetSession().MessageDispatch().Init(transportMgr));
ReturnErrorOnFailure(GetSession().MessageDispatch().Init(sessionMgr));

return CHIP_NO_ERROR;
}
Expand Down
2 changes: 0 additions & 2 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,8 +1218,6 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PacketHead
CHIP_ERROR err = ValidateReceivedMessage(ec, packetHeader, payloadHeader, msg);
SuccessOrExit(err);

SetPeerAddress(mMessageDispatch.GetPeerAddress());

switch (static_cast<Protocols::SecureChannel::MsgType>(payloadHeader.GetMessageType()))
{
case Protocols::SecureChannel::MsgType::CASE_SigmaR1:
Expand Down
Loading

0 comments on commit 5dd6f18

Please sign in to comment.