From 025fb3314677d49c46b9e4c94b8172f9f93b68d9 Mon Sep 17 00:00:00 2001 From: sukun Date: Sat, 10 Aug 2024 20:56:50 +0530 Subject: [PATCH] basic_host: close swarm on Close Using the `BasicHost` constructor transfers the ownership of the swarm. This is similar to how using `libp2p.New` transfers the ownership of user provided config options like `ResourceManager`, all of which are closed on `host.Close` --- config/config.go | 20 +++++++++----------- p2p/host/basic/basic_host.go | 4 ++++ p2p/host/basic/basic_host_test.go | 9 ++++++++- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/config/config.go b/config/config.go index 4b08c076ff..d0a71664f7 100644 --- a/config/config.go +++ b/config/config.go @@ -368,7 +368,7 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { fxopts = append(fxopts, cfg.QUICReuse...) } else { fxopts = append(fxopts, - fx.Provide(func(key quic.StatelessResetKey, tokenGenerator quic.TokenGeneratorKey, _ *swarm.Swarm, lifecycle fx.Lifecycle) (*quicreuse.ConnManager, error) { + fx.Provide(func(key quic.StatelessResetKey, tokenGenerator quic.TokenGeneratorKey, lifecycle fx.Lifecycle) (*quicreuse.ConnManager, error) { var opts []quicreuse.Option if !cfg.DisableMetrics { opts = append(opts, quicreuse.EnableMetrics(cfg.PrometheusRegisterer)) @@ -469,18 +469,17 @@ func (cfg *Config) NewNode() (host.Host, error) { fx.Provide(func() event.Bus { return eventbus.NewBus(eventbus.WithMetricsTracer(eventbus.NewMetricsTracer(eventbus.WithRegisterer(cfg.PrometheusRegisterer)))) }), - fx.Provide(func(eventBus event.Bus, lifecycle fx.Lifecycle) (*swarm.Swarm, error) { - sw, err := cfg.makeSwarm(eventBus, !cfg.DisableMetrics) - if err != nil { - return nil, err - } - lifecycle.Append(fx.StopHook(sw.Close)) - return sw, nil + fx.Provide(func() crypto.PrivKey { + return cfg.PeerKey }), // Make sure the swarm constructor depends on the quicreuse.ConnManager. // That way, the ConnManager will be started before the swarm, and more importantly, // the swarm will be stopped before the ConnManager. - fx.Decorate(func(sw *swarm.Swarm, _ *quicreuse.ConnManager, lifecycle fx.Lifecycle) *swarm.Swarm { + fx.Provide(func(eventBus event.Bus, _ *quicreuse.ConnManager, lifecycle fx.Lifecycle) (*swarm.Swarm, error) { + sw, err := cfg.makeSwarm(eventBus, !cfg.DisableMetrics) + if err != nil { + return nil, err + } lifecycle.Append(fx.Hook{ OnStart: func(context.Context) error { // TODO: This method succeeds if listening on one address succeeds. We @@ -491,14 +490,13 @@ func (cfg *Config) NewNode() (host.Host, error) { return sw.Close() }, }) - return sw + return sw, nil }), fx.Provide(cfg.newBasicHost), fx.Provide(func(bh *bhost.BasicHost) host.Host { return bh }), fx.Provide(func(h *swarm.Swarm) peer.ID { return h.LocalPeer() }), - fx.Provide(func(h *swarm.Swarm) crypto.PrivKey { return h.Peerstore().PrivKey(h.LocalPeer()) }), } transportOpts, err := cfg.addTransports() if err != nil { diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 6d8ee06d5e..7b7f8855fb 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -1080,6 +1080,10 @@ func (h *BasicHost) Close() error { _ = h.emitters.evtLocalProtocolsUpdated.Close() _ = h.emitters.evtLocalAddrsUpdated.Close() + if err := h.network.Close(); err != nil { + log.Errorf("swarm close failed: %v", err) + } + h.psManager.Close() if h.Peerstore() != nil { h.Peerstore().Close() diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index c9cff26a9a..ba3eac0cd5 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -83,7 +83,14 @@ func TestMultipleClose(t *testing.T) { require.NoError(t, h.Close()) require.NoError(t, h.Close()) - require.NoError(t, h.Close()) + h2, err := NewHost(swarmt.GenSwarm(t), nil) + defer h2.Close() + require.Error(t, h.Connect(context.Background(), peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})) + h.Network().Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.PermanentAddrTTL) + _, err = h.NewStream(context.Background(), h2.ID()) + require.Error(t, err) + require.Empty(t, h.Addrs()) + require.Empty(t, h.AllAddrs()) } func TestSignedPeerRecordWithNoListenAddrs(t *testing.T) {