diff --git a/peer/client.go b/peer/client.go index 0af5814eb0..41ebc7d6d9 100644 --- a/peer/client.go +++ b/peer/client.go @@ -66,16 +66,8 @@ func (c *client) SendAppRequestAny(ctx context.Context, minVersion *version.Appl if err != nil { return nil, nodeID, err } - - select { - case <-ctx.Done(): - return nil, nodeID, ctx.Err() - case response := <-waitingHandler.responseChan: - if waitingHandler.failed { - return nil, nodeID, ErrRequestFailed - } - return response, nodeID, nil - } + response, err := waitingHandler.WaitForResult(ctx) + return response, nodeID, err } // SendAppRequest synchronously sends request to the specified nodeID @@ -85,16 +77,7 @@ func (c *client) SendAppRequest(ctx context.Context, nodeID ids.NodeID, request if err := c.network.SendAppRequest(ctx, nodeID, request, waitingHandler); err != nil { return nil, err } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case response := <-waitingHandler.responseChan: - if waitingHandler.failed { - return nil, ErrRequestFailed - } - return response, nil - } + return waitingHandler.WaitForResult(ctx) } // SendCrossChainRequest synchronously sends request to the specified chainID @@ -104,15 +87,7 @@ func (c *client) SendCrossChainRequest(ctx context.Context, chainID ids.ID, requ if err := c.network.SendCrossChainRequest(ctx, chainID, request, waitingHandler); err != nil { return nil, err } - select { - case <-ctx.Done(): - return nil, ctx.Err() - case response := <-waitingHandler.responseChan: - if waitingHandler.failed { - return nil, ErrRequestFailed - } - return response, nil - } + return waitingHandler.WaitForResult(ctx) } func (c *client) Gossip(gossip []byte) error { diff --git a/peer/waiting_handler.go b/peer/waiting_handler.go index e6a7d9fd87..5e020e9861 100644 --- a/peer/waiting_handler.go +++ b/peer/waiting_handler.go @@ -4,6 +4,8 @@ package peer import ( + "context" + "github.com/ava-labs/subnet-evm/plugin/evm/message" ) @@ -18,6 +20,16 @@ type waitingResponseHandler struct { failed bool // whether the original request is failed } +// newWaitingResponseHandler returns new instance of the waitingResponseHandler +func newWaitingResponseHandler() *waitingResponseHandler { + return &waitingResponseHandler{ + // Make buffer length 1 so that OnResponse can complete + // even if no goroutine is waiting on the channel (i.e. + // the context of a request is cancelled.) + responseChan: make(chan []byte, 1), + } +} + // OnResponse passes the response bytes to the responseChan and closes the channel func (w *waitingResponseHandler) OnResponse(response []byte) error { w.responseChan <- response @@ -32,12 +44,14 @@ func (w *waitingResponseHandler) OnFailure() error { return nil } -// newWaitingResponseHandler returns new instance of the waitingResponseHandler -func newWaitingResponseHandler() *waitingResponseHandler { - return &waitingResponseHandler{ - // Make buffer length 1 so that OnResponse can complete - // even if no goroutine is waiting on the channel (i.e. - // the context of a request is cancelled.) - responseChan: make(chan []byte, 1), +func (waitingHandler *waitingResponseHandler) WaitForResult(ctx context.Context) ([]byte, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-waitingHandler.responseChan: + if waitingHandler.failed { + return nil, ErrRequestFailed + } + return response, nil } }