-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathupgrader.go
343 lines (306 loc) · 9.85 KB
/
upgrader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
package upgrader
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
ipnet "github.com/libp2p/go-libp2p/core/pnet"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/net/pnet"
manet "github.com/multiformats/go-multiaddr/net"
mss "github.com/multiformats/go-multistream"
)
// ErrNilPeer is returned when attempting to upgrade an outbound connection
// without specifying a peer ID.
var ErrNilPeer = errors.New("nil peer")
// AcceptQueueLength is the number of connections to fully setup before not accepting any new connections
var AcceptQueueLength = 16
const (
defaultAcceptTimeout = 15 * time.Second
defaultNegotiateTimeout = 60 * time.Second
)
type Option func(*upgrader) error
func WithAcceptTimeout(t time.Duration) Option {
return func(u *upgrader) error {
u.acceptTimeout = t
return nil
}
}
type StreamMuxer struct {
ID protocol.ID
Muxer network.Multiplexer
}
// Upgrader is a multistream upgrader that can upgrade an underlying connection
// to a full transport connection (secure and multiplexed).
type upgrader struct {
psk ipnet.PSK
connGater connmgr.ConnectionGater
rcmgr network.ResourceManager
muxerMuxer *mss.MultistreamMuxer[protocol.ID]
muxers []StreamMuxer
muxerIDs []protocol.ID
security []sec.SecureTransport
securityMuxer *mss.MultistreamMuxer[protocol.ID]
securityIDs []protocol.ID
// AcceptTimeout is the maximum duration an Accept is allowed to take.
// This includes the time between accepting the raw network connection,
// protocol selection as well as the handshake, if applicable.
//
// If unset, the default value (15s) is used.
acceptTimeout time.Duration
}
var _ transport.Upgrader = &upgrader{}
func New(security []sec.SecureTransport, muxers []StreamMuxer, psk ipnet.PSK, rcmgr network.ResourceManager, connGater connmgr.ConnectionGater, opts ...Option) (transport.Upgrader, error) {
u := &upgrader{
acceptTimeout: defaultAcceptTimeout,
rcmgr: rcmgr,
connGater: connGater,
psk: psk,
muxerMuxer: mss.NewMultistreamMuxer[protocol.ID](),
muxers: muxers,
security: security,
securityMuxer: mss.NewMultistreamMuxer[protocol.ID](),
}
for _, opt := range opts {
if err := opt(u); err != nil {
return nil, err
}
}
if u.rcmgr == nil {
u.rcmgr = &network.NullResourceManager{}
}
u.muxerIDs = make([]protocol.ID, 0, len(muxers))
for _, m := range muxers {
u.muxerMuxer.AddHandler(m.ID, nil)
u.muxerIDs = append(u.muxerIDs, m.ID)
}
u.securityIDs = make([]protocol.ID, 0, len(security))
for _, s := range security {
u.securityMuxer.AddHandler(s.ID(), nil)
u.securityIDs = append(u.securityIDs, s.ID())
}
return u, nil
}
// UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
func (u *upgrader) UpgradeListener(t transport.Transport, list manet.Listener) transport.Listener {
ctx, cancel := context.WithCancel(context.Background())
l := &listener{
Listener: list,
upgrader: u,
transport: t,
rcmgr: u.rcmgr,
threshold: newThreshold(AcceptQueueLength),
incoming: make(chan transport.CapableConn),
cancel: cancel,
ctx: ctx,
}
go l.handleIncoming()
return l
}
// Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection.
func (u *upgrader) Upgrade(ctx context.Context, t transport.Transport, maconn manet.Conn, dir network.Direction, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) {
c, err := u.upgrade(ctx, t, maconn, dir, p, connScope)
if err != nil {
connScope.Done()
return nil, err
}
return c, nil
}
func (u *upgrader) upgrade(ctx context.Context, t transport.Transport, maconn manet.Conn, dir network.Direction, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) {
if dir == network.DirOutbound && p == "" {
return nil, ErrNilPeer
}
var stat network.ConnStats
if cs, ok := maconn.(network.ConnStat); ok {
stat = cs.Stat()
}
var conn net.Conn = maconn
if u.psk != nil {
pconn, err := pnet.NewProtectedConn(u.psk, conn)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to setup private network protector: %w", err)
}
conn = pconn
} else if ipnet.ForcePrivateNetwork {
log.Error("tried to dial with no Private Network Protector but usage of Private Networks is forced by the environment")
return nil, ipnet.ErrNotInPrivateNetwork
}
isServer := dir == network.DirInbound
sconn, security, err := u.setupSecurity(ctx, conn, p, isServer)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to negotiate security protocol: %w", err)
}
// call the connection gater, if one is registered.
if u.connGater != nil && !u.connGater.InterceptSecured(dir, sconn.RemotePeer(), maconn) {
if err := maconn.Close(); err != nil {
log.Errorw("failed to close connection", "peer", p, "addr", maconn.RemoteMultiaddr(), "error", err)
}
return nil, fmt.Errorf("gater rejected connection with peer %s and addr %s with direction %d",
sconn.RemotePeer(), maconn.RemoteMultiaddr(), dir)
}
// Only call SetPeer if it hasn't already been set -- this can happen when we don't know
// the peer in advance and in some bug scenarios.
if connScope.PeerScope() == nil {
if err := connScope.SetPeer(sconn.RemotePeer()); err != nil {
log.Debugw("resource manager blocked connection for peer", "peer", sconn.RemotePeer(), "addr", conn.RemoteAddr(), "error", err)
if err := maconn.Close(); err != nil {
log.Errorw("failed to close connection", "peer", p, "addr", maconn.RemoteMultiaddr(), "error", err)
}
return nil, fmt.Errorf("resource manager connection with peer %s and addr %s with direction %d",
sconn.RemotePeer(), maconn.RemoteMultiaddr(), dir)
}
}
muxer, smconn, err := u.setupMuxer(ctx, sconn, isServer, connScope.PeerScope())
if err != nil {
sconn.Close()
return nil, fmt.Errorf("failed to negotiate stream multiplexer: %w", err)
}
tc := &transportConn{
MuxedConn: smconn,
ConnMultiaddrs: maconn,
ConnSecurity: sconn,
transport: t,
stat: stat,
scope: connScope,
muxer: muxer,
security: security,
usedEarlyMuxerNegotiation: sconn.ConnState().UsedEarlyMuxerNegotiation,
}
return tc, nil
}
func (u *upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, isServer bool) (sec.SecureConn, protocol.ID, error) {
st, err := u.negotiateSecurity(ctx, conn, isServer)
if err != nil {
return nil, "", err
}
if isServer {
sconn, err := st.SecureInbound(ctx, conn, p)
return sconn, st.ID(), err
}
sconn, err := st.SecureOutbound(ctx, conn, p)
return sconn, st.ID(), err
}
func (u *upgrader) negotiateMuxer(nc net.Conn, isServer bool) (*StreamMuxer, error) {
if err := nc.SetDeadline(time.Now().Add(defaultNegotiateTimeout)); err != nil {
return nil, err
}
var proto protocol.ID
if isServer {
selected, _, err := u.muxerMuxer.Negotiate(nc)
if err != nil {
return nil, err
}
proto = selected
} else {
selected, err := mss.SelectOneOf(u.muxerIDs, nc)
if err != nil {
return nil, err
}
proto = selected
}
if err := nc.SetDeadline(time.Time{}); err != nil {
return nil, err
}
if m := u.getMuxerByID(proto); m != nil {
return m, nil
}
return nil, fmt.Errorf("selected protocol we don't have a transport for")
}
func (u *upgrader) getMuxerByID(id protocol.ID) *StreamMuxer {
for _, m := range u.muxers {
if m.ID == id {
return &m
}
}
return nil
}
func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server bool, scope network.PeerScope) (protocol.ID, network.MuxedConn, error) {
muxerSelected := conn.ConnState().StreamMultiplexer
// Use muxer selected from security handshake if available. Otherwise fall back to multistream-selection.
if len(muxerSelected) > 0 {
m := u.getMuxerByID(muxerSelected)
if m == nil {
return "", nil, fmt.Errorf("selected a muxer we don't know: %s", muxerSelected)
}
c, err := m.Muxer.NewConn(conn, server, scope)
if err != nil {
return "", nil, err
}
return muxerSelected, c, nil
}
type result struct {
smconn network.MuxedConn
muxerID protocol.ID
err error
}
done := make(chan result, 1)
// TODO: The muxer should take a context.
go func() {
m, err := u.negotiateMuxer(conn, server)
if err != nil {
done <- result{err: err}
return
}
smconn, err := m.Muxer.NewConn(conn, server, scope)
done <- result{smconn: smconn, muxerID: m.ID, err: err}
}()
select {
case r := <-done:
return r.muxerID, r.smconn, r.err
case <-ctx.Done():
// interrupt this process
conn.Close()
// wait to finish
<-done
return "", nil, ctx.Err()
}
}
func (u *upgrader) getSecurityByID(id protocol.ID) sec.SecureTransport {
for _, s := range u.security {
if s.ID() == id {
return s
}
}
return nil
}
func (u *upgrader) negotiateSecurity(ctx context.Context, insecure net.Conn, server bool) (sec.SecureTransport, error) {
type result struct {
proto protocol.ID
err error
}
done := make(chan result, 1)
go func() {
if server {
var r result
r.proto, _, r.err = u.securityMuxer.Negotiate(insecure)
done <- r
return
}
var r result
r.proto, r.err = mss.SelectOneOf(u.securityIDs, insecure)
done <- r
}()
select {
case r := <-done:
if r.err != nil {
return nil, r.err
}
if s := u.getSecurityByID(r.proto); s != nil {
return s, nil
}
return nil, fmt.Errorf("selected unknown security transport: %s", r.proto)
case <-ctx.Done():
// We *must* do this. We have outstanding work on the connection, and it's no longer safe to use.
insecure.Close()
<-done // wait to stop using the connection.
return nil, ctx.Err()
}
}