diff --git a/src/lib/support/StateMachine.h b/src/lib/support/StateMachine.h index 4cbdf767ac772d..eb96d3d3b9dfe7 100644 --- a/src/lib/support/StateMachine.h +++ b/src/lib/support/StateMachine.h @@ -48,12 +48,12 @@ struct VariantState : Variant { private: - template - void Enter() + template + void Enter(Optional & event) { if (chip::Variant::template Is()) { - chip::Variant::template Get().Enter(); + event = chip::Variant::template Get().Enter(); } } @@ -93,9 +93,12 @@ struct VariantState : Variant return instance; } - void Enter() + template + Optional Enter() { - [](...) {}((this->template Enter(), 0)...); + Optional event; + [](...) {}((this->template Enter(event), 0)...); + return event; } void Exit() @@ -116,25 +119,6 @@ struct VariantState : Variant } }; -/** - * The interface for dispatching events into the State Machine. - * @tparam TEvent a variant holding the Events for the State Machine. - */ -template -class Context -{ -public: - virtual ~Context() = default; - - /** - * Dispatch an event to the current state. - * @note This call can result in the current state being deleted. Do not - * access current state memory after calling this method. - * @param evt a variant holding an Event for the State Machine. - */ - virtual void Dispatch(const TEvent & evt) = 0; -}; - /** * This is a functional approach to the State Machine design pattern. The design is * borrowed from http://www.vishalchovatiya.com/state-design-pattern-in-modern-cpp @@ -203,48 +187,27 @@ class Context * @tparam TTransitions an object that implements the () operator for transitions. */ template -class StateMachine : public Context +class StateMachine { public: StateMachine(TTransitions & tr) : mCurrentState(tr.GetInitState()), mTransitions(tr) {} - ~StateMachine() override = default; - void Dispatch(const TEvent & evt) override - { - auto inProcess = !events.empty(); - events.push_back(evt); - if (!inProcess) - { - HandleEvents(); - } - } - TState mCurrentState; - -private: - void HandleEvents() + void Dispatch(const TEvent & evt) { - while (!events.empty()) + Optional optEvent(evt); + Optional optState; + while (optEvent.HasValue() && (optState = mTransitions(mCurrentState, evt)).HasValue()) { - auto count = events.size(); - auto optState = mTransitions(mCurrentState, events.front()); - if (optState.HasValue()) - { - auto newState = optState.Value(); - newState.LogTransition(mCurrentState.GetName()); - mCurrentState.Exit(); - mCurrentState = newState; - while (events.size() > count) - { - events.pop_back(); // events discarded per design - } - mCurrentState.Enter(); - } - events.pop_front(); + auto newState = optState.Value(); + newState.LogTransition(mCurrentState.GetName()); + mCurrentState.Exit(); + mCurrentState = newState; + optEvent = mCurrentState.template Enter(); } } + TState mCurrentState; TTransitions & mTransitions; - std::deque events{}; }; } // namespace StateMachine diff --git a/src/lib/support/tests/TestStateMachine.cpp b/src/lib/support/tests/TestStateMachine.cpp index a9fbafd419a729..053e94f240988b 100644 --- a/src/lib/support/tests/TestStateMachine.cpp +++ b/src/lib/support/tests/TestStateMachine.cpp @@ -36,7 +36,6 @@ struct Event4 }; using Event = chip::Variant; -using Context = chip::StateMachine::Context; struct MockState { @@ -45,7 +44,7 @@ struct MockState unsigned mLogged; const char * mPrevious; - void Enter() { ++mEntered; } + chip::StateMachine::Optional Enter() { ++mEntered; return {}; } void Exit() { ++mExited; } void LogTransition(const char * previous) { @@ -56,46 +55,40 @@ struct MockState struct BaseState { - void Enter() { mMock.Enter(); } + chip::StateMachine::Optional Enter() { return mMock.Enter(); } void Exit() { mMock.Exit(); } void LogTransition(const char * previous) { mMock.LogTransition(previous); } const char * GetName() { return mName; } - chip::StateMachine::Context & mCtx; const char * mName; MockState & mMock; }; struct State1 : public BaseState { - State1(Context & ctx, const char * name, MockState & mock) : BaseState{ ctx, name, mock } {} + State1(MockState & mock) : BaseState{ "State2", mock } {} }; struct State2 : public BaseState { - State2(Context & ctx, const char * name, MockState & mock) : BaseState{ ctx, name, mock } {} + State2(MockState & mock) : BaseState{ "State2", mock } {} }; using State = chip::StateMachine::VariantState; struct StateFactory { - Context & mCtx; MockState ms1{ 0, 0, 0, nullptr }; MockState ms2{ 0, 0, 0, nullptr }; - StateFactory(Context & ctx) : mCtx(ctx) {} + auto CreateState1() { return State::Create(ms1); } - auto CreateState1() { return State::Create(mCtx, "State1", ms1); } - - auto CreateState2() { return State::Create(mCtx, "State2", ms2); } + auto CreateState2() { return State::Create(ms2); } }; struct Transitions { - Context & mCtx; StateFactory mFactory; - Transitions(Context & ctx) : mCtx(ctx), mFactory(ctx) {} using OptState = chip::StateMachine::Optional; State GetInitState() { return mFactory.CreateState1(); } @@ -111,15 +104,10 @@ struct Transitions } else if (state.Is() && event.Is()) { - // legal - Dispatches event without transition - mCtx.Dispatch(Event::Create()); - return {}; + return mFactory.CreateState2(); } else if (state.Is() && event.Is()) { - // illegal - Returned Transition will cause events - // dispatched from the transitions table to be ignored. - mCtx.Dispatch(Event::Create()); return mFactory.CreateState1(); } else @@ -135,7 +123,7 @@ class SimpleStateMachine Transitions mTransitions; chip::StateMachine::StateMachine mStateMachine; - SimpleStateMachine() : mTransitions(mStateMachine), mStateMachine(mTransitions) {} + SimpleStateMachine() : mStateMachine(mTransitions) {} ~SimpleStateMachine() {} };