From 75ac743bdd22dc65ee5a3d00f43a919edfd4190d Mon Sep 17 00:00:00 2001 From: Michael Sandstedt Date: Mon, 29 Nov 2021 21:26:38 -0600 Subject: [PATCH] Add variant-type-driven state machine to libs (#12223) * Add variant-type-driven state machine to libs This commit introduces a functional approach for State Machine implementation that leverages the sdk's Variant and Optional types. By providing a clean separation between states, events, and transitions, state machines can be implemented in such a way that states and events may be provided in the core sdk, but with transitions defined in consuming applications to support individual use cases. And because states and events are associated to the state machine by inclusion in variants and not by inheritance, applications may also extend state machines with arbitrary, application-specific events and states. Co-Authored-by: Bill Schiller * fix Wshadow * plausible simplified approach * Revert "plausible simplified approach" This reverts commit 93227b149224f6e2282c21b18ec6d1a656451dd5. * Use preorder recursion to remove need for Dispatch queue * cleanup test state machine state construction * Add guardrails to disallow Dispatch from illegal contexts The following are made explicitly illegal and will abort: * Dispatch from Exit() or LogTransition() state methods * Dispatch from transitions table w/ returned transition * make StateMachine members variables private Co-authored-by: Bill Schiller --- src/lib/support/BUILD.gn | 1 + src/lib/support/StateMachine.h | 238 +++++++++++++++++++ src/lib/support/Variant.h | 8 + src/lib/support/tests/BUILD.gn | 1 + src/lib/support/tests/TestStateMachine.cpp | 251 +++++++++++++++++++++ 5 files changed, 499 insertions(+) create mode 100644 src/lib/support/StateMachine.h create mode 100644 src/lib/support/tests/TestStateMachine.cpp diff --git a/src/lib/support/BUILD.gn b/src/lib/support/BUILD.gn index 8d060acd1feff9..2c5e651db49830 100644 --- a/src/lib/support/BUILD.gn +++ b/src/lib/support/BUILD.gn @@ -92,6 +92,7 @@ static_library("support") { "SafeInt.h", "SerializableIntegerSet.cpp", "SerializableIntegerSet.h", + "StateMachine.h", "ThreadOperationalDataset.cpp", "ThreadOperationalDataset.h", "TimeUtils.cpp", diff --git a/src/lib/support/StateMachine.h b/src/lib/support/StateMachine.h new file mode 100644 index 00000000000000..329dd27eab062a --- /dev/null +++ b/src/lib/support/StateMachine.h @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * Copyright (c) 2021 SmartThings + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace chip { +namespace StateMachine { + +/** + * An extension of the Optional class that removes the explicit requirement + * for construction from a T value as a convenience to allow auto construction + * of Optional. + */ +template +class Optional : public chip::Optional +{ +public: + Optional(const T & value) : chip::Optional(value) {} + Optional() : chip::Optional() {} +}; + +/** + * An extension of the Variant class offering pattern matching of State types + * to dynamically dispatch execution of the required State interface methods: + * Enter, Exit, GetName, LogTtransition. + */ +template +struct VariantState : Variant +{ + +private: + template + void Enter() + { + if (chip::Variant::template Is()) + { + chip::Variant::template Get().Enter(); + } + } + + template + void Exit() + { + if (chip::Variant::template Is()) + { + chip::Variant::template Get().Exit(); + } + } + + template + void GetName(const char ** name) + { + if (name && chip::Variant::template Is()) + { + *name = chip::Variant::template Get().GetName(); + } + } + + template + void LogTransition(const char * previous) + { + if (chip::Variant::template Is()) + { + chip::Variant::template Get().LogTransition(previous); + } + } + +public: + template + static VariantState Create(Args &&... args) + { + VariantState instance; + instance.template Set(std::forward(args)...); + return instance; + } + + void Enter() + { + [](...) {}((this->template Enter(), 0)...); + } + + void Exit() + { + [](...) {}((this->template Exit(), 0)...); + } + + const char * GetName() + { + const char * name = nullptr; + [](...) {}((this->template GetName(&name), 0)...); + return name; + } + + void LogTransition(const char * previous) + { + [](...) {}((this->template LogTransition(previous), 0)...); + } +}; + +/** + * The interface for dispatching events into the State Machine. + * @tparam TEvent a variant holding the Events for the State Machine. + */ +template +class Context +{ +public: + virtual ~Context() = default; + + /** + * Dispatch an event to the current state. + * @param evt a variant holding an Event for the State Machine. + */ + virtual void Dispatch(const TEvent & evt) = 0; +}; + +/** + * This is a functional approach to the State Machine design pattern. The design is + * borrowed from http://www.vishalchovatiya.com/state-design-pattern-in-modern-cpp + * and extended for this application. + * + * At a high-level, the purpose of a State Machine is to switch between States. Each + * State handles Events. The handling of Events may lead to Transitions. The purpose + * of this design pattern is to decouple States, Events, and Transitions. For instance, + * it is desirable to remove knowledge of next/previous States from each individual + * State. This allows adding/removing States with minimal code change and leads to a + * simpler implementation. + * + * This State Machine design emulates C++17 features to achieve the functional approach. + * Instead of using an enum or inheritance for the Events, the Events are defined as + * structs and placed in a variant. Likewise, the States are all defined as structs and + * placed in a variant. With the Events and States in two different variants, the + * Transitions table uses the type introspction feature of the variant object to match a + * given state and event to an optional new-state return. + * + * For event dispatch, the State Machine implements the Context interface. The Context + * interface is passed to States to allow Dispatch() of events when needed. + * + * The State held in the TState must provide four methods to support calls from + * the State Machine: + * @code + * struct State { + * void Enter() { } + * void Exit() { } + * void LogTransition(const char *) { } + * const char *GetName() { return ""; } + * } + * @endcode + * + * The TTransitions table type is implemented with an overloaded callable operator method + * to match the combinations of State / Event variants that may produce a new-state return. + * This allows the Transition table to define how each State responds to Events. Below is + * an example of a Transitions table implemented as a struct: + * + * @code + * struct Transitions { + * using State = chip::StateMachine::VariantState; + * chip::StateMachine::Optional operator()(State &state, Event &event) + * { + * if (state.Is() && event.Is()) + * { + * return State::Create(); + * } + * else if (state.Is() && event.Is()) + * { + * return State::Create(); + * } + * else + * { + * return {} + * } + * } + * } + * @endcode + * + * The rules for calling Dispatch from within the state machien are as follows: + * + * (1) Only the State::Enter method should call Dispatch. Calls from Exit or + * LogTransition will cause an abort. + * (2) The transitions table may return a new state OR call Dispatch, but must + * never do both. Doing both will cause an abort. + * + * @tparam TState a variant holding the States. + * @tparam TEvent a variant holding the Events. + * @tparam TTransitions an object that implements the () operator for transitions. + */ +template +class StateMachine : public Context +{ +public: + StateMachine(TTransitions & tr) : mCurrentState(tr.GetInitState()), mTransitions(tr), mSequence(0) {} + ~StateMachine() override = default; + void Dispatch(const TEvent & evt) override + { + ++mSequence; + auto prev = mSequence; + auto newState = mTransitions(mCurrentState, evt); + if (newState.HasValue()) + { + auto oldState = mCurrentState; + oldState.Exit(); + mCurrentState = newState.Value(); + mCurrentState.LogTransition(oldState.GetName()); + // It is impermissible to dispatch events from Exit() or + // LogTransition(), or from the transitions table when a transition + // has also been returned. Verify that this hasn't occured. + VerifyOrDie(prev == mSequence); + mCurrentState.Enter(); + } + } + TState GetState() { return mCurrentState; } + +private: + TState mCurrentState; + TTransitions & mTransitions; + unsigned mSequence; +}; + +} // namespace StateMachine +} // namespace chip diff --git a/src/lib/support/Variant.h b/src/lib/support/Variant.h index 338bff6a419b25..8c6a99dae3dc41 100644 --- a/src/lib/support/Variant.h +++ b/src/lib/support/Variant.h @@ -189,6 +189,14 @@ struct Variant bool Valid() const { return (mTypeId != kInvalidType); } + template + static Variant Create(Args &&... args) + { + Variant instance; + instance.template Set(std::forward(args)...); + return instance; + } + template void Set(Args &&... args) { diff --git a/src/lib/support/tests/BUILD.gn b/src/lib/support/tests/BUILD.gn index 13c63bd3690b42..7ef9c6c8be1ad5 100644 --- a/src/lib/support/tests/BUILD.gn +++ b/src/lib/support/tests/BUILD.gn @@ -41,6 +41,7 @@ chip_test_suite("tests") { "TestScopedBuffer.cpp", "TestSerializableIntegerSet.cpp", "TestSpan.cpp", + "TestStateMachine.cpp", "TestStringBuilder.cpp", "TestThreadOperationalDataset.cpp", "TestTimeUtils.cpp", diff --git a/src/lib/support/tests/TestStateMachine.cpp b/src/lib/support/tests/TestStateMachine.cpp new file mode 100644 index 00000000000000..37433dd64a0b9d --- /dev/null +++ b/src/lib/support/tests/TestStateMachine.cpp @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * Copyright (c) 2021 SmartThings + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace { + +struct Event1 +{ +}; +struct Event2 +{ +}; +struct Event3 +{ +}; +struct Event4 +{ +}; + +using Event = chip::Variant; +using Context = chip::StateMachine::Context; + +struct MockState +{ + unsigned mEntered; + unsigned mExited; + unsigned mLogged; + const char * mPrevious; + + void Enter() { ++mEntered; } + void Exit() { ++mExited; } + void LogTransition(const char * previous) + { + ++mLogged; + mPrevious = previous; + } +}; + +struct BaseState +{ + void Enter() { mMock.Enter(); } + void Exit() { mMock.Exit(); } + void LogTransition(const char * previous) { mMock.LogTransition(previous); } + const char * GetName() { return mName; } + + chip::StateMachine::Context & mCtx; + const char * mName; + MockState & mMock; +}; + +struct State1 : public BaseState +{ + State1(Context & ctx, MockState & mock) : BaseState{ ctx, "State1", mock } {} +}; + +struct State2 : public BaseState +{ + State2(Context & ctx, MockState & mock) : BaseState{ ctx, "State2", mock } {} +}; + +using State = chip::StateMachine::VariantState; + +struct StateFactory +{ + Context & mCtx; + MockState ms1{ 0, 0, 0, nullptr }; + MockState ms2{ 0, 0, 0, nullptr }; + + StateFactory(Context & ctx) : mCtx(ctx) {} + + auto CreateState1() { return State::Create(mCtx, ms1); } + + auto CreateState2() { return State::Create(mCtx, ms2); } +}; + +struct Transitions +{ + Context & mCtx; + StateFactory mFactory; + Transitions(Context & ctx) : mCtx(ctx), mFactory(ctx) {} + + using OptState = chip::StateMachine::Optional; + State GetInitState() { return mFactory.CreateState1(); } + OptState operator()(const State & state, const Event & event) + { + if (state.Is() && event.Is()) + { + return mFactory.CreateState2(); + } + else if (state.Is() && event.Is()) + { + return mFactory.CreateState1(); + } + else if (state.Is() && event.Is()) + { + // legal - Dispatches event without transition + mCtx.Dispatch(Event::Create()); + return {}; + } + else if (state.Is() && event.Is()) + { + // mCtx.Dispatch(Event::Create()); // dsipatching an event and returning a transition would be illegal + return mFactory.CreateState1(); + } + else + { + return {}; + } + } +}; + +class SimpleStateMachine +{ +public: + Transitions mTransitions; + chip::StateMachine::StateMachine mStateMachine; + + SimpleStateMachine() : mTransitions(mStateMachine), mStateMachine(mTransitions) {} + ~SimpleStateMachine() {} +}; + +void TestInit(nlTestSuite * inSuite, void * inContext) +{ + // state machine initializes to State1 + SimpleStateMachine fsm; + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); +} + +void TestIgnoredEvents(nlTestSuite * inSuite, void * inContext) +{ + // in State1 - ignore Event1 and Event3 + SimpleStateMachine fsm; + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + // transition to State2 + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + // in State2 - ignore Event2 and Event3 + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); +} + +void TestTransitions(nlTestSuite * inSuite, void * inContext) +{ + // in State1 + SimpleStateMachine fsm; + // dispatch Event2 to transition to State2 + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + // dispatch Event1 to transition back to State1 + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + // dispatch Event2 to transition to State2 + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + // dispatch Event4 to transitions to State1. + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); +} + +void TestTransitionsDispatch(nlTestSuite * inSuite, void * inContext) +{ + // in State1 + SimpleStateMachine fsm; + // Dispatch Event4, which in turn dispatches Event2 from the transitions + // table and ultimately places us in State2. + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); +} + +void TestMethodExec(nlTestSuite * inSuite, void * inContext) +{ + // in State1 + SimpleStateMachine fsm; + // transition to State2 + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + // verify expected method calls + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mEntered == 0); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mExited == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mLogged == 0); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mPrevious == nullptr); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mEntered == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mExited == 0); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mLogged == 1); + NL_TEST_ASSERT(inSuite, strcmp(fsm.mTransitions.mFactory.ms2.mPrevious, "State1") == 0); + // transition back to State1 + fsm.mStateMachine.Dispatch(Event::Create()); + NL_TEST_ASSERT(inSuite, fsm.mStateMachine.GetState().Is()); + // verify expected method calls + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mEntered == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mExited == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms1.mLogged == 1); + NL_TEST_ASSERT(inSuite, strcmp(fsm.mTransitions.mFactory.ms1.mPrevious, "State2") == 0); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mEntered == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mExited == 1); + NL_TEST_ASSERT(inSuite, fsm.mTransitions.mFactory.ms2.mLogged == 1); + NL_TEST_ASSERT(inSuite, strcmp(fsm.mTransitions.mFactory.ms2.mPrevious, "State1") == 0); +} + +int Setup(void * inContext) +{ + return SUCCESS; +} + +int Teardown(void * inContext) +{ + return SUCCESS; +} + +} // namespace + +static const nlTest sTests[] = { + NL_TEST_DEF("TestInit", TestInit), + NL_TEST_DEF("TestIgnoredEvents", TestIgnoredEvents), + NL_TEST_DEF("TestTransitions", TestTransitions), + NL_TEST_DEF("TestTransitionsDispatch", TestTransitionsDispatch), + NL_TEST_DEF("TestMethodExec", TestMethodExec), + NL_TEST_SENTINEL(), +}; + +int StateMachineTestSuite() +{ + nlTestSuite suite = { "CHIP State Machine tests", &sTests[0], Setup, Teardown }; + nlTestRunner(&suite, nullptr); + return nlTestRunnerStats(&suite); +} + +CHIP_REGISTER_TEST_SUITE(StateMachineTestSuite);