Skip to content

Commit

Permalink
Persist and reload message counters on the controller
Browse files Browse the repository at this point in the history
  • Loading branch information
sagar-apple committed May 25, 2021
1 parent 0d14370 commit 2573970
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 60 deletions.
62 changes: 56 additions & 6 deletions src/controller/CHIPDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@
#include <support/CHIPMem.h>
#include <support/CodeUtils.h>
#include <support/ErrorStr.h>
#include <support/PersistentStorageUtils.h>
#include <support/SafeInt.h>
#include <support/logging/CHIPLogging.h>
#include <system/TLVPacketBufferBackingStore.h>
#include <transport/MessageCounter.h>
#include <transport/PeerMessageCounter.h>

using namespace chip::Inet;
using namespace chip::System;
Expand Down Expand Up @@ -150,8 +153,10 @@ CHIP_ERROR Device::SendCommands(app::CommandSender * commandObj)

CHIP_ERROR Device::Serialize(SerializedDevice & output)
{
CHIP_ERROR error = CHIP_NO_ERROR;
uint16_t serializedLen = 0;
CHIP_ERROR error = CHIP_NO_ERROR;
uint16_t serializedLen = 0;
uint32_t localMessageCounter = 0;
uint32_t peerMessageCounter = 0;
SerializableDevice serializable;

static_assert(BASE64_ENCODED_LEN(sizeof(serializable)) <= sizeof(output.inner),
Expand All @@ -164,6 +169,14 @@ CHIP_ERROR Device::Serialize(SerializedDevice & output)
serializable.mDevicePort = Encoding::LittleEndian::HostSwap16(mDeviceAddress.GetPort());
serializable.mAdminId = Encoding::LittleEndian::HostSwap16(mAdminId);

Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession);
VerifyOrExit(connectionState != nullptr, error = CHIP_ERROR_INCORRECT_STATE);
localMessageCounter = connectionState->GetSessionMessageCounter().GetLocalMessageCounter().Value();
peerMessageCounter = connectionState->GetSessionMessageCounter().GetPeerMessageCounter().GetCounter();

serializable.mLocalMessageCounter = Encoding::LittleEndian::HostSwap32(localMessageCounter);
serializable.mPeerMessageCounter = Encoding::LittleEndian::HostSwap32(peerMessageCounter);

serializable.mCASESessionKeyId = Encoding::LittleEndian::HostSwap16(mCASESessionKeyId);
serializable.mDeviceProvisioningComplete = (mDeviceProvisioningComplete) ? 1 : 0;

Expand Down Expand Up @@ -215,10 +228,12 @@ CHIP_ERROR Device::Deserialize(const SerializedDevice & input)
IPAddress::FromString(Uint8::to_const_char(serializable.mDeviceAddr), sizeof(serializable.mDeviceAddr) - 1, ipAddress),
error = CHIP_ERROR_INVALID_ADDRESS);

mPairing = serializable.mOpsCreds;
mDeviceId = Encoding::LittleEndian::HostSwap64(serializable.mDeviceId);
port = Encoding::LittleEndian::HostSwap16(serializable.mDevicePort);
mAdminId = Encoding::LittleEndian::HostSwap16(serializable.mAdminId);
mPairing = serializable.mOpsCreds;
mDeviceId = Encoding::LittleEndian::HostSwap64(serializable.mDeviceId);
port = Encoding::LittleEndian::HostSwap16(serializable.mDevicePort);
mAdminId = Encoding::LittleEndian::HostSwap16(serializable.mAdminId);
mLocalMessageCounter = Encoding::LittleEndian::HostSwap32(serializable.mLocalMessageCounter);
mPeerMessageCounter = Encoding::LittleEndian::HostSwap32(serializable.mPeerMessageCounter);

mCASESessionKeyId = Encoding::LittleEndian::HostSwap16(serializable.mCASESessionKeyId);
mDeviceProvisioningComplete = (serializable.mDeviceProvisioningComplete != 0);
Expand Down Expand Up @@ -262,6 +277,17 @@ void Device::OnNewConnection(SecureSessionHandle session)
{
mState = ConnectionState::SecureConnected;
mSecureSession = session;

// Reset the message counters here because this is the first time we get a handle to the secure session.
Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession);
VerifyOrReturn(connectionState != nullptr);
MessageCounter & localCounter = connectionState->GetSessionMessageCounter().GetLocalMessageCounter();
if (localCounter.SetCounter(mLocalMessageCounter))
{
ChipLogError(Controller, "Unable to restore local counter to %d", mLocalMessageCounter);
};
Transport::PeerMessageCounter & peerCounter = connectionState->GetSessionMessageCounter().GetPeerMessageCounter();
peerCounter.SetCounter(mPeerMessageCounter);
}

void Device::OnConnectionExpired(SecureSessionHandle session)
Expand Down Expand Up @@ -470,5 +496,29 @@ void Device::AddReportHandler(EndpointId endpoint, ClusterId cluster, AttributeI
mCallbacksMgr.AddReportCallback(mDeviceId, endpoint, cluster, attribute, onReportCallback);
}

Device::~Device()
{
if (mExchangeMgr)
{
// Ensure that any exchange contexts we have open get closed now,
// because we don't want them to call back in to us after this
// point.
mExchangeMgr->CloseAllContextsForDelegate(this);
}

if (mStorageDelegate)
{
// Store the current device in persistent storage so we have the latest
// message counters available next time.
Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession);
VerifyOrReturn(connectionState != nullptr);
SerializedDevice serialized;
Serialize(serialized);
// TODO: no need to base-64 the serialized values AGAIN
PERSISTENT_KEY_OP(GetDeviceId(), kPairedDeviceKeyPrefix, key,
mStorageDelegate->SyncSetKeyValue(key, serialized.inner, sizeof(serialized.inner)));
}
}

} // namespace Controller
} // namespace chip
44 changes: 22 additions & 22 deletions src/controller/CHIPDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ using DeviceTransportMgr = TransportMgr<Transport::UDP /* IPv6 */

struct ControllerDeviceInitParams
{
DeviceTransportMgr * transportMgr = nullptr;
SecureSessionMgr * sessionMgr = nullptr;
Messaging::ExchangeManager * exchangeMgr = nullptr;
Inet::InetLayer * inetLayer = nullptr;

DeviceTransportMgr * transportMgr = nullptr;
SecureSessionMgr * sessionMgr = nullptr;
Messaging::ExchangeManager * exchangeMgr = nullptr;
Inet::InetLayer * inetLayer = nullptr;
PersistentStorageDelegate * storageDelegate = nullptr;
Credentials::OperationalCredentialSet * credentials = nullptr;
#if CONFIG_NETWORK_LAYER_BLE
Ble::BleLayer * bleLayer = nullptr;
Expand All @@ -87,16 +87,7 @@ struct ControllerDeviceInitParams
class DLL_EXPORT Device : public Messaging::ExchangeDelegate, public SessionEstablishmentDelegate
{
public:
~Device()
{
if (mExchangeMgr)
{
// Ensure that any exchange contexts we have open get closed now,
// because we don't want them to call back in to us after this
// point.
mExchangeMgr->CloseAllContextsForDelegate(this);
}
}
~Device();

enum class PairingWindowOption
{
Expand Down Expand Up @@ -173,13 +164,14 @@ class DLL_EXPORT Device : public Messaging::ExchangeDelegate, public SessionEsta
*/
void Init(ControllerDeviceInitParams params, uint16_t listenPort, Transport::AdminId admin)
{
mTransportMgr = params.transportMgr;
mSessionManager = params.sessionMgr;
mExchangeMgr = params.exchangeMgr;
mInetLayer = params.inetLayer;
mListenPort = listenPort;
mAdminId = admin;
mCredentials = params.credentials;
mTransportMgr = params.transportMgr;
mSessionManager = params.sessionMgr;
mExchangeMgr = params.exchangeMgr;
mInetLayer = params.inetLayer;
mListenPort = listenPort;
mAdminId = admin;
mStorageDelegate = params.storageDelegate;
mCredentials = params.credentials;
#if CONFIG_NETWORK_LAYER_BLE
mBleLayer = params.bleLayer;
#endif
Expand Down Expand Up @@ -406,6 +398,10 @@ class DLL_EXPORT Device : public Messaging::ExchangeDelegate, public SessionEsta

uint8_t mSequenceNumber = 0;

// Message counts start at 1
uint32_t mLocalMessageCounter = 1;
uint32_t mPeerMessageCounter = 1;

app::CHIPDeviceCallbacksMgr & mCallbacksMgr = app::CHIPDeviceCallbacksMgr::GetInstance();

/**
Expand Down Expand Up @@ -440,6 +436,8 @@ class DLL_EXPORT Device : public Messaging::ExchangeDelegate, public SessionEsta
uint16_t mCASESessionKeyId = 0;

Credentials::OperationalCredentialSet * mCredentials = nullptr;

PersistentStorageDelegate * mStorageDelegate = nullptr;
};

/**
Expand Down Expand Up @@ -494,6 +492,8 @@ typedef struct SerializableDevice
uint8_t mDeviceTransport;
uint8_t mDeviceProvisioningComplete;
uint8_t mInterfaceName[kMaxInterfaceName];
uint32_t mLocalMessageCounter; /* This field is serialized in LittleEndian byte order */
uint32_t mPeerMessageCounter; /* This field is serialized in LittleEndian byte order */
} SerializableDevice;

typedef struct SerializedDevice
Expand Down
19 changes: 1 addition & 18 deletions src/controller/CHIPDeviceController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include <support/CHIPMem.h>
#include <support/CodeUtils.h>
#include <support/ErrorStr.h>
#include <support/PersistentStorageUtils.h>
#include <support/SafeInt.h>
#include <support/ScopedBuffer.h>
#include <support/TimeUtils.h>
Expand Down Expand Up @@ -89,10 +90,6 @@ namespace Controller {

using namespace chip::Encoding;

constexpr const char kPairedDeviceListKeyPrefix[] = "ListPairedDevices";
constexpr const char kPairedDeviceKeyPrefix[] = "PairedDevice";
constexpr const char kNextAvailableKeyID[] = "StartKeyID";

#if CHIP_DEVICE_CONFIG_ENABLE_MDNS
constexpr uint16_t kMdnsPort = 5353;
#endif
Expand All @@ -103,20 +100,6 @@ constexpr uint32_t kMaxCHIPOpCertLength = 1024;
constexpr uint32_t kMaxCHIPCSRLength = 1024;
constexpr uint32_t kOpCSRNonceLength = 32;

// This macro generates a key using node ID an key prefix, and performs the given action
// on that key.
#define PERSISTENT_KEY_OP(node, keyPrefix, key, action) \
do \
{ \
constexpr size_t len = std::extent<decltype(keyPrefix)>::value; \
nlSTATIC_ASSERT_PRINT(len > 0, "keyPrefix length must be known at compile time"); \
/* 2 * sizeof(NodeId) to accomodate 2 character for each byte in Node Id */ \
char key[len + 2 * sizeof(NodeId) + 1]; \
nlSTATIC_ASSERT_PRINT(sizeof(node) <= sizeof(uint64_t), "Node ID size is greater than expected"); \
snprintf(key, sizeof(key), "%s%" PRIx64, keyPrefix, node); \
action; \
} while (0)

DeviceController::DeviceController()
{
mState = State::NotInitialized;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,9 @@ - (void)onPairingComplete:(NSError *)error
if (error.code != CHIPSuccess) {
NSLog(@"Got pairing error back %@", error);
} else {
dispatch_after(dispatch_time(DISPATCH_TIME_NOW, DISPATCH_TIME_NOW), dispatch_get_main_queue(), ^{
dispatch_async(dispatch_get_main_queue(), ^{
[self->_deviceList refreshDeviceList];
[self retrieveAndSendWifiCredentials];
[self setVendorIDOnAccessory];
});
}
}
Expand Down Expand Up @@ -623,6 +622,7 @@ - (void)onAddressUpdated:(NSError *)error
NSLog(@"Error retrieving device informations over Mdns: %@", error);
return;
}
[self setVendorIDOnAccessory];
}

- (void)updateUIFields:(CHIPSetupPayload *)payload decimalString:(nullable NSString *)decimalString
Expand Down
1 change: 1 addition & 0 deletions src/lib/support/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ static_library("support") {
"LifetimePersistedCounter.h",
"PersistedCounter.cpp",
"PersistedCounter.h",
"PersistentStorageUtils.h",
"Pool.cpp",
"Pool.h",
"PrivateHeap.cpp",
Expand Down
48 changes: 48 additions & 0 deletions src/lib/support/PersistentStorageUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
*
* Copyright (c) 2021 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* @file
* This file defines and implements some utlities for generating persistent storage keys
*
*/

#pragma once

#include <support/CodeUtils.h>

namespace chip {

constexpr const char kPairedDeviceListKeyPrefix[] = "ListPairedDevices";
constexpr const char kPairedDeviceKeyPrefix[] = "PairedDevice";
constexpr const char kNextAvailableKeyID[] = "StartKeyID";

// This macro generates a key for storage using a node ID and a key prefix, and performs the given action
// on that key.
#define PERSISTENT_KEY_OP(node, keyPrefix, key, action) \
do \
{ \
constexpr size_t len = std::extent<decltype(keyPrefix)>::value; \
nlSTATIC_ASSERT_PRINT(len > 0, "keyPrefix length must be known at compile time"); \
/* 2 * sizeof(NodeId) to accomodate 2 character for each byte in Node Id */ \
char key[len + 2 * sizeof(NodeId) + 1]; \
nlSTATIC_ASSERT_PRINT(sizeof(node) <= sizeof(uint64_t), "Node ID size is greater than expected"); \
snprintf(key, sizeof(key), "%s%" PRIx64, keyPrefix, node); \
action; \
} while (0)

} // namespace chip
4 changes: 2 additions & 2 deletions src/messaging/ExchangeMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const
UnsolicitedMessageHandler * matchingUMH = nullptr;
bool sendAckAndCloseExchange = false;

ChipLogProgress(ExchangeManager, "Received message of type %d and protocolId %d", payloadHeader.GetMessageType(),
payloadHeader.GetProtocolID());
ChipLogProgress(ExchangeManager, "Received message of type %d and protocolId %d on exchange %d", payloadHeader.GetMessageType(),
payloadHeader.GetProtocolID(), payloadHeader.GetExchangeID());

// Search for an existing exchange that the message applies to. If a match is found...
bool found = false;
Expand Down
22 changes: 18 additions & 4 deletions src/transport/MessageCounter.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ class MessageCounter

virtual ~MessageCounter() = 0;

virtual Type GetType() = 0;
virtual void Reset() = 0;
virtual uint32_t Value() = 0; /** Get current value */
virtual CHIP_ERROR Advance() = 0; /** Advance the counter */
virtual Type GetType() = 0;
virtual void Reset() = 0;
virtual uint32_t Value() = 0; /** Get current value */
virtual CHIP_ERROR Advance() = 0; /** Advance the counter */
virtual CHIP_ERROR SetCounter(uint32_t count) = 0; /** Set the counter to the specified value */
};

inline MessageCounter::~MessageCounter() {}
Expand All @@ -71,6 +72,12 @@ class GlobalUnencryptedMessageCounter : public MessageCounter
++value;
return CHIP_NO_ERROR;
}
CHIP_ERROR SetCounter(uint32_t count) override
{
Reset();
value = count;
return CHIP_NO_ERROR;
}

private:
uint32_t value;
Expand All @@ -89,6 +96,7 @@ class GlobalEncryptedMessageCounter : public MessageCounter
}
uint32_t Value() override { return persisted.GetValue(); }
CHIP_ERROR Advance() override { return persisted.Advance(); }
CHIP_ERROR SetCounter(uint32_t count) override { return CHIP_ERROR_NOT_IMPLEMENTED; }

private:
#if CONFIG_DEVICE_LAYER
Expand Down Expand Up @@ -127,6 +135,12 @@ class LocalSessionMessageCounter : public MessageCounter
++value;
return CHIP_NO_ERROR;
}
CHIP_ERROR SetCounter(uint32_t count) override
{
Reset();
value = count;
return CHIP_NO_ERROR;
}

private:
uint32_t value;
Expand Down
1 change: 0 additions & 1 deletion src/transport/PeerMessageCounter.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ class PeerMessageCounter
mSynced.mWindow.reset();
}

/* Test-only */
uint32_t GetCounter() { return mSynced.mMaxCounter; }

private:
Expand Down
6 changes: 1 addition & 5 deletions src/transport/SecureSessionMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,7 @@ void SecureSessionMgr::SecureMessageDispatch(const PacketHeader & packetHeader,
{
ChipLogError(Inet, "Message counter verify failed, err = %d", err);
}
// TODO - Enable exit on error for message counter verification failure.
// We are now using IM messages in commissioner class to provision op creds and
// other device commissioning steps. This is somehow causing issues with message counter
// verification. Disabling this check for now. Enable it after debugging the cause.
// SuccessOrExit(err);
SuccessOrExit(err);
}

admin = mAdmins->FindAdminWithId(state->GetAdminId());
Expand Down

0 comments on commit 2573970

Please sign in to comment.