diff --git a/cmd/client/main.go b/cmd/client/main.go index c39ca87..50889cf 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -38,7 +38,7 @@ func run(raddr string, p string) error { return err } - t, err := libp2pquic.NewTransport(priv, nil, nil) + t, err := libp2pquic.NewTransport(priv, nil, nil, nil) if err != nil { return err } diff --git a/cmd/server/main.go b/cmd/server/main.go index ec995bd..cb5b528 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -38,7 +38,7 @@ func run(port string) error { return err } - t, err := libp2pquic.NewTransport(priv, nil, nil) + t, err := libp2pquic.NewTransport(priv, nil, nil, nil) if err != nil { return err } diff --git a/conn.go b/conn.go index 4f74d25..94f695a 100644 --- a/conn.go +++ b/conn.go @@ -5,16 +5,18 @@ import ( ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/mux" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" ) type conn struct { sess quic.Session transport tpt.Transport + rcmgr network.ResourceManager localPeer peer.ID privKey ic.PrivKey @@ -38,14 +40,28 @@ func (c *conn) IsClosed() bool { // OpenStream creates a new stream. func (c *conn) OpenStream(ctx context.Context) (mux.MuxedStream, error) { + scope, err := c.rcmgr.OpenStream(c.remotePeerID, network.DirOutbound) + if err != nil { + return nil, err + } qstr, err := c.sess.OpenStreamSync(ctx) + if err != nil { + scope.Done() + return nil, err + } return &stream{Stream: qstr}, err } // AcceptStream accepts a stream opened by the other side. func (c *conn) AcceptStream() (mux.MuxedStream, error) { qstr, err := c.sess.AcceptStream(context.Background()) - return &stream{Stream: qstr}, err + if err != nil { + return nil, err + } + if _, err := c.rcmgr.OpenStream(c.remotePeerID, network.DirInbound); err != nil { + return nil, err + } + return &stream{Stream: qstr}, nil } // LocalPeer returns our peer ID diff --git a/conn_test.go b/conn_test.go index 82e2d62..cf5789b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -69,13 +69,13 @@ var _ = Describe("Connection", func() { }) It("handshakes on IPv4", func() { - serverTransport, err := NewTransport(serverKey, nil, nil) + serverTransport, err := NewTransport(serverKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer serverTransport.(io.Closer).Close() ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() - clientTransport, err := NewTransport(clientKey, nil, nil) + clientTransport, err := NewTransport(clientKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer clientTransport.(io.Closer).Close() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) @@ -95,13 +95,13 @@ var _ = Describe("Connection", func() { }) It("handshakes on IPv6", func() { - serverTransport, err := NewTransport(serverKey, nil, nil) + serverTransport, err := NewTransport(serverKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer serverTransport.(io.Closer).Close() ln := runServer(serverTransport, "/ip6/::1/udp/0/quic") defer ln.Close() - clientTransport, err := NewTransport(clientKey, nil, nil) + clientTransport, err := NewTransport(clientKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer clientTransport.(io.Closer).Close() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) @@ -121,13 +121,13 @@ var _ = Describe("Connection", func() { }) It("opens and accepts streams", func() { - serverTransport, err := NewTransport(serverKey, nil, nil) + serverTransport, err := NewTransport(serverKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer serverTransport.(io.Closer).Close() ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() - clientTransport, err := NewTransport(clientKey, nil, nil) + clientTransport, err := NewTransport(clientKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer clientTransport.(io.Closer).Close() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) @@ -152,12 +152,12 @@ var _ = Describe("Connection", func() { It("fails if the peer ID doesn't match", func() { thirdPartyID, _ := createPeer() - serverTransport, err := NewTransport(serverKey, nil, nil) + serverTransport, err := NewTransport(serverKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer serverTransport.(io.Closer).Close() ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") - clientTransport, err := NewTransport(clientKey, nil, nil) + clientTransport, err := NewTransport(clientKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) // dial, but expect the wrong peer ID _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID) @@ -179,7 +179,7 @@ var _ = Describe("Connection", func() { It("gates accepted connections", func() { cg := NewMockConnectionGater(mockCtrl) cg.EXPECT().InterceptAccept(gomock.Any()) - serverTransport, err := NewTransport(serverKey, nil, cg) + serverTransport, err := NewTransport(serverKey, nil, cg, nil) Expect(err).ToNot(HaveOccurred()) defer serverTransport.(io.Closer).Close() ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") @@ -193,7 +193,7 @@ var _ = Describe("Connection", func() { Expect(err).ToNot(HaveOccurred()) }() - clientTransport, err := NewTransport(clientKey, nil, nil) + clientTransport, err := NewTransport(clientKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer clientTransport.(io.Closer).Close() // make sure that connection attempts fails @@ -214,7 +214,7 @@ var _ = Describe("Connection", func() { }) It("gates secured connections", func() { - serverTransport, err := NewTransport(serverKey, nil, nil) + serverTransport, err := NewTransport(serverKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer serverTransport.(io.Closer).Close() ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") @@ -223,7 +223,7 @@ var _ = Describe("Connection", func() { cg := NewMockConnectionGater(mockCtrl) cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()) - clientTransport, err := NewTransport(clientKey, nil, cg) + clientTransport, err := NewTransport(clientKey, nil, cg, nil) Expect(err).ToNot(HaveOccurred()) defer clientTransport.(io.Closer).Close() @@ -243,12 +243,12 @@ var _ = Describe("Connection", func() { It("dials to two servers at the same time", func() { serverID2, serverKey2 := createPeer() - serverTransport, err := NewTransport(serverKey, nil, nil) + serverTransport, err := NewTransport(serverKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer serverTransport.(io.Closer).Close() ln1 := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln1.Close() - serverTransport2, err := NewTransport(serverKey2, nil, nil) + serverTransport2, err := NewTransport(serverKey2, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer serverTransport2.(io.Closer).Close() ln2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic") @@ -275,7 +275,7 @@ var _ = Describe("Connection", func() { } }() - clientTransport, err := NewTransport(clientKey, nil, nil) + clientTransport, err := NewTransport(clientKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer clientTransport.(io.Closer).Close() c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID) @@ -305,7 +305,7 @@ var _ = Describe("Connection", func() { }) It("sends stateless resets", func() { - serverTransport, err := NewTransport(serverKey, nil, nil) + serverTransport, err := NewTransport(serverKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer serverTransport.(io.Closer).Close() ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") @@ -322,7 +322,7 @@ var _ = Describe("Connection", func() { defer proxy.Close() // establish a connection - clientTransport, err := NewTransport(clientKey, nil, nil) + clientTransport, err := NewTransport(clientKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer clientTransport.(io.Closer).Close() proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr()) @@ -365,7 +365,7 @@ var _ = Describe("Connection", func() { }) It("hole punches", func() { - t1, err := NewTransport(serverKey, nil, nil) + t1, err := NewTransport(serverKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer t1.(io.Closer).Close() laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") @@ -381,7 +381,7 @@ var _ = Describe("Connection", func() { } }() - t2, err := NewTransport(clientKey, nil, nil) + t2, err := NewTransport(clientKey, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) defer t2.(io.Closer).Close() ln2, err := t2.Listen(laddr) diff --git a/go.mod b/go.mod index d2c50b0..7384fbe 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/google/gopacket v1.1.17 github.com/ipfs/go-log/v2 v2.4.0 github.com/klauspost/compress v1.11.7 - github.com/libp2p/go-libp2p-core v0.10.0 + github.com/libp2p/go-libp2p-core v0.13.1-0.20211231090304-48c94b6fddec github.com/libp2p/go-libp2p-tls v0.3.0 github.com/libp2p/go-netroute v0.1.3 github.com/lucas-clemente/quic-go v0.24.0 diff --git a/go.sum b/go.sum index 713c1db..e4c1263 100644 --- a/go.sum +++ b/go.sum @@ -218,8 +218,9 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/libp2p/go-buffer-pool v0.0.2 h1:QNK2iAFa8gjAe1SPz6mHSMuCcjs+X1wlHzeOSqcmlfs= github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoRZd1Vi32+RXyFM= github.com/libp2p/go-flow-metrics v0.0.3/go.mod h1:HeoSNUrOJVK1jEpDqVEiUOIXqhbnS27omG0uWU5slZs= -github.com/libp2p/go-libp2p-core v0.10.0 h1:jFy7v5Muq58GTeYkPhGzIH8Qq4BFfziqc0ixPd/pP9k= github.com/libp2p/go-libp2p-core v0.10.0/go.mod h1:ECdxehoYosLYHgDDFa2N4yE8Y7aQRAMf0sX9mf2sbGg= +github.com/libp2p/go-libp2p-core v0.13.1-0.20211231090304-48c94b6fddec h1:3TxVyBAlNFAJ1stkc/uoGwFqbXlvBg7bQH05MHMU2dQ= +github.com/libp2p/go-libp2p-core v0.13.1-0.20211231090304-48c94b6fddec/go.mod h1:ECdxehoYosLYHgDDFa2N4yE8Y7aQRAMf0sX9mf2sbGg= github.com/libp2p/go-libp2p-tls v0.3.0 h1:8BgvUJiOTcj0Gp6XvEicF0rL5aUtRg/UzEdeZDmDlC8= github.com/libp2p/go-libp2p-tls v0.3.0/go.mod h1:fwF5X6PWGxm6IDRwF3V8AVCCj/hOd5oFlg+wo2FxJDY= github.com/libp2p/go-maddr-filter v0.1.0/go.mod h1:VzZhTXkMucEGGEOSKddrwGiOv0tUhgnKqNEmIAz/bPU= diff --git a/integrationtests/main.go b/integrationtests/main.go index a9e3cc6..adae52c 100644 --- a/integrationtests/main.go +++ b/integrationtests/main.go @@ -78,7 +78,7 @@ func readKeys(hostKeyFile, peerKeyFile string) (crypto.PrivKey, crypto.PubKey, e } func runServer(hostKey crypto.PrivKey, peerKey crypto.PubKey, addr ma.Multiaddr, test string) error { - tr, err := libp2pquic.NewTransport(hostKey, nil, nil) + tr, err := libp2pquic.NewTransport(hostKey, nil, nil, nil) if err != nil { return err } @@ -129,7 +129,7 @@ func runServer(hostKey crypto.PrivKey, peerKey crypto.PubKey, addr ma.Multiaddr, } func runClient(hostKey crypto.PrivKey, serverKey crypto.PubKey, addr ma.Multiaddr, test string) error { - tr, err := libp2pquic.NewTransport(hostKey, nil, nil) + tr, err := libp2pquic.NewTransport(hostKey, nil, nil, nil) if err != nil { return err } diff --git a/listener.go b/listener.go index 039f717..bcbb49a 100644 --- a/listener.go +++ b/listener.go @@ -6,7 +6,7 @@ import ( "net" ic "github.com/libp2p/go-libp2p-core/crypto" - n "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" @@ -23,6 +23,7 @@ type listener struct { quicListener quic.Listener conn *reuseConn transport *transport + rcmgr network.ResourceManager privKey ic.PrivKey localPeer peer.ID localMultiaddr ma.Multiaddr @@ -30,7 +31,10 @@ type listener struct { var _ tpt.Listener = &listener{} -func newListener(rconn *reuseConn, t *transport, localPeer peer.ID, key ic.PrivKey, identity *p2ptls.Identity) (tpt.Listener, error) { +func newListener(rconn *reuseConn, t *transport, localPeer peer.ID, key ic.PrivKey, identity *p2ptls.Identity, rcmgr network.ResourceManager) (tpt.Listener, error) { + if rcmgr == nil { + panic("nil rcmgr") + } var tlsConf tls.Config tlsConf.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) { // return a tls.Config that verifies the peer's certificate chain. @@ -52,6 +56,7 @@ func newListener(rconn *reuseConn, t *transport, localPeer peer.ID, key ic.PrivK conn: rconn, quicListener: ln, transport: t, + rcmgr: rcmgr, privKey: key, localPeer: localPeer, localMultiaddr: localMultiaddr, @@ -70,7 +75,13 @@ func (l *listener) Accept() (tpt.CapableConn, error) { sess.CloseWithError(0, err.Error()) continue } - if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(conn) && l.transport.gater.InterceptSecured(n.DirInbound, conn.remotePeerID, conn)) { + connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false) + if err != nil { + sess.CloseWithError(0, err.Error()) + continue + } + if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(conn) && l.transport.gater.InterceptSecured(network.DirInbound, conn.remotePeerID, conn)) { + connScope.Done() sess.CloseWithError(errorCodeConnectionGating, "connection gated") continue } @@ -116,6 +127,7 @@ func (l *listener) setupConn(sess quic.Session) (*conn, error) { return &conn{ sess: sess, transport: l.transport, + rcmgr: l.rcmgr, localPeer: l.localPeer, localMultiaddr: l.localMultiaddr, privKey: l.privKey, diff --git a/listener_test.go b/listener_test.go index f8d2041..de18982 100644 --- a/listener_test.go +++ b/listener_test.go @@ -35,7 +35,7 @@ var _ = Describe("Listener", func() { Expect(err).ToNot(HaveOccurred()) key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) Expect(err).ToNot(HaveOccurred()) - t, err = NewTransport(key, nil, nil) + t, err = NewTransport(key, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) }) diff --git a/transport.go b/transport.go index 13efc95..d93e6d9 100644 --- a/transport.go +++ b/transport.go @@ -10,21 +10,24 @@ import ( "sync" "time" - "github.com/minio/sha256-simd" "golang.org/x/crypto/hkdf" - logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p-core/connmgr" ic "github.com/libp2p/go-libp2p-core/crypto" - n "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/pnet" tpt "github.com/libp2p/go-libp2p-core/transport" + p2ptls "github.com/libp2p/go-libp2p-tls" - "github.com/lucas-clemente/quic-go" + ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" manet "github.com/multiformats/go-multiaddr/net" + + logging "github.com/ipfs/go-log/v2" + "github.com/lucas-clemente/quic-go" + "github.com/minio/sha256-simd" ) var log = logging.Logger("quic-transport") @@ -109,6 +112,7 @@ type transport struct { serverConfig *quic.Config clientConfig *quic.Config gater connmgr.ConnectionGater + rcmgr network.ResourceManager holePunchingMx sync.Mutex holePunching map[holePunchKey]*activeHolePunch @@ -127,7 +131,7 @@ type activeHolePunch struct { } // NewTransport creates a new QUIC transport -func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) (tpt.Transport, error) { +func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Transport, error) { if len(psk) > 0 { log.Error("QUIC doesn't support private networks yet.") return nil, errors.New("QUIC doesn't support private networks yet") @@ -144,6 +148,9 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) ( if err != nil { return nil, err } + if rcmgr == nil { + rcmgr = network.NullResourceManager + } config := quicConfig.Clone() keyBytes, err := key.Raw() if err != nil { @@ -164,17 +171,18 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) ( serverConfig: config, clientConfig: config.Clone(), gater: gater, + rcmgr: rcmgr, holePunching: make(map[holePunchKey]*activeHolePunch), }, nil } // Dial dials a new QUIC connection func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { - network, host, err := manet.DialArgs(raddr) + netw, host, err := manet.DialArgs(raddr) if err != nil { return nil, err } - addr, err := net.ResolveUDPAddr(network, host) + addr, err := net.ResolveUDPAddr(netw, host) if err != nil { return nil, err } @@ -183,17 +191,25 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return nil, err } tlsConf, keyCh := t.identity.ConfigForPeer(p) - - if ok, isClient, _ := n.GetSimultaneousConnect(ctx); ok && !isClient { - return t.holePunch(ctx, network, addr, p) + if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient { + return t.holePunch(ctx, netw, addr, p) } - pconn, err := t.connManager.Dial(network, addr) + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false) + if err != nil { + return nil, err + } + if err := scope.SetPeer(p); err != nil { + scope.Done() + return nil, err + } + pconn, err := t.connManager.Dial(netw, addr) if err != nil { return nil, err } sess, err := quicDialContext(ctx, pconn, addr, host, tlsConf, t.clientConfig) if err != nil { + scope.Done() pconn.DecreaseCount() return nil, err } @@ -205,21 +221,25 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp } if remotePubKey == nil { pconn.DecreaseCount() + scope.Done() return nil, errors.New("go-libp2p-quic-transport BUG: expected remote pub key to be set") } go func() { <-sess.Context().Done() + scope.Done() pconn.DecreaseCount() }() localMultiaddr, err := toQuicMultiaddr(pconn.LocalAddr()) if err != nil { + scope.Done() sess.CloseWithError(0, "") return nil, err } conn := &conn{ sess: sess, transport: t, + rcmgr: t.rcmgr, privKey: t.privKey, localPeer: t.localPeer, localMultiaddr: localMultiaddr, @@ -227,7 +247,8 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp remotePeerID: p, remoteMultiaddr: remoteMultiaddr, } - if t.gater != nil && !t.gater.InterceptSecured(n.DirOutbound, p, conn) { + if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, conn) { + scope.Done() sess.CloseWithError(errorCodeConnectionGating, "connection gated") return nil, fmt.Errorf("secured connection gated") } @@ -332,7 +353,7 @@ func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { if err != nil { return nil, err } - ln, err := newListener(conn, t, t.localPeer, t.privKey, t.identity) + ln, err := newListener(conn, t, t.localPeer, t.privKey, t.identity, t.rcmgr) if err != nil { conn.DecreaseCount() return nil, err diff --git a/transport_test.go b/transport_test.go index 5f52ce0..64c668e 100644 --- a/transport_test.go +++ b/transport_test.go @@ -27,7 +27,7 @@ var _ = Describe("Transport", func() { Expect(err).ToNot(HaveOccurred()) key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) Expect(err).ToNot(HaveOccurred()) - t, err = NewTransport(key, nil, nil) + t, err = NewTransport(key, nil, nil, nil) Expect(err).ToNot(HaveOccurred()) })