diff --git a/internal/coord/behaviour_test.go b/internal/coord/behaviour_test.go index ed0761a..2624827 100644 --- a/internal/coord/behaviour_test.go +++ b/internal/coord/behaviour_test.go @@ -2,11 +2,12 @@ package coord import ( "context" + "testing" ) type RecordingSM[E any, S any] struct { State S - Received E + Received []E } func NewRecordingSM[E any, S any](response S) *RecordingSM[E, S] { @@ -16,6 +17,27 @@ func NewRecordingSM[E any, S any](response S) *RecordingSM[E, S] { } func (r *RecordingSM[E, S]) Advance(ctx context.Context, e E) S { - r.Received = e + r.Received = append(r.Received, e) return r.State } + +func (r *RecordingSM[E, S]) first() E { + if len(r.Received) == 0 { + var zero E + return zero + } + return r.Received[0] +} + +func DrainBehaviour[I BehaviourEvent, O BehaviourEvent](t *testing.T, ctx context.Context, b Behaviour[I, O]) { + for { + select { + case <-b.Ready(): + b.Perform(ctx) + case <-ctx.Done(): + t.Fatal("context cancelled while draining behaviour") + default: + return + } + } +} diff --git a/internal/coord/coordinator.go b/internal/coord/coordinator.go index 0f28a8e..f975928 100644 --- a/internal/coord/coordinator.go +++ b/internal/coord/coordinator.go @@ -96,7 +96,7 @@ type CoordinatorConfig struct { Routing RoutingConfig // Query is the configuration used for the [PooledQueryBehaviour] which manages the execution of user queries. - Query PooledQueryConfig + Query QueryConfig } // Validate checks the configuration options and returns an error if any have invalid values. @@ -141,7 +141,7 @@ func DefaultCoordinatorConfig() *CoordinatorConfig { TracerProvider: otel.GetTracerProvider(), } - cfg.Query = *DefaultPooledQueryConfig() + cfg.Query = *DefaultQueryConfig() cfg.Query.Clock = cfg.Clock cfg.Query.Logger = cfg.Logger.With("behaviour", "pooledquery") cfg.Query.Tracer = cfg.TracerProvider.Tracer(tele.TracerName) @@ -168,7 +168,7 @@ func NewCoordinator(self kadt.PeerID, rtr coordt.Router[kadt.Key, kadt.PeerID, * return nil, fmt.Errorf("init telemetry: %w", err) } - queryBehaviour, err := NewPooledQueryBehaviour(self, &cfg.Query) + queryBehaviour, err := NewQueryBehaviour(self, &cfg.Query) if err != nil { return nil, fmt.Errorf("query behaviour: %w", err) } diff --git a/internal/coord/query.go b/internal/coord/query.go index 8244c45..9fe117a 100644 --- a/internal/coord/query.go +++ b/internal/coord/query.go @@ -18,7 +18,7 @@ import ( "github.com/plprobelab/zikade/tele" ) -type PooledQueryConfig struct { +type QueryConfig struct { // Clock is a clock that may replaced by a mock when testing Clock clock.Clock @@ -42,7 +42,7 @@ type PooledQueryConfig struct { } // Validate checks the configuration options and returns an error if any have invalid values. -func (cfg *PooledQueryConfig) Validate() error { +func (cfg *QueryConfig) Validate() error { if cfg.Clock == nil { return &errs.ConfigurationError{ Component: "PooledQueryConfig", @@ -94,8 +94,8 @@ func (cfg *PooledQueryConfig) Validate() error { return nil } -func DefaultPooledQueryConfig() *PooledQueryConfig { - return &PooledQueryConfig{ +func DefaultQueryConfig() *QueryConfig { + return &QueryConfig{ Clock: clock.New(), Logger: tele.DefaultLogger("coord"), Tracer: tele.NoopTracer(), @@ -107,10 +107,10 @@ func DefaultPooledQueryConfig() *PooledQueryConfig { } } -// PooledQueryBehaviour holds the behaviour and state for managing a pool of queries. -type PooledQueryBehaviour struct { +// QueryBehaviour holds the behaviour and state for managing a pool of queries. +type QueryBehaviour struct { // cfg is a copy of the optional configuration supplied to the behaviour. - cfg PooledQueryConfig + cfg QueryConfig // performMu is held while Perform is executing to ensure sequential execution of work. performMu sync.Mutex @@ -137,11 +137,11 @@ type PooledQueryBehaviour struct { ready chan struct{} } -// NewPooledQueryBehaviour initialises a new PooledQueryBehaviour, setting up the query +// NewQueryBehaviour initialises a new [QueryBehaviour], setting up the query // pool and other internal state. -func NewPooledQueryBehaviour(self kadt.PeerID, cfg *PooledQueryConfig) (*PooledQueryBehaviour, error) { +func NewQueryBehaviour(self kadt.PeerID, cfg *QueryConfig) (*QueryBehaviour, error) { if cfg == nil { - cfg = DefaultPooledQueryConfig() + cfg = DefaultQueryConfig() } else if err := cfg.Validate(); err != nil { return nil, err } @@ -158,7 +158,7 @@ func NewPooledQueryBehaviour(self kadt.PeerID, cfg *PooledQueryConfig) (*PooledQ return nil, fmt.Errorf("query pool: %w", err) } - h := &PooledQueryBehaviour{ + h := &QueryBehaviour{ cfg: *cfg, pool: pool, notifiers: make(map[coordt.QueryID]*queryNotifier[*EventQueryFinished]), @@ -170,7 +170,7 @@ func NewPooledQueryBehaviour(self kadt.PeerID, cfg *PooledQueryConfig) (*PooledQ // Notify receives a behaviour event and takes appropriate actions such as starting, // stopping, or updating queries. It also queues events for later processing and // triggers the advancement of the query pool if applicable. -func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { +func (p *QueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { p.pendingInboundMu.Lock() defer p.pendingInboundMu.Unlock() @@ -187,14 +187,14 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { // Ready returns a channel that signals when the pooled query behaviour is ready to // perform work. -func (p *PooledQueryBehaviour) Ready() <-chan struct{} { +func (p *QueryBehaviour) Ready() <-chan struct{} { return p.ready } // Perform executes the next available task from the queue of pending events or advances // the query pool. Returns an event containing the result of the work performed and a // true value, or nil and a false value if no event was generated. -func (p *PooledQueryBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { +func (p *QueryBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { p.performMu.Lock() defer p.performMu.Unlock() @@ -230,7 +230,7 @@ func (p *PooledQueryBehaviour) Perform(ctx context.Context) (BehaviourEvent, boo return p.nextPendingOutbound() } -func (p *PooledQueryBehaviour) nextPendingOutbound() (BehaviourEvent, bool) { +func (p *QueryBehaviour) nextPendingOutbound() (BehaviourEvent, bool) { if len(p.pendingOutbound) == 0 { return nil, false } @@ -239,7 +239,7 @@ func (p *PooledQueryBehaviour) nextPendingOutbound() (BehaviourEvent, bool) { return ev, true } -func (p *PooledQueryBehaviour) nextPendingInbound() (CtxEvent[BehaviourEvent], bool) { +func (p *QueryBehaviour) nextPendingInbound() (CtxEvent[BehaviourEvent], bool) { p.pendingInboundMu.Lock() defer p.pendingInboundMu.Unlock() if len(p.pendingInbound) == 0 { @@ -250,7 +250,7 @@ func (p *PooledQueryBehaviour) nextPendingInbound() (CtxEvent[BehaviourEvent], b return pev, true } -func (p *PooledQueryBehaviour) perfomNextInbound(ctx context.Context) (BehaviourEvent, bool) { +func (p *QueryBehaviour) perfomNextInbound(ctx context.Context) (BehaviourEvent, bool) { ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.perfomNextInbound") defer span.End() pev, ok := p.nextPendingInbound() @@ -346,7 +346,7 @@ func (p *PooledQueryBehaviour) perfomNextInbound(ctx context.Context) (Behaviour return p.advancePool(pev.Ctx, cmd) } -func (p *PooledQueryBehaviour) updateReadyStatus() { +func (p *QueryBehaviour) updateReadyStatus() { if len(p.pendingOutbound) != 0 { select { case p.ready <- struct{}{}: @@ -371,7 +371,7 @@ func (p *PooledQueryBehaviour) updateReadyStatus() { // advancePool advances the query pool state machine and returns an outbound event if // there is work to be performed. Also notifies waiters of query completion or // progress. -func (p *PooledQueryBehaviour) advancePool(ctx context.Context, ev query.PoolEvent) (out BehaviourEvent, term bool) { +func (p *QueryBehaviour) advancePool(ctx context.Context, ev query.PoolEvent) (out BehaviourEvent, term bool) { ctx, span := p.cfg.Tracer.Start(ctx, "PooledQueryBehaviour.advancePool", trace.WithAttributes(tele.AttrInEvent(ev))) defer func() { span.SetAttributes(tele.AttrOutEvent(out)) @@ -420,7 +420,7 @@ func (p *PooledQueryBehaviour) advancePool(ctx context.Context, ev query.PoolEve return nil, false } -func (p *PooledQueryBehaviour) queueAddNodeEvents(nodes []kadt.PeerID) { +func (p *QueryBehaviour) queueAddNodeEvents(nodes []kadt.PeerID) { for _, info := range nodes { p.pendingOutbound = append(p.pendingOutbound, &EventAddNode{ NodeID: info, @@ -428,7 +428,7 @@ func (p *PooledQueryBehaviour) queueAddNodeEvents(nodes []kadt.PeerID) { } } -func (p *PooledQueryBehaviour) queueNonConnectivityEvent(nid kadt.PeerID) { +func (p *QueryBehaviour) queueNonConnectivityEvent(nid kadt.PeerID) { p.pendingOutbound = append(p.pendingOutbound, &EventNotifyNonConnectivity{ NodeID: nid, }) diff --git a/internal/coord/query_test.go b/internal/coord/query_test.go index 83afe34..2aeaabe 100644 --- a/internal/coord/query_test.go +++ b/internal/coord/query_test.go @@ -16,34 +16,34 @@ import ( "github.com/plprobelab/zikade/pb" ) -func TestPooledQueryConfigValidate(t *testing.T) { +func TestQueryConfigValidate(t *testing.T) { t.Run("default is valid", func(t *testing.T) { - cfg := DefaultPooledQueryConfig() + cfg := DefaultQueryConfig() require.NoError(t, cfg.Validate()) }) t.Run("clock is not nil", func(t *testing.T) { - cfg := DefaultPooledQueryConfig() + cfg := DefaultQueryConfig() cfg.Clock = nil require.Error(t, cfg.Validate()) }) t.Run("logger not nil", func(t *testing.T) { - cfg := DefaultPooledQueryConfig() + cfg := DefaultQueryConfig() cfg.Logger = nil require.Error(t, cfg.Validate()) }) t.Run("tracer not nil", func(t *testing.T) { - cfg := DefaultPooledQueryConfig() + cfg := DefaultQueryConfig() cfg.Tracer = nil require.Error(t, cfg.Validate()) }) t.Run("query concurrency positive", func(t *testing.T) { - cfg := DefaultPooledQueryConfig() + cfg := DefaultQueryConfig() cfg.Concurrency = 0 require.Error(t, cfg.Validate()) @@ -52,7 +52,7 @@ func TestPooledQueryConfigValidate(t *testing.T) { }) t.Run("query timeout positive", func(t *testing.T) { - cfg := DefaultPooledQueryConfig() + cfg := DefaultQueryConfig() cfg.Timeout = 0 require.Error(t, cfg.Validate()) @@ -61,7 +61,7 @@ func TestPooledQueryConfigValidate(t *testing.T) { }) t.Run("request concurrency positive", func(t *testing.T) { - cfg := DefaultPooledQueryConfig() + cfg := DefaultQueryConfig() cfg.RequestConcurrency = 0 require.Error(t, cfg.Validate()) @@ -70,7 +70,7 @@ func TestPooledQueryConfigValidate(t *testing.T) { }) t.Run("request timeout positive", func(t *testing.T) { - cfg := DefaultPooledQueryConfig() + cfg := DefaultQueryConfig() cfg.RequestTimeout = 0 require.Error(t, cfg.Validate()) @@ -86,7 +86,7 @@ func TestQueryBehaviourBase(t *testing.T) { type QueryBehaviourBaseTestSuite struct { suite.Suite - cfg *PooledQueryConfig + cfg *QueryConfig top *nettest.Topology nodes []*nettest.Peer } @@ -99,7 +99,7 @@ func (ts *QueryBehaviourBaseTestSuite) SetupTest() { ts.top = top ts.nodes = nodes - ts.cfg = DefaultPooledQueryConfig() + ts.cfg = DefaultQueryConfig() ts.cfg.Clock = clk } @@ -111,7 +111,7 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesNoProgress() { rt := ts.nodes[0].RoutingTable seeds := rt.NearestNodes(target, 5) - b, err := NewPooledQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) + b, err := NewQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) ts.Require().NoError(err) waiter := NewQueryWaiter(5) @@ -158,7 +158,7 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryProgressed() { rt := ts.nodes[0].RoutingTable seeds := rt.NearestNodes(target, 5) - b, err := NewPooledQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) + b, err := NewQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) ts.Require().NoError(err) waiter := NewQueryWaiter(5) @@ -206,7 +206,7 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryFinished() { rt := ts.nodes[0].RoutingTable seeds := rt.NearestNodes(target, 5) - b, err := NewPooledQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) + b, err := NewQueryBehaviour(ts.nodes[0].NodeID, ts.cfg) ts.Require().NoError(err) waiter := NewQueryWaiter(5) @@ -274,7 +274,7 @@ func (ts *QueryBehaviourBaseTestSuite) TestNotifiesQueryFinished() { kadtest.ReadItem[CtxEvent[*EventQueryFinished]](t, ctx, waiter.Finished()) } -func TestPooledQuery_deadlock_regression(t *testing.T) { +func TestQuery_deadlock_regression(t *testing.T) { t.Skip() ctx := kadtest.CtxShort(t) msg := &pb.Message{} diff --git a/internal/coord/routing.go b/internal/coord/routing.go index 8d212ef..a23489c 100644 --- a/internal/coord/routing.go +++ b/internal/coord/routing.go @@ -290,24 +290,40 @@ type RoutingBehaviour struct { // cfg is a copy of the optional configuration supplied to the behaviour cfg RoutingConfig + // performMu is held while Perform is executing to ensure sequential execution of work. + performMu sync.Mutex + // bootstrap is the bootstrap state machine, responsible for bootstrapping the routing table + // it must only be accessed while performMu is held bootstrap coordt.StateMachine[routing.BootstrapEvent, routing.BootstrapState] // include is the inclusion state machine, responsible for vetting nodes before including them in the routing table + // it must only be accessed while performMu is held include coordt.StateMachine[routing.IncludeEvent, routing.IncludeState] // probe is the node probing state machine, responsible for periodically checking connectivity of nodes in the routing table + // it must only be accessed while performMu is held probe coordt.StateMachine[routing.ProbeEvent, routing.ProbeState] // explore is the routing table explore state machine, responsible for increasing the occupant of the routing table + // it must only be accessed while performMu is held explore coordt.StateMachine[routing.ExploreEvent, routing.ExploreState] // crawl is the state machine that can crawl the network from a set of seed nodes + // it must only be accessed while performMu is held crawl coordt.StateMachine[routing.CrawlEvent, routing.CrawlState] - pendingMu sync.Mutex - pending []BehaviourEvent - ready chan struct{} + // pendingOutbound is a queue of outbound events. + // it must only be accessed while performMu is held + pendingOutbound []BehaviourEvent + + // pendingInboundMu guards access to pendingInbound + pendingInboundMu sync.Mutex + + // pendingInbound is a queue of inbound events that are awaiting processing + pendingInbound []CtxEvent[BehaviourEvent] + + ready chan struct{} } func NewRoutingBehaviour(self kadt.PeerID, rt routing.RoutingTableCpl[kadt.Key, kadt.PeerID], cfg *RoutingConfig) (*RoutingBehaviour, error) { @@ -351,7 +367,7 @@ func NewRoutingBehaviour(self kadt.PeerID, rt routing.RoutingTableCpl[kadt.Key, probeCfg.Concurrency = cfg.ProbeRequestConcurrency probeCfg.CheckInterval = cfg.ProbeCheckInterval - probe, err := routing.NewProbe[kadt.Key, kadt.PeerID](rt, probeCfg) + probe, err := routing.NewProbe[kadt.Key](rt, probeCfg) if err != nil { return nil, fmt.Errorf("probe: %w", err) } @@ -369,7 +385,7 @@ func NewRoutingBehaviour(self kadt.PeerID, rt routing.RoutingTableCpl[kadt.Key, return nil, fmt.Errorf("explore schedule: %w", err) } - explore, err := routing.NewExplore[kadt.Key, kadt.PeerID](self, rt, cplutil.GenRandPeerID, schedule, exploreCfg) + explore, err := routing.NewExplore[kadt.Key](self, rt, cplutil.GenRandPeerID, schedule, exploreCfg) if err != nil { return nil, fmt.Errorf("explore: %w", err) } @@ -415,29 +431,110 @@ func ComposeRoutingBehaviour( } func (r *RoutingBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { + r.pendingInboundMu.Lock() + defer r.pendingInboundMu.Unlock() + ctx, span := r.cfg.Tracer.Start(ctx, "RoutingBehaviour.Notify") defer span.End() - r.pendingMu.Lock() - defer r.pendingMu.Unlock() - r.notify(ctx, ev) + r.pendingInbound = append(r.pendingInbound, CtxEvent[BehaviourEvent]{Ctx: ctx, Event: ev}) + + select { + case r.ready <- struct{}{}: + default: + } +} + +func (r *RoutingBehaviour) Ready() <-chan struct{} { + return r.ready } -// notify must only be called while r.pendingMu is held -func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { - ctx, span := r.cfg.Tracer.Start(ctx, "RoutingBehaviour.notify", trace.WithAttributes(tele.AttrInEvent(ev))) +func (r *RoutingBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { + r.performMu.Lock() + defer r.performMu.Unlock() + + ctx, span := r.cfg.Tracer.Start(ctx, "RoutingBehaviour.Perform") defer span.End() - switch ev := ev.(type) { + defer r.updateReadyStatus() + + // drain queued events first. + // drain queued outbound events before starting new work. + ev, ok := r.nextPendingOutbound() + if ok { + return ev, true + } + + // perform one piece of pending inbound work. + ev, ok = r.perfomNextInbound() + if ok { + return ev, true + } + + // poll the child state machines in priority order to give each an opportunity to perform work + r.pollChildren(ctx) + + // finally check if any pending events were accumulated in the meantime + return r.nextPendingOutbound() +} + +func (r *RoutingBehaviour) nextPendingOutbound() (BehaviourEvent, bool) { + if len(r.pendingOutbound) == 0 { + return nil, false + } + var ev BehaviourEvent + ev, r.pendingOutbound = r.pendingOutbound[0], r.pendingOutbound[1:] + return ev, true +} + +func (r *RoutingBehaviour) updateReadyStatus() { + if len(r.pendingOutbound) != 0 { + select { + case r.ready <- struct{}{}: + default: + } + return + } + + r.pendingInboundMu.Lock() + hasPendingInbound := len(r.pendingInbound) != 0 + r.pendingInboundMu.Unlock() + + if hasPendingInbound { + select { + case r.ready <- struct{}{}: + default: + } + return + } +} + +func (r *RoutingBehaviour) nextPendingInbound() (CtxEvent[BehaviourEvent], bool) { + r.pendingInboundMu.Lock() + defer r.pendingInboundMu.Unlock() + if len(r.pendingInbound) == 0 { + return CtxEvent[BehaviourEvent]{}, false + } + var pev CtxEvent[BehaviourEvent] + pev, r.pendingInbound = r.pendingInbound[0], r.pendingInbound[1:] + return pev, true +} + +func (r *RoutingBehaviour) perfomNextInbound() (BehaviourEvent, bool) { + pev, ok := r.nextPendingInbound() + if !ok { + return nil, false + } + ctx, span := r.cfg.Tracer.Start(pev.Ctx, "PooledQueryBehaviour.perfomNextInbound", trace.WithAttributes(tele.AttrInEvent(pev))) + defer span.End() + + switch ev := pev.Event.(type) { case *EventStartBootstrap: cmd := &routing.EventBootstrapStart[kadt.Key, kadt.PeerID]{ KnownClosestNodes: ev.SeedNodes, } - // attempt to advance the bootstrap state machine - next, ok := r.advanceBootstrap(ctx, cmd) - if ok { - r.pending = append(r.pending, next) - } + // attempt to advance the bootstrap + return r.advanceBootstrap(ctx, cmd) case *EventStartCrawl: cmd := &routing.EventCrawlStart[kadt.Key, kadt.PeerID]{ @@ -446,7 +543,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the crawl state machine next, ok := r.advanceCrawl(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case *EventAddNode: @@ -470,7 +567,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the include state machine next, ok := r.advanceInclude(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case *EventRoutingUpdated: @@ -481,7 +578,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the probe state machine next, ok := r.advanceProbe(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case *EventGetCloserNodesSuccess: @@ -490,7 +587,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { case routing.BootstrapQueryID: for _, info := range ev.CloserNodes { // TODO: do this after advancing bootstrap - r.pending = append(r.pending, &EventAddNode{ + r.pendingOutbound = append(r.pendingOutbound, &EventAddNode{ NodeID: info, }) } @@ -501,7 +598,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the bootstrap next, ok := r.advanceBootstrap(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case IncludeQueryID: @@ -520,7 +617,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the include next, ok := r.advanceInclude(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case ProbeQueryID: @@ -539,12 +636,12 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the probe state machine next, ok := r.advanceProbe(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case routing.ExploreQueryID: for _, info := range ev.CloserNodes { - r.pending = append(r.pending, &EventAddNode{ + r.pendingOutbound = append(r.pendingOutbound, &EventAddNode{ NodeID: info, }) } @@ -554,11 +651,11 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { } next, ok := r.advanceExplore(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case routing.CrawlQueryID: - r.pending = append(r.pending, &EventAddNode{ + r.pendingOutbound = append(r.pendingOutbound, &EventAddNode{ NodeID: ev.To, Checked: true, }) @@ -572,7 +669,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the crawl next, ok := r.advanceCrawl(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } default: @@ -591,7 +688,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the bootstrap next, ok := r.advanceBootstrap(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case IncludeQueryID: cmd := &routing.EventIncludeConnectivityCheckFailure[kadt.Key, kadt.PeerID]{ @@ -601,7 +698,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the include state machine next, ok := r.advanceInclude(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case ProbeQueryID: cmd := &routing.EventProbeConnectivityCheckFailure[kadt.Key, kadt.PeerID]{ @@ -611,7 +708,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the probe state machine next, ok := r.advanceProbe(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case routing.ExploreQueryID: cmd := &routing.EventExploreFindCloserFailure[kadt.Key, kadt.PeerID]{ @@ -621,7 +718,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the explore next, ok := r.advanceExplore(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } case routing.CrawlQueryID: cmd := &routing.EventCrawlNodeFailure[kadt.Key, kadt.PeerID]{ @@ -632,7 +729,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // attempt to advance the crawl next, ok := r.advanceCrawl(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } default: @@ -653,7 +750,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { } next, ok := r.advanceInclude(ctx, cmd) if ok { - r.pending = append(r.pending, next) + r.pendingOutbound = append(r.pendingOutbound, next) } // tell the probe state machine in case there are connectivity checks that could be satisfied @@ -662,7 +759,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { } nextProbe, ok := r.advanceProbe(ctx, cmdProbe) if ok { - r.pending = append(r.pending, nextProbe) + r.pendingOutbound = append(r.pendingOutbound, nextProbe) } case *EventNotifyNonConnectivity: @@ -674,7 +771,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { } nextProbe, ok := r.advanceProbe(ctx, cmdProbe) if ok { - r.pending = append(r.pending, nextProbe) + r.pendingOutbound = append(r.pendingOutbound, nextProbe) } case *EventRoutingPoll: @@ -684,76 +781,34 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { panic(fmt.Sprintf("unexpected dht event: %T", ev)) } - if len(r.pending) > 0 { - select { - case r.ready <- struct{}{}: - default: - } - } -} - -func (r *RoutingBehaviour) Ready() <-chan struct{} { - return r.ready -} - -func (r *RoutingBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { - ctx, span := r.cfg.Tracer.Start(ctx, "RoutingBehaviour.Perform") - defer span.End() - - // No inbound work can be done until Perform is complete - r.pendingMu.Lock() - defer r.pendingMu.Unlock() - - for { - // drain queued events first. - if len(r.pending) > 0 { - var ev BehaviourEvent - ev, r.pending = r.pending[0], r.pending[1:] - - if len(r.pending) > 0 { - select { - case r.ready <- struct{}{}: - default: - } - } - return ev, true - } - - // poll the child state machines in priority order to give each an opportunity to perform work - r.pollChildren(ctx) - - // finally check if any pending events were accumulated in the meantime - if len(r.pending) == 0 { - return nil, false - } - } + return nil, false } // pollChildren must only be called while r.pendingMu is locked func (r *RoutingBehaviour) pollChildren(ctx context.Context) { ev, ok := r.advanceBootstrap(ctx, &routing.EventBootstrapPoll{}) if ok { - r.pending = append(r.pending, ev) + r.pendingOutbound = append(r.pendingOutbound, ev) } ev, ok = r.advanceInclude(ctx, &routing.EventIncludePoll{}) if ok { - r.pending = append(r.pending, ev) + r.pendingOutbound = append(r.pendingOutbound, ev) } ev, ok = r.advanceProbe(ctx, &routing.EventProbePoll{}) if ok { - r.pending = append(r.pending, ev) + r.pendingOutbound = append(r.pendingOutbound, ev) } ev, ok = r.advanceExplore(ctx, &routing.EventExplorePoll{}) if ok { - r.pending = append(r.pending, ev) + r.pendingOutbound = append(r.pendingOutbound, ev) } ev, ok = r.advanceCrawl(ctx, &routing.EventCrawlPoll{}) if ok { - r.pending = append(r.pending, ev) + r.pendingOutbound = append(r.pendingOutbound, ev) } } @@ -808,7 +863,7 @@ func (r *RoutingBehaviour) advanceInclude(ctx context.Context, ev routing.Includ // a node has been included in the routing table // notify other routing state machines that there is a new node in the routing table - r.notify(ctx, &EventRoutingUpdated{ + r.Notify(ctx, &EventRoutingUpdated{ NodeID: st.NodeID, }) @@ -852,14 +907,15 @@ func (r *RoutingBehaviour) advanceProbe(ctx context.Context, ev routing.ProbeEve // emit an EventRoutingRemoved event to notify clients that the node has been removed r.cfg.Logger.Debug("peer removed from routing table", tele.LogAttrPeerID(st.NodeID)) - r.pending = append(r.pending, &EventRoutingRemoved{ + r.pendingOutbound = append(r.pendingOutbound, &EventRoutingRemoved{ NodeID: st.NodeID, }) // add the node to the inclusion list for a second chance - r.notify(ctx, &EventAddNode{ + r.Notify(ctx, &EventAddNode{ NodeID: st.NodeID, }) + case *routing.StateProbeWaitingAtCapacity: // the probe state machine is waiting for responses for checks and the maximum number of concurrent checks has been reached. // nothing to do except wait for message response or timeout diff --git a/internal/coord/routing_test.go b/internal/coord/routing_test.go index 952b66b..89ea2ff 100644 --- a/internal/coord/routing_test.go +++ b/internal/coord/routing_test.go @@ -238,12 +238,13 @@ func TestRoutingStartBootstrapSendsEvent(t *testing.T) { } routingBehaviour.Notify(ctx, ev) + routingBehaviour.Perform(ctx) // the event that should be passed to the bootstrap state machine expected := &routing.EventBootstrapStart[kadt.Key, kadt.PeerID]{ KnownClosestNodes: ev.SeedNodes, } - require.Equal(t, expected, bootstrap.Received) + require.Equal(t, expected, bootstrap.first()) } func TestRoutingBootstrapGetClosestNodesSuccess(t *testing.T) { @@ -271,11 +272,12 @@ func TestRoutingBootstrapGetClosestNodesSuccess(t *testing.T) { } routingBehaviour.Notify(ctx, ev) + routingBehaviour.Perform(ctx) // bootstrap should receive message response event - require.IsType(t, &routing.EventBootstrapFindCloserResponse[kadt.Key, kadt.PeerID]{}, bootstrap.Received) + require.IsType(t, &routing.EventBootstrapFindCloserResponse[kadt.Key, kadt.PeerID]{}, bootstrap.first()) - rev := bootstrap.Received.(*routing.EventBootstrapFindCloserResponse[kadt.Key, kadt.PeerID]) + rev := bootstrap.first().(*routing.EventBootstrapFindCloserResponse[kadt.Key, kadt.PeerID]) require.True(t, nodes[1].NodeID.Equal(rev.NodeID)) require.Equal(t, ev.CloserNodes, rev.CloserNodes) } @@ -306,11 +308,12 @@ func TestRoutingBootstrapGetClosestNodesFailure(t *testing.T) { } routingBehaviour.Notify(ctx, ev) + routingBehaviour.Perform(ctx) // bootstrap should receive message response event - require.IsType(t, &routing.EventBootstrapFindCloserFailure[kadt.Key, kadt.PeerID]{}, bootstrap.Received) + require.IsType(t, &routing.EventBootstrapFindCloserFailure[kadt.Key, kadt.PeerID]{}, bootstrap.first()) - rev := bootstrap.Received.(*routing.EventBootstrapFindCloserFailure[kadt.Key, kadt.PeerID]) + rev := bootstrap.first().(*routing.EventBootstrapFindCloserFailure[kadt.Key, kadt.PeerID]) require.Equal(t, peer.ID(nodes[1].NodeID), peer.ID(rev.NodeID)) require.Equal(t, failure, rev.Error) } @@ -337,12 +340,13 @@ func TestRoutingAddNodeInfoSendsEvent(t *testing.T) { } routingBehaviour.Notify(ctx, ev) + routingBehaviour.Perform(ctx) // the event that should be passed to the include state machine expected := &routing.EventIncludeAddCandidate[kadt.Key, kadt.PeerID]{ NodeID: ev.NodeID, } - require.Equal(t, expected, include.Received) + require.Equal(t, expected, include.first()) } func TestRoutingIncludeGetClosestNodesSuccess(t *testing.T) { @@ -370,11 +374,12 @@ func TestRoutingIncludeGetClosestNodesSuccess(t *testing.T) { } routingBehaviour.Notify(ctx, ev) + routingBehaviour.Perform(ctx) // include should receive message response event - require.IsType(t, &routing.EventIncludeConnectivityCheckSuccess[kadt.Key, kadt.PeerID]{}, include.Received) + require.IsType(t, &routing.EventIncludeConnectivityCheckSuccess[kadt.Key, kadt.PeerID]{}, include.first()) - rev := include.Received.(*routing.EventIncludeConnectivityCheckSuccess[kadt.Key, kadt.PeerID]) + rev := include.first().(*routing.EventIncludeConnectivityCheckSuccess[kadt.Key, kadt.PeerID]) require.Equal(t, peer.ID(nodes[1].NodeID), peer.ID(rev.NodeID)) } @@ -404,11 +409,12 @@ func TestRoutingIncludeGetClosestNodesFailure(t *testing.T) { } routingBehaviour.Notify(ctx, ev) + routingBehaviour.Perform(ctx) // include should receive message response event - require.IsType(t, &routing.EventIncludeConnectivityCheckFailure[kadt.Key, kadt.PeerID]{}, include.Received) + require.IsType(t, &routing.EventIncludeConnectivityCheckFailure[kadt.Key, kadt.PeerID]{}, include.first()) - rev := include.Received.(*routing.EventIncludeConnectivityCheckFailure[kadt.Key, kadt.PeerID]) + rev := include.first().(*routing.EventIncludeConnectivityCheckFailure[kadt.Key, kadt.PeerID]) require.Equal(t, peer.ID(nodes[1].NodeID), peer.ID(rev.NodeID)) require.Equal(t, failure, rev.Error) } @@ -470,20 +476,23 @@ func TestRoutingIncludedNodeAddToProbeList(t *testing.T) { Target: oev.Target, CloserNodes: []kadt.PeerID{nodes[1].NodeID}, // must include one for include check to pass }) + dev, ok = routingBehaviour.Perform(ctx) // the routing table should now contain the node _, intable = rt.GetNode(candidate.Key()) require.True(t, intable) // routing update event should be emitted from the include state machine - dev, ok = routingBehaviour.Perform(ctx) require.True(t, ok) require.IsType(t, &EventRoutingUpdated{}, dev) + // drain any pending work + DrainBehaviour[BehaviourEvent, BehaviourEvent](t, ctx, routingBehaviour) + // advance time past the probe check interval clk.Add(probeCfg.CheckInterval) - // routing update event should be emitted from the include state machine + // probe should be sent for the node dev, ok = routingBehaviour.Perform(ctx) require.True(t, ok) require.IsType(t, &EventOutboundGetCloserNodes{}, dev) @@ -558,11 +567,12 @@ func TestRoutingExploreGetClosestNodesSuccess(t *testing.T) { CloserNodes: []kadt.PeerID{nodes[2].NodeID}, } routingBehaviour.Notify(ctx, ev) + routingBehaviour.Perform(ctx) // explore should receive message response event - require.IsType(t, &routing.EventExploreFindCloserResponse[kadt.Key, kadt.PeerID]{}, explore.Received) + require.IsType(t, &routing.EventExploreFindCloserResponse[kadt.Key, kadt.PeerID]{}, explore.first()) - rev := explore.Received.(*routing.EventExploreFindCloserResponse[kadt.Key, kadt.PeerID]) + rev := explore.first().(*routing.EventExploreFindCloserResponse[kadt.Key, kadt.PeerID]) require.True(t, nodes[1].NodeID.Equal(rev.NodeID)) require.Equal(t, ev.CloserNodes, rev.CloserNodes) } @@ -593,11 +603,12 @@ func TestRoutingExploreGetClosestNodesFailure(t *testing.T) { } routingBehaviour.Notify(ctx, ev) + routingBehaviour.Perform(ctx) // bootstrap should receive message response event - require.IsType(t, &routing.EventExploreFindCloserFailure[kadt.Key, kadt.PeerID]{}, explore.Received) + require.IsType(t, &routing.EventExploreFindCloserFailure[kadt.Key, kadt.PeerID]{}, explore.first()) - rev := explore.Received.(*routing.EventExploreFindCloserFailure[kadt.Key, kadt.PeerID]) + rev := explore.first().(*routing.EventExploreFindCloserFailure[kadt.Key, kadt.PeerID]) require.Equal(t, peer.ID(nodes[1].NodeID), peer.ID(rev.NodeID)) require.Equal(t, failure, rev.Error) }