vendredi 28 mai 2021

C++ base class for multi-threaded, even driven state-pattern implementation

I'm starting a project with C++11.

In this project, there'll be lots of objects interact with each other based on their states; on the other hand, the software needs to handle lots of messages from different interfaces (UARTs, internet, Bluetooth...) asynchronously, therefore each object should own their own thread do the message/event processing.

I'm trying to construct a state-pattern base class(es) (State, StateMachine, Event) for my project, with the following constraints:

  1. Every StateMachine has its own thread to process events
  2. The derived classes can define their own Event and State

Following is my current implementation:


// statemachine.hpp

struct EventBase
{
    EventBase(const std::string name): ev_name(name) {}
    virtual ~EventBase() = default;
    const std::string ev_name;
};
using EventBasePtr = std::shared_ptr<EventBase>;

template<class ContextType>
class StateBase
{
public:
    StateBase(const std::string name): st_name(name) {}
    virtual ~StateBase() = default;

    std::string GetName() const { return st_name; }

    virtual void OnEvent(EventBasePtr ev) = 0;

    virtual void ActionEntry() = 0;
    virtual void ActionExit() = 0;

    virtual ContextType& GetContext() = 0;

private:
    const std::string st_name;
};

template<class ContextType>
class StateMachineBase {
public:
    StateMachineBase(std::shared_ptr<StateBase<ContextType> > st): st_(st) 
    {
        thread_run_ = true;
        ev_proc_thread_ = std::make_shared<std::thread>(&StateMachineBase::EventProcLoop, this);
    }
    virtual ~StateMachineBase() 
    {
        thread_run_ = false;
        if (ev_proc_thread_) {
            ev_proc_thread_->join();
            ev_proc_thread_.reset();
        }
    }

    // NOTE: 
    // DispatchEvent() should always be called in 
    // 1. callback functions
    // 2. msg handler functions
    void DispatchEvent(EventBasePtr ev) {
        ev_queue_.push(ev);
        cv_.notify_one();
    }

    // NOTE: TransitTo() should always be called in
    // 1. OnEvent()
    // 2. StateBase::ActionEntry() (if state base is a transition state)
    void TransitTo(std::shared_ptr<StateBase<ContextType> > st)
    {
        if (st->GetName() == st_->GetName())
        {
            LOG(WARN) << "ignore same-state transition";
            return;
        }
        st_->ActionExit();
        st_ = st;
        st_->ActionEntry();
    }
    std::shared_ptr<StateBase<ContextType> > st_;

private:
    void ProcessEvent(EventBasePtr ev)
    {
        st_->OnEvent(ev);
    }

    void EventProcLoop()
    {
        while (thread_run_)
        {
            std::unique_lock<std::mutex> lk(mtx_cv_);
            cv_.wait(lk, [this] {
                return !thread_run_ || ev_queue_.size() != 0;
            });
            lk.unlock();

            // thread_run_ might be toggled by another thread
            if (!thread_run_)
                break;

            EventBasePtr ev = nullptr;
            mtx_ev_queue_.lock();
            if (!ev_queue_.empty()) {
                ev = ev_queue_.front();
                ev_queue_.pop();
            }
            mtx_ev_queue_.unlock();

            if (ev) {
                ProcessEvent(ev);
            }
        }
    }

    bool thread_run_;
    std::condition_variable cv_;
    std::mutex mtx_cv_;
    std::shared_ptr<std::thread> ev_proc_thread_;

    std::mutex mtx_ev_queue_;
    std::queue<EventBasePtr> ev_queue_;
};

// wificontroller.hpp

class WifiController : public StateMachineBase<WifiController>
{
public:
    static WifiController& GetInstance()
    {
        static WifiController inst;
        return inst;
    }

private:
    WifiController();
    virtual WifiController() = default;

    // =======================================================================
    // state and event declarations
    // =======================================================================
    struct Event : public EventBase {
        enum class Type
        {
            CONNECT_CMD,
            CONNECT_RESULT,
        };
        const Type ev_type;
        Event(Type t, std::string n) : ev_type(t), EventBase(n) {}
        virtual ~Event() = default;
    };
    using EventPtr = std::shared_ptr<Event>;

    struct ConnCmdEvent : public Event
    {
        int command; // start(1), stop(0)
        string ssid;
        string password;
        ConnCmdEvent(int cmd, string id, string pw): Event(Type::CONNECT_CMD, "CONNECT_CMD"), command(cmd), ssid(id), password(pw) {}
    };

    struct ConnResEvent : public Event {
        bool success;
        ConnResEvent(bool succ): Event(Type::CONNECT_RESULT, "CONNECT_RESULT"), success(succ) {}
    };
 
    class State : public StateBase<WifiController>
    {
    public:
        enum class Type
        {
            DISCONNECTED,
            CONNECTING,
            CONNECTED
        };
        State(Type id, std::string name): StateBase(name), st_type_(id) {}

        Type GetType() const { return st_type_; }
        WifiController& GetContext() override { return WifiController::GetInstance(); }
    private:
        const Type st_type_;
    };
    using StatePtr = std::shared_ptr<State>;

    class DisconnectedState : public State
    {
    public:
        DisconnectedState() = default;
        virtual ~DisconnectedState() = default;

        void OnEvent(EventBasePtr ev) override 
        {   
            EventPtr event = dynamic_pointer_cast<Event>(ev);
            switch(event->ev_type)
            {
            case Event::Type::CONN_CMD: {
                shared_ptr<ConnCmdEvent> conn_cmd_ev = dynamic_pointer_cast<ConnCmdEvent>(event);
                if (conn_cmd_ev->cmd == 1) {
                    GetContext().TransitTo(make_shared<ConnectingState>(conn_cmd_ev->ssid, conn_cmd_ev->password));
                }
            break;
            }
            default:
                LOG(WARN) << "DisconnectedState ignores " << ev->ev_name;
            } // switch
        }
        void ActionEntry() override {}
        void ActionExit() override {}
    };

    class ConnectingState : public State
    {
    public:
        ConnectingState() : State(Type::CONNECTING, "CONNECTING") {}
        void OnEvent(EventBasePtr ev) override 
        {
            case Event::Type::CONN_RESULT: {
                shared_ptr<ConnResEvent> conn_res_ev = dynamic_pointer_cast<ConnResEvent>(event);
                if (conn_res_ev->success) {
                    GetContext().TransitTo(make_shared<ConnectedState>());
                } else {
                    GetContext().TransitTo(make_shared<DisconnectedState>());
                }
            break;
            }
            default:
                LOG(WARN) << "ConnectingState ignores " << ev->ev_name;
        }
        void ActionEntry() override { GetContext().ConnectWifi(ssid_, pw_); }
        void ActionExit() override {}
    };
}; 

The code is a simplified example and not completed. In which there're some bad smells that I want to improve but not knowing how:

  1. Since the State(s) always call StateMachineBase::TransitTo() to trigger state transition, is it possible make TransitTo(shared_ptr<StateBase>) as a member function of StateBase which will always call GetContext().TransitTo(st)?

  2. I want all classes which inherit EventBase and StateBase class has enum class Type in which they define the type id for their states and events, but in the current implementation, there's no this constraint. How should I modify the code to force all derived classes to follow this rule?

  3. Is there any other bad smell in my code? How to improve it?

Thanks in advance.

Aucun commentaire:

Enregistrer un commentaire