From 6d6f29f8c9df5fc8651c3a2e80ad980c1728eaaf Mon Sep 17 00:00:00 2001 From: Rod Vagg Date: Fri, 26 Jul 2024 23:27:22 +1000 Subject: [PATCH] fix!: gateway: fix rate limiting, general cleanup Minor API changes: * gateway.NewRateLimiterHandler and gateway.NewConnectionRateLimiterHandler have been replaced with gateway.NewRateLimitHandler. * The handlers returned by both gateway.NewRateLimitHandler and the primary gateway.Handler return an http.Handler augmented with a Shutdown(ctx) method to be used for graceful cleanup of resources. Fix: * --per-conn-rate-limit was previously applied as a global rate limiter, effectively making it have the same impact as --rate-limit. This change fixes the behaviour such that --per-conn-rate-limit is applied as a API call limiter within a single connection (i.e. a WebSocket connection). The rate is specified as tokens-per-second, where tokens are relative to the expense of the API call being made. --- CHANGELOG.md | 7 ++ cmd/lotus-gateway/main.go | 47 ++++---- gateway/handler.go | 226 +++++++++++++++++++++++++------------- gateway/handler_test.go | 42 +++++++ gateway/node.go | 23 ++-- gateway/node_test.go | 34 ++++-- itests/gateway_test.go | 188 ++++++++++++++++++++++++------- 7 files changed, 417 insertions(+), 150 deletions(-) create mode 100644 gateway/handler_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e9dad67f3d..60def96bbc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,13 @@ - fix: add datacap balance to circ supply internal accounting as unCirc #12348 +## Improvements + +- fix!: gateway: fix rate limiting, general cleanup ([filecoin-project/lotus#12315](https://github.com/filecoin-project/lotus/pull/12315)). + - CLI usage documentation has been improved for `lotus-gateway` + - `--per-conn-rate-limit` now works as advertised. + - Some APIs have changed which may impact users consuming Lotus Gateway code as a library. + # v1.28.1 / 2024-07-24 This is the MANDATORY Lotus v1.28.1 release, which will deliver the Filecoin network version 23, codenamed Waffle 🧇. v1.28.1 is also the minimal version that supports nv23. diff --git a/cmd/lotus-gateway/main.go b/cmd/lotus-gateway/main.go index 2c5279ed448..a2c46c6b26c 100644 --- a/cmd/lotus-gateway/main.go +++ b/cmd/lotus-gateway/main.go @@ -132,23 +132,29 @@ var runCmd = &cli.Command{ Value: int64(gateway.DefaultStateWaitLookbackLimit), }, &cli.Int64Flag{ - Name: "rate-limit", - Usage: "rate-limit API calls. Use 0 to disable", + Name: "rate-limit", + Usage: fmt.Sprintf( + "Global API call throttling rate limit (per second), weighted by relative expense of the call, with the most expensive calls counting for %d. Use 0 to disable", + gateway.MaxRateLimitTokens, + ), Value: 0, }, &cli.Int64Flag{ - Name: "per-conn-rate-limit", - Usage: "rate-limit API calls per each connection. Use 0 to disable", + Name: "per-conn-rate-limit", + Usage: fmt.Sprintf( + "API call throttling rate limit (per second) per WebSocket connection, weighted by relative expense of the call, with the most expensive calls counting for %d. Use 0 to disable", + gateway.MaxRateLimitTokens, + ), Value: 0, }, &cli.DurationFlag{ Name: "rate-limit-timeout", - Usage: "the maximum time to wait for the rate limiter before returning an error to clients", + Usage: "The maximum time to wait for the API call throttling rate limiter before returning an error to clients", Value: gateway.DefaultRateLimitTimeout, }, &cli.Int64Flag{ Name: "conn-per-minute", - Usage: "The number of incomming connections to accept from a single IP per minute. Use 0 to disable", + Usage: "A hard limit on the number of incomming connections (requests) to accept per remote host per minute. Use 0 to disable", Value: 0, }, }, @@ -171,13 +177,13 @@ var runCmd = &cli.Command{ defer closer() var ( - lookbackCap = cctx.Duration("api-max-lookback") - address = cctx.String("listen") - waitLookback = abi.ChainEpoch(cctx.Int64("api-wait-lookback-limit")) - rateLimit = cctx.Int64("rate-limit") - perConnRateLimit = cctx.Int64("per-conn-rate-limit") - rateLimitTimeout = cctx.Duration("rate-limit-timeout") - connPerMinute = cctx.Int64("conn-per-minute") + lookbackCap = cctx.Duration("api-max-lookback") + address = cctx.String("listen") + waitLookback = abi.ChainEpoch(cctx.Int64("api-wait-lookback-limit")) + globalRateLimit = cctx.Int("rate-limit") + perConnectionRateLimit = cctx.Int("per-conn-rate-limit") + rateLimitTimeout = cctx.Duration("rate-limit-timeout") + perHostConnectionsPerMinute = cctx.Int("conn-per-minute") ) serverOptions := make([]jsonrpc.ServerOption, 0) @@ -197,21 +203,22 @@ var runCmd = &cli.Command{ return xerrors.Errorf("failed to convert endpoint address to multiaddr: %w", err) } - gwapi := gateway.NewNode(api, subHnd, lookbackCap, waitLookback, rateLimit, rateLimitTimeout) - h, err := gateway.Handler(gwapi, api, perConnRateLimit, connPerMinute, serverOptions...) + gwapi := gateway.NewNode(api, subHnd, lookbackCap, waitLookback, int64(globalRateLimit), rateLimitTimeout) + handler, err := gateway.Handler(gwapi, api, perConnectionRateLimit, perHostConnectionsPerMinute, serverOptions...) if err != nil { return xerrors.Errorf("failed to set up gateway HTTP handler") } - stopFunc, err := node.ServeRPC(h, "lotus-gateway", maddr) + stopFunc, err := node.ServeRPC(handler, "lotus-gateway", maddr) if err != nil { return xerrors.Errorf("failed to serve rpc endpoint: %w", err) } - <-node.MonitorShutdown(nil, node.ShutdownHandler{ - Component: "rpc", - StopFunc: stopFunc, - }) + <-node.MonitorShutdown( + nil, + node.ShutdownHandler{Component: "rpc", StopFunc: stopFunc}, + node.ShutdownHandler{Component: "rpc-handler", StopFunc: handler.Shutdown}, + ) return nil }, } diff --git a/gateway/handler.go b/gateway/handler.go index 2a9ee20807f..d98b9594cf2 100644 --- a/gateway/handler.go +++ b/gateway/handler.go @@ -21,20 +21,48 @@ import ( "github.com/filecoin-project/lotus/node" ) -type perConnLimiterKeyType string +type perConnectionAPIRateLimiterKeyType string +type filterTrackerKeyType string -const perConnLimiterKey perConnLimiterKeyType = "limiter" +const ( + perConnectionAPIRateLimiterKey perConnectionAPIRateLimiterKeyType = "limiter" + statefulCallTrackerKey filterTrackerKeyType = "statefulCallTracker" + connectionLimiterCleanupInterval = 30 * time.Second +) -type filterTrackerKeyType string +// ShutdownHandler is an http.Handler that can be gracefully shutdown. +type ShutdownHandler interface { + http.Handler + + Shutdown(ctx context.Context) error +} -const statefulCallTrackerKey filterTrackerKeyType = "statefulCallTracker" +var _ ShutdownHandler = &statefulCallHandler{} +var _ ShutdownHandler = &RateLimitHandler{} + +// Handler returns a gateway http.Handler, to be mounted as-is on the server. The handler is +// returned as a ShutdownHandler which allows for graceful shutdown of the handler via its +// Shutdown method. +// +// The handler will limit the number of API calls per minute within a single WebSocket connection +// (where API calls are weighted by their relative expense), and the number of connections per +// minute from a single host. +// +// Connection limiting is a hard limit that will reject requests with a 429 status code if the limit +// is exceeded. API call limiting is a soft limit that will delay requests if the limit is exceeded. +func Handler( + gwapi lapi.Gateway, + api lapi.FullNode, + perConnectionAPIRateLimit int, + perHostConnectionsPerMinute int, + opts ...jsonrpc.ServerOption, +) (ShutdownHandler, error) { -// Handler returns a gateway http.Handler, to be mounted as-is on the server. -func Handler(gwapi lapi.Gateway, api lapi.FullNode, rateLimit int64, connPerMinute int64, opts ...jsonrpc.ServerOption) (http.Handler, error) { m := mux.NewRouter() + opts = append(opts, jsonrpc.WithReverseClient[lapi.EthSubscriberMethods]("Filecoin"), jsonrpc.WithServerErrors(lapi.RPCErrors)) serveRpc := func(path string, hnd interface{}) { - rpcServer := jsonrpc.NewServer(append(opts, jsonrpc.WithReverseClient[lapi.EthSubscriberMethods]("Filecoin"), jsonrpc.WithServerErrors(lapi.RPCErrors))...) + rpcServer := jsonrpc.NewServer(opts...) rpcServer.Register("Filecoin", hnd) rpcServer.AliasMethod("rpc.discover", "Filecoin.Discover") @@ -61,104 +89,152 @@ func Handler(gwapi lapi.Gateway, api lapi.FullNode, rateLimit int64, connPerMinu m.Handle("/health/readyz", node.NewReadyHandler(api)) m.PathPrefix("/").Handler(http.DefaultServeMux) - /*ah := &auth.Handler{ - Verify: nodeApi.AuthVerify, - Next: mux.ServeHTTP, - }*/ - - rlh := NewRateLimiterHandler(m, rateLimit) - clh := NewConnectionRateLimiterHandler(rlh, connPerMinute) - return clh, nil + handler := &statefulCallHandler{m} + if perConnectionAPIRateLimit > 0 && perHostConnectionsPerMinute > 0 { + return NewRateLimitHandler( + handler, + perConnectionAPIRateLimit, + perHostConnectionsPerMinute, + connectionLimiterCleanupInterval, + ), nil + } + return handler, nil } -func NewRateLimiterHandler(handler http.Handler, rateLimit int64) *RateLimiterHandler { - limiter := limiterFromRateLimit(rateLimit) - - return &RateLimiterHandler{ - handler: handler, - limiter: limiter, - } +type statefulCallHandler struct { + next http.Handler } -// RateLimiterHandler adds a rate limiter to the request context for per-connection rate limiting -type RateLimiterHandler struct { - handler http.Handler - limiter *rate.Limiter +func (h statefulCallHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(context.WithValue(r.Context(), statefulCallTrackerKey, newStatefulCallTracker())) + h.next.ServeHTTP(w, r) } -func (h RateLimiterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - r = r.WithContext(context.WithValue(r.Context(), perConnLimiterKey, h.limiter)) +func (h statefulCallHandler) Shutdown(ctx context.Context) error { + return shutdown(ctx, h.next) +} - // also add a filter tracker to the context - r = r.WithContext(context.WithValue(r.Context(), statefulCallTrackerKey, newStatefulCallTracker())) +type hostLimiter struct { + limiter *rate.Limiter + lastAccess time.Time +} - h.handler.ServeHTTP(w, r) +type RateLimitHandler struct { + cancelFunc context.CancelFunc + mu sync.Mutex + limiters map[string]*hostLimiter + perConnectionAPILimit rate.Limit + perHostConnectionsPerMinute int + next http.Handler + cleanupInterval time.Duration + expiryDuration time.Duration } -// NewConnectionRateLimiterHandler blocks new connections if there have already been too many. -func NewConnectionRateLimiterHandler(handler http.Handler, connPerMinute int64) *ConnectionRateLimiterHandler { - ipmap := make(map[string]int64) - return &ConnectionRateLimiterHandler{ - ipmap: ipmap, - connPerMinute: connPerMinute, - handler: handler, +// NewRateLimitHandler creates a new RateLimitHandler that wraps the +// provided handler and limits the number of API calls per minute within a single WebSocket +// connection (where API calls are weighted by their relative expense), and the number of +// connections per minute from a single host. +// The cleanupInterval determines how often the handler will check for unused limiters to clean up. +func NewRateLimitHandler( + next http.Handler, + perConnectionAPIRateLimit int, + perHostConnectionsPerMinute int, + cleanupInterval time.Duration, +) *RateLimitHandler { + + ctx, cancel := context.WithCancel(context.Background()) + h := &RateLimitHandler{ + cancelFunc: cancel, + limiters: make(map[string]*hostLimiter), + perConnectionAPILimit: rate.Inf, + perHostConnectionsPerMinute: perHostConnectionsPerMinute, + next: next, + cleanupInterval: cleanupInterval, + expiryDuration: 5 * cleanupInterval, + } + if perConnectionAPIRateLimit > 0 { + h.perConnectionAPILimit = rate.Every(time.Second / time.Duration(perConnectionAPIRateLimit)) } + go h.cleanupExpiredLimiters(ctx) + return h } -type ConnectionRateLimiterHandler struct { - mu sync.Mutex - ipmap map[string]int64 - connPerMinute int64 - handler http.Handler +func (h *RateLimitHandler) getLimits(host string) *hostLimiter { + h.mu.Lock() + defer h.mu.Unlock() + + entry, exists := h.limiters[host] + if !exists { + var limiter *rate.Limiter + if h.perHostConnectionsPerMinute > 0 { + requestLimit := rate.Every(time.Minute / time.Duration(h.perHostConnectionsPerMinute)) + limiter = rate.NewLimiter(requestLimit, h.perHostConnectionsPerMinute) + } + entry = &hostLimiter{ + limiter: limiter, + lastAccess: time.Now(), + } + h.limiters[host] = entry + } else { + entry.lastAccess = time.Now() + } + + return entry } -func (h *ConnectionRateLimiterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if h.connPerMinute == 0 { - h.handler.ServeHTTP(w, r) - return - } +func (h *RateLimitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } - h.mu.Lock() - seen, ok := h.ipmap[host] - if !ok { - h.ipmap[host] = 1 - h.mu.Unlock() - h.handler.ServeHTTP(w, r) + limits := h.getLimits(host) + if limits.limiter != nil && !limits.limiter.Allow() { + w.WriteHeader(http.StatusTooManyRequests) return } - // rate limited - if seen > h.connPerMinute { - h.mu.Unlock() - w.WriteHeader(http.StatusTooManyRequests) + + if h.perConnectionAPILimit != rate.Inf { + // new rate limiter for each connection, to throttle a single WebSockets connection; + // allow for a burst of MaxRateLimitTokens + apiLimiter := rate.NewLimiter(h.perConnectionAPILimit, MaxRateLimitTokens) + r = r.WithContext(context.WithValue(r.Context(), perConnectionAPIRateLimiterKey, apiLimiter)) + } + + h.next.ServeHTTP(w, r) +} + +func (h *RateLimitHandler) cleanupExpiredLimiters(ctx context.Context) { + if h.cleanupInterval == 0 { return } - h.ipmap[host] = seen + 1 - h.mu.Unlock() - go func() { + + for { select { - case <-time.After(time.Minute): + case <-ctx.Done(): + return + case <-time.After(h.cleanupInterval): h.mu.Lock() - defer h.mu.Unlock() - h.ipmap[host] = h.ipmap[host] - 1 - if h.ipmap[host] <= 0 { - delete(h.ipmap, host) + now := time.Now() + for host, entry := range h.limiters { + if now.Sub(entry.lastAccess) > h.expiryDuration { + delete(h.limiters, host) + } } + h.mu.Unlock() } - }() - h.handler.ServeHTTP(w, r) + } } -func limiterFromRateLimit(rateLimit int64) *rate.Limiter { - var limit rate.Limit - if rateLimit == 0 { - limit = rate.Inf - } else { - limit = rate.Every(time.Second / time.Duration(rateLimit)) +func (h *RateLimitHandler) Shutdown(ctx context.Context) error { + h.cancelFunc() + return shutdown(ctx, h.next) +} + +func shutdown(ctx context.Context, handler http.Handler) error { + if sh, ok := handler.(ShutdownHandler); ok { + return sh.Shutdown(ctx) } - return rate.NewLimiter(limit, stateRateLimitTokens) + return nil } diff --git a/gateway/handler_test.go b/gateway/handler_test.go new file mode 100644 index 00000000000..65e56836c0a --- /dev/null +++ b/gateway/handler_test.go @@ -0,0 +1,42 @@ +package gateway_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/filecoin-project/lotus/gateway" +) + +func TestRequestRateLimiterHandler(t *testing.T) { + var callCount int + h := gateway.NewRateLimitHandler( + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + callCount++ + }), + 0, // api rate + 2, // request rate (per minute) + 0, // cleanup interval + ) + + runRequest := func(host string, expectedStatus, expectedCallCount int) { + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = host + ":1234" + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + + require.Equal(t, expectedStatus, w.Code, "expected status %v, got %v", expectedStatus, w.Code) + require.Equal(t, expectedCallCount, callCount, "expected callCount to be %v, got %v", expectedCallCount, callCount) + } + + // Test that the handler allows up to 2 requests per minute per host. + runRequest("boop", http.StatusOK, 1) + runRequest("boop", http.StatusOK, 2) + runRequest("beep", http.StatusOK, 3) + runRequest("boop", http.StatusTooManyRequests, 3) + runRequest("beep", http.StatusOK, 4) + runRequest("boop", http.StatusTooManyRequests, 4) + runRequest("beep", http.StatusTooManyRequests, 4) +} diff --git a/gateway/node.go b/gateway/node.go index facf7f3f7b0..0b15918a96b 100644 --- a/gateway/node.go +++ b/gateway/node.go @@ -39,6 +39,9 @@ const ( walletRateLimitTokens = 1 chainRateLimitTokens = 2 stateRateLimitTokens = 3 + + // MaxRateLimitTokens is the number of tokens consumed for the most expensive types of operations + MaxRateLimitTokens = stateRateLimitTokens ) // TargetAPI defines the API methods that the Node depends on @@ -175,11 +178,17 @@ var ( ) // NewNode creates a new gateway node. -func NewNode(api TargetAPI, sHnd *EthSubHandler, lookbackCap time.Duration, stateWaitLookbackLimit abi.ChainEpoch, rateLimit int64, rateLimitTimeout time.Duration) *Node { - var limit rate.Limit - if rateLimit == 0 { - limit = rate.Inf - } else { +func NewNode( + api TargetAPI, + sHnd *EthSubHandler, + lookbackCap time.Duration, + stateWaitLookbackLimit abi.ChainEpoch, + rateLimit int64, + rateLimitTimeout time.Duration, +) *Node { + + limit := rate.Inf + if rateLimit > 0 { limit = rate.Every(time.Second / time.Duration(rateLimit)) } return &Node{ @@ -187,7 +196,7 @@ func NewNode(api TargetAPI, sHnd *EthSubHandler, lookbackCap time.Duration, stat subHnd: sHnd, lookbackCap: lookbackCap, stateWaitLookbackLimit: stateWaitLookbackLimit, - rateLimiter: rate.NewLimiter(limit, stateRateLimitTokens), + rateLimiter: rate.NewLimiter(limit, MaxRateLimitTokens), // allow for a burst of MaxRateLimitTokens rateLimitTimeout: rateLimitTimeout, errLookback: fmt.Errorf("lookbacks of more than %s are disallowed", lookbackCap), } @@ -238,7 +247,7 @@ func (gw *Node) checkTimestamp(at time.Time) error { func (gw *Node) limit(ctx context.Context, tokens int) error { ctx2, cancel := context.WithTimeout(ctx, gw.rateLimitTimeout) defer cancel() - if perConnLimiter, ok := ctx2.Value(perConnLimiterKey).(*rate.Limiter); ok { + if perConnLimiter, ok := ctx2.Value(perConnectionAPIRateLimiterKey).(*rate.Limiter); ok { err := perConnLimiter.WaitN(ctx2, tokens) if err != nil { return fmt.Errorf("connection limited. %w", err) diff --git a/gateway/node_test.go b/gateway/node_test.go index 3b801e19d84..0c60679d012 100644 --- a/gateway/node_test.go +++ b/gateway/node_test.go @@ -260,18 +260,32 @@ func TestGatewayLimitTokensAvailable(t *testing.T) { require.NoError(t, a.limit(ctx, tokens), "requests should not be limited when there are enough tokens available") } -func TestGatewayLimitTokensNotAvailable(t *testing.T) { +func TestGatewayLimitTokensRate(t *testing.T) { ctx := context.Background() mock := &mockGatewayDepsAPI{} tokens := 3 - a := NewNode(mock, nil, DefaultLookbackCap, DefaultStateWaitLookbackLimit, int64(1), time.Millisecond) - var err error - // try to be rate limited - for i := 0; i <= 1000; i++ { - err = a.limit(ctx, tokens) - if err != nil { - break - } + var rateLimit int64 = 200 + rateLimitTimeout := time.Second / time.Duration(rateLimit/3) // large enough to not be hit + a := NewNode(mock, nil, DefaultLookbackCap, DefaultStateWaitLookbackLimit, rateLimit, rateLimitTimeout) + + start := time.Now() + calls := 10 + for i := 0; i < calls; i++ { + require.NoError(t, a.limit(ctx, tokens)) } - require.Error(t, err, "requiests should be rate limited when they hit limits") + // We should be slowed down by the rate limit, but not hard limited because the timeout is + // large; the duration should be roughly the rate limit (per second) times the number of calls, + // with one extra free call because the first one can use up the burst tokens. We'll also add a + // couple more to account for slow test runs. + delayPerToken := time.Second / time.Duration(rateLimit) + expectedDuration := delayPerToken * time.Duration((calls-1)*tokens) + expectedEnd := start.Add(expectedDuration) + require.WithinDuration(t, expectedEnd, time.Now(), delayPerToken*time.Duration(2*tokens), "API calls should be rate limited when they hit limits") + + // In this case our timeout is too short to allow for the rate limit, so we should hit the + // hard rate limit. + rateLimitTimeout = time.Second / time.Duration(rateLimit) + a = NewNode(mock, nil, DefaultLookbackCap, DefaultStateWaitLookbackLimit, rateLimit, rateLimitTimeout) + require.NoError(t, a.limit(ctx, tokens)) + require.ErrorContains(t, a.limit(ctx, tokens), "server busy", "API calls should be hard rate limited when they hit limits") } diff --git a/itests/gateway_test.go b/itests/gateway_test.go index b994d6de3c8..b223fbdc7ab 100644 --- a/itests/gateway_test.go +++ b/itests/gateway_test.go @@ -4,9 +4,12 @@ package itests import ( "bytes" "context" + "encoding/json" "fmt" + "io" "math" "net" + "net/http" "testing" "time" @@ -46,9 +49,8 @@ func TestGatewayWalletMsig(t *testing.T) { //stm: @CHAIN_INCOMING_HANDLE_INCOMING_BLOCKS_001, @CHAIN_INCOMING_VALIDATE_BLOCK_PUBSUB_001, @CHAIN_INCOMING_VALIDATE_MESSAGE_PUBSUB_001 kit.QuietMiningLogs() - blocktime := 5 * time.Millisecond ctx := context.Background() - nodes := startNodes(ctx, t, blocktime, maxLookbackCap, maxStateWaitLookbackLimit) + nodes := startNodes(ctx, t) lite := nodes.lite full := nodes.full @@ -185,51 +187,72 @@ func TestGatewayMsigCLI(t *testing.T) { //stm: @CHAIN_SYNCER_NEW_PEER_HEAD_001, @CHAIN_SYNCER_VALIDATE_MESSAGE_META_001, @CHAIN_SYNCER_STOP_001 kit.QuietMiningLogs() - blocktime := 5 * time.Millisecond ctx := context.Background() - nodes := startNodesWithFunds(ctx, t, blocktime, maxLookbackCap, maxStateWaitLookbackLimit) + nodes := startNodes(ctx, t, withFunds()) lite := nodes.lite multisig.RunMultisigTests(t, lite) } type testNodes struct { - lite *kit.TestFullNode - full *kit.TestFullNode - miner *kit.TestMiner + lite *kit.TestFullNode + full *kit.TestFullNode + miner *kit.TestMiner + gatewayAddr string } -func startNodesWithFunds( - ctx context.Context, - t *testing.T, - blocktime time.Duration, - lookbackCap time.Duration, - stateWaitLookbackLimit abi.ChainEpoch, -) *testNodes { - nodes := startNodes(ctx, t, blocktime, lookbackCap, stateWaitLookbackLimit) +type startOptions struct { + blocktime time.Duration + lookbackCap time.Duration + stateWaitLookbackLimit abi.ChainEpoch + fund bool + perConnectionAPIRateLimit int + perHostRequestsPerMinute int + nodeOpts []kit.NodeOpt +} - // The full node starts with a wallet - fullWalletAddr, err := nodes.full.WalletDefaultAddress(ctx) - require.NoError(t, err) +type startOption func(*startOptions) - // Get the lite node default wallet address. - liteWalletAddr, err := nodes.lite.WalletDefaultAddress(ctx) - require.NoError(t, err) +func applyStartOptions(opts ...startOption) startOptions { + o := startOptions{ + blocktime: 5 * time.Millisecond, + lookbackCap: maxLookbackCap, + stateWaitLookbackLimit: maxStateWaitLookbackLimit, + fund: false, + } + for _, opt := range opts { + opt(&o) + } + return o +} - // Send some funds from the full node to the lite node - err = sendFunds(ctx, nodes.full, fullWalletAddr, liteWalletAddr, types.NewInt(1e18)) - require.NoError(t, err) +func withFunds() startOption { + return func(opts *startOptions) { + opts.fund = true + } +} - return nodes +func withPerConnectionAPIRateLimit(rateLimit int) startOption { + return func(opts *startOptions) { + opts.perConnectionAPIRateLimit = rateLimit + } +} + +func withPerHostRequestsPerMinute(rateLimit int) startOption { + return func(opts *startOptions) { + opts.perHostRequestsPerMinute = rateLimit + } } -func startNodes( - ctx context.Context, - t *testing.T, - blocktime time.Duration, - lookbackCap time.Duration, - stateWaitLookbackLimit abi.ChainEpoch, -) *testNodes { +func withNodeOpts(nodeOpts ...kit.NodeOpt) startOption { + return func(opts *startOptions) { + opts.nodeOpts = nodeOpts + } +} + +func startNodes(ctx context.Context, t *testing.T, opts ...startOption) *testNodes { + options := applyStartOptions(opts...) + var closer jsonrpc.ClientCloser var ( @@ -246,11 +269,13 @@ func startNodes( // create the full node and the miner. var ens *kit.Ensemble full, miner, ens = kit.EnsembleMinimal(t, kit.MockProofs()) - ens.InterconnectAll().BeginMining(blocktime) + ens.InterconnectAll().BeginMining(options.blocktime) + api.RunningNodeType = api.NodeFull // Create a gateway server in front of the full node - gwapi := gateway.NewNode(full, nil, lookbackCap, stateWaitLookbackLimit, 0, time.Minute) - handler, err := gateway.Handler(gwapi, full, 0, 0) + gwapi := gateway.NewNode(full, nil, options.lookbackCap, options.stateWaitLookbackLimit, 0, time.Minute) + handler, err := gateway.Handler(gwapi, full, options.perConnectionAPIRateLimit, options.perHostRequestsPerMinute) + t.Cleanup(func() { _ = handler.Shutdown(ctx) }) require.NoError(t, err) l, err := net.Listen("tcp", "127.0.0.1:0") @@ -264,15 +289,37 @@ func startNodes( require.NoError(t, err) t.Cleanup(closer) - ens.FullNode(&lite, + nodeOpts := append([]kit.NodeOpt{ kit.LiteNode(), kit.ThroughRPC(), kit.ConstructorOpts( node.Override(new(api.Gateway), gapi), ), - ).Start().InterconnectAll() + }, options.nodeOpts...) + ens.FullNode(&lite, nodeOpts...).Start().InterconnectAll() + + nodes := &testNodes{ + lite: &lite, + full: full, + miner: miner, + gatewayAddr: srv.Listener.Addr().String(), + } + + if options.fund { + // The full node starts with a wallet + fullWalletAddr, err := nodes.full.WalletDefaultAddress(ctx) + require.NoError(t, err) + + // Get the lite node default wallet address. + liteWalletAddr, err := nodes.lite.WalletDefaultAddress(ctx) + require.NoError(t, err) - return &testNodes{lite: &lite, full: full, miner: miner} + // Send some funds from the full node to the lite node + err = sendFunds(ctx, nodes.full, fullWalletAddr, liteWalletAddr, types.NewInt(1e18)) + require.NoError(t, err) + } + + return nodes } func sendFunds(ctx context.Context, fromNode *kit.TestFullNode, fromAddr address.Address, toAddr address.Address, amt types.BigInt) error { @@ -297,3 +344,68 @@ func sendFunds(ctx context.Context, fromNode *kit.TestFullNode, fromAddr address return nil } + +func TestGatewayRateLimits(t *testing.T) { + req := require.New(t) + + kit.QuietMiningLogs() + ctx := context.Background() + tokensPerSecond := 10 + requestsPerMinute := 30 // http requests + nodes := startNodes(ctx, t, + withNodeOpts(kit.DisableEthRPC()), + withPerConnectionAPIRateLimit(tokensPerSecond), + withPerHostRequestsPerMinute(requestsPerMinute), + ) + + time.Sleep(time.Second) + + // ChainHead uses chainRateLimitTokens=2. + // But we're also competing with the paymentChannelSettler which listens to the chain uses + // ChainGetBlockMessages on each change, which also uses chainRateLimitTokens=2. + // So each loop should be 4 tokens. + loops := 10 + tokensPerLoop := 4 + start := time.Now() + for i := 0; i < loops; i++ { + _, err := nodes.lite.ChainHead(ctx) + req.NoError(err) + } + tokensUsed := loops * tokensPerLoop + expectedEnd := start.Add(time.Duration(float64(tokensUsed) / float64(tokensPerSecond) * float64(time.Second))) + allowPad := time.Duration(float64(tokensPerLoop) / float64(tokensPerSecond) * float64(time.Second)) // add padding to account for slow test runs + t.Logf("expected end: %s, now: %s, allowPad: %s, actual delta: %s", expectedEnd, time.Now(), allowPad, time.Since(expectedEnd)) + req.WithinDuration(expectedEnd, time.Now(), allowPad) + + client := &http.Client{} + url := fmt.Sprintf("http://%s/rpc/v1", nodes.gatewayAddr) + jsonPayload := []byte(`{"method":"Filecoin.ChainHead","params":[],"id":1,"jsonrpc":"2.0"}`) + var failed bool + for i := 0; i < requestsPerMinute*2 && !failed; i++ { + func() { + request, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonPayload)) + req.NoError(err) + request.Header.Set("Content-Type", "application/json") + response, err := client.Do(request) + req.NoError(err) + defer func() { _ = response.Body.Close() }() + req.NoError(err) + if http.StatusOK == response.StatusCode { + body, err := io.ReadAll(response.Body) + req.NoError(err) + result := map[string]interface{}{} + req.NoError(json.Unmarshal(body, &result)) + req.NoError(err) + req.NotNil(result["result"]) + height, ok := result["result"].(map[string]interface{})["Height"].(float64) + req.True(ok) + req.Greater(int(height), 0) + } else { + req.Equal(http.StatusTooManyRequests, response.StatusCode) + req.LessOrEqual(i, requestsPerMinute+1) + failed = true + } + }() + } + req.True(failed, "expected requests to fail due to rate limiting") +}