diff --git a/core/core.go b/core/core.go index 30106c79413..36b4d825e16 100644 --- a/core/core.go +++ b/core/core.go @@ -26,6 +26,7 @@ import ( routing "github.com/jbenet/go-ipfs/routing" dht "github.com/jbenet/go-ipfs/routing/dht" u "github.com/jbenet/go-ipfs/util" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" ) var log = u.Logger("core") @@ -71,17 +72,16 @@ type IpfsNode struct { // the pinning manager Pinning pin.Pinner + + ctxc.ContextCloser } // NewIpfsNode constructs a new IpfsNode based on the given config. -func NewIpfsNode(cfg *config.Config, online bool) (*IpfsNode, error) { - // derive this from a higher context. - // cancel if we need to fail early. - ctx, cancel := context.WithCancel(context.TODO()) +func NewIpfsNode(cfg *config.Config, online bool) (n *IpfsNode, err error) { success := false // flip to true after all sub-system inits succeed defer func() { - if !success { - cancel() + if !success && n != nil { + n.Close() } }() @@ -89,101 +89,81 @@ func NewIpfsNode(cfg *config.Config, online bool) (*IpfsNode, error) { return nil, fmt.Errorf("configuration required") } - d, err := makeDatastore(cfg.Datastore) - if err != nil { + // derive this from a higher context. + ctx := context.TODO() + n = &IpfsNode{ + Config: cfg, + ContextCloser: ctxc.NewContextCloser(ctx, nil), + } + + // setup datastore. + if n.Datastore, err = makeDatastore(cfg.Datastore); err != nil { return nil, err } - peerstore := peer.NewPeerstore() - local, err := initIdentity(cfg, peerstore, online) + // setup peerstore + local peer identity + n.Peerstore = peer.NewPeerstore() + n.Identity, err = initIdentity(n.Config, n.Peerstore, online) if err != nil { return nil, err } - // FIXME(brian): This is a bit dangerous. If any of the vars declared in - // this block are assigned inside of the "if online" block using the ":=" - // declaration syntax, the compiler permits re-declaration. This is rather - // undesirable - var ( - net inet.Network - // TODO: refactor so we can use IpfsRouting interface instead of being DHT-specific - route *dht.IpfsDHT - exchangeSession exchange.Interface - diagnostics *diag.Diagnostics - network inet.Network - ) - + // setup online services if online { - dhtService := netservice.NewService(nil) // nil handler for now, need to patch it - exchangeService := netservice.NewService(nil) // nil handler for now, need to patch it - diagService := netservice.NewService(nil) - - if err := dhtService.Start(ctx); err != nil { - return nil, err - } - if err := exchangeService.Start(ctx); err != nil { - return nil, err - } - if err := diagService.Start(ctx); err != nil { - return nil, err - } + dhtService := netservice.NewService(ctx, nil) // nil handler for now, need to patch it + exchangeService := netservice.NewService(ctx, nil) // nil handler for now, need to patch it + diagService := netservice.NewService(ctx, nil) // nil handler for now, need to patch it - net, err = inet.NewIpfsNetwork(ctx, local, peerstore, &mux.ProtocolMap{ + muxMap := &mux.ProtocolMap{ mux.ProtocolID_Routing: dhtService, mux.ProtocolID_Exchange: exchangeService, mux.ProtocolID_Diagnostic: diagService, // add protocol services here. - }) + } + + // setup the network + n.Network, err = inet.NewIpfsNetwork(ctx, n.Identity, n.Peerstore, muxMap) if err != nil { return nil, err } - network = net + n.AddCloserChild(n.Network) - diagnostics = diag.NewDiagnostics(local, net, diagService) - diagService.SetHandler(diagnostics) + // setup diagnostics service + n.Diagnostics = diag.NewDiagnostics(n.Identity, n.Network, diagService) + diagService.SetHandler(n.Diagnostics) - route = dht.NewDHT(ctx, local, peerstore, net, dhtService, d) + // setup routing service + dhtRouting := dht.NewDHT(ctx, n.Identity, n.Peerstore, n.Network, dhtService, n.Datastore) // TODO(brian): perform this inside NewDHT factory method - dhtService.SetHandler(route) // wire the handler to the service. + dhtService.SetHandler(dhtRouting) // wire the handler to the service. + n.Routing = dhtRouting + n.AddCloserChild(dhtRouting) + // setup exchange service const alwaysSendToPeer = true // use YesManStrategy - exchangeSession = bitswap.NetMessageSession(ctx, local, net, exchangeService, route, d, alwaysSendToPeer) + n.Exchange = bitswap.NetMessageSession(ctx, n.Identity, n.Network, exchangeService, n.Routing, n.Datastore, alwaysSendToPeer) + // ok, this function call is ridiculous o/ consider making it simpler. - // TODO(brian): pass a context to initConnections - go initConnections(ctx, cfg, peerstore, route) + go initConnections(ctx, n.Config, n.Peerstore, dhtRouting) } // TODO(brian): when offline instantiate the BlockService with a bitswap // session that simply doesn't return blocks - bs, err := bserv.NewBlockService(d, exchangeSession) + n.Blocks, err = bserv.NewBlockService(n.Datastore, n.Exchange) if err != nil { return nil, err } - dag := merkledag.NewDAGService(bs) - ns := namesys.NewNameSystem(route) - p, err := pin.LoadPinner(d, dag) + n.DAG = merkledag.NewDAGService(n.Blocks) + n.Namesys = namesys.NewNameSystem(n.Routing) + n.Pinning, err = pin.LoadPinner(n.Datastore, n.DAG) if err != nil { - p = pin.NewPinner(d, dag) + n.Pinning = pin.NewPinner(n.Datastore, n.DAG) } success = true - return &IpfsNode{ - Config: cfg, - Peerstore: peerstore, - Datastore: d, - Blocks: bs, - DAG: dag, - Resolver: &path.Resolver{DAG: dag}, - Exchange: exchangeSession, - Identity: local, - Routing: route, - Namesys: ns, - Diagnostics: diagnostics, - Network: network, - Pinning: p, - }, nil + return n, nil } func initIdentity(cfg *config.Config, peers peer.Peerstore, online bool) (peer.Peer, error) { diff --git a/net/interface.go b/net/interface.go index bac53497552..11918231d9c 100644 --- a/net/interface.go +++ b/net/interface.go @@ -5,10 +5,12 @@ import ( mux "github.com/jbenet/go-ipfs/net/mux" srv "github.com/jbenet/go-ipfs/net/service" peer "github.com/jbenet/go-ipfs/peer" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" ) // Network is the interface IPFS uses for connecting to the world. type Network interface { + ctxc.ContextCloser // Listen handles incoming connections on given Multiaddr. // Listen(*ma.Muliaddr) error @@ -35,9 +37,6 @@ type Network interface { // SendMessage sends given Message out SendMessage(msg.NetMessage) error - - // Close terminates all network operation - Close() error } // Sender interface for network services. diff --git a/net/mux/mux.go b/net/mux/mux.go index e717e67fba8..a8865bb7393 100644 --- a/net/mux/mux.go +++ b/net/mux/mux.go @@ -7,6 +7,7 @@ import ( msg "github.com/jbenet/go-ipfs/net/message" pb "github.com/jbenet/go-ipfs/net/mux/internal/pb" u "github.com/jbenet/go-ipfs/util" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto" @@ -14,6 +15,8 @@ import ( var log = u.Logger("muxer") +// ProtocolIDs used to identify each protocol. +// These should probably be defined elsewhere. var ( ProtocolID_Routing = pb.ProtocolID_Routing ProtocolID_Exchange = pb.ProtocolID_Exchange @@ -38,11 +41,6 @@ type Muxer struct { // Protocols are the multiplexed services. Protocols ProtocolMap - // cancel is the function to stop the Muxer - cancel context.CancelFunc - ctx context.Context - wg sync.WaitGroup - bwiLock sync.Mutex bwIn uint64 @@ -50,45 +48,33 @@ type Muxer struct { bwOut uint64 *msg.Pipe + ctxc.ContextCloser } // NewMuxer constructs a muxer given a protocol map. -func NewMuxer(mp ProtocolMap) *Muxer { - return &Muxer{ - Protocols: mp, - Pipe: msg.NewPipe(10), +func NewMuxer(ctx context.Context, mp ProtocolMap) *Muxer { + m := &Muxer{ + Protocols: mp, + Pipe: msg.NewPipe(10), + ContextCloser: ctxc.NewContextCloser(ctx, nil), } -} - -// GetPipe implements the Protocol interface -func (m *Muxer) GetPipe() *msg.Pipe { - return m.Pipe -} -// Start kicks off the Muxer goroutines. -func (m *Muxer) Start(ctx context.Context) error { - if m == nil { - panic("nix muxer") - } - - if m.cancel != nil { - return errors.New("Muxer already started.") - } - - // make a cancellable context. - m.ctx, m.cancel = context.WithCancel(ctx) - m.wg = sync.WaitGroup{} - - m.wg.Add(1) + m.Children().Add(1) go m.handleIncomingMessages() for pid, proto := range m.Protocols { - m.wg.Add(1) + m.Children().Add(1) go m.handleOutgoingMessages(pid, proto) } - return nil + return m } +// GetPipe implements the Protocol interface +func (m *Muxer) GetPipe() *msg.Pipe { + return m.Pipe +} + +// GetBandwidthTotals return the in/out bandwidth measured over this muxer. func (m *Muxer) GetBandwidthTotals() (in uint64, out uint64) { m.bwiLock.Lock() in = m.bwIn @@ -100,19 +86,6 @@ func (m *Muxer) GetBandwidthTotals() (in uint64, out uint64) { return } -// Stop stops muxer activity. -func (m *Muxer) Stop() { - if m.cancel == nil { - panic("muxer stopped twice.") - } - // issue cancel, and wipe func. - m.cancel() - m.cancel = context.CancelFunc(nil) - - // wait for everything to wind down. - m.wg.Wait() -} - // AddProtocol adds a Protocol with given ProtocolID to the Muxer. func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error { if _, found := m.Protocols[pid]; found { @@ -126,28 +99,26 @@ func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error { // handleIncoming consumes the messages on the m.Incoming channel and // routes them appropriately (to the protocols). func (m *Muxer) handleIncomingMessages() { - defer m.wg.Done() + defer m.Children().Done() for { - if m == nil { - panic("nil muxer") - } - select { + case <-m.Closing(): + return + case msg, more := <-m.Incoming: if !more { return } + m.Children().Add(1) go m.handleIncomingMessage(msg) - - case <-m.ctx.Done(): - return } } } // handleIncomingMessage routes message to the appropriate protocol. func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { + defer m.Children().Done() m.bwiLock.Lock() // TODO: compensate for overhead @@ -169,8 +140,7 @@ func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { select { case proto.GetPipe().Incoming <- m2: - case <-m.ctx.Done(): - log.Error(m.ctx.Err()) + case <-m.Closing(): return } } @@ -178,7 +148,7 @@ func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { // handleOutgoingMessages consumes the messages on the proto.Outgoing channel, // wraps them and sends them out. func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) { - defer m.wg.Done() + defer m.Children().Done() for { select { @@ -186,9 +156,10 @@ func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) { if !more { return } + m.Children().Add(1) go m.handleOutgoingMessage(pid, msg) - case <-m.ctx.Done(): + case <-m.Closing(): return } } @@ -196,6 +167,8 @@ func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) { // handleOutgoingMessage wraps out a message and sends it out the func (m *Muxer) handleOutgoingMessage(pid pb.ProtocolID, m1 msg.NetMessage) { + defer m.Children().Done() + data, err := wrapData(m1.Data(), pid) if err != nil { log.Errorf("muxer serializing error: %v", err) @@ -204,13 +177,14 @@ func (m *Muxer) handleOutgoingMessage(pid pb.ProtocolID, m1 msg.NetMessage) { m.bwoLock.Lock() // TODO: compensate for overhead + // TODO(jbenet): switch this to a goroutine to prevent sync waiting. m.bwOut += uint64(len(data)) m.bwoLock.Unlock() m2 := msg.New(m1.Peer(), data) select { case m.GetPipe().Outgoing <- m2: - case <-m.ctx.Done(): + case <-m.Closing(): return } } diff --git a/net/mux/mux_test.go b/net/mux/mux_test.go index 72187893b7e..3b0235820e6 100644 --- a/net/mux/mux_test.go +++ b/net/mux/mux_test.go @@ -54,23 +54,20 @@ func testWrappedMsg(t *testing.T, m msg.NetMessage, pid pb.ProtocolID, data []by } func TestSimpleMuxer(t *testing.T) { + ctx := context.Background() // setup p1 := &TestProtocol{Pipe: msg.NewPipe(10)} p2 := &TestProtocol{Pipe: msg.NewPipe(10)} pid1 := pb.ProtocolID_Test pid2 := pb.ProtocolID_Routing - mux1 := NewMuxer(ProtocolMap{ + mux1 := NewMuxer(ctx, ProtocolMap{ pid1: p1, pid2: p2, }) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") - // run muxer - ctx := context.Background() - mux1.Start(ctx) - // test outgoing p1 for _, s := range []string{"foo", "bar", "baz"} { p1.Outgoing <- msg.New(peer1, []byte(s)) @@ -105,23 +102,21 @@ func TestSimpleMuxer(t *testing.T) { } func TestSimultMuxer(t *testing.T) { + // run muxer + ctx, cancel := context.WithCancel(context.Background()) // setup p1 := &TestProtocol{Pipe: msg.NewPipe(10)} p2 := &TestProtocol{Pipe: msg.NewPipe(10)} pid1 := pb.ProtocolID_Test pid2 := pb.ProtocolID_Identify - mux1 := NewMuxer(ProtocolMap{ + mux1 := NewMuxer(ctx, ProtocolMap{ pid1: p1, pid2: p2, }) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") - // run muxer - ctx, cancel := context.WithCancel(context.Background()) - mux1.Start(ctx) - // counts total := 10000 speed := time.Microsecond * 1 @@ -214,22 +209,20 @@ func TestSimultMuxer(t *testing.T) { } func TestStopping(t *testing.T) { + ctx := context.Background() // setup p1 := &TestProtocol{Pipe: msg.NewPipe(10)} p2 := &TestProtocol{Pipe: msg.NewPipe(10)} pid1 := pb.ProtocolID_Test pid2 := pb.ProtocolID_Identify - mux1 := NewMuxer(ProtocolMap{ + mux1 := NewMuxer(ctx, ProtocolMap{ pid1: p1, pid2: p2, }) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") - // run muxer - mux1.Start(context.Background()) - // test outgoing p1 for _, s := range []string{"foo1", "bar1", "baz1"} { p1.Outgoing <- msg.New(peer1, []byte(s)) @@ -246,10 +239,7 @@ func TestStopping(t *testing.T) { testMsg(t, <-p1.Incoming, []byte(s)) } - mux1.Stop() - if mux1.cancel != nil { - t.Error("mux.cancel should be nil") - } + mux1.Close() // waits // test outgoing p1 for _, s := range []string{"foo3", "bar3", "baz3"} { @@ -274,5 +264,4 @@ func TestStopping(t *testing.T) { case <-time.After(time.Millisecond): } } - } diff --git a/net/net.go b/net/net.go index de433546a8b..a6154a7802a 100644 --- a/net/net.go +++ b/net/net.go @@ -1,12 +1,11 @@ package net import ( - "errors" - msg "github.com/jbenet/go-ipfs/net/message" mux "github.com/jbenet/go-ipfs/net/mux" swarm "github.com/jbenet/go-ipfs/net/swarm" peer "github.com/jbenet/go-ipfs/peer" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" ) @@ -23,36 +22,30 @@ type IpfsNetwork struct { // peer connection multiplexing swarm *swarm.Swarm - // network context - ctx context.Context - cancel context.CancelFunc + // network context closer + ctxc.ContextCloser } // NewIpfsNetwork is the structure that implements the network interface func NewIpfsNetwork(ctx context.Context, local peer.Peer, peers peer.Peerstore, pmap *mux.ProtocolMap) (*IpfsNetwork, error) { - ctx, cancel := context.WithCancel(ctx) - in := &IpfsNetwork{ - local: local, - muxer: mux.NewMuxer(*pmap), - ctx: ctx, - cancel: cancel, - } - - err := in.muxer.Start(ctx) - if err != nil { - cancel() - return nil, err + local: local, + muxer: mux.NewMuxer(ctx, *pmap), + ContextCloser: ctxc.NewContextCloser(ctx, nil), } + var err error in.swarm, err = swarm.NewSwarm(ctx, local, peers) if err != nil { - cancel() + in.Close() return nil, err } + in.AddCloserChild(in.swarm) + in.AddCloserChild(in.muxer) + // remember to wire components together. in.muxer.Pipe.ConnectTo(in.swarm.Pipe) @@ -94,20 +87,6 @@ func (n *IpfsNetwork) SendMessage(m msg.NetMessage) error { return nil } -// Close terminates all network operation -func (n *IpfsNetwork) Close() error { - if n.cancel == nil { - return errors.New("Network already closed.") - } - - n.swarm.Close() - n.muxer.Stop() - - n.cancel() - n.cancel = nil - return nil -} - // GetPeerList returns the networks list of connected peers func (n *IpfsNetwork) GetPeerList() []peer.Peer { return n.swarm.GetPeerList() diff --git a/net/service/service.go b/net/service/service.go index 18ddc00b683..deca061fb14 100644 --- a/net/service/service.go +++ b/net/service/service.go @@ -2,10 +2,12 @@ package service import ( "errors" + "fmt" "sync" msg "github.com/jbenet/go-ipfs/net/message" u "github.com/jbenet/go-ipfs/util" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" ) @@ -39,10 +41,7 @@ type Sender interface { // incomig (SetHandler) requests. type Service interface { Sender - - // Start + Stop Service - Start(ctx context.Context) error - Stop() + ctxc.ContextCloser // GetPipe GetPipe() *msg.Pipe @@ -56,45 +55,30 @@ type Service interface { // messages over the same channel, and to issue + handle requests. type service struct { // Handler is the object registered to handle incoming requests. - Handler Handler + Handler Handler + HandlerLock sync.RWMutex // Requests are all the pending requests on this service. Requests RequestMap RequestsLock sync.RWMutex - // cancel is the function to stop the Service - cancel context.CancelFunc - // Message Pipe (connected to the outside world) *msg.Pipe + ctxc.ContextCloser } // NewService creates a service object with given type ID and Handler -func NewService(h Handler) Service { - return &service{ - Handler: h, - Requests: RequestMap{}, - Pipe: msg.NewPipe(10), - } -} - -// Start kicks off the Service goroutines. -func (s *service) Start(ctx context.Context) error { - if s.cancel != nil { - return errors.New("Service already started.") +func NewService(ctx context.Context, h Handler) Service { + s := &service{ + Handler: h, + Requests: RequestMap{}, + Pipe: msg.NewPipe(10), + ContextCloser: ctxc.NewContextCloser(ctx, nil), } - // make a cancellable context. - ctx, s.cancel = context.WithCancel(ctx) - - go s.handleIncomingMessages(ctx) - return nil -} - -// Stop stops Service activity. -func (s *service) Stop() { - s.cancel() - s.cancel = context.CancelFunc(nil) + s.Children().Add(1) + go s.handleIncomingMessages() + return s } // GetPipe implements the mux.Protocol interface @@ -132,6 +116,15 @@ func (s *service) SendMessage(ctx context.Context, m msg.NetMessage) error { // SendRequest sends a request message out and awaits a response. func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMessage, error) { + // check if we should bail given our contexts + select { + default: + case <-s.Closing(): + return nil, fmt.Errorf("service closed: %s", s.Context().Err()) + case <-ctx.Done(): + return nil, ctx.Err() + } + // create a request r, err := NewRequest(m.Peer().ID()) if err != nil { @@ -153,6 +146,8 @@ func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMes // check if we should bail after waiting for mutex select { default: + case <-s.Closing(): + return nil, fmt.Errorf("service closed: %s", s.Context().Err()) case <-ctx.Done(): return nil, ctx.Err() } @@ -165,6 +160,8 @@ func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMes err = nil select { case m = <-r.Response: + case <-s.Closed(): + err = fmt.Errorf("service closed: %s", s.Context().Err()) case <-ctx.Done(): err = ctx.Err() } @@ -178,43 +175,50 @@ func (s *service) SendRequest(ctx context.Context, m msg.NetMessage) (msg.NetMes // handleIncoming consumes the messages on the s.Incoming channel and // routes them appropriately (to requests, or handler). -func (s *service) handleIncomingMessages(ctx context.Context) { +func (s *service) handleIncomingMessages() { + defer s.Children().Done() + for { select { case m, more := <-s.Incoming: if !more { return } - go s.handleIncomingMessage(ctx, m) + s.Children().Add(1) + go s.handleIncomingMessage(m) - case <-ctx.Done(): + case <-s.Closing(): return } } } -func (s *service) handleIncomingMessage(ctx context.Context, m msg.NetMessage) { +func (s *service) handleIncomingMessage(m msg.NetMessage) { + defer s.Children().Done() // unwrap the incoming message data, rid, err := unwrapData(m.Data()) if err != nil { - log.Errorf("de-serializing error: %v", err) + log.Errorf("service de-serializing error: %v", err) + return } + m2 := msg.New(m.Peer(), data) // if it's a request (or has no RequestID), handle it if rid == nil || rid.IsRequest() { - if s.Handler == nil { + handler := s.GetHandler() + if handler == nil { log.Errorf("service dropped msg: %v", m) return // no handler, drop it. } // should this be "go HandleMessage ... ?" - r1 := s.Handler.HandleMessage(ctx, m2) + r1 := handler.HandleMessage(s.Context(), m2) // if handler gave us a response, send it back out! if r1 != nil { - err := s.sendMessage(ctx, r1, rid.Response()) + err := s.sendMessage(s.Context(), r1, rid.Response()) if err != nil { log.Errorf("error sending response message: %v", err) } @@ -239,16 +243,20 @@ func (s *service) handleIncomingMessage(ctx context.Context, m msg.NetMessage) { select { case r.Response <- m2: - case <-ctx.Done(): + case <-s.Closing(): } } // SetHandler assigns the request Handler for this service. func (s *service) SetHandler(h Handler) { + s.HandlerLock.Lock() + defer s.HandlerLock.Unlock() s.Handler = h } // GetHandler returns the request Handler for this service. func (s *service) GetHandler() Handler { + s.HandlerLock.RLock() + defer s.HandlerLock.RUnlock() return s.Handler } diff --git a/net/service/service_test.go b/net/service/service_test.go index ddcd93b8929..7152289aafc 100644 --- a/net/service/service_test.go +++ b/net/service/service_test.go @@ -38,13 +38,9 @@ func newPeer(t *testing.T, id string) peer.Peer { func TestServiceHandler(t *testing.T) { ctx := context.Background() h := &ReverseHandler{} - s := NewService(h) + s := NewService(ctx, h) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") - if err := s.Start(ctx); err != nil { - t.Error(err) - } - d, err := wrapData([]byte("beep"), nil) if err != nil { t.Error(err) @@ -70,16 +66,8 @@ func TestServiceHandler(t *testing.T) { func TestServiceRequest(t *testing.T) { ctx := context.Background() - s1 := NewService(&ReverseHandler{}) - s2 := NewService(&ReverseHandler{}) - - if err := s1.Start(ctx); err != nil { - t.Error(err) - } - - if err := s2.Start(ctx); err != nil { - t.Error(err) - } + s1 := NewService(ctx, &ReverseHandler{}) + s2 := NewService(ctx, &ReverseHandler{}) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") @@ -110,18 +98,10 @@ func TestServiceRequest(t *testing.T) { func TestServiceRequestTimeout(t *testing.T) { ctx, _ := context.WithTimeout(context.Background(), time.Millisecond) - s1 := NewService(&ReverseHandler{}) - s2 := NewService(&ReverseHandler{}) + s1 := NewService(ctx, &ReverseHandler{}) + s2 := NewService(ctx, &ReverseHandler{}) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") - if err := s1.Start(ctx); err != nil { - t.Error(err) - } - - if err := s2.Start(ctx); err != nil { - t.Error(err) - } - // patch services together go func() { for { @@ -143,3 +123,41 @@ func TestServiceRequestTimeout(t *testing.T) { t.Error("should've timed out") } } + +func TestServiceClose(t *testing.T) { + ctx := context.Background() + s1 := NewService(ctx, &ReverseHandler{}) + s2 := NewService(ctx, &ReverseHandler{}) + + peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") + + // patch services together + go func() { + for { + select { + case m := <-s1.GetPipe().Outgoing: + s2.GetPipe().Incoming <- m + case m := <-s2.GetPipe().Outgoing: + s1.GetPipe().Incoming <- m + case <-ctx.Done(): + return + } + } + }() + + m1 := msg.New(peer1, []byte("beep")) + m2, err := s1.SendRequest(ctx, m1) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(m2.Data(), []byte("peeb")) { + t.Errorf("service handler data incorrect: %v != %v", m2.Data(), "oof") + } + + s1.Close() + s2.Close() + + <-s1.Closed() + <-s2.Closed() +} diff --git a/routing/dht/dht.go b/routing/dht/dht.go index fdb9f96f229..76cde7fb54e 100644 --- a/routing/dht/dht.go +++ b/routing/dht/dht.go @@ -14,6 +14,7 @@ import ( pb "github.com/jbenet/go-ipfs/routing/dht/pb" kb "github.com/jbenet/go-ipfs/routing/kbucket" u "github.com/jbenet/go-ipfs/util" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" ds "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-datastore" @@ -56,7 +57,7 @@ type IpfsDHT struct { //lock to make diagnostics work better diaglock sync.Mutex - ctx context.Context + ctxc.ContextCloser } // NewDHT creates a new DHT object with the given peer as the 'local' host @@ -67,9 +68,10 @@ func NewDHT(ctx context.Context, p peer.Peer, ps peer.Peerstore, dialer inet.Dia dht.datastore = dstore dht.self = p dht.peerstore = ps - dht.ctx = ctx + dht.ContextCloser = ctxc.NewContextCloser(ctx, nil) - dht.providers = NewProviderManager(p.ID()) + dht.providers = NewProviderManager(dht.Context(), p.ID()) + dht.AddCloserChild(dht.providers) dht.routingTables = make([]*kb.RoutingTable, 3) dht.routingTables[0] = kb.NewRoutingTable(20, kb.ConvertPeerID(p.ID()), time.Millisecond*1000) @@ -78,6 +80,7 @@ func NewDHT(ctx context.Context, p peer.Peer, ps peer.Peerstore, dialer inet.Dia dht.birth = time.Now() if doPinging { + dht.Children().Add(1) go dht.PingRoutine(time.Second * 10) } return dht @@ -516,6 +519,8 @@ func (dht *IpfsDHT) loadProvidableKeys() error { // PingRoutine periodically pings nearest neighbors. func (dht *IpfsDHT) PingRoutine(t time.Duration) { + defer dht.Children().Done() + tick := time.Tick(t) for { select { @@ -524,13 +529,13 @@ func (dht *IpfsDHT) PingRoutine(t time.Duration) { rand.Read(id) peers := dht.routingTables[0].NearestPeers(kb.ConvertKey(u.Key(id)), 5) for _, p := range peers { - ctx, _ := context.WithTimeout(dht.ctx, time.Second*5) + ctx, _ := context.WithTimeout(dht.Context(), time.Second*5) err := dht.Ping(ctx, p) if err != nil { log.Errorf("Ping error: %s", err) } } - case <-dht.ctx.Done(): + case <-dht.Closing(): return } } diff --git a/routing/dht/dht_test.go b/routing/dht/dht_test.go index 507db4eec58..2b873233813 100644 --- a/routing/dht/dht_test.go +++ b/routing/dht/dht_test.go @@ -23,11 +23,7 @@ import ( func setupDHT(ctx context.Context, t *testing.T, p peer.Peer) *IpfsDHT { peerstore := peer.NewPeerstore() - dhts := netservice.NewService(nil) // nil handler for now, need to patch it - if err := dhts.Start(ctx); err != nil { - t.Fatal(err) - } - + dhts := netservice.NewService(ctx, nil) // nil handler for now, need to patch it net, err := inet.NewIpfsNetwork(ctx, p, peerstore, &mux.ProtocolMap{ mux.ProtocolID_Routing: dhts, }) @@ -96,8 +92,8 @@ func TestPing(t *testing.T) { dhtA := setupDHT(ctx, t, peerA) dhtB := setupDHT(ctx, t, peerB) - defer dhtA.Halt() - defer dhtB.Halt() + defer dhtA.Close() + defer dhtB.Close() defer dhtA.dialer.(inet.Network).Close() defer dhtB.dialer.(inet.Network).Close() @@ -140,8 +136,8 @@ func TestValueGetSet(t *testing.T) { dhtA := setupDHT(ctx, t, peerA) dhtB := setupDHT(ctx, t, peerB) - defer dhtA.Halt() - defer dhtB.Halt() + defer dhtA.Close() + defer dhtB.Close() defer dhtA.dialer.(inet.Network).Close() defer dhtB.dialer.(inet.Network).Close() @@ -183,7 +179,7 @@ func TestProvides(t *testing.T) { _, peers, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { - dhts[i].Halt() + dhts[i].Close() defer dhts[i].dialer.(inet.Network).Close() } }() @@ -243,7 +239,7 @@ func TestProvidesAsync(t *testing.T) { _, peers, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { - dhts[i].Halt() + dhts[i].Close() defer dhts[i].dialer.(inet.Network).Close() } }() @@ -306,7 +302,7 @@ func TestLayeredGet(t *testing.T) { _, peers, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { - dhts[i].Halt() + dhts[i].Close() defer dhts[i].dialer.(inet.Network).Close() } }() @@ -359,7 +355,7 @@ func TestFindPeer(t *testing.T) { _, peers, dhts := setupDHTS(ctx, 4, t) defer func() { for i := 0; i < 4; i++ { - dhts[i].Halt() + dhts[i].Close() dhts[i].dialer.(inet.Network).Close() } }() @@ -447,8 +443,8 @@ func TestConnectCollision(t *testing.T) { t.Fatal("Timeout received!") } - dhtA.Halt() - dhtB.Halt() + dhtA.Close() + dhtB.Close() dhtA.dialer.(inet.Network).Close() dhtB.dialer.(inet.Network).Close() diff --git a/routing/dht/handlers.go b/routing/dht/handlers.go index d5db8d1da9f..fe628eeef57 100644 --- a/routing/dht/handlers.go +++ b/routing/dht/handlers.go @@ -205,9 +205,3 @@ func (dht *IpfsDHT) handleAddProvider(p peer.Peer, pmes *pb.Message) (*pb.Messag return pmes, nil // send back same msg as confirmation. } - -// Halt stops all communications from this peer and shut down -// TODO -- remove this in favor of context -func (dht *IpfsDHT) Halt() { - dht.providers.Halt() -} diff --git a/routing/dht/providers.go b/routing/dht/providers.go index 204fdf7d5da..f7d491d6a71 100644 --- a/routing/dht/providers.go +++ b/routing/dht/providers.go @@ -5,6 +5,9 @@ import ( peer "github.com/jbenet/go-ipfs/peer" u "github.com/jbenet/go-ipfs/util" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" ) type ProviderManager struct { @@ -14,8 +17,8 @@ type ProviderManager struct { getlocal chan chan []u.Key newprovs chan *addProv getprovs chan *getProv - halt chan struct{} period time.Duration + ctxc.ContextCloser } type addProv struct { @@ -28,19 +31,24 @@ type getProv struct { resp chan []peer.Peer } -func NewProviderManager(local peer.ID) *ProviderManager { +func NewProviderManager(ctx context.Context, local peer.ID) *ProviderManager { pm := new(ProviderManager) pm.getprovs = make(chan *getProv) pm.newprovs = make(chan *addProv) pm.providers = make(map[u.Key][]*providerInfo) pm.getlocal = make(chan chan []u.Key) pm.local = make(map[u.Key]struct{}) - pm.halt = make(chan struct{}) + pm.ContextCloser = ctxc.NewContextCloser(ctx, nil) + + pm.Children().Add(1) go pm.run() + return pm } func (pm *ProviderManager) run() { + defer pm.Children().Done() + tick := time.NewTicker(time.Hour) for { select { @@ -53,6 +61,7 @@ func (pm *ProviderManager) run() { pi.Value = np.val arr := pm.providers[np.k] pm.providers[np.k] = append(arr, pi) + case gp := <-pm.getprovs: var parr []peer.Peer provs := pm.providers[gp.k] @@ -60,12 +69,14 @@ func (pm *ProviderManager) run() { parr = append(parr, p.Value) } gp.resp <- parr + case lc := <-pm.getlocal: var keys []u.Key for k, _ := range pm.local { keys = append(keys, k) } lc <- keys + case <-tick.C: for k, provs := range pm.providers { var filtered []*providerInfo @@ -76,7 +87,8 @@ func (pm *ProviderManager) run() { } pm.providers[k] = filtered } - case <-pm.halt: + + case <-pm.Closing(): return } } @@ -102,7 +114,3 @@ func (pm *ProviderManager) GetLocal() []u.Key { pm.getlocal <- resp return <-resp } - -func (pm *ProviderManager) Halt() { - pm.halt <- struct{}{} -} diff --git a/routing/dht/providers_test.go b/routing/dht/providers_test.go index b37327d2e7e..c4ae53910a8 100644 --- a/routing/dht/providers_test.go +++ b/routing/dht/providers_test.go @@ -5,16 +5,19 @@ import ( "github.com/jbenet/go-ipfs/peer" u "github.com/jbenet/go-ipfs/util" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" ) func TestProviderManager(t *testing.T) { + ctx := context.Background() mid := peer.ID("testing") - p := NewProviderManager(mid) + p := NewProviderManager(ctx, mid) a := u.Key("test") p.AddProvider(a, peer.WithIDString("testingprovider")) resp := p.GetProviders(a) if len(resp) != 1 { t.Fatal("Could not retrieve provider.") } - p.Halt() + p.Close() } diff --git a/util/ctxcloser/closer.go b/util/ctxcloser/closer.go index e04178c2444..ca80368eae2 100644 --- a/util/ctxcloser/closer.go +++ b/util/ctxcloser/closer.go @@ -9,6 +9,8 @@ import ( // CloseFunc is a function used to close a ContextCloser type CloseFunc func() error +var nilCloseFunc = func() error { return nil } + // ContextCloser is an interface for services able to be opened and closed. // It has a parent Context, and Children. But ContextCloser is not a proper // "tree" like the Context tree. It is more like a Context-WaitGroup hybrid. @@ -48,10 +50,25 @@ type ContextCloser interface { // Children is a sync.Waitgroup for all children goroutines that should // shut down completely before this service is said to be "closed". // Follows the semantics of WaitGroup: + // // Children().Add(1) // add one more dependent child // Children().Done() // child signals it is done + // Children() *sync.WaitGroup + // AddCloserChild registers a dependent ContextCloser child. The child will + // be closed when this parent is closed, and waited upon to finish. It is + // the functional equivalent of the following: + // + // go func(parent, child ContextCloser) { + // parent.Children().Add(1) // add one more dependent child + // <-parent.Closing() // wait until parent is closing + // child.Close() // signal child to close + // parent.Children().Done() // child signals it is done + // }(a, b) + // + AddCloserChild(c ContextCloser) + // Close is a method to call when you wish to stop this ContextCloser Close() error @@ -92,6 +109,9 @@ type contextCloser struct { // NewContextCloser constructs and returns a ContextCloser. It will call // cf CloseFunc before its Done() Wait signals fire. func NewContextCloser(ctx context.Context, cf CloseFunc) ContextCloser { + if cf == nil { + cf = nilCloseFunc + } ctx, cancel := context.WithCancel(ctx) c := &contextCloser{ ctx: ctx, @@ -112,6 +132,15 @@ func (c *contextCloser) Children() *sync.WaitGroup { return &c.children } +func (c *contextCloser) AddCloserChild(child ContextCloser) { + c.children.Add(1) + go func(parent, child ContextCloser) { + <-parent.Closing() // wait until parent is closing + child.Close() // signal child to close + parent.Children().Done() // child signals it is done + }(c, child) +} + // Close is the external close function. it's a wrapper around internalClose // that waits on Closed() func (c *contextCloser) Close() error {