From 61be7dd9130af41517ad7b455f51f032400b0641 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 5 Jul 2023 12:25:25 +0300 Subject: [PATCH 01/35] Added websocket handler, common http handler, refactored rest, added subscribe_events route --- .../node_builder/access_node_builder.go | 77 +++++++---- cmd/observer/node_builder/observer_builder.go | 3 + engine/access/rest/events.go | 3 +- engine/access/rest/handler.go | 112 ++------------- engine/access/rest/http_handler.go | 130 ++++++++++++++++++ engine/access/rest/request/event_type.go | 58 ++++++++ engine/access/rest/request/get_events.go | 17 +-- engine/access/rest/request/request.go | 6 + .../access/rest/request/subscribe_events.go | 57 ++++++++ engine/access/rest/router.go | 34 ++++- engine/access/rest/server.go | 13 +- engine/access/rest/subscribe_events.go | 92 +++++++++++++ engine/access/rest/test_helpers.go | 10 +- engine/access/rest/websocket_handler.go | 74 ++++++++++ engine/access/rpc/engine.go | 17 ++- engine/access/state_stream/engine.go | 75 ++-------- engine/access/state_stream/handler.go | 56 +++----- engine/access/state_stream/mock/api.go | 19 ++- .../state_stream/backend.go | 34 ++++- .../state_stream/backend_events.go | 0 .../state_stream/backend_events_test.go | 4 +- .../state_stream/backend_executiondata.go | 0 .../backend_executiondata_test.go | 6 +- .../{access => common}/state_stream/event.go | 0 .../state_stream/event_test.go | 2 +- .../{access => common}/state_stream/filter.go | 0 .../state_stream/filter_test.go | 2 +- .../state_stream/streamer.go | 0 .../state_stream/streamer_test.go | 3 +- .../common/state_stream/subscribe_handler.go | 43 ++++++ .../state_stream/subscription.go | 0 .../state_stream/subscription_test.go | 3 +- go.mod | 2 +- integration/localnet/builder/bootstrap.go | 4 +- integration/testnet/network.go | 2 +- .../execution_data_requester_test.go | 3 +- .../jobs/execution_data_reader_test.go | 2 +- 37 files changed, 690 insertions(+), 273 deletions(-) create mode 100644 engine/access/rest/http_handler.go create mode 100644 engine/access/rest/request/event_type.go create mode 100644 engine/access/rest/request/subscribe_events.go create mode 100644 engine/access/rest/subscribe_events.go create mode 100644 engine/access/rest/websocket_handler.go rename engine/{access => common}/state_stream/backend.go (85%) rename engine/{access => common}/state_stream/backend_events.go (100%) rename engine/{access => common}/state_stream/backend_events_test.go (98%) rename engine/{access => common}/state_stream/backend_executiondata.go (100%) rename engine/{access => common}/state_stream/backend_executiondata_test.go (98%) rename engine/{access => common}/state_stream/event.go (100%) rename engine/{access => common}/state_stream/event_test.go (97%) rename engine/{access => common}/state_stream/filter.go (100%) rename engine/{access => common}/state_stream/filter_test.go (98%) rename engine/{access => common}/state_stream/streamer.go (100%) rename engine/{access => common}/state_stream/streamer_test.go (98%) create mode 100644 engine/common/state_stream/subscribe_handler.go rename engine/{access => common}/state_stream/subscription.go (100%) rename engine/{access => common}/state_stream/subscription_test.go (98%) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index 84aff969161..a8d8cc8d31d 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -35,6 +35,7 @@ import ( "github.com/onflow/flow-go/consensus/hotstuff/verification" recovery "github.com/onflow/flow-go/consensus/recovery/protocol" "github.com/onflow/flow-go/crypto" + "github.com/onflow/flow-go/engine" "github.com/onflow/flow-go/engine/access/ingestion" pingeng "github.com/onflow/flow-go/engine/access/ping" "github.com/onflow/flow-go/engine/access/rpc" @@ -42,6 +43,7 @@ import ( "github.com/onflow/flow-go/engine/access/state_stream" followereng "github.com/onflow/flow-go/engine/common/follower" "github.com/onflow/flow-go/engine/common/requester" + common_state_stream "github.com/onflow/flow-go/engine/common/state_stream" synceng "github.com/onflow/flow-go/engine/common/synchronization" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/model/flow/filter" @@ -115,7 +117,8 @@ type AccessNodeConfig struct { apiRatelimits map[string]int apiBurstlimits map[string]int rpcConf rpc.Config - stateStreamConf state_stream.Config + stateStreamConf common_state_stream.Config + stateStreamBackend common_state_stream.API stateStreamFilterConf map[string]int ExecutionNodeAddress string // deprecated HistoricalAccessRPCs []access.AccessAPIClient @@ -161,15 +164,16 @@ func DefaultAccessNodeConfig() *AccessNodeConfig { ArchiveAddressList: nil, MaxMsgSize: grpcutils.DefaultMaxMsgSize, }, - stateStreamConf: state_stream.Config{ + stateStreamConf: common_state_stream.Config{ MaxExecutionDataMsgSize: grpcutils.DefaultMaxMsgSize, - ExecutionDataCacheSize: state_stream.DefaultCacheSize, - ClientSendTimeout: state_stream.DefaultSendTimeout, - ClientSendBufferSize: state_stream.DefaultSendBufferSize, - MaxGlobalStreams: state_stream.DefaultMaxGlobalStreams, - EventFilterConfig: state_stream.DefaultEventFilterConfig, - ResponseLimit: state_stream.DefaultResponseLimit, + ExecutionDataCacheSize: common_state_stream.DefaultCacheSize, + ClientSendTimeout: common_state_stream.DefaultSendTimeout, + ClientSendBufferSize: common_state_stream.DefaultSendBufferSize, + MaxGlobalStreams: common_state_stream.DefaultMaxGlobalStreams, + EventFilterConfig: common_state_stream.DefaultEventFilterConfig, + ResponseLimit: common_state_stream.DefaultResponseLimit, }, + stateStreamBackend: nil, stateStreamFilterConf: nil, ExecutionNodeAddress: "localhost:9000", logTxTimeToFinalized: false, @@ -600,20 +604,35 @@ func (builder *FlowAccessNodeBuilder) BuildExecutionDataRequester() *FlowAccessN return nil, fmt.Errorf("could not get highest consecutive height: %w", err) } - stateStreamEng, err := state_stream.NewEng( + broadcaster := engine.NewBroadcaster() + + backend, err := common_state_stream.New( node.Logger, builder.stateStreamConf, - builder.ExecutionDataStore, - executionDataCache, node.State, node.Storage.Headers, node.Storage.Seals, node.Storage.Results, - node.RootChainID, + builder.ExecutionDataStore, + executionDataCache, + broadcaster, builder.executionDataConfig.InitialBlockHeight, highestAvailableHeight, + ) + if err != nil { + return nil, fmt.Errorf("could not create state stream backend: %w", err) + } + + stateStreamEng, err := state_stream.NewEng( + node.Logger, + builder.stateStreamConf, + executionDataCache, + node.Storage.Headers, + node.RootChainID, builder.apiRatelimits, builder.apiBurstlimits, + backend, + broadcaster, ) if err != nil { return nil, fmt.Errorf("could not create state stream engine: %w", err) @@ -656,7 +675,7 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { flags.UintVar(&builder.executionGRPCPort, "execution-ingress-port", defaultConfig.executionGRPCPort, "the grpc ingress port for all execution nodes") flags.StringVarP(&builder.rpcConf.UnsecureGRPCListenAddr, "rpc-addr", "r", defaultConfig.rpcConf.UnsecureGRPCListenAddr, "the address the unsecured gRPC server listens on") flags.StringVar(&builder.rpcConf.SecureGRPCListenAddr, "secure-rpc-addr", defaultConfig.rpcConf.SecureGRPCListenAddr, "the address the secure gRPC server listens on") - flags.StringVar(&builder.stateStreamConf.ListenAddr, "state-stream-addr", defaultConfig.stateStreamConf.ListenAddr, "the address the state stream server listens on (if empty the server will not be started)") + flags.StringVar(&builder.stateStreamConf.ListenAddr, "state_stream-addr", defaultConfig.stateStreamConf.ListenAddr, "the address the state stream server listens on (if empty the server will not be started)") flags.StringVarP(&builder.rpcConf.HTTPListenAddr, "http-addr", "h", defaultConfig.rpcConf.HTTPListenAddr, "the address the http proxy server listens on") flags.StringVar(&builder.rpcConf.RESTListenAddr, "rest-addr", defaultConfig.rpcConf.RESTListenAddr, "the address the REST server listens on (if empty the REST server will not be started)") flags.StringVarP(&builder.rpcConf.CollectionAddr, "static-collection-ingress-addr", "", defaultConfig.rpcConf.CollectionAddr, "the address (of the collection node) to send transactions to") @@ -694,12 +713,12 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { // Execution State Streaming API flags.Uint32Var(&builder.stateStreamConf.ExecutionDataCacheSize, "execution-data-cache-size", defaultConfig.stateStreamConf.ExecutionDataCacheSize, "block execution data cache size") - flags.Uint32Var(&builder.stateStreamConf.MaxGlobalStreams, "state-stream-global-max-streams", defaultConfig.stateStreamConf.MaxGlobalStreams, "global maximum number of concurrent streams") - flags.UintVar(&builder.stateStreamConf.MaxExecutionDataMsgSize, "state-stream-max-message-size", defaultConfig.stateStreamConf.MaxExecutionDataMsgSize, "maximum size for a gRPC message containing block execution data") - flags.StringToIntVar(&builder.stateStreamFilterConf, "state-stream-event-filter-limits", defaultConfig.stateStreamFilterConf, "event filter limits for ExecutionData SubscribeEvents API e.g. EventTypes=100,Addresses=100,Contracts=100 etc.") - flags.DurationVar(&builder.stateStreamConf.ClientSendTimeout, "state-stream-send-timeout", defaultConfig.stateStreamConf.ClientSendTimeout, "maximum wait before timing out while sending a response to a streaming client e.g. 30s") - flags.UintVar(&builder.stateStreamConf.ClientSendBufferSize, "state-stream-send-buffer-size", defaultConfig.stateStreamConf.ClientSendBufferSize, "maximum number of responses to buffer within a stream") - flags.Float64Var(&builder.stateStreamConf.ResponseLimit, "state-stream-response-limit", defaultConfig.stateStreamConf.ResponseLimit, "max number of responses per second to send over streaming endpoints. this helps manage resources consumed by each client querying data not in the cache e.g. 3 or 0.5. 0 means no limit") + flags.Uint32Var(&builder.stateStreamConf.MaxGlobalStreams, "state_stream-global-max-streams", defaultConfig.stateStreamConf.MaxGlobalStreams, "global maximum number of concurrent streams") + flags.UintVar(&builder.stateStreamConf.MaxExecutionDataMsgSize, "state_stream-max-message-size", defaultConfig.stateStreamConf.MaxExecutionDataMsgSize, "maximum size for a gRPC message containing block execution data") + flags.StringToIntVar(&builder.stateStreamFilterConf, "state_stream-event-filter-limits", defaultConfig.stateStreamFilterConf, "event filter limits for ExecutionData SubscribeEvents API e.g. EventTypes=100,Addresses=100,Contracts=100 etc.") + flags.DurationVar(&builder.stateStreamConf.ClientSendTimeout, "state_stream-send-timeout", defaultConfig.stateStreamConf.ClientSendTimeout, "maximum wait before timing out while sending a response to a streaming client e.g. 30s") + flags.UintVar(&builder.stateStreamConf.ClientSendBufferSize, "state_stream-send-buffer-size", defaultConfig.stateStreamConf.ClientSendBufferSize, "maximum number of responses to buffer within a stream") + flags.Float64Var(&builder.stateStreamConf.ResponseLimit, "state_stream-response-limit", defaultConfig.stateStreamConf.ResponseLimit, "max number of responses per second to send over streaming endpoints. this helps manage resources consumed by each client querying data not in the cache e.g. 3 or 0.5. 0 means no limit") }).ValidateFlags(func() error { if builder.supportsObserver && (builder.PublicNetworkConfig.BindAddress == cmd.NotSet || builder.PublicNetworkConfig.BindAddress == "") { return errors.New("public-network-address must be set if supports-observer is true") @@ -726,23 +745,23 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { return errors.New("execution-data-cache-size must be greater than 0") } if builder.stateStreamConf.ClientSendBufferSize == 0 { - return errors.New("state-stream-send-buffer-size must be greater than 0") + return errors.New("state_stream-send-buffer-size must be greater than 0") } if len(builder.stateStreamFilterConf) > 3 { - return errors.New("state-stream-event-filter-limits must have at most 3 keys (EventTypes, Addresses, Contracts)") + return errors.New("state_stream-event-filter-limits must have at most 3 keys (EventTypes, Addresses, Contracts)") } for key, value := range builder.stateStreamFilterConf { switch key { case "EventTypes", "Addresses", "Contracts": if value <= 0 { - return fmt.Errorf("state-stream-event-filter-limits %s must be greater than 0", key) + return fmt.Errorf("state_stream-event-filter-limits %s must be greater than 0", key) } default: - return errors.New("state-stream-event-filter-limits may only contain the keys EventTypes, Addresses, Contracts") + return errors.New("state_stream-event-filter-limits may only contain the keys EventTypes, Addresses, Contracts") } } if builder.stateStreamConf.ResponseLimit < 0 { - return errors.New("state-stream-response-limit must be greater than or equal to 0") + return errors.New("state_stream-response-limit must be greater than or equal to 0") } } @@ -892,6 +911,10 @@ func (builder *FlowAccessNodeBuilder) enqueueRelayNetwork() { } func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { + if builder.executionDataSyncEnabled { + builder.BuildExecutionDataRequester() + } + builder. BuildConsensusFollower(). Module("collection node client", func(node *cmd.NodeConfig) error { @@ -1008,6 +1031,8 @@ func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { builder.apiRatelimits, builder.apiBurstlimits, builder.Me, + builder.stateStreamBackend, + builder.stateStreamConf, ) if err != nil { return nil, err @@ -1093,10 +1118,6 @@ func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { }) } - if builder.executionDataSyncEnabled { - builder.BuildExecutionDataRequester() - } - builder.Component("ping engine", func(node *cmd.NodeConfig) (module.ReadyDoneAware, error) { ping, err := pingeng.New( node.Logger, diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index c0a6d62b2b4..41582e9f24a 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -31,6 +31,7 @@ import ( "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" "github.com/onflow/flow-go/engine/common/follower" + "github.com/onflow/flow-go/engine/common/state_stream" synceng "github.com/onflow/flow-go/engine/common/synchronization" "github.com/onflow/flow-go/engine/protocol" "github.com/onflow/flow-go/model/encodable" @@ -862,6 +863,8 @@ func (builder *ObserverServiceBuilder) enqueueRPCServer() { builder.apiRatelimits, builder.apiBurstlimits, builder.Me, + nil, + state_stream.Config{}, ) if err != nil { return nil, err diff --git a/engine/access/rest/events.go b/engine/access/rest/events.go index 2a79939bc21..44ad933affa 100644 --- a/engine/access/rest/events.go +++ b/engine/access/rest/events.go @@ -3,10 +3,9 @@ package rest import ( "fmt" + "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/models" "github.com/onflow/flow-go/engine/access/rest/request" - - "github.com/onflow/flow-go/access" ) const blockQueryParam = "block_ids" diff --git a/engine/access/rest/handler.go b/engine/access/rest/handler.go index 028176fc9e0..a4c6e7569cf 100644 --- a/engine/access/rest/handler.go +++ b/engine/access/rest/handler.go @@ -1,26 +1,17 @@ package rest import ( - "encoding/json" - "errors" - "fmt" "net/http" + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/models" "github.com/onflow/flow-go/engine/access/rest/request" "github.com/onflow/flow-go/engine/access/rest/util" - fvmErrors "github.com/onflow/flow-go/fvm/errors" "github.com/onflow/flow-go/model/flow" - - "github.com/rs/zerolog" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/onflow/flow-go/access" ) -const MaxRequestSize = 2 << 20 // 2MB - // ApiHandlerFunc is a function that contains endpoint handling logic, // it fetches necessary resources and returns an error or response model. type ApiHandlerFunc func( @@ -33,11 +24,10 @@ type ApiHandlerFunc func( // Handler function allows easier handling of errors and responses as it // wraps functionality for handling error and responses outside of endpoint handling. type Handler struct { - logger zerolog.Logger + *HttpHandler backend access.API linkGenerator models.LinkGenerator apiHandlerFunc ApiHandlerFunc - chain flow.Chain } func NewHandler( @@ -47,31 +37,26 @@ func NewHandler( generator models.LinkGenerator, chain flow.Chain, ) *Handler { - return &Handler{ - logger: logger, + handler := &Handler{ backend: backend, apiHandlerFunc: handlerFunc, linkGenerator: generator, - chain: chain, } + handler.HttpHandler = NewHttpHandler(logger, chain) + return handler } // ServerHTTP function acts as a wrapper to each request providing common handling functionality // such as logging, error handling, request decorators func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // create a logger - errLog := h.logger.With().Str("request_url", r.URL.String()).Logger() + errLog := h.Logger.With().Str("request_url", r.URL.String()).Logger() - // limit requested body size - r.Body = http.MaxBytesReader(w, r.Body, MaxRequestSize) - err := r.ParseForm() + err := h.VerifyRequest(w, r) if err != nil { - h.errorHandler(w, err, errLog) return } - - // create request decorator with parsed values - decoratedRequest := request.Decorate(r, h.chain) + decoratedRequest := request.Decorate(r, h.Chain) // execute handler function and check for error response, err := h.apiHandlerFunc(decoratedRequest, h.backend, h.linkGenerator) @@ -90,80 +75,3 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // write response to response stream h.jsonResponse(w, http.StatusOK, response, errLog) } - -func (h *Handler) errorHandler(w http.ResponseWriter, err error, errorLogger zerolog.Logger) { - // rest status type error should be returned with status and user message provided - var statusErr StatusError - if errors.As(err, &statusErr) { - h.errorResponse(w, statusErr.Status(), statusErr.UserMessage(), errorLogger) - return - } - - // handle cadence errors - cadenceError := fvmErrors.Find(err, fvmErrors.ErrCodeCadenceRunTimeError) - if cadenceError != nil { - msg := fmt.Sprintf("Cadence error: %s", cadenceError.Error()) - h.errorResponse(w, http.StatusBadRequest, msg, errorLogger) - return - } - - // handle grpc status error returned from the backend calls, we are forwarding the message to the client - if se, ok := status.FromError(err); ok { - if se.Code() == codes.NotFound { - msg := fmt.Sprintf("Flow resource not found: %s", se.Message()) - h.errorResponse(w, http.StatusNotFound, msg, errorLogger) - return - } - if se.Code() == codes.InvalidArgument { - msg := fmt.Sprintf("Invalid Flow argument: %s", se.Message()) - h.errorResponse(w, http.StatusBadRequest, msg, errorLogger) - return - } - if se.Code() == codes.Internal { - msg := fmt.Sprintf("Invalid Flow request: %s", se.Message()) - h.errorResponse(w, http.StatusBadRequest, msg, errorLogger) - return - } - } - - // stop going further - catch all error - msg := "internal server error" - errorLogger.Error().Err(err).Msg(msg) - h.errorResponse(w, http.StatusInternalServerError, msg, errorLogger) -} - -// jsonResponse builds a JSON response and send it to the client -func (h *Handler) jsonResponse(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger) { - w.Header().Set("Content-Type", "application/json; charset=UTF-8") - - // serialize response to JSON and handler errors - encodedResponse, err := json.MarshalIndent(response, "", "\t") - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - errLogger.Error().Err(err).Str("response", string(encodedResponse)).Msg("failed to indent response") - return - } - - w.WriteHeader(code) - // write response to response stream - _, err = w.Write(encodedResponse) - if err != nil { - errLogger.Error().Err(err).Str("response", string(encodedResponse)).Msg("failed to write http response") - } -} - -// errorResponse sends an HTTP error response to the client with the given return code -// and a model error with the given response message in the response body -func (h *Handler) errorResponse( - w http.ResponseWriter, - returnCode int, - responseMessage string, - logger zerolog.Logger, -) { - // create error response model - modelError := models.ModelError{ - Code: int32(returnCode), - Message: responseMessage, - } - h.jsonResponse(w, returnCode, modelError, logger) -} diff --git a/engine/access/rest/http_handler.go b/engine/access/rest/http_handler.go new file mode 100644 index 00000000000..a9fb0a79b94 --- /dev/null +++ b/engine/access/rest/http_handler.go @@ -0,0 +1,130 @@ +package rest + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/rs/zerolog" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/onflow/flow-go/engine/access/rest/models" + fvmErrors "github.com/onflow/flow-go/fvm/errors" + "github.com/onflow/flow-go/model/flow" +) + +const MaxRequestSize = 2 << 20 // 2MB + +// HttpHandler is custom http handler implementing custom handler function. +// HttpHandler function allows easier handling of errors and responses as it +// wraps functionality for handling error and responses outside of endpoint handling. +type HttpHandler struct { + Logger zerolog.Logger + Chain flow.Chain +} + +func NewHttpHandler( + logger zerolog.Logger, + chain flow.Chain, +) *HttpHandler { + return &HttpHandler{ + Logger: logger, + Chain: chain, + } +} + +// VerifyRequest function acts as a wrapper to each request providing common handling functionality +// such as logging, error handling +func (h *HttpHandler) VerifyRequest(w http.ResponseWriter, r *http.Request) error { + // create a logger + errLog := h.Logger.With().Str("request_url", r.URL.String()).Logger() + + // limit requested body size + r.Body = http.MaxBytesReader(w, r.Body, MaxRequestSize) + err := r.ParseForm() + if err != nil { + h.errorHandler(w, err, errLog) + return err + } + return nil +} + +func (h *HttpHandler) errorHandler(w http.ResponseWriter, err error, errorLogger zerolog.Logger) { + // rest status type error should be returned with status and user message provided + var statusErr StatusError + if errors.As(err, &statusErr) { + h.errorResponse(w, statusErr.Status(), statusErr.UserMessage(), errorLogger) + return + } + + // handle cadence errors + cadenceError := fvmErrors.Find(err, fvmErrors.ErrCodeCadenceRunTimeError) + if cadenceError != nil { + msg := fmt.Sprintf("Cadence error: %s", cadenceError.Error()) + h.errorResponse(w, http.StatusBadRequest, msg, errorLogger) + return + } + + // handle grpc status error returned from the backend calls, we are forwarding the message to the client + if se, ok := status.FromError(err); ok { + if se.Code() == codes.NotFound { + msg := fmt.Sprintf("Flow resource not found: %s", se.Message()) + h.errorResponse(w, http.StatusNotFound, msg, errorLogger) + return + } + if se.Code() == codes.InvalidArgument { + msg := fmt.Sprintf("Invalid Flow argument: %s", se.Message()) + h.errorResponse(w, http.StatusBadRequest, msg, errorLogger) + return + } + if se.Code() == codes.Internal { + msg := fmt.Sprintf("Invalid Flow request: %s", se.Message()) + h.errorResponse(w, http.StatusBadRequest, msg, errorLogger) + return + } + } + + // stop going further - catch all error + msg := "internal server error" + errorLogger.Error().Err(err).Msg(msg) + h.errorResponse(w, http.StatusInternalServerError, msg, errorLogger) +} + +// jsonResponse builds a JSON response and send it to the client +func (h *HttpHandler) jsonResponse(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + + // serialize response to JSON and handler errors + encodedResponse, err := json.MarshalIndent(response, "", "\t") + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + errLogger.Error().Err(err).Str("response", string(encodedResponse)).Msg("failed to indent response") + return + } + + w.WriteHeader(code) + // write response to response stream + _, err = w.Write(encodedResponse) + if err != nil { + errLogger.Error().Err(err).Str("response", string(encodedResponse)).Msg("failed to write http response") + } +} + +// errorResponse sends an HTTP error response to the client with the given return code +// and a model error with the given response message in the response body +func (h *HttpHandler) errorResponse( + w http.ResponseWriter, + returnCode int, + responseMessage string, + logger zerolog.Logger, +) { + // create error response model + modelError := models.ModelError{ + Code: int32(returnCode), + Message: responseMessage, + } + h.jsonResponse(w, returnCode, modelError, logger) +} diff --git a/engine/access/rest/request/event_type.go b/engine/access/rest/request/event_type.go new file mode 100644 index 00000000000..b1c99978a86 --- /dev/null +++ b/engine/access/rest/request/event_type.go @@ -0,0 +1,58 @@ +package request + +import ( + "fmt" + "regexp" +) + +type EventType string + +func (e *EventType) Parse(raw string) error { + basic, _ := regexp.MatchString(`[A-Z]\.[a-f0-9]{16}\.[\w+]*\.[\w+]*`, raw) + // match core events flow.event + core, _ := regexp.MatchString(`flow\.[\w]*`, raw) + if !core && !basic { + return fmt.Errorf("invalid event type format") + } + *e = EventType(raw) + return nil +} + +func (e EventType) Flow() string { + return string(e) +} + +type EventTypes []EventType + +func (e *EventTypes) Parse(raw []string) error { + if len(raw) > MaxIDsLength { + return fmt.Errorf("at most %d event types can be requested at a time", MaxIDsLength) + } + + // make a map to have only unique values as keys + eventTypes := make(EventTypes, 0) + uniqueTypes := make(map[string]bool) + for _, r := range raw { + var eType EventType + err := eType.Parse(r) + if err != nil { + return err + } + + if !uniqueTypes[eType.Flow()] { + uniqueTypes[eType.Flow()] = true + eventTypes = append(eventTypes, eType) + } + } + + *e = eventTypes + return nil +} + +func (e EventTypes) Flow() []string { + eventTypes := make([]string, len(e)) + for j, eType := range e { + eventTypes[j] = eType.Flow() + } + return eventTypes +} diff --git a/engine/access/rest/request/get_events.go b/engine/access/rest/request/get_events.go index db4839343a1..b9adcbe0f19 100644 --- a/engine/access/rest/request/get_events.go +++ b/engine/access/rest/request/get_events.go @@ -2,7 +2,6 @@ package request import ( "fmt" - "regexp" "github.com/onflow/flow-go/model/flow" ) @@ -57,19 +56,15 @@ func (g *GetEvents) Parse(rawType string, rawStart string, rawEnd string, rawBlo return fmt.Errorf("must provide either block IDs or start and end height range") } - g.Type = rawType - if g.Type == "" { + if rawType == "" { return fmt.Errorf("event type must be provided") } - - // match basic format A.address.contract.event (ignore err since regex will always compile) - basic, _ := regexp.MatchString(`[A-Z]\.[a-f0-9]{16}\.[\w+]*\.[\w+]*`, g.Type) - // match core events flow.event - core, _ := regexp.MatchString(`flow\.[\w]*`, g.Type) - - if !core && !basic { - return fmt.Errorf("invalid event type format") + var eventType EventType + err = eventType.Parse(rawType) + if err != nil { + return err } + g.Type = eventType.Flow() // validate start end height option if g.StartHeight != EmptyHeight && g.EndHeight != EmptyHeight { diff --git a/engine/access/rest/request/request.go b/engine/access/rest/request/request.go index b7500206fac..0aed3c0dff6 100644 --- a/engine/access/rest/request/request.go +++ b/engine/access/rest/request/request.go @@ -90,6 +90,12 @@ func (rd *Request) CreateTransactionRequest() (CreateTransaction, error) { return req, err } +func (rd *Request) SubscribeEventsRequest() (SubscribeEvents, error) { + var req SubscribeEvents + err := req.Build(rd) + return req, err +} + func (rd *Request) Expands(field string) bool { return rd.ExpandFields[field] } diff --git a/engine/access/rest/request/subscribe_events.go b/engine/access/rest/request/subscribe_events.go new file mode 100644 index 00000000000..d10efcf76bc --- /dev/null +++ b/engine/access/rest/request/subscribe_events.go @@ -0,0 +1,57 @@ +package request + +import ( + "fmt" + + "github.com/onflow/flow-go/model/flow" +) + +const startBlockIdQuery = "start_block_id" +const eventTypesQuery = "event_types" + +type SubscribeEvents struct { + StartBlockID flow.Identifier + StartHeight uint64 + + EventTypes []string + //Uliana: TODO: add events filter - contracts, addresses +} + +func (g *SubscribeEvents) Build(r *Request) error { + return g.Parse( + r.GetQueryParam(startBlockIdQuery), + r.GetQueryParam(startHeightQuery), + r.GetQueryParams(eventTypesQuery), + ) +} + +func (g *SubscribeEvents) Parse(rawBlockID string, rawStart string, rawTypes []string) error { + var height Height + err := height.Parse(rawStart) + if err != nil { + return fmt.Errorf("invalid start height: %w", err) + } + g.StartHeight = height.Flow() + + var startBlockID ID + err = startBlockID.Parse(rawBlockID) + if err != nil { + return err + } + g.StartBlockID = startBlockID.Flow() + + // if both height and one or both of start and end height are provided + if len(startBlockID) > 0 && g.StartHeight != EmptyHeight { + return fmt.Errorf("can only provide either block ID or start height range") + } + + var eventTypes EventTypes + err = eventTypes.Parse(rawTypes) + if err != nil { + return err + } + + g.EventTypes = eventTypes.Flow() + + return nil +} diff --git a/engine/access/rest/router.go b/engine/access/rest/router.go index f51f1c65f3e..afbcd252ffa 100644 --- a/engine/access/rest/router.go +++ b/engine/access/rest/router.go @@ -12,11 +12,19 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/middleware" "github.com/onflow/flow-go/engine/access/rest/models" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" ) -func newRouter(backend access.API, logger zerolog.Logger, chain flow.Chain, restCollector module.RestMetrics) (*mux.Router, error) { +func newRouter(backend access.API, + logger zerolog.Logger, + chain flow.Chain, + restCollector module.RestMetrics, + stateStreamApi state_stream.API, + conf state_stream.EventFilterConfig, + maxGlobalStreams uint32, +) (*mux.Router, error) { router := mux.NewRouter().StrictSlash(true) v1SubRouter := router.PathPrefix("/v1").Subrouter() @@ -36,6 +44,16 @@ func newRouter(backend access.API, logger zerolog.Logger, chain flow.Chain, rest Name(r.Name). Handler(h) } + + for _, r := range WSRoutes { + h := NewWSHandler(logger, r.Handler, chain, stateStreamApi, conf, maxGlobalStreams) + v1SubRouter. + Methods(r.Method). + Path(r.Pattern). + Name(r.Name). + Handler(h) + } + return router, nil } @@ -46,6 +64,13 @@ type route struct { Handler ApiHandlerFunc } +type wsroute struct { + Name string + Method string + Pattern string + Handler SubscribeHandlerFunc +} + var Routes = []route{{ Method: http.MethodGet, Pattern: "/transactions/{id}", @@ -118,6 +143,13 @@ var Routes = []route{{ Handler: GetNodeVersionInfo, }} +var WSRoutes = []wsroute{{ + Method: http.MethodPost, + Pattern: "/subscribe_events", + Name: "subscribeEvents", + Handler: SubscribeEvents, +}} + var routeUrlMap = map[string]string{} var routeRE = regexp.MustCompile(`(?i)/v1/(\w+)(/(\w+)(/(\w+))?)?`) diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index a1aa83710d8..51322246215 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -8,14 +8,23 @@ import ( "github.com/rs/zerolog" "github.com/onflow/flow-go/access" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" ) // NewServer returns an HTTP server initialized with the REST API handler -func NewServer(backend access.API, listenAddress string, logger zerolog.Logger, chain flow.Chain, restCollector module.RestMetrics) (*http.Server, error) { +func NewServer(backend access.API, + listenAddress string, + logger zerolog.Logger, + chain flow.Chain, + restCollector module.RestMetrics, + api state_stream.API, + conf state_stream.EventFilterConfig, + maxGlobalStreams uint32, +) (*http.Server, error) { - router, err := newRouter(backend, logger, chain, restCollector) + router, err := newRouter(backend, logger, chain, restCollector, api, conf, maxGlobalStreams) if err != nil { return nil, err } diff --git a/engine/access/rest/subscribe_events.go b/engine/access/rest/subscribe_events.go new file mode 100644 index 00000000000..b437cefc4e3 --- /dev/null +++ b/engine/access/rest/subscribe_events.go @@ -0,0 +1,92 @@ +package rest + +import ( + "fmt" + "net/http" + + "github.com/gorilla/websocket" + + "github.com/onflow/flow-go/engine/access/rest/request" + "github.com/onflow/flow-go/engine/common/rpc/convert" + "github.com/onflow/flow-go/engine/common/state_stream" + + executiondata "github.com/onflow/flow/protobuf/go/flow/executiondata" +) + +func SubscribeEvents(r *request.Request, w http.ResponseWriter, h *state_stream.SubscribeHandler) (interface{}, error) { + req, err := r.SubscribeEventsRequest() + if err != nil { + return nil, NewBadRequestError(err) + } + + // Upgrade the HTTP connection to a WebSocket connection + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r.Request, nil) + if err != nil { + err = fmt.Errorf("webSocket upgrade error: %s", err) + return nil, err + } + defer conn.Close() + + var filter state_stream.EventFilter + // Retrieve the filter parameters from the request, if provided + + emptyString := make([]string, 0) //Uliana: TODO: remove it + filter, err = state_stream.NewEventFilter( + h.EventFilterConfig, + r.Chain, + req.EventTypes, + emptyString, //Uliana: TODO: parsed addresses + emptyString, //Uliana: TODO: parsed contracts + ) + if err != nil { + err = fmt.Errorf("invalid event filter: %s", err) + return nil, err + } + + sub, err := h.SubscribeEvents(r.Context(), req.StartBlockID, req.StartHeight, filter) + + // Write messages to the WebSocket connection + writeToWebSocket := func(resp *state_stream.EventsResponse) error { + // Prepare the response message + response := &executiondata.SubscribeEventsResponse{ + BlockHeight: resp.Height, + BlockId: convert.IdentifierToMessage(resp.BlockID), + Events: convert.EventsToMessages(resp.Events), + } + + // Send the response message over the WebSocket connection + return conn.WriteJSON(response) + } + + for { + //select { + //case + v, ok := <-sub.Channel() + if !ok { + if sub.Err() != nil { + err = fmt.Errorf("stream encountered an error: %w", sub.Err()) + return nil, err + } + return nil, err + } + + resp, ok := v.(*state_stream.EventsResponse) + if !ok { + err = fmt.Errorf("unexpected response type: %s", v) + return nil, err + } + + // Write the response to the WebSocket connection + err := writeToWebSocket(resp) + if err != nil { + err = fmt.Errorf("failed to send response: %w", err) + return nil, err + } + + //case <-conn.ReadyState(): + // // WebSocket connection closed or in a non-writable state + // return + //} + } +} diff --git a/engine/access/rest/test_helpers.go b/engine/access/rest/test_helpers.go index 88170769c99..c188c8a0ebe 100644 --- a/engine/access/rest/test_helpers.go +++ b/engine/access/rest/test_helpers.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/onflow/flow-go/access/mock" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/metrics" ) @@ -30,7 +31,14 @@ func executeRequest(req *http.Request, backend *mock.API) (*httptest.ResponseRec var b bytes.Buffer logger := zerolog.New(&b) restCollector := metrics.NewNoopCollector() - router, err := newRouter(backend, logger, flow.Testnet.Chain(), restCollector) + stateStreamConfig := state_stream.Config{} //Uliana: TODO: add test for subscribe_events, provide state_stream backend + router, err := newRouter(backend, + logger, + flow.Testnet.Chain(), + restCollector, + nil, + stateStreamConfig.EventFilterConfig, + stateStreamConfig.MaxGlobalStreams) if err != nil { return nil, err } diff --git a/engine/access/rest/websocket_handler.go b/engine/access/rest/websocket_handler.go new file mode 100644 index 00000000000..3f93751d0da --- /dev/null +++ b/engine/access/rest/websocket_handler.go @@ -0,0 +1,74 @@ +package rest + +import ( + "net/http" + + "github.com/rs/zerolog" + + "github.com/onflow/flow-go/engine/access/rest/request" + "github.com/onflow/flow-go/engine/access/rest/util" + "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/model/flow" +) + +// SubscribeHandlerFunc is a function that contains endpoint handling logic for subscribes, +// it fetches necessary resources and returns an error. +type SubscribeHandlerFunc func( + r *request.Request, + w http.ResponseWriter, + h *state_stream.SubscribeHandler, +) (interface{}, error) + +// WSHandler is custom http handler implementing custom handler function. +// Handler function allows easier handling of errors and responses as it +// wraps functionality for handling error and responses outside of endpoint handling. +type WSHandler struct { + *HttpHandler + *state_stream.SubscribeHandler + subscribeFunc SubscribeHandlerFunc +} + +func NewWSHandler( + logger zerolog.Logger, + subscribeFunc SubscribeHandlerFunc, + chain flow.Chain, + api state_stream.API, + conf state_stream.EventFilterConfig, + maxGlobalStreams uint32, +) *WSHandler { + handler := &WSHandler{ + subscribeFunc: subscribeFunc, + } + handler.HttpHandler = NewHttpHandler(logger, chain) + handler.SubscribeHandler = state_stream.NewSubscribeHandler(api, chain, conf, maxGlobalStreams) + return handler +} + +// ServerHTTP function acts as a wrapper to each request providing common handling functionality +// such as logging, error handling, request decorators +func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // create a logger + errLog := h.Logger.With().Str("request_url", r.URL.String()).Logger() + + err := h.VerifyRequest(w, r) + if err != nil { + return + } + decoratedRequest := request.Decorate(r, h.HttpHandler.Chain) + + response, err := h.subscribeFunc(decoratedRequest, w, h.SubscribeHandler) + if err != nil { + h.errorHandler(w, err, errLog) + return + } + + // apply the select filter if any select fields have been specified + response, err = util.SelectFilter(response, decoratedRequest.Selects()) + if err != nil { + h.errorHandler(w, err, errLog) + return + } + + // write response to response stream + h.jsonResponse(w, http.StatusOK, response, errLog) +} diff --git a/engine/access/rpc/engine.go b/engine/access/rpc/engine.go index d4c812df997..82947be9f47 100644 --- a/engine/access/rpc/engine.go +++ b/engine/access/rpc/engine.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/onflow/flow-go/engine/common/state_stream" "net" "net/http" "sync" @@ -74,6 +75,9 @@ type Engine struct { unsecureGrpcAddress net.Addr secureGrpcAddress net.Addr restAPIAddress net.Addr + + stateStreamBackend state_stream.API + stateStreamConfig state_stream.Config } // NewBuilder returns a new RPC engine builder. @@ -97,6 +101,8 @@ func NewBuilder(log zerolog.Logger, apiRatelimits map[string]int, // the api rate limit (max calls per second) for each of the Access API e.g. Ping->100, GetTransaction->300 apiBurstLimits map[string]int, // the api burst limit (max calls at the same time) for each of the Access API e.g. Ping->50, GetTransaction->10 me module.Local, + stateStreamBackend state_stream.API, + stateStreamConfig state_stream.Config, ) (*RPCEngineBuilder, error) { log = log.With().Str("engine", "rpc").Logger() @@ -210,6 +216,8 @@ func NewBuilder(log zerolog.Logger, config: config, chain: chainID.Chain(), restCollector: accessMetrics, + stateStreamBackend: stateStreamBackend, + stateStreamConfig: stateStreamConfig, } backendNotifierActor, backendNotifierWorker := events.NewFinalizationActor(eng.notifyBackendOnBlockFinalized) eng.backendNotifierActor = backendNotifierActor @@ -384,7 +392,14 @@ func (e *Engine) serveREST(ctx irrecoverable.SignalerContext, ready component.Re e.log.Info().Str("rest_api_address", e.config.RESTListenAddr).Msg("starting REST server on address") - r, err := rest.NewServer(e.backend, e.config.RESTListenAddr, e.log, e.chain, e.restCollector) + r, err := rest.NewServer(e.backend, + e.config.RESTListenAddr, + e.log, + e.chain, + e.restCollector, + e.stateStreamBackend, + e.stateStreamConfig.EventFilterConfig, + e.stateStreamConfig.MaxGlobalStreams) if err != nil { e.log.Err(err).Msg("failed to initialize the REST server") ctx.Throw(err) diff --git a/engine/access/state_stream/engine.go b/engine/access/state_stream/engine.go index e993d6cbece..f21a64f3af8 100644 --- a/engine/access/state_stream/engine.go +++ b/engine/access/state_stream/engine.go @@ -3,56 +3,26 @@ package state_stream import ( "fmt" "net" - "time" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - access "github.com/onflow/flow/protobuf/go/flow/executiondata" + "github.com/rs/zerolog" + "google.golang.org/grpc" "github.com/onflow/flow-go/engine" "github.com/onflow/flow-go/engine/common/rpc" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/component" "github.com/onflow/flow-go/module/executiondatasync/execution_data" "github.com/onflow/flow-go/module/executiondatasync/execution_data/cache" "github.com/onflow/flow-go/module/irrecoverable" - "github.com/onflow/flow-go/state/protocol" "github.com/onflow/flow-go/storage" "github.com/onflow/flow-go/utils/logging" -) - -// Config defines the configurable options for the ingress server. -type Config struct { - EventFilterConfig - - // ListenAddr is the address the GRPC server will listen on as host:port - ListenAddr string - - // MaxExecutionDataMsgSize is the max message size for block execution data API - MaxExecutionDataMsgSize uint - - // RpcMetricsEnabled specifies whether to enable the GRPC metrics - RpcMetricsEnabled bool - - // MaxGlobalStreams defines the global max number of streams that can be open at the same time. - MaxGlobalStreams uint32 - // ExecutionDataCacheSize is the max number of objects for the execution data cache. - ExecutionDataCacheSize uint32 - - // ClientSendTimeout is the timeout for sending a message to the client. After the timeout, - // the stream is closed with an error. - ClientSendTimeout time.Duration - - // ClientSendBufferSize is the size of the response buffer for sending messages to the client. - ClientSendBufferSize uint - - // ResponseLimit is the max responses per second allowed on a stream. After exceeding the limit, - // the stream is paused until more capacity is available. Searches of past data can be CPU - // intensive, so this helps manage the impact. - ResponseLimit float64 -} + access "github.com/onflow/flow/protobuf/go/flow/executiondata" +) // Engine exposes the server with the state stream API. // By default, this engine is not enabled. @@ -60,9 +30,9 @@ type Config struct { type Engine struct { *component.ComponentManager log zerolog.Logger - backend *StateStreamBackend + backend *state_stream.StateStreamBackend server *grpc.Server - config Config + config state_stream.Config chain flow.Chain handler *Handler @@ -76,18 +46,14 @@ type Engine struct { // NewEng returns a new ingress server. func NewEng( log zerolog.Logger, - config Config, - execDataStore execution_data.ExecutionDataStore, + config state_stream.Config, execDataCache *cache.ExecutionDataCache, - state protocol.State, headers storage.Headers, - seals storage.Seals, - results storage.ExecutionResults, chainID flow.ChainID, - initialBlockHeight uint64, - highestBlockHeight uint64, apiRatelimits map[string]int, // the api rate limit (max calls per second) for each of the gRPC API e.g. Ping->100, GetExecutionDataByBlockID->300 apiBurstLimits map[string]int, // the api burst limit (max calls at the same time) for each of the gRPC API e.g. Ping->50, GetExecutionDataByBlockID->10 + backend *state_stream.StateStreamBackend, + broadcaster *engine.Broadcaster, ) (*Engine, error) { logger := log.With().Str("engine", "state_stream_rpc").Logger() @@ -119,25 +85,6 @@ func NewEng( server := grpc.NewServer(grpcOpts...) - broadcaster := engine.NewBroadcaster() - - backend, err := New( - logger, - config, - state, - headers, - seals, - results, - execDataStore, - execDataCache, - broadcaster, - initialBlockHeight, - highestBlockHeight, - ) - if err != nil { - return nil, fmt.Errorf("could not create state stream backend: %w", err) - } - e := &Engine{ log: logger, backend: backend, @@ -175,7 +122,7 @@ func (e *Engine) OnExecutionData(executionData *execution_data.BlockExecutionDat return } - if ok := e.backend.setHighestHeight(header.Height); !ok { + if ok := e.backend.SetHighestHeight(header.Height); !ok { // this means that the height was lower than the current highest height // OnExecutionData is guaranteed by the requester to be called in order, but may be called // multiple times for the same block. diff --git a/engine/access/state_stream/handler.go b/engine/access/state_stream/handler.go index df7c4dd9f6b..bf656b2d994 100644 --- a/engine/access/state_stream/handler.go +++ b/engine/access/state_stream/handler.go @@ -2,36 +2,26 @@ package state_stream import ( "context" - "sync/atomic" access "github.com/onflow/flow/protobuf/go/flow/executiondata" executiondata "github.com/onflow/flow/protobuf/go/flow/executiondata" + "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "github.com/onflow/flow-go/engine/common/rpc" "github.com/onflow/flow-go/engine/common/rpc/convert" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" ) type Handler struct { - api API - chain flow.Chain - - eventFilterConfig EventFilterConfig - - maxStreams int32 - streamCount atomic.Int32 + *state_stream.SubscribeHandler } -func NewHandler(api API, chain flow.Chain, conf EventFilterConfig, maxGlobalStreams uint32) *Handler { - h := &Handler{ - api: api, - chain: chain, - eventFilterConfig: conf, - maxStreams: int32(maxGlobalStreams), - streamCount: atomic.Int32{}, - } +func NewHandler(api state_stream.API, chain flow.Chain, conf state_stream.EventFilterConfig, maxGlobalStreams uint32) *Handler { + h := &Handler{} + h.SubscribeHandler = state_stream.NewSubscribeHandler(api, chain, conf, maxGlobalStreams) return h } @@ -41,7 +31,7 @@ func (h *Handler) GetExecutionDataByBlockID(ctx context.Context, request *access return nil, status.Errorf(codes.InvalidArgument, "could not convert block ID: %v", err) } - execData, err := h.api.GetExecutionDataByBlockID(ctx, blockID) + execData, err := h.Api.GetExecutionDataByBlockID(ctx, blockID) if err != nil { return nil, rpc.ConvertError(err, "could no get execution data", codes.Internal) } @@ -56,11 +46,11 @@ func (h *Handler) GetExecutionDataByBlockID(ctx context.Context, request *access func (h *Handler) SubscribeExecutionData(request *access.SubscribeExecutionDataRequest, stream access.ExecutionDataAPI_SubscribeExecutionDataServer) error { // check if the maximum number of streams is reached - if h.streamCount.Load() >= h.maxStreams { + if h.StreamCount.Load() >= h.MaxStreams { return status.Errorf(codes.ResourceExhausted, "maximum number of streams reached") } - h.streamCount.Add(1) - defer h.streamCount.Add(-1) + h.StreamCount.Add(1) + defer h.StreamCount.Add(-1) startBlockID := flow.ZeroID if request.GetStartBlockId() != nil { @@ -71,7 +61,7 @@ func (h *Handler) SubscribeExecutionData(request *access.SubscribeExecutionDataR startBlockID = blockID } - sub := h.api.SubscribeExecutionData(stream.Context(), startBlockID, request.GetStartBlockHeight()) + sub := h.Api.SubscribeExecutionData(stream.Context(), startBlockID, request.GetStartBlockHeight()) for { v, ok := <-sub.Channel() @@ -82,7 +72,7 @@ func (h *Handler) SubscribeExecutionData(request *access.SubscribeExecutionDataR return nil } - resp, ok := v.(*ExecutionDataResponse) + resp, ok := v.(*state_stream.ExecutionDataResponse) if !ok { return status.Errorf(codes.Internal, "unexpected response type: %T", v) } @@ -103,13 +93,6 @@ func (h *Handler) SubscribeExecutionData(request *access.SubscribeExecutionDataR } func (h *Handler) SubscribeEvents(request *access.SubscribeEventsRequest, stream access.ExecutionDataAPI_SubscribeEventsServer) error { - // check if the maximum number of streams is reached - if h.streamCount.Load() >= h.maxStreams { - return status.Errorf(codes.ResourceExhausted, "maximum number of streams reached") - } - h.streamCount.Add(1) - defer h.streamCount.Add(-1) - startBlockID := flow.ZeroID if request.GetStartBlockId() != nil { blockID, err := convert.BlockID(request.GetStartBlockId()) @@ -119,13 +102,13 @@ func (h *Handler) SubscribeEvents(request *access.SubscribeEventsRequest, stream startBlockID = blockID } - filter := EventFilter{} + filter := state_stream.EventFilter{} if request.GetFilter() != nil { var err error reqFilter := request.GetFilter() - filter, err = NewEventFilter( - h.eventFilterConfig, - h.chain, + filter, err = state_stream.NewEventFilter( + h.EventFilterConfig, + h.Chain, reqFilter.GetEventType(), reqFilter.GetAddress(), reqFilter.GetContract(), @@ -135,7 +118,10 @@ func (h *Handler) SubscribeEvents(request *access.SubscribeEventsRequest, stream } } - sub := h.api.SubscribeEvents(stream.Context(), startBlockID, request.GetStartBlockHeight(), filter) + sub, err := h.SubscribeHandler.SubscribeEvents(stream.Context(), startBlockID, request.GetStartBlockHeight(), filter) + if err != nil { + return err + } for { v, ok := <-sub.Channel() @@ -146,7 +132,7 @@ func (h *Handler) SubscribeEvents(request *access.SubscribeEventsRequest, stream return nil } - resp, ok := v.(*EventsResponse) + resp, ok := v.(*state_stream.EventsResponse) if !ok { return status.Errorf(codes.Internal, "unexpected response type: %T", v) } diff --git a/engine/access/state_stream/mock/api.go b/engine/access/state_stream/mock/api.go index 5b57efc917f..8ddbe1dfb86 100644 --- a/engine/access/state_stream/mock/api.go +++ b/engine/access/state_stream/mock/api.go @@ -4,13 +4,12 @@ package mock import ( context "context" + state_stream2 "github.com/onflow/flow-go/engine/common/state_stream" flow "github.com/onflow/flow-go/model/flow" execution_data "github.com/onflow/flow-go/module/executiondatasync/execution_data" mock "github.com/stretchr/testify/mock" - - state_stream "github.com/onflow/flow-go/engine/access/state_stream" ) // API is an autogenerated mock type for the API type @@ -45,15 +44,15 @@ func (_m *API) GetExecutionDataByBlockID(ctx context.Context, blockID flow.Ident } // SubscribeEvents provides a mock function with given fields: ctx, startBlockID, startHeight, filter -func (_m *API) SubscribeEvents(ctx context.Context, startBlockID flow.Identifier, startHeight uint64, filter state_stream.EventFilter) state_stream.Subscription { +func (_m *API) SubscribeEvents(ctx context.Context, startBlockID flow.Identifier, startHeight uint64, filter state_stream2.EventFilter) state_stream2.Subscription { ret := _m.Called(ctx, startBlockID, startHeight, filter) - var r0 state_stream.Subscription - if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, uint64, state_stream.EventFilter) state_stream.Subscription); ok { + var r0 state_stream2.Subscription + if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, uint64, state_stream2.EventFilter) state_stream2.Subscription); ok { r0 = rf(ctx, startBlockID, startHeight, filter) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(state_stream.Subscription) + r0 = ret.Get(0).(state_stream2.Subscription) } } @@ -61,15 +60,15 @@ func (_m *API) SubscribeEvents(ctx context.Context, startBlockID flow.Identifier } // SubscribeExecutionData provides a mock function with given fields: ctx, startBlockID, startBlockHeight -func (_m *API) SubscribeExecutionData(ctx context.Context, startBlockID flow.Identifier, startBlockHeight uint64) state_stream.Subscription { +func (_m *API) SubscribeExecutionData(ctx context.Context, startBlockID flow.Identifier, startBlockHeight uint64) state_stream2.Subscription { ret := _m.Called(ctx, startBlockID, startBlockHeight) - var r0 state_stream.Subscription - if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, uint64) state_stream.Subscription); ok { + var r0 state_stream2.Subscription + if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, uint64) state_stream2.Subscription); ok { r0 = rf(ctx, startBlockID, startBlockHeight) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(state_stream.Subscription) + r0 = ret.Get(0).(state_stream2.Subscription) } } diff --git a/engine/access/state_stream/backend.go b/engine/common/state_stream/backend.go similarity index 85% rename from engine/access/state_stream/backend.go rename to engine/common/state_stream/backend.go index 33c5e18cb77..2eb1f33d107 100644 --- a/engine/access/state_stream/backend.go +++ b/engine/common/state_stream/backend.go @@ -35,6 +35,38 @@ const ( DefaultResponseLimit = float64(0) ) +// Config defines the configurable options for the ingress server. +type Config struct { + EventFilterConfig + + // ListenAddr is the address the GRPC server will listen on as host:port + ListenAddr string + + // MaxExecutionDataMsgSize is the max message size for block execution data API + MaxExecutionDataMsgSize uint + + // RpcMetricsEnabled specifies whether to enable the GRPC metrics + RpcMetricsEnabled bool + + // MaxGlobalStreams defines the global max number of streams that can be open at the same time. + MaxGlobalStreams uint32 + + // ExecutionDataCacheSize is the max number of objects for the execution data cache. + ExecutionDataCacheSize uint32 + + // ClientSendTimeout is the timeout for sending a message to the client. After the timeout, + // the stream is closed with an error. + ClientSendTimeout time.Duration + + // ClientSendBufferSize is the size of the response buffer for sending messages to the client. + ClientSendBufferSize uint + + // ResponseLimit is the max responses per second allowed on a stream. After exceeding the limit, + // the stream is paused until more capacity is available. Searches of past data can be CPU + // intensive, so this helps manage the impact. + ResponseLimit float64 +} + type GetExecutionDataFunc func(context.Context, uint64) (*execution_data.BlockExecutionDataEntity, error) type GetStartHeightFunc func(flow.Identifier, uint64) (uint64, error) @@ -190,6 +222,6 @@ func (b *StateStreamBackend) getStartHeight(startBlockID flow.Identifier, startH } // SetHighestHeight sets the highest height for which execution data is available. -func (b *StateStreamBackend) setHighestHeight(height uint64) bool { +func (b *StateStreamBackend) SetHighestHeight(height uint64) bool { return b.highestHeight.Set(height) } diff --git a/engine/access/state_stream/backend_events.go b/engine/common/state_stream/backend_events.go similarity index 100% rename from engine/access/state_stream/backend_events.go rename to engine/common/state_stream/backend_events.go diff --git a/engine/access/state_stream/backend_events_test.go b/engine/common/state_stream/backend_events_test.go similarity index 98% rename from engine/access/state_stream/backend_events_test.go rename to engine/common/state_stream/backend_events_test.go index 68ca0a789cb..69a85d55907 100644 --- a/engine/access/state_stream/backend_events_test.go +++ b/engine/common/state_stream/backend_events_test.go @@ -108,7 +108,7 @@ func (s *BackendEventsSuite) TestSubscribeEvents() { // this simulates a subscription on a past block for i := 0; i <= test.highestBackfill; i++ { s.T().Logf("backfilling block %d", i) - s.backend.setHighestHeight(s.blocks[i].Header.Height) + s.backend.SetHighestHeight(s.blocks[i].Header.Height) } subCtx, subCancel := context.WithCancel(ctx) @@ -121,7 +121,7 @@ func (s *BackendEventsSuite) TestSubscribeEvents() { // simulate new exec data received. // exec data for all blocks with index <= highestBackfill were already received if i > test.highestBackfill { - s.backend.setHighestHeight(b.Header.Height) + s.backend.SetHighestHeight(b.Header.Height) s.broadcaster.Publish() } diff --git a/engine/access/state_stream/backend_executiondata.go b/engine/common/state_stream/backend_executiondata.go similarity index 100% rename from engine/access/state_stream/backend_executiondata.go rename to engine/common/state_stream/backend_executiondata.go diff --git a/engine/access/state_stream/backend_executiondata_test.go b/engine/common/state_stream/backend_executiondata_test.go similarity index 98% rename from engine/access/state_stream/backend_executiondata_test.go rename to engine/common/state_stream/backend_executiondata_test.go index b619a94e322..9f3b4ded4cc 100644 --- a/engine/access/state_stream/backend_executiondata_test.go +++ b/engine/common/state_stream/backend_executiondata_test.go @@ -255,7 +255,7 @@ func (s *BackendExecutionDataSuite) TestGetExecutionDataByBlockID() { execData := s.execDataMap[block.ID()] // notify backend block is available - s.backend.setHighestHeight(block.Header.Height) + s.backend.SetHighestHeight(block.Header.Height) var err error s.Run("happy path TestGetExecutionDataByBlockID success", func() { @@ -331,7 +331,7 @@ func (s *BackendExecutionDataSuite) TestSubscribeExecutionData() { // this simulates a subscription on a past block for i := 0; i <= test.highestBackfill; i++ { s.T().Logf("backfilling block %d", i) - s.backend.setHighestHeight(s.blocks[i].Header.Height) + s.backend.SetHighestHeight(s.blocks[i].Header.Height) } subCtx, subCancel := context.WithCancel(ctx) @@ -345,7 +345,7 @@ func (s *BackendExecutionDataSuite) TestSubscribeExecutionData() { // simulate new exec data received. // exec data for all blocks with index <= highestBackfill were already received if i > test.highestBackfill { - s.backend.setHighestHeight(b.Header.Height) + s.backend.SetHighestHeight(b.Header.Height) s.broadcaster.Publish() } diff --git a/engine/access/state_stream/event.go b/engine/common/state_stream/event.go similarity index 100% rename from engine/access/state_stream/event.go rename to engine/common/state_stream/event.go diff --git a/engine/access/state_stream/event_test.go b/engine/common/state_stream/event_test.go similarity index 97% rename from engine/access/state_stream/event_test.go rename to engine/common/state_stream/event_test.go index 3dbccd34406..65c8629989a 100644 --- a/engine/access/state_stream/event_test.go +++ b/engine/common/state_stream/event_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" ) diff --git a/engine/access/state_stream/filter.go b/engine/common/state_stream/filter.go similarity index 100% rename from engine/access/state_stream/filter.go rename to engine/common/state_stream/filter.go diff --git a/engine/access/state_stream/filter_test.go b/engine/common/state_stream/filter_test.go similarity index 98% rename from engine/access/state_stream/filter_test.go rename to engine/common/state_stream/filter_test.go index d25c272a06f..982687ab756 100644 --- a/engine/access/state_stream/filter_test.go +++ b/engine/common/state_stream/filter_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/onflow/flow-go/engine/access/state_stream" + state_stream "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) diff --git a/engine/access/state_stream/streamer.go b/engine/common/state_stream/streamer.go similarity index 100% rename from engine/access/state_stream/streamer.go rename to engine/common/state_stream/streamer.go diff --git a/engine/access/state_stream/streamer_test.go b/engine/common/state_stream/streamer_test.go similarity index 98% rename from engine/access/state_stream/streamer_test.go rename to engine/common/state_stream/streamer_test.go index 6c80feec7ed..f728b39c0a3 100644 --- a/engine/access/state_stream/streamer_test.go +++ b/engine/common/state_stream/streamer_test.go @@ -3,6 +3,7 @@ package state_stream_test import ( "context" "fmt" + "testing" "time" @@ -11,8 +12,8 @@ import ( "github.com/stretchr/testify/mock" "github.com/onflow/flow-go/engine" - "github.com/onflow/flow-go/engine/access/state_stream" streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" + state_stream "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/utils/unittest" ) diff --git a/engine/common/state_stream/subscribe_handler.go b/engine/common/state_stream/subscribe_handler.go new file mode 100644 index 00000000000..1e8da756225 --- /dev/null +++ b/engine/common/state_stream/subscribe_handler.go @@ -0,0 +1,43 @@ +package state_stream + +import ( + "context" + "sync/atomic" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/onflow/flow-go/model/flow" +) + +type SubscribeHandler struct { + Api API + Chain flow.Chain + + EventFilterConfig EventFilterConfig + + MaxStreams int32 + StreamCount atomic.Int32 +} + +func NewSubscribeHandler(api API, chain flow.Chain, conf EventFilterConfig, maxGlobalStreams uint32) *SubscribeHandler { + h := &SubscribeHandler{ + Api: api, + Chain: chain, + EventFilterConfig: conf, + MaxStreams: int32(maxGlobalStreams), + StreamCount: atomic.Int32{}, + } + return h +} + +func (h *SubscribeHandler) SubscribeEvents(ctx context.Context, startBlockID flow.Identifier, startBlockHeight uint64, filter EventFilter) (Subscription, error) { + // check if the maximum number of streams is reached + if h.StreamCount.Load() >= h.MaxStreams { + return nil, status.Errorf(codes.ResourceExhausted, "maximum number of streams reached") + } + h.StreamCount.Add(1) + defer h.StreamCount.Add(-1) + + return h.Api.SubscribeEvents(ctx, startBlockID, startBlockHeight, filter), nil +} diff --git a/engine/access/state_stream/subscription.go b/engine/common/state_stream/subscription.go similarity index 100% rename from engine/access/state_stream/subscription.go rename to engine/common/state_stream/subscription.go diff --git a/engine/access/state_stream/subscription_test.go b/engine/common/state_stream/subscription_test.go similarity index 98% rename from engine/access/state_stream/subscription_test.go rename to engine/common/state_stream/subscription_test.go index d5ef7296cf3..81b0edf013b 100644 --- a/engine/access/state_stream/subscription_test.go +++ b/engine/common/state_stream/subscription_test.go @@ -3,6 +3,7 @@ package state_stream_test import ( "context" "fmt" + "sync" "testing" "time" @@ -10,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/utils/unittest" ) diff --git a/go.mod b/go.mod index 71a7d0daf68..b978c7c9f87 100644 --- a/go.mod +++ b/go.mod @@ -100,6 +100,7 @@ require ( require ( github.com/coreos/go-semver v0.3.0 github.com/go-playground/validator/v10 v10.14.1 + github.com/gorilla/websocket v1.5.0 github.com/mitchellh/mapstructure v1.5.0 github.com/onflow/wal v0.0.0-20230529184820-bc9f8244608d github.com/slok/go-http-metrics v0.10.0 @@ -166,7 +167,6 @@ require ( github.com/google/gopacket v1.1.19 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.7.1 // indirect - github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/huin/goupnp v1.0.3 // indirect diff --git a/integration/localnet/builder/bootstrap.go b/integration/localnet/builder/bootstrap.go index 0bd709b9203..f11d8fee9c7 100644 --- a/integration/localnet/builder/bootstrap.go +++ b/integration/localnet/builder/bootstrap.go @@ -414,7 +414,7 @@ func prepareAccessService(container testnet.ContainerConfig, i int, n int) Servi fmt.Sprintf("--secure-rpc-addr=%s:%s", container.ContainerName, testnet.GRPCSecurePort), fmt.Sprintf("--http-addr=%s:%s", container.ContainerName, testnet.GRPCWebPort), fmt.Sprintf("--rest-addr=%s:%s", container.ContainerName, testnet.RESTPort), - fmt.Sprintf("--state-stream-addr=%s:%s", container.ContainerName, testnet.ExecutionStatePort), + fmt.Sprintf("--state_stream-addr=%s:%s", container.ContainerName, testnet.ExecutionStatePort), fmt.Sprintf("--collection-ingress-port=%s", testnet.GRPCPort), "--supports-observer=true", fmt.Sprintf("--public-network-address=%s:%s", container.ContainerName, testnet.PublicNetworkPort), @@ -423,7 +423,7 @@ func prepareAccessService(container testnet.ContainerConfig, i int, n int) Servi "--log-tx-time-to-finalized-executed", "--execution-data-sync-enabled=true", "--execution-data-dir=/data/execution-data", - fmt.Sprintf("--state-stream-addr=%s:%s", container.ContainerName, testnet.ExecutionStatePort), + fmt.Sprintf("--state_stream-addr=%s:%s", container.ContainerName, testnet.ExecutionStatePort), ) service.AddExposedPorts( diff --git a/integration/testnet/network.go b/integration/testnet/network.go index 9f060dd0532..4f00abb4f36 100644 --- a/integration/testnet/network.go +++ b/integration/testnet/network.go @@ -873,7 +873,7 @@ func (net *FlowNetwork) AddNode(t *testing.T, bootstrapDir string, nodeConf Cont nodeContainer.AddFlag("rest-addr", nodeContainer.ContainerAddr(RESTPort)) nodeContainer.exposePort(ExecutionStatePort, testingdock.RandomPort(t)) - nodeContainer.AddFlag("state-stream-addr", nodeContainer.ContainerAddr(ExecutionStatePort)) + nodeContainer.AddFlag("state_stream-addr", nodeContainer.ContainerAddr(ExecutionStatePort)) // uncomment line below to point the access node exclusively to a single collection node // nodeContainer.AddFlag("static-collection-ingress-addr", "collection_1:9000") diff --git a/module/state_synchronization/requester/execution_data_requester_test.go b/module/state_synchronization/requester/execution_data_requester_test.go index f116da7a297..72ad81294bc 100644 --- a/module/state_synchronization/requester/execution_data_requester_test.go +++ b/module/state_synchronization/requester/execution_data_requester_test.go @@ -3,6 +3,7 @@ package requester_test import ( "context" "fmt" + "math/rand" "sync" "testing" @@ -18,7 +19,7 @@ import ( "github.com/onflow/flow-go/consensus/hotstuff/model" "github.com/onflow/flow-go/consensus/hotstuff/notifications/pubsub" - "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" "github.com/onflow/flow-go/module/blobs" diff --git a/module/state_synchronization/requester/jobs/execution_data_reader_test.go b/module/state_synchronization/requester/jobs/execution_data_reader_test.go index 365e0358ee6..14201942712 100644 --- a/module/state_synchronization/requester/jobs/execution_data_reader_test.go +++ b/module/state_synchronization/requester/jobs/execution_data_reader_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/onflow/flow-go/engine/access/state_stream" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/executiondatasync/execution_data" "github.com/onflow/flow-go/module/executiondatasync/execution_data/cache" From 6fcc308ae32030392dae520e8a6d804ebe193d03 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 7 Jul 2023 16:42:48 +0300 Subject: [PATCH 02/35] Updated last commit --- .../node_builder/access_node_builder.go | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index a8d8cc8d31d..f1a1985ca84 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -43,7 +43,7 @@ import ( "github.com/onflow/flow-go/engine/access/state_stream" followereng "github.com/onflow/flow-go/engine/common/follower" "github.com/onflow/flow-go/engine/common/requester" - common_state_stream "github.com/onflow/flow-go/engine/common/state_stream" + cstatestream "github.com/onflow/flow-go/engine/common/state_stream" synceng "github.com/onflow/flow-go/engine/common/synchronization" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/model/flow/filter" @@ -117,8 +117,8 @@ type AccessNodeConfig struct { apiRatelimits map[string]int apiBurstlimits map[string]int rpcConf rpc.Config - stateStreamConf common_state_stream.Config - stateStreamBackend common_state_stream.API + stateStreamConf cstatestream.Config + stateStreamBackend cstatestream.API stateStreamFilterConf map[string]int ExecutionNodeAddress string // deprecated HistoricalAccessRPCs []access.AccessAPIClient @@ -164,14 +164,14 @@ func DefaultAccessNodeConfig() *AccessNodeConfig { ArchiveAddressList: nil, MaxMsgSize: grpcutils.DefaultMaxMsgSize, }, - stateStreamConf: common_state_stream.Config{ + stateStreamConf: cstatestream.Config{ MaxExecutionDataMsgSize: grpcutils.DefaultMaxMsgSize, - ExecutionDataCacheSize: common_state_stream.DefaultCacheSize, - ClientSendTimeout: common_state_stream.DefaultSendTimeout, - ClientSendBufferSize: common_state_stream.DefaultSendBufferSize, - MaxGlobalStreams: common_state_stream.DefaultMaxGlobalStreams, - EventFilterConfig: common_state_stream.DefaultEventFilterConfig, - ResponseLimit: common_state_stream.DefaultResponseLimit, + ExecutionDataCacheSize: cstatestream.DefaultCacheSize, + ClientSendTimeout: cstatestream.DefaultSendTimeout, + ClientSendBufferSize: cstatestream.DefaultSendBufferSize, + MaxGlobalStreams: cstatestream.DefaultMaxGlobalStreams, + EventFilterConfig: cstatestream.DefaultEventFilterConfig, + ResponseLimit: cstatestream.DefaultResponseLimit, }, stateStreamBackend: nil, stateStreamFilterConf: nil, @@ -606,7 +606,7 @@ func (builder *FlowAccessNodeBuilder) BuildExecutionDataRequester() *FlowAccessN broadcaster := engine.NewBroadcaster() - backend, err := common_state_stream.New( + backend, err := cstatestream.New( node.Logger, builder.stateStreamConf, node.State, @@ -675,7 +675,7 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { flags.UintVar(&builder.executionGRPCPort, "execution-ingress-port", defaultConfig.executionGRPCPort, "the grpc ingress port for all execution nodes") flags.StringVarP(&builder.rpcConf.UnsecureGRPCListenAddr, "rpc-addr", "r", defaultConfig.rpcConf.UnsecureGRPCListenAddr, "the address the unsecured gRPC server listens on") flags.StringVar(&builder.rpcConf.SecureGRPCListenAddr, "secure-rpc-addr", defaultConfig.rpcConf.SecureGRPCListenAddr, "the address the secure gRPC server listens on") - flags.StringVar(&builder.stateStreamConf.ListenAddr, "state_stream-addr", defaultConfig.stateStreamConf.ListenAddr, "the address the state stream server listens on (if empty the server will not be started)") + flags.StringVar(&builder.stateStreamConf.ListenAddr, "state-stream-addr", defaultConfig.stateStreamConf.ListenAddr, "the address the state stream server listens on (if empty the server will not be started)") flags.StringVarP(&builder.rpcConf.HTTPListenAddr, "http-addr", "h", defaultConfig.rpcConf.HTTPListenAddr, "the address the http proxy server listens on") flags.StringVar(&builder.rpcConf.RESTListenAddr, "rest-addr", defaultConfig.rpcConf.RESTListenAddr, "the address the REST server listens on (if empty the REST server will not be started)") flags.StringVarP(&builder.rpcConf.CollectionAddr, "static-collection-ingress-addr", "", defaultConfig.rpcConf.CollectionAddr, "the address (of the collection node) to send transactions to") @@ -713,12 +713,12 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { // Execution State Streaming API flags.Uint32Var(&builder.stateStreamConf.ExecutionDataCacheSize, "execution-data-cache-size", defaultConfig.stateStreamConf.ExecutionDataCacheSize, "block execution data cache size") - flags.Uint32Var(&builder.stateStreamConf.MaxGlobalStreams, "state_stream-global-max-streams", defaultConfig.stateStreamConf.MaxGlobalStreams, "global maximum number of concurrent streams") + flags.Uint32Var(&builder.stateStreamConf.MaxGlobalStreams, "state-stream-global-max-streams", defaultConfig.stateStreamConf.MaxGlobalStreams, "global maximum number of concurrent streams") flags.UintVar(&builder.stateStreamConf.MaxExecutionDataMsgSize, "state_stream-max-message-size", defaultConfig.stateStreamConf.MaxExecutionDataMsgSize, "maximum size for a gRPC message containing block execution data") - flags.StringToIntVar(&builder.stateStreamFilterConf, "state_stream-event-filter-limits", defaultConfig.stateStreamFilterConf, "event filter limits for ExecutionData SubscribeEvents API e.g. EventTypes=100,Addresses=100,Contracts=100 etc.") - flags.DurationVar(&builder.stateStreamConf.ClientSendTimeout, "state_stream-send-timeout", defaultConfig.stateStreamConf.ClientSendTimeout, "maximum wait before timing out while sending a response to a streaming client e.g. 30s") - flags.UintVar(&builder.stateStreamConf.ClientSendBufferSize, "state_stream-send-buffer-size", defaultConfig.stateStreamConf.ClientSendBufferSize, "maximum number of responses to buffer within a stream") - flags.Float64Var(&builder.stateStreamConf.ResponseLimit, "state_stream-response-limit", defaultConfig.stateStreamConf.ResponseLimit, "max number of responses per second to send over streaming endpoints. this helps manage resources consumed by each client querying data not in the cache e.g. 3 or 0.5. 0 means no limit") + flags.StringToIntVar(&builder.stateStreamFilterConf, "state-stream-event-filter-limits", defaultConfig.stateStreamFilterConf, "event filter limits for ExecutionData SubscribeEvents API e.g. EventTypes=100,Addresses=100,Contracts=100 etc.") + flags.DurationVar(&builder.stateStreamConf.ClientSendTimeout, "state-stream-send-timeout", defaultConfig.stateStreamConf.ClientSendTimeout, "maximum wait before timing out while sending a response to a streaming client e.g. 30s") + flags.UintVar(&builder.stateStreamConf.ClientSendBufferSize, "state-stream-send-buffer-size", defaultConfig.stateStreamConf.ClientSendBufferSize, "maximum number of responses to buffer within a stream") + flags.Float64Var(&builder.stateStreamConf.ResponseLimit, "state-stream-response-limit", defaultConfig.stateStreamConf.ResponseLimit, "max number of responses per second to send over streaming endpoints. this helps manage resources consumed by each client querying data not in the cache e.g. 3 or 0.5. 0 means no limit") }).ValidateFlags(func() error { if builder.supportsObserver && (builder.PublicNetworkConfig.BindAddress == cmd.NotSet || builder.PublicNetworkConfig.BindAddress == "") { return errors.New("public-network-address must be set if supports-observer is true") @@ -745,23 +745,23 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { return errors.New("execution-data-cache-size must be greater than 0") } if builder.stateStreamConf.ClientSendBufferSize == 0 { - return errors.New("state_stream-send-buffer-size must be greater than 0") + return errors.New("state-stream-send-buffer-size must be greater than 0") } if len(builder.stateStreamFilterConf) > 3 { - return errors.New("state_stream-event-filter-limits must have at most 3 keys (EventTypes, Addresses, Contracts)") + return errors.New("state-stream-event-filter-limits must have at most 3 keys (EventTypes, Addresses, Contracts)") } for key, value := range builder.stateStreamFilterConf { switch key { case "EventTypes", "Addresses", "Contracts": if value <= 0 { - return fmt.Errorf("state_stream-event-filter-limits %s must be greater than 0", key) + return fmt.Errorf("state-stream-event-filter-limits %s must be greater than 0", key) } default: - return errors.New("state_stream-event-filter-limits may only contain the keys EventTypes, Addresses, Contracts") + return errors.New("state-stream-event-filter-limits may only contain the keys EventTypes, Addresses, Contracts") } } if builder.stateStreamConf.ResponseLimit < 0 { - return errors.New("state_stream-response-limit must be greater than or equal to 0") + return errors.New("state-stream-response-limit must be greater than or equal to 0") } } From 3040fe01205b5151008aafa10bc0a3924ad26f85 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 7 Jul 2023 16:46:24 +0300 Subject: [PATCH 03/35] Reverted back flag name --- cmd/access/node_builder/access_node_builder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index f1a1985ca84..b4be08edb34 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -714,7 +714,7 @@ func (builder *FlowAccessNodeBuilder) extraFlags() { // Execution State Streaming API flags.Uint32Var(&builder.stateStreamConf.ExecutionDataCacheSize, "execution-data-cache-size", defaultConfig.stateStreamConf.ExecutionDataCacheSize, "block execution data cache size") flags.Uint32Var(&builder.stateStreamConf.MaxGlobalStreams, "state-stream-global-max-streams", defaultConfig.stateStreamConf.MaxGlobalStreams, "global maximum number of concurrent streams") - flags.UintVar(&builder.stateStreamConf.MaxExecutionDataMsgSize, "state_stream-max-message-size", defaultConfig.stateStreamConf.MaxExecutionDataMsgSize, "maximum size for a gRPC message containing block execution data") + flags.UintVar(&builder.stateStreamConf.MaxExecutionDataMsgSize, "state-stream-max-message-size", defaultConfig.stateStreamConf.MaxExecutionDataMsgSize, "maximum size for a gRPC message containing block execution data") flags.StringToIntVar(&builder.stateStreamFilterConf, "state-stream-event-filter-limits", defaultConfig.stateStreamFilterConf, "event filter limits for ExecutionData SubscribeEvents API e.g. EventTypes=100,Addresses=100,Contracts=100 etc.") flags.DurationVar(&builder.stateStreamConf.ClientSendTimeout, "state-stream-send-timeout", defaultConfig.stateStreamConf.ClientSendTimeout, "maximum wait before timing out while sending a response to a streaming client e.g. 30s") flags.UintVar(&builder.stateStreamConf.ClientSendBufferSize, "state-stream-send-buffer-size", defaultConfig.stateStreamConf.ClientSendBufferSize, "maximum number of responses to buffer within a stream") From 82b29ab7dd647cd3f09cb39d1abd10a07aec1edb Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 7 Jul 2023 17:00:32 +0300 Subject: [PATCH 04/35] Added comments --- engine/access/rest/websocket_handler.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/engine/access/rest/websocket_handler.go b/engine/access/rest/websocket_handler.go index 3f93751d0da..4a64b1ad50b 100644 --- a/engine/access/rest/websocket_handler.go +++ b/engine/access/rest/websocket_handler.go @@ -19,9 +19,8 @@ type SubscribeHandlerFunc func( h *state_stream.SubscribeHandler, ) (interface{}, error) -// WSHandler is custom http handler implementing custom handler function. -// Handler function allows easier handling of errors and responses as it -// wraps functionality for handling error and responses outside of endpoint handling. +// WSHandler is websocket handler implementing custom handler function and allows easier handling of errors and +// responses as it wraps functionality for handling error and responses outside of endpoint handling. type WSHandler struct { *HttpHandler *state_stream.SubscribeHandler @@ -44,7 +43,7 @@ func NewWSHandler( return handler } -// ServerHTTP function acts as a wrapper to each request providing common handling functionality +// ServeHTTP function acts as a wrapper to each request providing common handling functionality // such as logging, error handling, request decorators func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // create a logger From 39e87684507c24a2530ae87b6011a3fa25019918 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 13 Jul 2023 15:57:06 +0300 Subject: [PATCH 05/35] Refactored handlers, added filters for subscribe_events endpoint --- .../node_builder/access_node_builder.go | 25 ++++++++++--------- cmd/observer/node_builder/observer_builder.go | 3 ++- .../access/rest/request/subscribe_events.go | 12 +++++++-- engine/access/rest/router.go | 4 +-- engine/access/rest/server.go | 4 +-- engine/access/rest/subscribe_events.go | 5 ++-- engine/access/rest/test_helpers.go | 22 ++++++++++------ engine/access/rest/websocket_handler.go | 4 +-- engine/access/rest_api_test.go | 4 +++ engine/access/rpc/engine.go | 15 ++++++----- engine/access/rpc/rate_limit_test.go | 3 ++- engine/access/secure_grpcr_test.go | 4 +++ engine/access/state_stream/handler.go | 4 +-- .../common/state_stream/subscribe_handler.go | 4 +-- 14 files changed, 70 insertions(+), 43 deletions(-) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index b4be08edb34..0225482497a 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -43,7 +43,7 @@ import ( "github.com/onflow/flow-go/engine/access/state_stream" followereng "github.com/onflow/flow-go/engine/common/follower" "github.com/onflow/flow-go/engine/common/requester" - cstatestream "github.com/onflow/flow-go/engine/common/state_stream" + common_state_stream "github.com/onflow/flow-go/engine/common/state_stream" synceng "github.com/onflow/flow-go/engine/common/synchronization" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/model/flow/filter" @@ -117,8 +117,8 @@ type AccessNodeConfig struct { apiRatelimits map[string]int apiBurstlimits map[string]int rpcConf rpc.Config - stateStreamConf cstatestream.Config - stateStreamBackend cstatestream.API + stateStreamConf common_state_stream.Config + stateStreamBackend common_state_stream.API stateStreamFilterConf map[string]int ExecutionNodeAddress string // deprecated HistoricalAccessRPCs []access.AccessAPIClient @@ -164,14 +164,14 @@ func DefaultAccessNodeConfig() *AccessNodeConfig { ArchiveAddressList: nil, MaxMsgSize: grpcutils.DefaultMaxMsgSize, }, - stateStreamConf: cstatestream.Config{ + stateStreamConf: common_state_stream.Config{ MaxExecutionDataMsgSize: grpcutils.DefaultMaxMsgSize, - ExecutionDataCacheSize: cstatestream.DefaultCacheSize, - ClientSendTimeout: cstatestream.DefaultSendTimeout, - ClientSendBufferSize: cstatestream.DefaultSendBufferSize, - MaxGlobalStreams: cstatestream.DefaultMaxGlobalStreams, - EventFilterConfig: cstatestream.DefaultEventFilterConfig, - ResponseLimit: cstatestream.DefaultResponseLimit, + ExecutionDataCacheSize: common_state_stream.DefaultCacheSize, + ClientSendTimeout: common_state_stream.DefaultSendTimeout, + ClientSendBufferSize: common_state_stream.DefaultSendBufferSize, + MaxGlobalStreams: common_state_stream.DefaultMaxGlobalStreams, + EventFilterConfig: common_state_stream.DefaultEventFilterConfig, + ResponseLimit: common_state_stream.DefaultResponseLimit, }, stateStreamBackend: nil, stateStreamFilterConf: nil, @@ -606,7 +606,7 @@ func (builder *FlowAccessNodeBuilder) BuildExecutionDataRequester() *FlowAccessN broadcaster := engine.NewBroadcaster() - backend, err := cstatestream.New( + backend, err := common_state_stream.New( node.Logger, builder.stateStreamConf, node.State, @@ -1032,7 +1032,8 @@ func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { builder.apiBurstlimits, builder.Me, builder.stateStreamBackend, - builder.stateStreamConf, + builder.stateStreamConf.EventFilterConfig, + builder.stateStreamConf.MaxGlobalStreams, ) if err != nil { return nil, err diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index 41582e9f24a..5d25b7b7176 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -864,7 +864,8 @@ func (builder *ObserverServiceBuilder) enqueueRPCServer() { builder.apiBurstlimits, builder.Me, nil, - state_stream.Config{}, + state_stream.DefaultEventFilterConfig, + 0, ) if err != nil { return nil, err diff --git a/engine/access/rest/request/subscribe_events.go b/engine/access/rest/request/subscribe_events.go index d10efcf76bc..efe6517c102 100644 --- a/engine/access/rest/request/subscribe_events.go +++ b/engine/access/rest/request/subscribe_events.go @@ -8,13 +8,16 @@ import ( const startBlockIdQuery = "start_block_id" const eventTypesQuery = "event_types" +const addressesQuery = "addresses" +const contractsQuery = "contracts" type SubscribeEvents struct { StartBlockID flow.Identifier StartHeight uint64 EventTypes []string - //Uliana: TODO: add events filter - contracts, addresses + Addresses []string + Contracts []string } func (g *SubscribeEvents) Build(r *Request) error { @@ -22,10 +25,12 @@ func (g *SubscribeEvents) Build(r *Request) error { r.GetQueryParam(startBlockIdQuery), r.GetQueryParam(startHeightQuery), r.GetQueryParams(eventTypesQuery), + r.GetQueryParams(addressesQuery), + r.GetQueryParams(contractsQuery), ) } -func (g *SubscribeEvents) Parse(rawBlockID string, rawStart string, rawTypes []string) error { +func (g *SubscribeEvents) Parse(rawBlockID string, rawStart string, rawTypes []string, rawAddresses []string, rawContracts []string) error { var height Height err := height.Parse(rawStart) if err != nil { @@ -53,5 +58,8 @@ func (g *SubscribeEvents) Parse(rawBlockID string, rawStart string, rawTypes []s g.EventTypes = eventTypes.Flow() + g.Addresses = rawAddresses + g.Contracts = rawContracts + return nil } diff --git a/engine/access/rest/router.go b/engine/access/rest/router.go index afbcd252ffa..2e800e141ef 100644 --- a/engine/access/rest/router.go +++ b/engine/access/rest/router.go @@ -22,7 +22,7 @@ func newRouter(backend access.API, chain flow.Chain, restCollector module.RestMetrics, stateStreamApi state_stream.API, - conf state_stream.EventFilterConfig, + eventFilterConfig state_stream.EventFilterConfig, maxGlobalStreams uint32, ) (*mux.Router, error) { router := mux.NewRouter().StrictSlash(true) @@ -46,7 +46,7 @@ func newRouter(backend access.API, } for _, r := range WSRoutes { - h := NewWSHandler(logger, r.Handler, chain, stateStreamApi, conf, maxGlobalStreams) + h := NewWSHandler(logger, r.Handler, chain, stateStreamApi, eventFilterConfig, maxGlobalStreams) v1SubRouter. Methods(r.Method). Path(r.Pattern). diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index 51322246215..68be80f47b4 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -20,11 +20,11 @@ func NewServer(backend access.API, chain flow.Chain, restCollector module.RestMetrics, api state_stream.API, - conf state_stream.EventFilterConfig, + eventFilterConfig state_stream.EventFilterConfig, maxGlobalStreams uint32, ) (*http.Server, error) { - router, err := newRouter(backend, logger, chain, restCollector, api, conf, maxGlobalStreams) + router, err := newRouter(backend, logger, chain, restCollector, api, eventFilterConfig, maxGlobalStreams) if err != nil { return nil, err } diff --git a/engine/access/rest/subscribe_events.go b/engine/access/rest/subscribe_events.go index b437cefc4e3..cf9a20cf0fa 100644 --- a/engine/access/rest/subscribe_events.go +++ b/engine/access/rest/subscribe_events.go @@ -31,13 +31,12 @@ func SubscribeEvents(r *request.Request, w http.ResponseWriter, h *state_stream. var filter state_stream.EventFilter // Retrieve the filter parameters from the request, if provided - emptyString := make([]string, 0) //Uliana: TODO: remove it filter, err = state_stream.NewEventFilter( h.EventFilterConfig, r.Chain, req.EventTypes, - emptyString, //Uliana: TODO: parsed addresses - emptyString, //Uliana: TODO: parsed contracts + req.Addresses, + req.Contracts, ) if err != nil { err = fmt.Errorf("invalid event filter: %s", err) diff --git a/engine/access/rest/test_helpers.go b/engine/access/rest/test_helpers.go index c188c8a0ebe..66f8da148ca 100644 --- a/engine/access/rest/test_helpers.go +++ b/engine/access/rest/test_helpers.go @@ -12,7 +12,8 @@ import ( "github.com/stretchr/testify/require" "github.com/onflow/flow-go/access/mock" - "github.com/onflow/flow-go/engine/common/state_stream" + mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" + common_state_stream "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/metrics" ) @@ -27,16 +28,21 @@ const ( heightQueryParam = "height" ) -func executeRequest(req *http.Request, backend *mock.API) (*httptest.ResponseRecorder, error) { +func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_state_stream.API) (*httptest.ResponseRecorder, error) { var b bytes.Buffer logger := zerolog.New(&b) restCollector := metrics.NewNoopCollector() - stateStreamConfig := state_stream.Config{} //Uliana: TODO: add test for subscribe_events, provide state_stream backend + + stateStreamConfig := common_state_stream.Config{ + EventFilterConfig: common_state_stream.DefaultEventFilterConfig, + MaxGlobalStreams: common_state_stream.DefaultMaxGlobalStreams, + } + router, err := newRouter(backend, logger, flow.Testnet.Chain(), restCollector, - nil, + stateStreamApi, stateStreamConfig.EventFilterConfig, stateStreamConfig.MaxGlobalStreams) if err != nil { @@ -48,12 +54,12 @@ func executeRequest(req *http.Request, backend *mock.API) (*httptest.ResponseRec return rr, nil } -func assertOKResponse(t *testing.T, req *http.Request, expectedRespBody string, backend *mock.API) { - assertResponse(t, req, http.StatusOK, expectedRespBody, backend) +func assertOKResponse(t *testing.T, req *http.Request, expectedRespBody string, backend *mock.API, stateStreamApi *mock_state_stream.API) { + assertResponse(t, req, http.StatusOK, expectedRespBody, backend, stateStreamApi) } -func assertResponse(t *testing.T, req *http.Request, status int, expectedRespBody string, backend *mock.API) { - rr, err := executeRequest(req, backend) +func assertResponse(t *testing.T, req *http.Request, status int, expectedRespBody string, backend *mock.API, stateStreamApi *mock_state_stream.API) { + rr, err := executeRequest(req, backend, stateStreamApi) assert.NoError(t, err) actualResponseBody := rr.Body.String() diff --git a/engine/access/rest/websocket_handler.go b/engine/access/rest/websocket_handler.go index 4a64b1ad50b..5fa8e2ce906 100644 --- a/engine/access/rest/websocket_handler.go +++ b/engine/access/rest/websocket_handler.go @@ -32,14 +32,14 @@ func NewWSHandler( subscribeFunc SubscribeHandlerFunc, chain flow.Chain, api state_stream.API, - conf state_stream.EventFilterConfig, + eventFilterConfig state_stream.EventFilterConfig, maxGlobalStreams uint32, ) *WSHandler { handler := &WSHandler{ subscribeFunc: subscribeFunc, } handler.HttpHandler = NewHttpHandler(logger, chain) - handler.SubscribeHandler = state_stream.NewSubscribeHandler(api, chain, conf, maxGlobalStreams) + handler.SubscribeHandler = state_stream.NewSubscribeHandler(api, chain, eventFilterConfig, maxGlobalStreams) return handler } diff --git a/engine/access/rest_api_test.go b/engine/access/rest_api_test.go index 5ee8f6d9730..de1d1b07e31 100644 --- a/engine/access/rest_api_test.go +++ b/engine/access/rest_api_test.go @@ -22,6 +22,7 @@ import ( "github.com/onflow/flow-go/engine/access/rest" "github.com/onflow/flow-go/engine/access/rest/request" "github.com/onflow/flow-go/engine/access/rpc" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/metrics" @@ -139,6 +140,9 @@ func (suite *RestAPITestSuite) SetupTest() { nil, nil, suite.me, + nil, + state_stream.DefaultEventFilterConfig, + 0, ) assert.NoError(suite.T(), err) suite.rpcEng, err = rpcEngBuilder.WithLegacy().Build() diff --git a/engine/access/rpc/engine.go b/engine/access/rpc/engine.go index 82947be9f47..8bacf752cf9 100644 --- a/engine/access/rpc/engine.go +++ b/engine/access/rpc/engine.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/onflow/flow-go/engine/common/state_stream" "net" "net/http" "sync" @@ -21,6 +20,7 @@ import ( "github.com/onflow/flow-go/engine/access/rest" "github.com/onflow/flow-go/engine/access/rpc/backend" "github.com/onflow/flow-go/engine/common/rpc" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" "github.com/onflow/flow-go/module/component" @@ -77,7 +77,8 @@ type Engine struct { restAPIAddress net.Addr stateStreamBackend state_stream.API - stateStreamConfig state_stream.Config + eventFilterConfig state_stream.EventFilterConfig + maxGlobalStreams uint32 } // NewBuilder returns a new RPC engine builder. @@ -102,7 +103,8 @@ func NewBuilder(log zerolog.Logger, apiBurstLimits map[string]int, // the api burst limit (max calls at the same time) for each of the Access API e.g. Ping->50, GetTransaction->10 me module.Local, stateStreamBackend state_stream.API, - stateStreamConfig state_stream.Config, + eventFilterConfig state_stream.EventFilterConfig, + maxGlobalStreams uint32, ) (*RPCEngineBuilder, error) { log = log.With().Str("engine", "rpc").Logger() @@ -217,7 +219,8 @@ func NewBuilder(log zerolog.Logger, chain: chainID.Chain(), restCollector: accessMetrics, stateStreamBackend: stateStreamBackend, - stateStreamConfig: stateStreamConfig, + eventFilterConfig: eventFilterConfig, + maxGlobalStreams: maxGlobalStreams, } backendNotifierActor, backendNotifierWorker := events.NewFinalizationActor(eng.notifyBackendOnBlockFinalized) eng.backendNotifierActor = backendNotifierActor @@ -398,8 +401,8 @@ func (e *Engine) serveREST(ctx irrecoverable.SignalerContext, ready component.Re e.chain, e.restCollector, e.stateStreamBackend, - e.stateStreamConfig.EventFilterConfig, - e.stateStreamConfig.MaxGlobalStreams) + e.eventFilterConfig, + e.maxGlobalStreams) if err != nil { e.log.Err(err).Msg("failed to initialize the REST server") ctx.Throw(err) diff --git a/engine/access/rpc/rate_limit_test.go b/engine/access/rpc/rate_limit_test.go index 3cce6e97fda..5258e679a1b 100644 --- a/engine/access/rpc/rate_limit_test.go +++ b/engine/access/rpc/rate_limit_test.go @@ -20,6 +20,7 @@ import ( "google.golang.org/grpc/status" accessmock "github.com/onflow/flow-go/engine/access/mock" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/metrics" @@ -118,7 +119,7 @@ func (suite *RateLimitTestSuite) SetupTest() { suite.snapshot.On("Head").Return(block, nil) rpcEngBuilder, err := NewBuilder(suite.log, suite.state, config, suite.collClient, nil, suite.blocks, suite.headers, suite.collections, suite.transactions, nil, - nil, suite.chainID, suite.metrics, 0, 0, false, false, apiRateLimt, apiBurstLimt, suite.me) + nil, suite.chainID, suite.metrics, 0, 0, false, false, apiRateLimt, apiBurstLimt, suite.me, nil, state_stream.DefaultEventFilterConfig, 0) require.NoError(suite.T(), err) suite.rpcEng, err = rpcEngBuilder.WithLegacy().Build() require.NoError(suite.T(), err) diff --git a/engine/access/secure_grpcr_test.go b/engine/access/secure_grpcr_test.go index b82160668db..ba25b33a7cd 100644 --- a/engine/access/secure_grpcr_test.go +++ b/engine/access/secure_grpcr_test.go @@ -18,6 +18,7 @@ import ( "github.com/onflow/flow-go/crypto" accessmock "github.com/onflow/flow-go/engine/access/mock" "github.com/onflow/flow-go/engine/access/rpc" + "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/metrics" @@ -129,6 +130,9 @@ func (suite *SecureGRPCTestSuite) SetupTest() { nil, nil, suite.me, + nil, + state_stream.DefaultEventFilterConfig, + 0, ) assert.NoError(suite.T(), err) suite.rpcEng, err = rpcEngBuilder.WithLegacy().Build() diff --git a/engine/access/state_stream/handler.go b/engine/access/state_stream/handler.go index bf656b2d994..b9a5b8b55ca 100644 --- a/engine/access/state_stream/handler.go +++ b/engine/access/state_stream/handler.go @@ -19,9 +19,9 @@ type Handler struct { *state_stream.SubscribeHandler } -func NewHandler(api state_stream.API, chain flow.Chain, conf state_stream.EventFilterConfig, maxGlobalStreams uint32) *Handler { +func NewHandler(api state_stream.API, chain flow.Chain, config state_stream.EventFilterConfig, maxGlobalStreams uint32) *Handler { h := &Handler{} - h.SubscribeHandler = state_stream.NewSubscribeHandler(api, chain, conf, maxGlobalStreams) + h.SubscribeHandler = state_stream.NewSubscribeHandler(api, chain, config, maxGlobalStreams) return h } diff --git a/engine/common/state_stream/subscribe_handler.go b/engine/common/state_stream/subscribe_handler.go index 1e8da756225..621e8b9555d 100644 --- a/engine/common/state_stream/subscribe_handler.go +++ b/engine/common/state_stream/subscribe_handler.go @@ -20,11 +20,11 @@ type SubscribeHandler struct { StreamCount atomic.Int32 } -func NewSubscribeHandler(api API, chain flow.Chain, conf EventFilterConfig, maxGlobalStreams uint32) *SubscribeHandler { +func NewSubscribeHandler(api API, chain flow.Chain, config EventFilterConfig, maxGlobalStreams uint32) *SubscribeHandler { h := &SubscribeHandler{ Api: api, Chain: chain, - EventFilterConfig: conf, + EventFilterConfig: config, MaxStreams: int32(maxGlobalStreams), StreamCount: atomic.Int32{}, } From 6a68c015b2c103ade611124ba343c0a6ff069490 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Mon, 24 Jul 2023 18:25:39 +0300 Subject: [PATCH 06/35] Fixed flags --- integration/localnet/builder/bootstrap.go | 4 ++-- integration/testnet/network.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/integration/localnet/builder/bootstrap.go b/integration/localnet/builder/bootstrap.go index e77ee92b004..dc54dba8cdb 100644 --- a/integration/localnet/builder/bootstrap.go +++ b/integration/localnet/builder/bootstrap.go @@ -414,7 +414,7 @@ func prepareAccessService(container testnet.ContainerConfig, i int, n int) Servi fmt.Sprintf("--secure-rpc-addr=%s:%s", container.ContainerName, testnet.GRPCSecurePort), fmt.Sprintf("--http-addr=%s:%s", container.ContainerName, testnet.GRPCWebPort), fmt.Sprintf("--rest-addr=%s:%s", container.ContainerName, testnet.RESTPort), - fmt.Sprintf("--state_stream-addr=%s:%s", container.ContainerName, testnet.ExecutionStatePort), + fmt.Sprintf("--state-stream-addr=%s:%s", container.ContainerName, testnet.ExecutionStatePort), fmt.Sprintf("--collection-ingress-port=%s", testnet.GRPCPort), "--supports-observer=true", fmt.Sprintf("--public-network-address=%s:%s", container.ContainerName, testnet.PublicNetworkPort), @@ -423,7 +423,7 @@ func prepareAccessService(container testnet.ContainerConfig, i int, n int) Servi "--log-tx-time-to-finalized-executed", "--execution-data-sync-enabled=true", "--execution-data-dir=/data/execution-data", - fmt.Sprintf("--state_stream-addr=%s:%s", container.ContainerName, testnet.ExecutionStatePort), + fmt.Sprintf("--state-stream-addr=%s:%s", container.ContainerName, testnet.ExecutionStatePort), ) service.AddExposedPorts( diff --git a/integration/testnet/network.go b/integration/testnet/network.go index f389e97a54b..8b9522d7ba6 100644 --- a/integration/testnet/network.go +++ b/integration/testnet/network.go @@ -876,7 +876,7 @@ func (net *FlowNetwork) AddNode(t *testing.T, bootstrapDir string, nodeConf Cont nodeContainer.AddFlag("rest-addr", nodeContainer.ContainerAddr(RESTPort)) nodeContainer.exposePort(ExecutionStatePort, testingdock.RandomPort(t)) - nodeContainer.AddFlag("state_stream-addr", nodeContainer.ContainerAddr(ExecutionStatePort)) + nodeContainer.AddFlag("state-stream-addr", nodeContainer.ContainerAddr(ExecutionStatePort)) // uncomment line below to point the access node exclusively to a single collection node // nodeContainer.AddFlag("static-collection-ingress-addr", "collection_1:9000") From d6f51da79935dc2a1fc6b6972cd0a0820dddb463 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Mon, 24 Jul 2023 18:44:53 +0300 Subject: [PATCH 07/35] Updated tests --- .../node_builder/access_node_builder.go | 1 + engine/access/rest/routes/accounts_test.go | 10 +++++- engine/access/rest/routes/blocks_test.go | 24 +++++++++---- .../access/rest/routes/transactions_test.go | 36 ++----------------- 4 files changed, 29 insertions(+), 42 deletions(-) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index a3695c21ed7..9d796372431 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -181,6 +181,7 @@ func DefaultAccessNodeConfig() *AccessNodeConfig { stateStreamBackend: nil, stateStreamFilterConf: nil, execDataBroadcaster: nil, + executionDataCache: nil, ExecutionNodeAddress: "localhost:9000", logTxTimeToFinalized: false, logTxTimeToExecuted: false, diff --git a/engine/access/rest/routes/accounts_test.go b/engine/access/rest/routes/accounts_test.go index 326c35b96d2..a1803561ed3 100644 --- a/engine/access/rest/routes/accounts_test.go +++ b/engine/access/rest/routes/accounts_test.go @@ -35,7 +35,15 @@ func accountURL(t *testing.T, address string, height string) string { return u.String() } -func TestGetAccount(t *testing.T) { +// TestAccessGetAccount tests local getAccount request. +// +// Runs the following tests: +// 1. Get account by address at latest sealed block. +// 2. Get account by address at latest finalized block. +// 3. Get account by address at height. +// 4. Get account by address at height condensed. +// 5. Get invalid account. +func TestAccessGetAccount(t *testing.T) { backend := &mock.API{} stateStreamBackend := &mock_state_stream.API{} diff --git a/engine/access/rest/routes/blocks_test.go b/engine/access/rest/routes/blocks_test.go index fcdcc829129..43973958854 100644 --- a/engine/access/rest/routes/blocks_test.go +++ b/engine/access/rest/routes/blocks_test.go @@ -32,13 +32,12 @@ type testVector struct { expectedResponse string } -// TestGetBlocks tests the get blocks by ID and get blocks by heights API -func TestGetBlocks(t *testing.T) { - backend := &mock.API{} - stateStreamBackend := &mock_state_stream.API{} - - blkCnt := 10 - blockIDs, heights, blocks, executionResults := generateMocks(backend, blkCnt) +func prepareTestVectors(t *testing.T, + blockIDs []string, + heights []string, + blocks []*flow.Block, + executionResults []*flow.ExecutionResult, + blkCnt int) []testVector { singleBlockExpandedResponse := expectedBlockResponsesExpanded(blocks[:1], executionResults[:1], true, flow.BlockStatusUnknown) singleSealedBlockExpandedResponse := expectedBlockResponsesExpanded(blocks[:1], executionResults[:1], true, flow.BlockStatusSealed) @@ -139,6 +138,17 @@ func TestGetBlocks(t *testing.T) { expectedResponse: fmt.Sprintf(`{"code":400, "message": "at most %d IDs can be requested at a time"}`, request.MaxBlockRequestHeightRange), }, } + return testVectors +} + +// TestGetBlocks tests local get blocks by ID and get blocks by heights API +func TestAccessGetBlocks(t *testing.T) { + backend := &mock.API{} + stateStreamBackend := &mock_state_stream.API{} + + blkCnt := 10 + blockIDs, heights, blocks, executionResults := generateMocks(backend, blkCnt) + testVectors := prepareTestVectors(t, blockIDs, heights, blocks, executionResults, blkCnt) for _, tv := range testVectors { responseRec, err := executeRequest(tv.request, backend, stateStreamBackend) diff --git a/engine/access/rest/routes/transactions_test.go b/engine/access/rest/routes/transactions_test.go index 14f5924960e..0accd2dd72c 100644 --- a/engine/access/rest/routes/transactions_test.go +++ b/engine/access/rest/routes/transactions_test.go @@ -70,38 +70,6 @@ func createTransactionReq(body interface{}) *http.Request { return req } -func validCreateBody(tx flow.TransactionBody) map[string]interface{} { - tx.Arguments = [][]uint8{} // fix how fixture creates nil values - auth := make([]string, len(tx.Authorizers)) - for i, a := range tx.Authorizers { - auth[i] = a.String() - } - - return map[string]interface{}{ - "script": util.ToBase64(tx.Script), - "arguments": tx.Arguments, - "reference_block_id": tx.ReferenceBlockID.String(), - "gas_limit": fmt.Sprintf("%d", tx.GasLimit), - "payer": tx.Payer.String(), - "proposal_key": map[string]interface{}{ - "address": tx.ProposalKey.Address.String(), - "key_index": fmt.Sprintf("%d", tx.ProposalKey.KeyIndex), - "sequence_number": fmt.Sprintf("%d", tx.ProposalKey.SequenceNumber), - }, - "authorizers": auth, - "payload_signatures": []map[string]interface{}{{ - "address": tx.PayloadSignatures[0].Address.String(), - "key_index": fmt.Sprintf("%d", tx.PayloadSignatures[0].KeyIndex), - "signature": util.ToBase64(tx.PayloadSignatures[0].Signature), - }}, - "envelope_signatures": []map[string]interface{}{{ - "address": tx.EnvelopeSignatures[0].Address.String(), - "key_index": fmt.Sprintf("%d", tx.EnvelopeSignatures[0].KeyIndex), - "signature": util.ToBase64(tx.EnvelopeSignatures[0].Signature), - }}, - } -} - func TestGetTransactions(t *testing.T) { stateStreamBackend := &mock_state_stream.API{} @@ -377,7 +345,7 @@ func TestCreateTransaction(t *testing.T) { tx := unittest.TransactionBodyFixture() tx.PayloadSignatures = []flow.TransactionSignature{unittest.TransactionSignatureFixture()} tx.Arguments = [][]uint8{} - req := createTransactionReq(validCreateBody(tx)) + req := createTransactionReq(unittest.CreateSendTxHttpPayload(tx)) backend.Mock. On("SendTransaction", mocks.Anything, &tx). @@ -445,7 +413,7 @@ func TestCreateTransaction(t *testing.T) { for _, test := range tests { tx := unittest.TransactionBodyFixture() tx.PayloadSignatures = []flow.TransactionSignature{unittest.TransactionSignatureFixture()} - testTx := validCreateBody(tx) + testTx := unittest.CreateSendTxHttpPayload(tx) testTx[test.inputField] = test.inputValue req := createTransactionReq(testTx) From 6f298e65bac5c7723004a44b6f09dd6122e1236b Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Mon, 24 Jul 2023 19:16:29 +0300 Subject: [PATCH 08/35] Added test, added Hijack impl for response writer in metrics --- engine/access/rest/middleware/logging.go | 14 ++++++++++++++ engine/access/rest/routes/events_test.go | 10 +++++----- engine/access/rest/routes/router_test.go | 10 ++++++++++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/engine/access/rest/middleware/logging.go b/engine/access/rest/middleware/logging.go index 577843e2c86..9c524681562 100644 --- a/engine/access/rest/middleware/logging.go +++ b/engine/access/rest/middleware/logging.go @@ -1,6 +1,9 @@ package middleware import ( + "bufio" + "fmt" + "net" "net/http" "time" @@ -40,6 +43,9 @@ type responseWriter struct { statusCode int } +// http.Hijacker necessary for upgrading gorilla websocket connection for "subscribe_events" route. +var _ http.Hijacker = (*responseWriter)(nil) + func newResponseWriter(w http.ResponseWriter) *responseWriter { return &responseWriter{w, http.StatusOK} } @@ -48,3 +54,11 @@ func (rw *responseWriter) WriteHeader(code int) { rw.statusCode = code rw.ResponseWriter.WriteHeader(code) } + +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := rw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("hijacking not supported") + } + return hijacker.Hijack() +} diff --git a/engine/access/rest/routes/events_test.go b/engine/access/rest/routes/events_test.go index ee783550a65..305c330e9b7 100644 --- a/engine/access/rest/routes/events_test.go +++ b/engine/access/rest/routes/events_test.go @@ -9,16 +9,16 @@ import ( "testing" "time" + mocks "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/util" mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" - - mocks "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) func TestGetEvents(t *testing.T) { diff --git a/engine/access/rest/routes/router_test.go b/engine/access/rest/routes/router_test.go index e3c2a2c3fdd..ed73e71a230 100644 --- a/engine/access/rest/routes/router_test.go +++ b/engine/access/rest/routes/router_test.go @@ -84,6 +84,11 @@ func TestParseURL(t *testing.T) { url: "/v1/node_version_info", expected: "getNodeVersionInfo", }, + { + name: "/v1/subscribe_events", + url: "/v1/subscribe_events", + expected: "subscribeEvents", + }, } for _, tt := range tests { @@ -171,6 +176,11 @@ func TestBenchmarkParseURL(t *testing.T) { url: "/v1/node_version_info", expected: "getNodeVersionInfo", }, + { + name: "/v1/subscribe_events", + url: "/v1/subscribe_events", + expected: "subscribeEvents", + }, } for _, tt := range tests { From feb6ddad22a302617fc98673bd1ce20b72de35ed Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 28 Jul 2023 09:51:59 +0300 Subject: [PATCH 09/35] Updated subscribe_events rest route, remove subscribe_handler --- .../node_builder/access_node_builder.go | 151 +++++++++--------- cmd/observer/node_builder/observer_builder.go | 1 + .../access/rest/request/subscribe_events.go | 25 +-- engine/access/rest/routes/http_handler.go | 5 + engine/access/rest/routes/router.go | 19 ++- engine/access/rest/routes/subscribe_events.go | 148 ++++++++++++----- engine/access/rest/routes/test_helpers.go | 80 ++++++++-- .../access/rest/routes/websocket_handler.go | 52 +++--- engine/access/rpc/engine.go | 9 +- engine/access/state_stream/handler.go | 47 ++++-- engine/common/state_stream/backend.go | 5 +- .../common/state_stream/subscribe_handler.go | 43 ----- integration/tests/access/observer_test.go | 2 +- 13 files changed, 357 insertions(+), 230 deletions(-) delete mode 100644 engine/common/state_stream/subscribe_handler.go diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index 9d796372431..217f897a77c 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -18,7 +18,6 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" - "github.com/onflow/flow/protobuf/go/flow/access" "github.com/onflow/go-bitswap" "github.com/onflow/flow-go/admin/commands" @@ -89,6 +88,8 @@ import ( "github.com/onflow/flow-go/storage" bstorage "github.com/onflow/flow-go/storage/badger" "github.com/onflow/flow-go/utils/grpcutils" + + "github.com/onflow/flow/protobuf/go/flow/access" ) // AccessNodeBuilder extends cmd.NodeBuilder and declares additional functions needed to bootstrap an Access node. @@ -120,8 +121,6 @@ type AccessNodeConfig struct { rpcConf rpc.Config stateStreamBackend *cstate_stream.StateStreamBackend stateStreamConf cstate_stream.Config - execDataBroadcaster *engine.Broadcaster - executionDataCache *execdatacache.ExecutionDataCache stateStreamFilterConf map[string]int ExecutionNodeAddress string // deprecated HistoricalAccessRPCs []access.AccessAPIClient @@ -180,8 +179,6 @@ func DefaultAccessNodeConfig() *AccessNodeConfig { }, stateStreamBackend: nil, stateStreamFilterConf: nil, - execDataBroadcaster: nil, - executionDataCache: nil, ExecutionNodeAddress: "localhost:9000", logTxTimeToFinalized: false, logTxTimeToExecuted: false, @@ -436,13 +433,14 @@ func (builder *FlowAccessNodeBuilder) BuildConsensusFollower() *FlowAccessNodeBu return builder } -func (builder *FlowAccessNodeBuilder) BuildExecutionDataRequester() *FlowAccessNodeBuilder { +func (builder *FlowAccessNodeBuilder) BuildStateStreamPool() *FlowAccessNodeBuilder { var ds *badger.Datastore var bs network.BlobService var processedBlockHeight storage.ConsumerProgress var processedNotifications storage.ConsumerProgress var bsDependable *module.ProxiedReadyDoneAware var execDataDistributor *edrequester.ExecutionDataDistributor + var execDataCacheBackend *herocache.BlockExecutionData builder. AdminCommand("read-execution-data", func(config *cmd.NodeConfig) commands.AdminCommand { @@ -551,11 +549,27 @@ func (builder *FlowAccessNodeBuilder) BuildExecutionDataRequester() *FlowAccessN execDataDistributor = edrequester.NewExecutionDataDistributor() + var heroCacheCollector module.HeroCacheMetrics = metrics.NewNoopCollector() + if builder.HeroCacheMetricsEnable { + heroCacheCollector = metrics.AccessNodeExecutionDataCacheMetrics(builder.MetricsRegisterer) + } + + execDataCacheBackend = herocache.NewBlockExecutionData(builder.stateStreamConf.ExecutionDataCacheSize, builder.Logger, heroCacheCollector) + // Execution Data cache with a downloader as the backend. This is used by the requester + // to download and cache execution data for each block. + executionDataCache := execdatacache.NewExecutionDataCache( + builder.ExecutionDataDownloader, + builder.Storage.Headers, + builder.Storage.Seals, + builder.Storage.Results, + execDataCacheBackend, + ) + builder.ExecutionDataRequester = edrequester.New( builder.Logger, metrics.NewExecutionDataRequesterCollector(), builder.ExecutionDataDownloader, - builder.executionDataCache, + executionDataCache, processedBlockHeight, processedNotifications, builder.State, @@ -571,16 +585,61 @@ func (builder *FlowAccessNodeBuilder) BuildExecutionDataRequester() *FlowAccessN if builder.stateStreamConf.ListenAddr != "" { builder.Component("exec state stream engine", func(node *cmd.NodeConfig) (module.ReadyDoneAware, error) { + for key, value := range builder.stateStreamFilterConf { + switch key { + case "EventTypes": + builder.stateStreamConf.MaxEventTypes = value + case "Addresses": + builder.stateStreamConf.MaxAddresses = value + case "Contracts": + builder.stateStreamConf.MaxContracts = value + } + } + builder.stateStreamConf.RpcMetricsEnabled = builder.rpcMetricsEnabled + + // Execution Data cache that uses a blobstore as the backend (instead of a downloader) + // This ensures that it simply returns a not found error if the blob doesn't exist + // instead of attempting to download it from the network. It shares a cache backend instance + // with the requester's implementation. + executionDataCache := execdatacache.NewExecutionDataCache( + builder.ExecutionDataStore, + builder.Storage.Headers, + builder.Storage.Seals, + builder.Storage.Results, + execDataCacheBackend, + ) + + highestAvailableHeight, err := builder.ExecutionDataRequester.HighestConsecutiveHeight() + if err != nil { + return nil, fmt.Errorf("could not get highest consecutive height: %w", err) + } + broadcaster := engine.NewBroadcaster() + + builder.stateStreamBackend, err = cstate_stream.New(node.Logger, + builder.stateStreamConf, + node.State, + node.Storage.Headers, + node.Storage.Seals, + node.Storage.Results, + builder.ExecutionDataStore, + executionDataCache, + broadcaster, + builder.executionDataConfig.InitialBlockHeight, + highestAvailableHeight) + + if err != nil { + return nil, fmt.Errorf("could not create state stream backend: %w", err) + } stateStreamEng, err := state_stream.NewEng( node.Logger, builder.stateStreamConf, - builder.executionDataCache, + executionDataCache, node.Storage.Headers, node.RootChainID, builder.stateStreamGrpcServer, builder.stateStreamBackend, - builder.execDataBroadcaster, + broadcaster, ) if err != nil { return nil, fmt.Errorf("could not create state stream engine: %w", err) @@ -859,6 +918,10 @@ func (builder *FlowAccessNodeBuilder) enqueueRelayNetwork() { } func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { + if builder.executionDataSyncEnabled { + builder.BuildStateStreamPool() + } + builder. BuildConsensusFollower(). Module("collection node client", func(node *cmd.NodeConfig) error { @@ -994,72 +1057,6 @@ func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { return nil }). - Module("state stream backend", func(node *cmd.NodeConfig) error { - for key, value := range builder.stateStreamFilterConf { - switch key { - case "EventTypes": - builder.stateStreamConf.MaxEventTypes = value - case "Addresses": - builder.stateStreamConf.MaxAddresses = value - case "Contracts": - builder.stateStreamConf.MaxContracts = value - } - } - builder.stateStreamConf.RpcMetricsEnabled = builder.rpcMetricsEnabled - - var heroCacheCollector module.HeroCacheMetrics = metrics.NewNoopCollector() - if builder.HeroCacheMetricsEnable { - heroCacheCollector = metrics.AccessNodeExecutionDataCacheMetrics(builder.MetricsRegisterer) - } - - execDataCacheBackend := herocache.NewBlockExecutionData(builder.stateStreamConf.ExecutionDataCacheSize, builder.Logger, heroCacheCollector) - // Execution Data cache with a downloader as the backend. This is used by the requester - // to download and cache execution data for each block. - builder.executionDataCache = execdatacache.NewExecutionDataCache( - builder.ExecutionDataDownloader, - builder.Storage.Headers, - builder.Storage.Seals, - builder.Storage.Results, - execDataCacheBackend, - ) - - // Execution Data cache that uses a blobstore as the backend (instead of a downloader) - // This ensures that it simply returns a not found error if the blob doesn't exist - // instead of attempting to download it from the network. It shares a cache backend instance - // with the requester's implementation. - builder.executionDataCache = execdatacache.NewExecutionDataCache( - builder.ExecutionDataStore, - builder.Storage.Headers, - builder.Storage.Seals, - builder.Storage.Results, - execDataCacheBackend, - ) - - highestAvailableHeight, err := builder.ExecutionDataRequester.HighestConsecutiveHeight() - if err != nil { - return fmt.Errorf("could not get highest consecutive height: %w", err) - } - - builder.execDataBroadcaster = engine.NewBroadcaster() - - builder.stateStreamBackend, err = cstate_stream.New( - node.Logger, - builder.stateStreamConf, - node.State, - node.Storage.Headers, - node.Storage.Seals, - node.Storage.Results, - builder.ExecutionDataStore, - builder.executionDataCache, - builder.execDataBroadcaster, - builder.executionDataConfig.InitialBlockHeight, - highestAvailableHeight, - ) - if err != nil { - return fmt.Errorf("could not create state stream backend: %w", err) - } - return nil - }). Component("RPC engine", func(node *cmd.NodeConfig) (module.ReadyDoneAware, error) { config := builder.rpcConf backendConfig := config.BackendConfig @@ -1205,10 +1202,6 @@ func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { }) } - if builder.executionDataSyncEnabled { - builder.BuildExecutionDataRequester() - } - builder.Component("secure grpc server", func(node *cmd.NodeConfig) (module.ReadyDoneAware, error) { return builder.secureGrpcServer, nil }) diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index be361789472..dd49ed99d85 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -945,6 +945,7 @@ func (builder *ObserverServiceBuilder) enqueueRPCServer() { observerCollector, node.RootChainID.Chain()) if err != nil { + return nil, err } diff --git a/engine/access/rest/request/subscribe_events.go b/engine/access/rest/request/subscribe_events.go index efe6517c102..ac73a1842df 100644 --- a/engine/access/rest/request/subscribe_events.go +++ b/engine/access/rest/request/subscribe_events.go @@ -30,24 +30,29 @@ func (g *SubscribeEvents) Build(r *Request) error { ) } -func (g *SubscribeEvents) Parse(rawBlockID string, rawStart string, rawTypes []string, rawAddresses []string, rawContracts []string) error { +func (g *SubscribeEvents) Parse(rawStartBlockID string, rawStartHeight string, rawTypes []string, rawAddresses []string, rawContracts []string) error { + var startBlockID ID + err := startBlockID.Parse(rawStartBlockID) + if err != nil { + return err + } + g.StartBlockID = startBlockID.Flow() + var height Height - err := height.Parse(rawStart) + err = height.Parse(rawStartHeight) if err != nil { return fmt.Errorf("invalid start height: %w", err) } g.StartHeight = height.Flow() - var startBlockID ID - err = startBlockID.Parse(rawBlockID) - if err != nil { - return err + // if both start_block_id and start_height are provided + if g.StartBlockID != flow.ZeroID && g.StartHeight != EmptyHeight { + return fmt.Errorf("can only provide either block ID or start height") } - g.StartBlockID = startBlockID.Flow() - // if both height and one or both of start and end height are provided - if len(startBlockID) > 0 && g.StartHeight != EmptyHeight { - return fmt.Errorf("can only provide either block ID or start height range") + // default to root block + if g.StartHeight == EmptyHeight { + g.StartHeight = 0 } var eventTypes EventTypes diff --git a/engine/access/rest/routes/http_handler.go b/engine/access/rest/routes/http_handler.go index 47af5c5e1ca..f6a190ba0ad 100644 --- a/engine/access/rest/routes/http_handler.go +++ b/engine/access/rest/routes/http_handler.go @@ -85,6 +85,11 @@ func (h *HttpHandler) errorHandler(w http.ResponseWriter, err error, errorLogger h.errorResponse(w, http.StatusBadRequest, msg, errorLogger) return } + if se.Code() == codes.Unavailable { + msg := fmt.Sprintf("Failed to process request: %s", se.Message()) + h.errorResponse(w, http.StatusServiceUnavailable, msg, errorLogger) + return + } } // stop going further - catch all error diff --git a/engine/access/rest/routes/router.go b/engine/access/rest/routes/router.go index 9f5a4eb9f36..82eb54c3a31 100644 --- a/engine/access/rest/routes/router.go +++ b/engine/access/rest/routes/router.go @@ -44,13 +44,16 @@ func NewRouter(backend access.API, Handler(h) } - for _, r := range WSRoutes { - h := NewWSHandler(logger, r.Handler, chain, stateStreamApi, eventFilterConfig, maxGlobalStreams) - v1SubRouter. - Methods(r.Method). - Path(r.Pattern). - Name(r.Name). - Handler(h) + // Note: we add subscribe routes only if stateStreamApi is available + if stateStreamApi != nil { + for _, r := range WSRoutes { + h := NewWSHandler(logger, r.Handler, chain, stateStreamApi, eventFilterConfig, maxGlobalStreams) + v1SubRouter. + Methods(r.Method). + Path(r.Pattern). + Name(r.Name). + Handler(h) + } } return router, nil @@ -143,7 +146,7 @@ var Routes = []route{{ }} var WSRoutes = []wsroute{{ - Method: http.MethodPost, + Method: http.MethodGet, Pattern: "/subscribe_events", Name: "subscribeEvents", Handler: SubscribeEvents, diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index 25face152be..92a2ecdbacc 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -1,85 +1,151 @@ package routes import ( + "context" "fmt" "net/http" + "sync/atomic" + "time" "github.com/gorilla/websocket" - - executiondata "github.com/onflow/flow/protobuf/go/flow/executiondata" + "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/rest/models" "github.com/onflow/flow-go/engine/access/rest/request" - "github.com/onflow/flow-go/engine/common/rpc/convert" "github.com/onflow/flow-go/engine/common/state_stream" ) -func SubscribeEvents(r *request.Request, w http.ResponseWriter, h *state_stream.SubscribeHandler) (interface{}, error) { +const ( + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 +) + +func SubscribeEvents(r *request.Request, + w http.ResponseWriter, + logger zerolog.Logger, + api state_stream.API, + eventFilterConfig state_stream.EventFilterConfig, + maxStreams int32, + streamCount *atomic.Int32, + errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), + jsonResponse func(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger)) { req, err := r.SubscribeEventsRequest() if err != nil { - return nil, models.NewBadRequestError(err) + errorHandler(w, models.NewBadRequestError(err), logger) + return + } + + logger = logger.With().Str("subscribe events", r.URL.String()).Logger() + if streamCount.Load() >= maxStreams { + err := fmt.Errorf("maximum number of streams reached") + errorHandler(w, models.NewRestError(http.StatusServiceUnavailable, "maximum number of streams reached", err), logger) + return } // Upgrade the HTTP connection to a WebSocket connection upgrader := websocket.Upgrader{} conn, err := upgrader.Upgrade(w, r.Request, nil) if err != nil { - err = fmt.Errorf("webSocket upgrade error: %s", err) - return nil, err + errorHandler(w, models.NewRestError(http.StatusInternalServerError, "webSocket upgrade error: ", err), logger) + return } - defer conn.Close() - var filter state_stream.EventFilter // Retrieve the filter parameters from the request, if provided - - filter, err = state_stream.NewEventFilter( - h.EventFilterConfig, + filter, err := state_stream.NewEventFilter( + eventFilterConfig, r.Chain, req.EventTypes, req.Addresses, req.Contracts, ) if err != nil { - err = fmt.Errorf("invalid event filter: %s", err) - return nil, err + errorHandler(w, models.NewRestError(http.StatusInternalServerError, "create event filter error: ", err), logger) + return } - sub, err := h.SubscribeEvents(r.Context(), req.StartBlockID, req.StartHeight, filter) + streamCount.Add(1) // Write messages to the WebSocket connection - writeToWebSocket := func(resp *state_stream.EventsResponse) error { - // Prepare the response message - response := &executiondata.SubscribeEventsResponse{ - BlockHeight: resp.Height, - BlockId: convert.IdentifierToMessage(resp.BlockID), - Events: convert.EventsToMessages(resp.Events), - } + go writeEvents(logger, w, req, r.Context(), conn, api, filter, errorHandler, streamCount) + time.Sleep(1 * time.Second) // wait for creating child context in goroutine + jsonResponse(w, http.StatusOK, "{}", logger) +} + +func writeEvents( + log zerolog.Logger, + w http.ResponseWriter, + req request.SubscribeEvents, + c context.Context, + conn *websocket.Conn, + api state_stream.API, + filter state_stream.EventFilter, + errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), + streamCount *atomic.Int32, +) { + ticker := time.NewTicker(pingPeriod) + ctx, cancel := context.WithCancel(c) - // Send the response message over the WebSocket connection - return conn.WriteJSON(response) + sub := api.SubscribeEvents(ctx, req.StartBlockID, req.StartHeight, filter) + defer func() { + ticker.Stop() + streamCount.Add(-1) + conn.Close() + cancel() + }() + err := conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message + if err != nil { + errorHandler(w, models.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err), log) + return } + conn.SetPongHandler(func(string) error { + err = conn.SetReadDeadline(time.Now().Add(pongWait)) // Reset the read deadline upon receiving a pong message + if err != nil { + errorHandler(w, models.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err), log) + conn.Close() + return err + } + return nil + }) for { - v, ok := <-sub.Channel() - if !ok { - if sub.Err() != nil { - err = fmt.Errorf("stream encountered an error: %w", sub.Err()) - return nil, err + select { + case v, ok := <-sub.Channel(): + if !ok { + if sub.Err() != nil { + err := fmt.Errorf("stream encountered an error: %w", sub.Err()) + errorHandler(w, models.NewBadRequestError(err), log) + conn.Close() + return + } + err := fmt.Errorf("subscription channel closed, no error occurred") + errorHandler(w, err, log) + conn.Close() + return } - return nil, err - } - resp, ok := v.(*state_stream.EventsResponse) - if !ok { - err = fmt.Errorf("unexpected response type: %s", v) - return nil, err - } + resp, ok := v.(*state_stream.EventsResponse) + if !ok { + err := fmt.Errorf("unexpected response type: %T", v) + errorHandler(w, err, log) + conn.Close() + return + } - // Write the response to the WebSocket connection - err := writeToWebSocket(resp) - if err != nil { - err = fmt.Errorf("failed to send response: %w", err) - return nil, err + // Write the response to the WebSocket connection + err := conn.WriteJSON(resp) + if err != nil { + errorHandler(w, err, log) + conn.Close() + return + } + case <-ticker.C: + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + conn.Close() + return + } } } } diff --git a/engine/access/rest/routes/test_helpers.go b/engine/access/rest/routes/test_helpers.go index 2914780bc2f..81d0effe069 100644 --- a/engine/access/rest/routes/test_helpers.go +++ b/engine/access/rest/routes/test_helpers.go @@ -1,20 +1,27 @@ package routes import ( + "bufio" "bytes" "fmt" - "net/http" - "net/http/httptest" - "testing" + "io" + "strings" + "time" + "github.com/gorilla/mux" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + //"io" + "net" + "net/http" + "net/http/httptest" + "testing" + "github.com/onflow/flow-go/access/mock" mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" common_state_stream "github.com/onflow/flow-go/engine/common/state_stream" - "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/metrics" ) @@ -27,9 +34,57 @@ const ( startHeightQueryParam = "start_height" endHeightQueryParam = "end_height" heightQueryParam = "height" + startBlockIdQueryParam = "start_block_id" + eventTypesQueryParams = "event_types" + addressesQueryParams = "addresses" + contractsQueryParams = "contracts" ) -func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_state_stream.API) (*httptest.ResponseRecorder, error) { +type fakeNetConn struct { + io.Reader + io.Writer +} + +func (c fakeNetConn) Close() error { return nil } +func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } +func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } +func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type fakeAddr int + +var ( + localAddr = fakeAddr(1) + remoteAddr = fakeAddr(2) +) + +func (a fakeAddr) Network() string { + return "net" +} + +func (a fakeAddr) String() string { + return "str" +} + +type HijackResponseRecorder struct { + *httptest.ResponseRecorder + brw *bufio.ReadWriter +} + +func (w *HijackResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, w.brw, nil +} + +func NewHijackResponseRecorder(brw *bufio.ReadWriter) *HijackResponseRecorder { + responseRecorder := &HijackResponseRecorder{ + brw: brw, + } + responseRecorder.ResponseRecorder = httptest.NewRecorder() + return responseRecorder +} + +func newRouter(backend *mock.API, stateStreamApi *mock_state_stream.API) (*mux.Router, error) { var b bytes.Buffer logger := zerolog.New(&b) restCollector := metrics.NewNoopCollector() @@ -39,20 +94,27 @@ func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_s MaxGlobalStreams: common_state_stream.DefaultMaxGlobalStreams, } - router, err := NewRouter(backend, + return NewRouter(backend, logger, flow.Testnet.Chain(), restCollector, stateStreamApi, stateStreamConfig.EventFilterConfig, stateStreamConfig.MaxGlobalStreams) +} + +func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_state_stream.API) (*httptest.ResponseRecorder, error) { + router, err := newRouter(backend, stateStreamApi) if err != nil { return nil, err } - rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) - return rr, nil + br := bufio.NewReaderSize(strings.NewReader(""), common_state_stream.DefaultSendBufferSize) + bw := bufio.NewWriterSize(&bytes.Buffer{}, common_state_stream.DefaultSendBufferSize) + resp := NewHijackResponseRecorder(bufio.NewReadWriter(br, bw)) + + router.ServeHTTP(resp, req) + return resp.ResponseRecorder, nil } func assertOKResponse(t *testing.T, req *http.Request, expectedRespBody string, backend *mock.API, stateStreamApi *mock_state_stream.API) { diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 8b148d50f4f..095a281f68d 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -2,11 +2,11 @@ package routes import ( "net/http" + "sync/atomic" "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/rest/request" - "github.com/onflow/flow-go/engine/access/rest/util" "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" ) @@ -16,15 +16,26 @@ import ( type SubscribeHandlerFunc func( r *request.Request, w http.ResponseWriter, - h *state_stream.SubscribeHandler, -) (interface{}, error) -// WSHandler is websocket handler implementing custom handler function and allows easier handling of errors and + logger zerolog.Logger, + api state_stream.API, + eventFilterConfig state_stream.EventFilterConfig, + maxStreams int32, + streamCount *atomic.Int32, + errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), + jsonResponse func(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger), +) + +// WSHandler is websocket handler implementing custom websocket handler function and allows easier handling of errors and // responses as it wraps functionality for handling error and responses outside of endpoint handling. type WSHandler struct { *HttpHandler - *state_stream.SubscribeHandler subscribeFunc SubscribeHandlerFunc + + api state_stream.API + eventFilterConfig state_stream.EventFilterConfig + maxStreams int32 + streamCount atomic.Int32 } func NewWSHandler( @@ -36,10 +47,14 @@ func NewWSHandler( maxGlobalStreams uint32, ) *WSHandler { handler := &WSHandler{ - subscribeFunc: subscribeFunc, + subscribeFunc: subscribeFunc, + api: api, + eventFilterConfig: eventFilterConfig, + maxStreams: int32(maxGlobalStreams), + streamCount: atomic.Int32{}, } handler.HttpHandler = NewHttpHandler(logger, chain) - handler.SubscribeHandler = state_stream.NewSubscribeHandler(api, chain, eventFilterConfig, maxGlobalStreams) + return handler } @@ -55,19 +70,14 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } decoratedRequest := request.Decorate(r, h.HttpHandler.Chain) - response, err := h.subscribeFunc(decoratedRequest, w, h.SubscribeHandler) - if err != nil { - h.errorHandler(w, err, errLog) - return - } - - // apply the select filter if any select fields have been specified - response, err = util.SelectFilter(response, decoratedRequest.Selects()) - if err != nil { - h.errorHandler(w, err, errLog) - return - } + h.subscribeFunc(decoratedRequest, + w, + errLog, + h.api, + h.eventFilterConfig, + h.maxStreams, + &h.streamCount, + h.errorHandler, + h.jsonResponse) - // write response to response stream - h.jsonResponse(w, http.StatusOK, response, errLog) } diff --git a/engine/access/rpc/engine.go b/engine/access/rpc/engine.go index a4c706de66c..e64c8780e30 100644 --- a/engine/access/rpc/engine.go +++ b/engine/access/rpc/engine.go @@ -223,7 +223,14 @@ func (e *Engine) serveREST(ctx irrecoverable.SignalerContext, ready component.Re e.log.Info().Str("rest_api_address", e.config.RESTListenAddr).Msg("starting REST server on address") - r, err := rest.NewServer(e.restHandler, e.config.RESTListenAddr, e.log, e.chain, e.restCollector, e.stateStreamBackend, e.eventFilterConfig, e.maxGlobalStreams) + r, err := rest.NewServer(e.restHandler, + e.config.RESTListenAddr, + e.log, + e.chain, + e.restCollector, + e.stateStreamBackend, + e.eventFilterConfig, + e.maxGlobalStreams) if err != nil { e.log.Err(err).Msg("failed to initialize the REST server") ctx.Throw(err) diff --git a/engine/access/state_stream/handler.go b/engine/access/state_stream/handler.go index b9a5b8b55ca..446f71adb32 100644 --- a/engine/access/state_stream/handler.go +++ b/engine/access/state_stream/handler.go @@ -2,6 +2,7 @@ package state_stream import ( "context" + "sync/atomic" access "github.com/onflow/flow/protobuf/go/flow/executiondata" executiondata "github.com/onflow/flow/protobuf/go/flow/executiondata" @@ -16,12 +17,23 @@ import ( ) type Handler struct { - *state_stream.SubscribeHandler + api state_stream.API + chain flow.Chain + + eventFilterConfig state_stream.EventFilterConfig + + maxStreams int32 + streamCount atomic.Int32 } -func NewHandler(api state_stream.API, chain flow.Chain, config state_stream.EventFilterConfig, maxGlobalStreams uint32) *Handler { - h := &Handler{} - h.SubscribeHandler = state_stream.NewSubscribeHandler(api, chain, config, maxGlobalStreams) +func NewHandler(api state_stream.API, chain flow.Chain, conf state_stream.EventFilterConfig, maxGlobalStreams uint32) *Handler { + h := &Handler{ + api: api, + chain: chain, + eventFilterConfig: conf, + maxStreams: int32(maxGlobalStreams), + streamCount: atomic.Int32{}, + } return h } @@ -31,7 +43,7 @@ func (h *Handler) GetExecutionDataByBlockID(ctx context.Context, request *access return nil, status.Errorf(codes.InvalidArgument, "could not convert block ID: %v", err) } - execData, err := h.Api.GetExecutionDataByBlockID(ctx, blockID) + execData, err := h.api.GetExecutionDataByBlockID(ctx, blockID) if err != nil { return nil, rpc.ConvertError(err, "could no get execution data", codes.Internal) } @@ -46,11 +58,11 @@ func (h *Handler) GetExecutionDataByBlockID(ctx context.Context, request *access func (h *Handler) SubscribeExecutionData(request *access.SubscribeExecutionDataRequest, stream access.ExecutionDataAPI_SubscribeExecutionDataServer) error { // check if the maximum number of streams is reached - if h.StreamCount.Load() >= h.MaxStreams { + if h.streamCount.Load() >= h.maxStreams { return status.Errorf(codes.ResourceExhausted, "maximum number of streams reached") } - h.StreamCount.Add(1) - defer h.StreamCount.Add(-1) + h.streamCount.Add(1) + defer h.streamCount.Add(-1) startBlockID := flow.ZeroID if request.GetStartBlockId() != nil { @@ -61,7 +73,7 @@ func (h *Handler) SubscribeExecutionData(request *access.SubscribeExecutionDataR startBlockID = blockID } - sub := h.Api.SubscribeExecutionData(stream.Context(), startBlockID, request.GetStartBlockHeight()) + sub := h.api.SubscribeExecutionData(stream.Context(), startBlockID, request.GetStartBlockHeight()) for { v, ok := <-sub.Channel() @@ -93,6 +105,13 @@ func (h *Handler) SubscribeExecutionData(request *access.SubscribeExecutionDataR } func (h *Handler) SubscribeEvents(request *access.SubscribeEventsRequest, stream access.ExecutionDataAPI_SubscribeEventsServer) error { + // check if the maximum number of streams is reached + if h.streamCount.Load() >= h.maxStreams { + return status.Errorf(codes.ResourceExhausted, "maximum number of streams reached") + } + h.streamCount.Add(1) + defer h.streamCount.Add(-1) + startBlockID := flow.ZeroID if request.GetStartBlockId() != nil { blockID, err := convert.BlockID(request.GetStartBlockId()) @@ -107,8 +126,8 @@ func (h *Handler) SubscribeEvents(request *access.SubscribeEventsRequest, stream var err error reqFilter := request.GetFilter() filter, err = state_stream.NewEventFilter( - h.EventFilterConfig, - h.Chain, + h.eventFilterConfig, + h.chain, reqFilter.GetEventType(), reqFilter.GetAddress(), reqFilter.GetContract(), @@ -117,11 +136,7 @@ func (h *Handler) SubscribeEvents(request *access.SubscribeEventsRequest, stream return status.Errorf(codes.InvalidArgument, "invalid event filter: %v", err) } } - - sub, err := h.SubscribeHandler.SubscribeEvents(stream.Context(), startBlockID, request.GetStartBlockHeight(), filter) - if err != nil { - return err - } + sub := h.api.SubscribeEvents(stream.Context(), startBlockID, request.GetStartBlockHeight(), filter) for { v, ok := <-sub.Channel() diff --git a/engine/common/state_stream/backend.go b/engine/common/state_stream/backend.go index 2eb1f33d107..0593472d162 100644 --- a/engine/common/state_stream/backend.go +++ b/engine/common/state_stream/backend.go @@ -187,7 +187,10 @@ func (b *StateStreamBackend) getStartHeight(startBlockID flow.Identifier, startH // if the start block is the root block, there will not be an execution data. skip it and // begin from the next block. // Note: we can skip the block lookup since it was already done in the constructor - if startBlockID == b.rootBlockID || startHeight == b.rootBlockHeight { + if startBlockID == b.rootBlockID || + // Note: when startBlockID is provided and startHeight no needed then startHeight should be 0, otherwise, an + // InvalidArgument error is returned above, so we need also check if startBlockID provided before skip it. + (startHeight == b.rootBlockHeight && startBlockID == flow.ZeroID) { return b.rootBlockHeight + 1, nil } diff --git a/engine/common/state_stream/subscribe_handler.go b/engine/common/state_stream/subscribe_handler.go deleted file mode 100644 index 621e8b9555d..00000000000 --- a/engine/common/state_stream/subscribe_handler.go +++ /dev/null @@ -1,43 +0,0 @@ -package state_stream - -import ( - "context" - "sync/atomic" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/onflow/flow-go/model/flow" -) - -type SubscribeHandler struct { - Api API - Chain flow.Chain - - EventFilterConfig EventFilterConfig - - MaxStreams int32 - StreamCount atomic.Int32 -} - -func NewSubscribeHandler(api API, chain flow.Chain, config EventFilterConfig, maxGlobalStreams uint32) *SubscribeHandler { - h := &SubscribeHandler{ - Api: api, - Chain: chain, - EventFilterConfig: config, - MaxStreams: int32(maxGlobalStreams), - StreamCount: atomic.Int32{}, - } - return h -} - -func (h *SubscribeHandler) SubscribeEvents(ctx context.Context, startBlockID flow.Identifier, startBlockHeight uint64, filter EventFilter) (Subscription, error) { - // check if the maximum number of streams is reached - if h.StreamCount.Load() >= h.MaxStreams { - return nil, status.Errorf(codes.ResourceExhausted, "maximum number of streams reached") - } - h.StreamCount.Add(1) - defer h.StreamCount.Add(-1) - - return h.Api.SubscribeEvents(ctx, startBlockID, startBlockHeight, filter), nil -} diff --git a/integration/tests/access/observer_test.go b/integration/tests/access/observer_test.go index 25bfeab2f3a..771f52029b5 100644 --- a/integration/tests/access/observer_test.go +++ b/integration/tests/access/observer_test.go @@ -227,7 +227,7 @@ func (s *ObserverSuite) TestObserverRest() { require.NoError(t, err) t.Run("HandledByUpstream", func(t *testing.T) { - // verify that we receive StatusInternalServerError, StatusServiceUnavailable errors from all rests handled upstream + // verify that we receive StatusServiceUnavailable errors from all rests handled upstream for _, endpoint := range s.getRestEndpoints() { if _, local := s.localRest[endpoint.name]; local { continue From 88da0b1c9c41b0844229a1a0bc8ea6bbd0a9f6d6 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 2 Aug 2023 10:46:34 +0300 Subject: [PATCH 10/35] Added unit tests --- engine/access/rest/routes/subscribe_events.go | 24 +- .../rest/routes/subscribe_events_test.go | 278 ++++++++++++++++++ engine/access/rest/routes/test_helpers.go | 10 +- integration/go.mod | 2 +- 4 files changed, 301 insertions(+), 13 deletions(-) create mode 100644 engine/access/rest/routes/subscribe_events_test.go diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index 92a2ecdbacc..8378df22094 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -23,6 +23,7 @@ const ( pingPeriod = (pongWait * 9) / 10 ) +// SubscribeEvents create websocket connection and write to it requested events. func SubscribeEvents(r *request.Request, w http.ResponseWriter, logger zerolog.Logger, @@ -62,16 +63,25 @@ func SubscribeEvents(r *request.Request, req.Contracts, ) if err != nil { - errorHandler(w, models.NewRestError(http.StatusInternalServerError, "create event filter error: ", err), logger) + err := fmt.Errorf("event filter error") + errorHandler(w, models.NewBadRequestError(err), logger) return } streamCount.Add(1) // Write messages to the WebSocket connection - go writeEvents(logger, w, req, r.Context(), conn, api, filter, errorHandler, streamCount) - time.Sleep(1 * time.Second) // wait for creating child context in goroutine - jsonResponse(w, http.StatusOK, "{}", logger) + go writeEvents(logger, + w, + req, + r.Context(), + conn, + api, + filter, + errorHandler, + jsonResponse, + streamCount) + time.Sleep(2 * time.Second) // wait for creating child context in goroutine } func writeEvents( @@ -83,6 +93,7 @@ func writeEvents( api state_stream.API, filter state_stream.EventFilter, errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), + jsonResponse func(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger), streamCount *atomic.Int32, ) { ticker := time.NewTicker(pingPeriod) @@ -115,13 +126,13 @@ func writeEvents( case v, ok := <-sub.Channel(): if !ok { if sub.Err() != nil { - err := fmt.Errorf("stream encountered an error: %w", sub.Err()) + err := fmt.Errorf("stream encountered an error: %v", sub.Err()) errorHandler(w, models.NewBadRequestError(err), log) conn.Close() return } err := fmt.Errorf("subscription channel closed, no error occurred") - errorHandler(w, err, log) + errorHandler(w, models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err), log) conn.Close() return } @@ -141,6 +152,7 @@ func writeEvents( conn.Close() return } + jsonResponse(w, http.StatusOK, "", log) case <-ticker.C: if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { conn.Close() diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go new file mode 100644 index 00000000000..0d96cb56e43 --- /dev/null +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -0,0 +1,278 @@ +package routes + +import ( + "fmt" + "net/http" + "net/url" + "strings" + "testing" + + "golang.org/x/exp/slices" + + "github.com/stretchr/testify/assert" + mocks "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/onflow/flow-go/access/mock" + "github.com/onflow/flow-go/engine/access/rest/request" + mockstatestream "github.com/onflow/flow-go/engine/access/state_stream/mock" + "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" +) + +type testType struct { + name string + startBlockID flow.Identifier + startHeight uint64 + + eventTypes []string + addresses []string + contracts []string +} + +var testEventTypes = []flow.EventType{ + "A.0123456789abcdef.flow.event", + "B.0123456789abcdef.flow.event", + "C.0123456789abcdef.flow.event", +} + +type SubscribeEventsSuite struct { + suite.Suite + + blocks []*flow.Block + blockEvents map[flow.Identifier]flow.EventsList +} + +func TestSubscribeEventsSuite(t *testing.T) { + suite.Run(t, new(SubscribeEventsSuite)) +} + +func (s *SubscribeEventsSuite) SetupTest() { + rootBlock := unittest.BlockFixture() + parent := rootBlock.Header + + blockCount := 5 + + s.blocks = make([]*flow.Block, 0, blockCount) + s.blockEvents = make(map[flow.Identifier]flow.EventsList, blockCount) + + for i := 0; i < blockCount; i++ { + block := unittest.BlockWithParentFixture(parent) + // update for next iteration + parent = block.Header + + result := unittest.ExecutionResultFixture() + blockEvents := unittest.BlockEventsFixture(block.Header, (i%len(testEventTypes))*3+1, testEventTypes...) + + s.blocks = append(s.blocks, block) + s.blockEvents[block.ID()] = blockEvents.Events + + s.T().Logf("adding exec data for block %d %d %v => %v", i, block.Header.Height, block.ID(), result.ExecutionDataID) + } +} + +func (s *SubscribeEventsSuite) TestSubscribeEvents() { + testVectors := []testType{ + { + name: "happy path - all events from root height", + startBlockID: flow.ZeroID, + startHeight: request.EmptyHeight, + }, + { + name: "happy path - all events from startHeight", + startBlockID: flow.ZeroID, + startHeight: s.blocks[0].Header.Height, + }, + { + name: "happy path - all events from startBlockID", + startBlockID: s.blocks[0].ID(), + startHeight: request.EmptyHeight, + }, + } + chain := flow.MonotonicEmulator.Chain() + + // create variations for each of the base test + tests := make([]testType, 0, len(testVectors)*2) + for _, test := range testVectors { + t1 := test + t1.name = fmt.Sprintf("%s - all events", test.name) + tests = append(tests, t1) + + t2 := test + t2.name = fmt.Sprintf("%s - some events", test.name) + t2.eventTypes = []string{string(testEventTypes[0])} + tests = append(tests, t2) + } + + for _, test := range tests { + s.Run(test.name, func() { + stateStreamBackend := &mockstatestream.API{} + backend := &mock.API{} + + subscription := &mockstatestream.Subscription{} + + expectedEvents := flow.EventsList{} + for _, event := range s.blockEvents[s.blocks[0].ID()] { + if slices.Contains(test.eventTypes, string(event.Type)) { + expectedEvents = append(expectedEvents, event) + } + } + + // Create a channel to receive mock EventsResponse objects + ch := make(chan interface{}) + var chReadOnly <-chan interface{} + expectedEventsResponses := []*state_stream.EventsResponse{} + + for i, b := range s.blocks { + s.T().Logf("checking block %d %v", i, b.ID()) + + //simulate EventsResponse + eventResponse := &state_stream.EventsResponse{ + Height: b.Header.Height, + BlockID: b.ID(), + Events: expectedEvents, + } + expectedEventsResponses = append(expectedEventsResponses, eventResponse) + } + + // Simulate sending a mock EventsResponse + go func() { + for _, eventResponse := range expectedEventsResponses { + // Send the mock EventsResponse through the channel + ch <- eventResponse + } + }() + + chReadOnly = ch + + subscription.Mock.On("Channel").Return(chReadOnly) + subscription.Mock.On("Err").Return(nil) + + filter, err := state_stream.NewEventFilter(state_stream.DefaultEventFilterConfig, chain, test.eventTypes, test.addresses, test.contracts) + assert.NoError(s.T(), err) + var startHeight uint64 + if test.startHeight == request.EmptyHeight { + startHeight = 0 + } else { + startHeight = test.startHeight + } + stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, test.startBlockID, startHeight, filter).Return(subscription) + + req := getSubscribeEventsRequest(s.T(), test.startBlockID, test.startHeight, test.eventTypes, test.addresses, test.contracts) + rr, err := executeRequest(req, backend, stateStreamBackend) + assert.NoError(s.T(), err) + assert.Equal(s.T(), http.StatusOK, rr.Code) + }) + } +} + +func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { + s.Run("returns error for block id and height", func() { + stateStreamBackend := &mockstatestream.API{} + backend := &mock.API{} + + req := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil) + assertResponse(s.T(), req, http.StatusBadRequest, `{"code":400,"message":"can only provide either block ID or start height"}`, backend, stateStreamBackend) + }) + + s.Run("returns error for invalid block id", func() { + stateStreamBackend := &mockstatestream.API{} + backend := &mock.API{} + + invalidBlock := unittest.BlockFixture() + subscription := &mockstatestream.Subscription{} + + ch := make(chan interface{}) + var chReadOnly <-chan interface{} + go func() { + close(ch) + }() + chReadOnly = ch + + subscription.Mock.On("Channel").Return(chReadOnly) + subscription.Mock.On("Err").Return(fmt.Errorf("subscription error")) + stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, invalidBlock.ID(), mocks.Anything, mocks.Anything).Return(subscription) + + req := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil) + assertResponse(s.T(), req, http.StatusBadRequest, `{"code":400,"message":"stream encountered an error: subscription error"}`, backend, stateStreamBackend) + }) + + s.Run("returns error when channel closed", func() { + stateStreamBackend := &mockstatestream.API{} + backend := &mock.API{} + subscription := &mockstatestream.Subscription{} + + ch := make(chan interface{}) + var chReadOnly <-chan interface{} + + go func() { + close(ch) + }() + chReadOnly = ch + + subscription.Mock.On("Channel").Return(chReadOnly) + subscription.Mock.On("Err").Return(nil) + stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), mocks.Anything, mocks.Anything).Return(subscription) + + req := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) + assertResponse(s.T(), req, http.StatusRequestTimeout, `{"code":408,"message":"subscription channel closed"}`, backend, stateStreamBackend) + }) + + s.Run("returns error for unexpected response type", func() { + stateStreamBackend := &mockstatestream.API{} + backend := &mock.API{} + subscription := &mockstatestream.Subscription{} + + ch := make(chan interface{}) + var chReadOnly <-chan interface{} + go func() { + executionDataResponse := &state_stream.ExecutionDataResponse{ + Height: s.blocks[0].Header.Height, + } + ch <- executionDataResponse + }() + chReadOnly = ch + + subscription.Mock.On("Channel").Return(chReadOnly) + subscription.Mock.On("Err").Return(nil) + stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), 0, mocks.Anything).Return(subscription) + + req := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) + assertResponse(s.T(), req, http.StatusInternalServerError, `{"code":500,"message":"internal server error"}`, backend, stateStreamBackend) + }) +} + +func getSubscribeEventsRequest(t *testing.T, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) *http.Request { + u, _ := url.Parse("/v1/subscribe_events") + q := u.Query() + + if startBlockId != flow.ZeroID { + q.Add(startBlockIdQueryParam, startBlockId.String()) + } + + if startHeight != request.EmptyHeight { + q.Add(startHeightQueryParam, fmt.Sprintf("%d", startHeight)) + } + + if len(eventTypes) > 0 { + q.Add(eventTypesQueryParams, strings.Join(eventTypes, ",")) + } + if len(addresses) > 0 { + q.Add(addressesQueryParams, strings.Join(addresses, ",")) + } + if len(contracts) > 0 { + q.Add(contractsQueryParams, strings.Join(contracts, ",")) + } + + u.RawQuery = q.Encode() + + req, err := http.NewRequest("GET", u.String(), nil) + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-Websocket-Version", "13") + req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") // Uliana: ask or read about it + require.NoError(t, err) + return req +} diff --git a/engine/access/rest/routes/test_helpers.go b/engine/access/rest/routes/test_helpers.go index 81d0effe069..ed9f8d973b2 100644 --- a/engine/access/rest/routes/test_helpers.go +++ b/engine/access/rest/routes/test_helpers.go @@ -5,7 +5,11 @@ import ( "bytes" "fmt" "io" + "net" + "net/http" + "net/http/httptest" "strings" + "testing" "time" "github.com/gorilla/mux" @@ -13,12 +17,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - //"io" - "net" - "net/http" - "net/http/httptest" - "testing" - "github.com/onflow/flow-go/access/mock" mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" common_state_stream "github.com/onflow/flow-go/engine/common/state_stream" diff --git a/integration/go.mod b/integration/go.mod index e47e8bac3c3..23b6b18f94e 100644 --- a/integration/go.mod +++ b/integration/go.mod @@ -12,6 +12,7 @@ require ( github.com/docker/go-connections v0.4.0 github.com/go-git/go-git/v5 v5.5.2 github.com/go-yaml/yaml v2.1.0+incompatible + github.com/gorilla/websocket v1.5.0 github.com/ipfs/go-blockservice v0.4.0 github.com/ipfs/go-cid v0.4.1 github.com/ipfs/go-datastore v0.6.0 @@ -138,7 +139,6 @@ require ( github.com/google/uuid v1.3.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.7.1 // indirect - github.com/gorilla/websocket v1.5.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware/providers/zerolog/v2 v2.0.0-rc.2 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0-20200501113911-9a95f0fdbfea // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect From fbbc2405d60596c3687bdb411c99cab1edb0e745 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 2 Aug 2023 16:52:00 +0300 Subject: [PATCH 11/35] Added part of intagration tests --- engine/access/rest/routes/subscribe_events.go | 21 ++- .../rest/routes/subscribe_events_test.go | 39 ++++- engine/common/state_stream/backend_events.go | 5 + integration/go.mod | 1 + integration/go.sum | 1 + integration/tests/access/access_test.go | 146 ++++++++++++++++-- 6 files changed, 192 insertions(+), 21 deletions(-) diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index 8378df22094..8c496c7dbf0 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -33,8 +33,11 @@ func SubscribeEvents(r *request.Request, streamCount *atomic.Int32, errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), jsonResponse func(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger)) { + fmt.Println("+++++SubscribeEvents") + logger.Info().Msg("+++++SubscribeEvents") req, err := r.SubscribeEventsRequest() if err != nil { + fmt.Println(fmt.Sprintf("SubscribeEventsRequest .Err(): %v", err)) errorHandler(w, models.NewBadRequestError(err), logger) return } @@ -50,10 +53,10 @@ func SubscribeEvents(r *request.Request, upgrader := websocket.Upgrader{} conn, err := upgrader.Upgrade(w, r.Request, nil) if err != nil { + fmt.Println(fmt.Sprintf("Upgrade.Err(): %v", err)) errorHandler(w, models.NewRestError(http.StatusInternalServerError, "webSocket upgrade error: ", err), logger) return } - // Retrieve the filter parameters from the request, if provided filter, err := state_stream.NewEventFilter( eventFilterConfig, @@ -97,7 +100,8 @@ func writeEvents( streamCount *atomic.Int32, ) { ticker := time.NewTicker(pingPeriod) - ctx, cancel := context.WithCancel(c) + ctx, cancel := context.WithCancel(context.Background()) + //ctx, cancel := context.WithCancel(c) sub := api.SubscribeEvents(ctx, req.StartBlockID, req.StartHeight, filter) defer func() { @@ -124,34 +128,45 @@ func writeEvents( for { select { case v, ok := <-sub.Channel(): + fmt.Println(fmt.Sprintf("____sub")) if !ok { if sub.Err() != nil { + fmt.Println(fmt.Sprintf("____sub.Err(): %v", sub.Err())) err := fmt.Errorf("stream encountered an error: %v", sub.Err()) - errorHandler(w, models.NewBadRequestError(err), log) + fmt.Println("stream encountered an error:") + errorHandler(w, models.NewRestError(http.StatusRequestTimeout, "bla bla", err), log) + //errorHandler(w, models.NewBadRequestError(err), log) conn.Close() return } err := fmt.Errorf("subscription channel closed, no error occurred") + fmt.Println("subscription channel closed, no error occurred") errorHandler(w, models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err), log) conn.Close() return } + fmt.Println("_____before resp, ok := v.(*state_stream.EventsResponse)") resp, ok := v.(*state_stream.EventsResponse) if !ok { + fmt.Println("____error: resp, ok := v.(*state_stream.EventsResponse)") err := fmt.Errorf("unexpected response type: %T", v) errorHandler(w, err, log) conn.Close() return } + fmt.Println(fmt.Sprintf("____response %v", resp)) // Write the response to the WebSocket connection err := conn.WriteJSON(resp) if err != nil { + fmt.Println("_____error, err := conn.WriteJSON(resp)") + fmt.Println(err) errorHandler(w, err, log) conn.Close() return } + fmt.Println("StatusOK") jsonResponse(w, http.StatusOK, "", log) case <-ticker.C: if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index 0d96cb56e43..b97693045ea 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -1,6 +1,8 @@ package routes import ( + "crypto/rand" + "encoding/base64" "fmt" "net/http" "net/url" @@ -160,7 +162,8 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { } stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, test.startBlockID, startHeight, filter).Return(subscription) - req := getSubscribeEventsRequest(s.T(), test.startBlockID, test.startHeight, test.eventTypes, test.addresses, test.contracts) + req, err := getSubscribeEventsRequest(s.T(), test.startBlockID, test.startHeight, test.eventTypes, test.addresses, test.contracts) + assert.NoError(s.T(), err) rr, err := executeRequest(req, backend, stateStreamBackend) assert.NoError(s.T(), err) assert.Equal(s.T(), http.StatusOK, rr.Code) @@ -173,7 +176,8 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { stateStreamBackend := &mockstatestream.API{} backend := &mock.API{} - req := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil) + req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil) + assert.NoError(s.T(), err) assertResponse(s.T(), req, http.StatusBadRequest, `{"code":400,"message":"can only provide either block ID or start height"}`, backend, stateStreamBackend) }) @@ -195,7 +199,8 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { subscription.Mock.On("Err").Return(fmt.Errorf("subscription error")) stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, invalidBlock.ID(), mocks.Anything, mocks.Anything).Return(subscription) - req := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil) + req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil) + assert.NoError(s.T(), err) assertResponse(s.T(), req, http.StatusBadRequest, `{"code":400,"message":"stream encountered an error: subscription error"}`, backend, stateStreamBackend) }) @@ -216,7 +221,8 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { subscription.Mock.On("Err").Return(nil) stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), mocks.Anything, mocks.Anything).Return(subscription) - req := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) + req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) + assert.NoError(s.T(), err) assertResponse(s.T(), req, http.StatusRequestTimeout, `{"code":408,"message":"subscription channel closed"}`, backend, stateStreamBackend) }) @@ -239,12 +245,13 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { subscription.Mock.On("Err").Return(nil) stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), 0, mocks.Anything).Return(subscription) - req := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) + req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) + assert.NoError(s.T(), err) assertResponse(s.T(), req, http.StatusInternalServerError, `{"code":500,"message":"internal server error"}`, backend, stateStreamBackend) }) } -func getSubscribeEventsRequest(t *testing.T, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) *http.Request { +func getSubscribeEventsRequest(t *testing.T, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) (*http.Request, error) { u, _ := url.Parse("/v1/subscribe_events") q := u.Query() @@ -267,12 +274,28 @@ func getSubscribeEventsRequest(t *testing.T, startBlockId flow.Identifier, start } u.RawQuery = q.Encode() + key, err := generateWebSocketKey() + if err != nil { + err := fmt.Errorf("error generating websocket key: %v", err) + return nil, err + } req, err := http.NewRequest("GET", u.String(), nil) req.Header.Set("Connection", "upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-Websocket-Version", "13") - req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") // Uliana: ask or read about it + req.Header.Set("Sec-Websocket-Key", key) require.NoError(t, err) - return req + return req, nil +} + +func generateWebSocketKey() (string, error) { + // Generate 16 random bytes. + keyBytes := make([]byte, 16) + if _, err := rand.Read(keyBytes); err != nil { + return "", err + } + + // Encode the bytes to base64 and return the key as a string. + return base64.StdEncoding.EncodeToString(keyBytes), nil } diff --git a/engine/common/state_stream/backend_events.go b/engine/common/state_stream/backend_events.go index 2691ef5e7d0..3a81b0398f3 100644 --- a/engine/common/state_stream/backend_events.go +++ b/engine/common/state_stream/backend_events.go @@ -30,13 +30,18 @@ type EventsBackend struct { } func (b EventsBackend) SubscribeEvents(ctx context.Context, startBlockID flow.Identifier, startHeight uint64, filter EventFilter) Subscription { + fmt.Println("_____SubscribeEvents_ start") nextHeight, err := b.getStartHeight(startBlockID, startHeight) if err != nil { + fmt.Println("_____SubscribeEvents: getStartHeight failed") return NewFailedSubscription(err, "could not get start height") } + fmt.Println("_____SubscribeEvents: getStartHeight success") sub := NewHeightBasedSubscription(b.sendBufferSize, nextHeight, b.getResponseFactory(filter)) + fmt.Println("_____SubscribeEvents: NewHeightBasedSubscription success") + go NewStreamer(b.log, b.broadcaster, b.sendTimeout, b.responseLimit, sub).Stream(ctx) return sub diff --git a/integration/go.mod b/integration/go.mod index 23b6b18f94e..de81e43a56a 100644 --- a/integration/go.mod +++ b/integration/go.mod @@ -139,6 +139,7 @@ require ( github.com/google/uuid v1.3.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.7.1 // indirect + github.com/gorilla/mux v1.8.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware/providers/zerolog/v2 v2.0.0-rc.2 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0-20200501113911-9a95f0fdbfea // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect diff --git a/integration/go.sum b/integration/go.sum index 4aac8d7305d..5baed62df2f 100644 --- a/integration/go.sum +++ b/integration/go.sum @@ -635,6 +635,7 @@ github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51 github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.1-0.20190629185528-ae1634f6a989/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= diff --git a/integration/tests/access/access_test.go b/integration/tests/access/access_test.go index 82d268d9a65..68d57e45f68 100644 --- a/integration/tests/access/access_test.go +++ b/integration/tests/access/access_test.go @@ -2,13 +2,18 @@ package access import ( "context" - "net" - "testing" - "time" - + "fmt" + "github.com/gorilla/websocket" "github.com/onflow/flow-go/consensus/hotstuff/committees" "github.com/onflow/flow-go/consensus/hotstuff/signature" + "github.com/onflow/flow-go/engine/access/rest/request" "github.com/onflow/flow-go/engine/common/rpc/convert" + "github.com/onflow/flow-go/engine/common/state_stream" + "net" + "net/url" + "strings" + "testing" + "time" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -55,19 +60,19 @@ func (s *AccessSuite) SetupTest() { }() nodeConfigs := []testnet.NodeConfig{ - testnet.NewNodeConfig(flow.RoleAccess, testnet.WithLogLevel(zerolog.InfoLevel)), + testnet.NewNodeConfig(flow.RoleAccess, testnet.WithLogLevel(zerolog.InfoLevel), testnet.WithAdditionalFlag(fmt.Sprintf("--state-stream-addr=%s", testnet.ExecutionStatePort))), } // need one dummy execution node (unused ghost) - exeConfig := testnet.NewNodeConfig(flow.RoleExecution, testnet.WithLogLevel(zerolog.FatalLevel), testnet.AsGhost()) + exeConfig := testnet.NewNodeConfig(flow.RoleExecution, testnet.WithLogLevel(zerolog.FatalLevel)) nodeConfigs = append(nodeConfigs, exeConfig) // need one dummy verification node (unused ghost) - verConfig := testnet.NewNodeConfig(flow.RoleVerification, testnet.WithLogLevel(zerolog.FatalLevel), testnet.AsGhost()) + verConfig := testnet.NewNodeConfig(flow.RoleVerification, testnet.WithLogLevel(zerolog.FatalLevel)) nodeConfigs = append(nodeConfigs, verConfig) // need one controllable collection node (unused ghost) - collConfig := testnet.NewNodeConfig(flow.RoleCollection, testnet.WithLogLevel(zerolog.FatalLevel), testnet.AsGhost()) + collConfig := testnet.NewNodeConfig(flow.RoleCollection, testnet.WithLogLevel(zerolog.FatalLevel)) nodeConfigs = append(nodeConfigs, collConfig) // need three consensus nodes (unused ghost) @@ -75,8 +80,7 @@ func (s *AccessSuite) SetupTest() { conID := unittest.IdentifierFixture() nodeConfig := testnet.NewNodeConfig(flow.RoleConsensus, testnet.WithLogLevel(zerolog.FatalLevel), - testnet.WithID(conID), - testnet.AsGhost()) + testnet.WithID(conID)) nodeConfigs = append(nodeConfigs, nodeConfig) } @@ -189,6 +193,128 @@ func (s *AccessSuite) TestSignerIndicesDecoding() { assert.ElementsMatch(s.T(), transformed, msg.ParentVoterIds, "response must contain correctly encoded signer IDs") } +// TestRestSubscribeEvents tests event streaming on REST +func (s *AccessSuite) TestRestSubscribeEvents() { + time.Sleep(5 * time.Second) + t := s.T() + + ctx, cancel := context.WithTimeout(s.ctx, 30*time.Second) + defer cancel() + + accessAddr := s.net.ContainerByName(testnet.PrimaryAN).Addr(testnet.RESTPort) + + t.Run("subscribe events", func(t *testing.T) { + startBlockId := flow.ZeroID + startHeight := uint64(0) + url := getSubscribeEventsRequest(accessAddr, startBlockId, startHeight, nil, nil, nil) + + s.log.Info().Msg("================> resp.Request.URL.String()" + url) + client, err := s.getWSClient(ctx, url) + require.NoError(t, err) + var receivedEvents []*state_stream.EventsResponse + eventChan := make(chan *state_stream.EventsResponse) + + go func() { + for { + resp := &state_stream.EventsResponse{} + err := client.ReadJSON(resp) + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + t.Logf("unexpected close error: %v", err) + } + t.Log(fmt.Sprintf("______ReadJSON error %v", err)) + close(eventChan) // Close the event channel when the client connection is closed + return + } + t.Log(fmt.Sprintf("______response %v", resp)) + eventChan <- resp + } + }() + + // Wait for events or timeout + select { + case <-time.After(10 * time.Second): + // Handle the timeout, close the client connection, and break the loop + t.Log("______receivedEvents") + t.Log(receivedEvents) + + client.Close() + t.Log("Client closed") + return + case event := <-eventChan: + receivedEvents = append(receivedEvents, event) + t.Log(fmt.Sprintf("______received events %v", event)) + } + + }) +} + +func (s *AccessSuite) getWSClient(ctx context.Context, address string) (*websocket.Conn, error) { + // helper func to create WebSocket client + client, _, err := websocket.DefaultDialer.DialContext(ctx, strings.Replace(address, "http", "ws", 1), nil) + if err != nil { + return nil, err + } + return client, nil +} + +// +//// Assert that the received events match the expected events +//assert.Equal(s.T(), len(expectedEvents), len(receivedEvents)) +//for i, expected := range expectedEvents { +// received := receivedEvents[i] +// s.T().Logf("expected" + expected.String()) +// s.T().Logf("received: BlockID" + received.BlockID.String() + ", height: " + fmt.Sprint(received.Height)) +// //assert.Equal(s.T(), expected.Height, received.Height) +// //assert.Equal(s.T(), expected.BlockID, received.BlockID) +// //assert.Equal(s.T(), len(expected.Events), len(received.Events)) +// //for j, expectedEvent := range expected.Events { +// // receivedEvent := received.Events[j] +// // // Perform further assertions on each event if needed +// // assert.Equal(s.T(), expectedEvent.Type, receivedEvent.Type) +// // assert.Equal(s.T(), expectedEvent.Data, receivedEvent.Data) +// //} +//} + +//resp, err := makeSubscribeEventsCall(accessAddr, startBlockId, startHeight, nil, nil, nil) +//assert.NoError(t, err) +//assert.Contains(t, [...]int{ +// http.StatusOK, +//}, resp.StatusCode) +//s.log.Info().Msg(fmt.Sprintf("================> %s %d", resp.Status, resp.StatusCode)) + +func getSubscribeEventsRequest(accessAddr string, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) string { + u, _ := url.Parse("http://" + accessAddr + "/v1/subscribe_events") + q := u.Query() + + if startBlockId != flow.ZeroID { + q.Add("start_block_id", startBlockId.String()) + } + + if startHeight != request.EmptyHeight { + q.Add("start_height", fmt.Sprintf("%d", startHeight)) + } + + if len(eventTypes) > 0 { + q.Add("event_types", strings.Join(eventTypes, ",")) + } + if len(addresses) > 0 { + q.Add("addresses", strings.Join(addresses, ",")) + } + if len(contracts) > 0 { + q.Add("contracts", strings.Join(contracts, ",")) + } + + u.RawQuery = q.Encode() + return u.String() +} + +//func makeSubscribeEventsCall(accessAddr string, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) (*http.Response, error) { +// httpClient := http.DefaultClient +// url := getSubscribeEventsRequest(accessAddr, startBlockId, startHeight, eventTypes, addresses, contracts) +// return httpClient.Get(url) +//} + // makeApiRequest is a helper function that encapsulates context creation for grpc client call, used to avoid repeated creation // of new context for each call. func makeApiRequest[Func func(context.Context, *Req, ...grpc.CallOption) (*Resp, error), Req any, Resp any](apiCall Func, ctx context.Context, req *Req) (*Resp, error) { From d07b5faf117dac07e9ec5d982b04471667f799d1 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 3 Aug 2023 17:37:09 +0300 Subject: [PATCH 12/35] Added integration test for rest event streaming, updated unit tests, remove unnecessary logs --- engine/access/rest/routes/subscribe_events.go | 26 +- .../rest/routes/subscribe_events_test.go | 9 +- engine/common/state_stream/backend_events.go | 5 - integration/tests/access/access_test.go | 156 ++---------- .../tests/access/rest_state_stream_test.go | 239 ++++++++++++++++++ 5 files changed, 264 insertions(+), 171 deletions(-) create mode 100644 integration/tests/access/rest_state_stream_test.go diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index 8c496c7dbf0..27a7b012b9d 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -33,11 +33,8 @@ func SubscribeEvents(r *request.Request, streamCount *atomic.Int32, errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), jsonResponse func(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger)) { - fmt.Println("+++++SubscribeEvents") - logger.Info().Msg("+++++SubscribeEvents") req, err := r.SubscribeEventsRequest() if err != nil { - fmt.Println(fmt.Sprintf("SubscribeEventsRequest .Err(): %v", err)) errorHandler(w, models.NewBadRequestError(err), logger) return } @@ -53,7 +50,6 @@ func SubscribeEvents(r *request.Request, upgrader := websocket.Upgrader{} conn, err := upgrader.Upgrade(w, r.Request, nil) if err != nil { - fmt.Println(fmt.Sprintf("Upgrade.Err(): %v", err)) errorHandler(w, models.NewRestError(http.StatusInternalServerError, "webSocket upgrade error: ", err), logger) return } @@ -77,8 +73,8 @@ func SubscribeEvents(r *request.Request, go writeEvents(logger, w, req, - r.Context(), conn, + r.Context(), api, filter, errorHandler, @@ -91,19 +87,18 @@ func writeEvents( log zerolog.Logger, w http.ResponseWriter, req request.SubscribeEvents, - c context.Context, conn *websocket.Conn, + c context.Context, api state_stream.API, filter state_stream.EventFilter, errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), jsonResponse func(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger), streamCount *atomic.Int32, ) { - ticker := time.NewTicker(pingPeriod) ctx, cancel := context.WithCancel(context.Background()) - //ctx, cancel := context.WithCancel(c) - + ticker := time.NewTicker(pingPeriod) sub := api.SubscribeEvents(ctx, req.StartBlockID, req.StartHeight, filter) + defer func() { ticker.Stop() streamCount.Add(-1) @@ -128,45 +123,34 @@ func writeEvents( for { select { case v, ok := <-sub.Channel(): - fmt.Println(fmt.Sprintf("____sub")) if !ok { if sub.Err() != nil { - fmt.Println(fmt.Sprintf("____sub.Err(): %v", sub.Err())) err := fmt.Errorf("stream encountered an error: %v", sub.Err()) - fmt.Println("stream encountered an error:") - errorHandler(w, models.NewRestError(http.StatusRequestTimeout, "bla bla", err), log) - //errorHandler(w, models.NewBadRequestError(err), log) + errorHandler(w, models.NewBadRequestError(err), log) conn.Close() return } err := fmt.Errorf("subscription channel closed, no error occurred") - fmt.Println("subscription channel closed, no error occurred") errorHandler(w, models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err), log) conn.Close() return } - fmt.Println("_____before resp, ok := v.(*state_stream.EventsResponse)") resp, ok := v.(*state_stream.EventsResponse) if !ok { - fmt.Println("____error: resp, ok := v.(*state_stream.EventsResponse)") err := fmt.Errorf("unexpected response type: %T", v) errorHandler(w, err, log) conn.Close() return } - fmt.Println(fmt.Sprintf("____response %v", resp)) // Write the response to the WebSocket connection err := conn.WriteJSON(resp) if err != nil { - fmt.Println("_____error, err := conn.WriteJSON(resp)") - fmt.Println(err) errorHandler(w, err, log) conn.Close() return } - fmt.Println("StatusOK") jsonResponse(w, http.StatusOK, "", log) case <-ticker.C: if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index b97693045ea..e9923ebca6f 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -154,9 +154,10 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { filter, err := state_stream.NewEventFilter(state_stream.DefaultEventFilterConfig, chain, test.eventTypes, test.addresses, test.contracts) assert.NoError(s.T(), err) + var startHeight uint64 if test.startHeight == request.EmptyHeight { - startHeight = 0 + startHeight = uint64(0) } else { startHeight = test.startHeight } @@ -197,7 +198,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { subscription.Mock.On("Channel").Return(chReadOnly) subscription.Mock.On("Err").Return(fmt.Errorf("subscription error")) - stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, invalidBlock.ID(), mocks.Anything, mocks.Anything).Return(subscription) + stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, invalidBlock.ID(), uint64(0), mocks.Anything).Return(subscription) req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil) assert.NoError(s.T(), err) @@ -219,7 +220,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { subscription.Mock.On("Channel").Return(chReadOnly) subscription.Mock.On("Err").Return(nil) - stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), mocks.Anything, mocks.Anything).Return(subscription) + stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), uint64(0), mocks.Anything).Return(subscription) req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) assert.NoError(s.T(), err) @@ -243,7 +244,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { subscription.Mock.On("Channel").Return(chReadOnly) subscription.Mock.On("Err").Return(nil) - stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), 0, mocks.Anything).Return(subscription) + stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), uint64(0), mocks.Anything).Return(subscription) req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) assert.NoError(s.T(), err) diff --git a/engine/common/state_stream/backend_events.go b/engine/common/state_stream/backend_events.go index 3a81b0398f3..2691ef5e7d0 100644 --- a/engine/common/state_stream/backend_events.go +++ b/engine/common/state_stream/backend_events.go @@ -30,18 +30,13 @@ type EventsBackend struct { } func (b EventsBackend) SubscribeEvents(ctx context.Context, startBlockID flow.Identifier, startHeight uint64, filter EventFilter) Subscription { - fmt.Println("_____SubscribeEvents_ start") nextHeight, err := b.getStartHeight(startBlockID, startHeight) if err != nil { - fmt.Println("_____SubscribeEvents: getStartHeight failed") return NewFailedSubscription(err, "could not get start height") } - fmt.Println("_____SubscribeEvents: getStartHeight success") sub := NewHeightBasedSubscription(b.sendBufferSize, nextHeight, b.getResponseFactory(filter)) - fmt.Println("_____SubscribeEvents: NewHeightBasedSubscription success") - go NewStreamer(b.log, b.broadcaster, b.sendTimeout, b.responseLimit, sub).Stream(ctx) return sub diff --git a/integration/tests/access/access_test.go b/integration/tests/access/access_test.go index 68d57e45f68..2bcb5645f13 100644 --- a/integration/tests/access/access_test.go +++ b/integration/tests/access/access_test.go @@ -2,19 +2,14 @@ package access import ( "context" - "fmt" - "github.com/gorilla/websocket" - "github.com/onflow/flow-go/consensus/hotstuff/committees" - "github.com/onflow/flow-go/consensus/hotstuff/signature" - "github.com/onflow/flow-go/engine/access/rest/request" - "github.com/onflow/flow-go/engine/common/rpc/convert" - "github.com/onflow/flow-go/engine/common/state_stream" "net" - "net/url" - "strings" "testing" "time" + "github.com/onflow/flow-go/consensus/hotstuff/committees" + "github.com/onflow/flow-go/consensus/hotstuff/signature" + "github.com/onflow/flow-go/engine/common/rpc/convert" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -60,19 +55,19 @@ func (s *AccessSuite) SetupTest() { }() nodeConfigs := []testnet.NodeConfig{ - testnet.NewNodeConfig(flow.RoleAccess, testnet.WithLogLevel(zerolog.InfoLevel), testnet.WithAdditionalFlag(fmt.Sprintf("--state-stream-addr=%s", testnet.ExecutionStatePort))), + testnet.NewNodeConfig(flow.RoleAccess, testnet.WithLogLevel(zerolog.InfoLevel)), } // need one dummy execution node (unused ghost) - exeConfig := testnet.NewNodeConfig(flow.RoleExecution, testnet.WithLogLevel(zerolog.FatalLevel)) + exeConfig := testnet.NewNodeConfig(flow.RoleExecution, testnet.WithLogLevel(zerolog.FatalLevel), testnet.AsGhost()) nodeConfigs = append(nodeConfigs, exeConfig) // need one dummy verification node (unused ghost) - verConfig := testnet.NewNodeConfig(flow.RoleVerification, testnet.WithLogLevel(zerolog.FatalLevel)) + verConfig := testnet.NewNodeConfig(flow.RoleVerification, testnet.WithLogLevel(zerolog.FatalLevel), testnet.AsGhost()) nodeConfigs = append(nodeConfigs, verConfig) // need one controllable collection node (unused ghost) - collConfig := testnet.NewNodeConfig(flow.RoleCollection, testnet.WithLogLevel(zerolog.FatalLevel)) + collConfig := testnet.NewNodeConfig(flow.RoleCollection, testnet.WithLogLevel(zerolog.FatalLevel), testnet.AsGhost()) nodeConfigs = append(nodeConfigs, collConfig) // need three consensus nodes (unused ghost) @@ -80,7 +75,8 @@ func (s *AccessSuite) SetupTest() { conID := unittest.IdentifierFixture() nodeConfig := testnet.NewNodeConfig(flow.RoleConsensus, testnet.WithLogLevel(zerolog.FatalLevel), - testnet.WithID(conID)) + testnet.WithID(conID), + testnet.AsGhost()) nodeConfigs = append(nodeConfigs, nodeConfig) } @@ -140,17 +136,17 @@ func (s *AccessSuite) TestSignerIndicesDecoding() { client := accessproto.NewAccessAPIClient(conn) // query latest finalized block - latestFinalizedBlock, err := makeApiRequest(client.GetLatestBlockHeader, ctx, &accessproto.GetLatestBlockHeaderRequest{ + latestFinalizedBlock, err := MakeApiRequest(client.GetLatestBlockHeader, ctx, &accessproto.GetLatestBlockHeaderRequest{ IsSealed: false, }) require.NoError(s.T(), err) - blockByID, err := makeApiRequest(client.GetBlockHeaderByID, ctx, &accessproto.GetBlockHeaderByIDRequest{Id: latestFinalizedBlock.Block.Id}) + blockByID, err := MakeApiRequest(client.GetBlockHeaderByID, ctx, &accessproto.GetBlockHeaderByIDRequest{Id: latestFinalizedBlock.Block.Id}) require.NoError(s.T(), err) require.Equal(s.T(), latestFinalizedBlock, blockByID, "expect to receive same block by ID") - blockByHeight, err := makeApiRequest(client.GetBlockHeaderByHeight, ctx, + blockByHeight, err := MakeApiRequest(client.GetBlockHeaderByHeight, ctx, &accessproto.GetBlockHeaderByHeightRequest{Height: latestFinalizedBlock.Block.Height}) require.NoError(s.T(), err) @@ -193,131 +189,9 @@ func (s *AccessSuite) TestSignerIndicesDecoding() { assert.ElementsMatch(s.T(), transformed, msg.ParentVoterIds, "response must contain correctly encoded signer IDs") } -// TestRestSubscribeEvents tests event streaming on REST -func (s *AccessSuite) TestRestSubscribeEvents() { - time.Sleep(5 * time.Second) - t := s.T() - - ctx, cancel := context.WithTimeout(s.ctx, 30*time.Second) - defer cancel() - - accessAddr := s.net.ContainerByName(testnet.PrimaryAN).Addr(testnet.RESTPort) - - t.Run("subscribe events", func(t *testing.T) { - startBlockId := flow.ZeroID - startHeight := uint64(0) - url := getSubscribeEventsRequest(accessAddr, startBlockId, startHeight, nil, nil, nil) - - s.log.Info().Msg("================> resp.Request.URL.String()" + url) - client, err := s.getWSClient(ctx, url) - require.NoError(t, err) - var receivedEvents []*state_stream.EventsResponse - eventChan := make(chan *state_stream.EventsResponse) - - go func() { - for { - resp := &state_stream.EventsResponse{} - err := client.ReadJSON(resp) - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { - t.Logf("unexpected close error: %v", err) - } - t.Log(fmt.Sprintf("______ReadJSON error %v", err)) - close(eventChan) // Close the event channel when the client connection is closed - return - } - t.Log(fmt.Sprintf("______response %v", resp)) - eventChan <- resp - } - }() - - // Wait for events or timeout - select { - case <-time.After(10 * time.Second): - // Handle the timeout, close the client connection, and break the loop - t.Log("______receivedEvents") - t.Log(receivedEvents) - - client.Close() - t.Log("Client closed") - return - case event := <-eventChan: - receivedEvents = append(receivedEvents, event) - t.Log(fmt.Sprintf("______received events %v", event)) - } - - }) -} - -func (s *AccessSuite) getWSClient(ctx context.Context, address string) (*websocket.Conn, error) { - // helper func to create WebSocket client - client, _, err := websocket.DefaultDialer.DialContext(ctx, strings.Replace(address, "http", "ws", 1), nil) - if err != nil { - return nil, err - } - return client, nil -} - -// -//// Assert that the received events match the expected events -//assert.Equal(s.T(), len(expectedEvents), len(receivedEvents)) -//for i, expected := range expectedEvents { -// received := receivedEvents[i] -// s.T().Logf("expected" + expected.String()) -// s.T().Logf("received: BlockID" + received.BlockID.String() + ", height: " + fmt.Sprint(received.Height)) -// //assert.Equal(s.T(), expected.Height, received.Height) -// //assert.Equal(s.T(), expected.BlockID, received.BlockID) -// //assert.Equal(s.T(), len(expected.Events), len(received.Events)) -// //for j, expectedEvent := range expected.Events { -// // receivedEvent := received.Events[j] -// // // Perform further assertions on each event if needed -// // assert.Equal(s.T(), expectedEvent.Type, receivedEvent.Type) -// // assert.Equal(s.T(), expectedEvent.Data, receivedEvent.Data) -// //} -//} - -//resp, err := makeSubscribeEventsCall(accessAddr, startBlockId, startHeight, nil, nil, nil) -//assert.NoError(t, err) -//assert.Contains(t, [...]int{ -// http.StatusOK, -//}, resp.StatusCode) -//s.log.Info().Msg(fmt.Sprintf("================> %s %d", resp.Status, resp.StatusCode)) - -func getSubscribeEventsRequest(accessAddr string, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) string { - u, _ := url.Parse("http://" + accessAddr + "/v1/subscribe_events") - q := u.Query() - - if startBlockId != flow.ZeroID { - q.Add("start_block_id", startBlockId.String()) - } - - if startHeight != request.EmptyHeight { - q.Add("start_height", fmt.Sprintf("%d", startHeight)) - } - - if len(eventTypes) > 0 { - q.Add("event_types", strings.Join(eventTypes, ",")) - } - if len(addresses) > 0 { - q.Add("addresses", strings.Join(addresses, ",")) - } - if len(contracts) > 0 { - q.Add("contracts", strings.Join(contracts, ",")) - } - - u.RawQuery = q.Encode() - return u.String() -} - -//func makeSubscribeEventsCall(accessAddr string, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) (*http.Response, error) { -// httpClient := http.DefaultClient -// url := getSubscribeEventsRequest(accessAddr, startBlockId, startHeight, eventTypes, addresses, contracts) -// return httpClient.Get(url) -//} - -// makeApiRequest is a helper function that encapsulates context creation for grpc client call, used to avoid repeated creation +// MakeApiRequest is a helper function that encapsulates context creation for grpc client call, used to avoid repeated creation // of new context for each call. -func makeApiRequest[Func func(context.Context, *Req, ...grpc.CallOption) (*Resp, error), Req any, Resp any](apiCall Func, ctx context.Context, req *Req) (*Resp, error) { +func MakeApiRequest[Func func(context.Context, *Req, ...grpc.CallOption) (*Resp, error), Req any, Resp any](apiCall Func, ctx context.Context, req *Req) (*Resp, error) { clientCtx, cancel := context.WithTimeout(ctx, 1*time.Second) resp, err := apiCall(clientCtx, req) cancel() diff --git a/integration/tests/access/rest_state_stream_test.go b/integration/tests/access/rest_state_stream_test.go new file mode 100644 index 00000000000..f4f4c6a3411 --- /dev/null +++ b/integration/tests/access/rest_state_stream_test.go @@ -0,0 +1,239 @@ +package access + +import ( + "context" + "fmt" + "net/url" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/onflow/flow-go/engine/access/rest/request" + "github.com/onflow/flow-go/engine/common/rpc/convert" + "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/integration/testnet" + "github.com/onflow/flow-go/model/flow" + "github.com/onflow/flow-go/utils/unittest" + + accessproto "github.com/onflow/flow/protobuf/go/flow/access" +) + +func TestRestStateStream(t *testing.T) { + suite.Run(t, new(RestStateStreamSuite)) +} + +type RestStateStreamSuite struct { + suite.Suite + + log zerolog.Logger + + // root context for the current test + ctx context.Context + cancel context.CancelFunc + + net *testnet.FlowNetwork +} + +func (s *RestStateStreamSuite) TearDownTest() { + s.log.Info().Msg("================> Start TearDownTest") + s.net.Remove() + s.cancel() + s.log.Info().Msg("================> Finish TearDownTest") +} + +func (s *RestStateStreamSuite) SetupTest() { + s.log = unittest.LoggerForTest(s.Suite.T(), zerolog.InfoLevel) + s.log.Info().Msg("================> SetupTest") + defer func() { + s.log.Info().Msg("================> Finish SetupTest") + }() + + // access node + bridgeANConfig := testnet.NewNodeConfig( + flow.RoleAccess, + testnet.WithLogLevel(zerolog.DebugLevel), + testnet.WithAdditionalFlag("--supports-observer=true"), + testnet.WithAdditionalFlag("--execution-data-sync-enabled=true"), + testnet.WithAdditionalFlag(fmt.Sprintf("--execution-data-dir=%s", testnet.DefaultExecutionDataServiceDir)), + testnet.WithAdditionalFlag("--execution-data-retry-delay=1s"), + ) + + // add the ghost (access) node config + ghostNode := testnet.NewNodeConfig( + flow.RoleAccess, + testnet.WithLogLevel(zerolog.FatalLevel), + testnet.AsGhost()) + + consensusConfigs := []func(config *testnet.NodeConfig){ + testnet.WithAdditionalFlag("--cruise-ctl-fallback-proposal-duration=100ms"), + testnet.WithAdditionalFlag(fmt.Sprintf("--required-verification-seal-approvals=%d", 1)), + testnet.WithAdditionalFlag(fmt.Sprintf("--required-construction-seal-approvals=%d", 1)), + testnet.WithLogLevel(zerolog.FatalLevel), + } + + nodeConfigs := []testnet.NodeConfig{ + testnet.NewNodeConfig(flow.RoleCollection, testnet.WithLogLevel(zerolog.FatalLevel)), + testnet.NewNodeConfig(flow.RoleCollection, testnet.WithLogLevel(zerolog.FatalLevel)), + testnet.NewNodeConfig(flow.RoleExecution, testnet.WithLogLevel(zerolog.FatalLevel)), + testnet.NewNodeConfig(flow.RoleExecution, testnet.WithLogLevel(zerolog.FatalLevel)), + testnet.NewNodeConfig(flow.RoleConsensus, consensusConfigs...), + testnet.NewNodeConfig(flow.RoleConsensus, consensusConfigs...), + testnet.NewNodeConfig(flow.RoleConsensus, consensusConfigs...), + testnet.NewNodeConfig(flow.RoleVerification, testnet.WithLogLevel(zerolog.FatalLevel)), + bridgeANConfig, + ghostNode, + } + + conf := testnet.NewNetworkConfig("access_api_test", nodeConfigs) + s.net = testnet.PrepareFlowNetwork(s.T(), conf, flow.Localnet) + + // start the network + s.T().Logf("starting flow network with docker containers") + s.ctx, s.cancel = context.WithCancel(context.Background()) + + s.net.Start(s.ctx) +} + +// TestRestEventStreaming tests event streaming on REST +func (s *RestStateStreamSuite) TestRestEventStreaming() { + ctx, cancel := context.WithTimeout(s.ctx, 1*time.Second) + defer cancel() + + restAddr := s.net.ContainerByName(testnet.PrimaryAN).Addr(testnet.RESTPort) + + s.T().Run("subscribe events", func(t *testing.T) { + startBlockId := flow.ZeroID + startHeight := uint64(0) + url := getSubscribeEventsRequest(restAddr, startBlockId, startHeight, nil, nil, nil) + + client, err := getWSClient(ctx, url) + require.NoError(t, err) + var receivedEventsResponse []*state_stream.EventsResponse + eventChan := make(chan *state_stream.EventsResponse) + + // Start the timeout goroutine + timeoutChan := make(chan struct{}) + go func() { + time.Sleep(10 * time.Second) // Sleep for 10 seconds + close(timeoutChan) // Signal the timeout + }() + + go func() { + for { + resp := &state_stream.EventsResponse{} + err := client.ReadJSON(resp) + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + s.T().Logf("unexpected close error: %v", err) + } + close(eventChan) // Close the event channel when the client connection is closed + return + } + eventChan <- resp + } + }() + + // Wait for events or timeout + for { + select { + case <-timeoutChan: + // Handle the timeout and close the client connection + client.Close() + s.T().Log("Client connection closed") + s.requireEvents(receivedEventsResponse) + return + case eventResponse, ok := <-eventChan: + if !ok { + // Event channel closed, events received + s.T().Log(" Event channel closed, events received") + client.Close() + return + } + receivedEventsResponse = append(receivedEventsResponse, eventResponse) + s.T().Logf(fmt.Sprintf("___event %v", eventResponse)) + } + } + }) +} + +func (s *RestStateStreamSuite) requireEvents(receivedEventsResponse []*state_stream.EventsResponse) { + grpcCtx, grpcCancel := context.WithCancel(s.ctx) + defer grpcCancel() + + grpcAddr := s.net.ContainerByName(testnet.PrimaryAN).Addr(testnet.GRPCPort) + + grpcConn, err := grpc.DialContext(grpcCtx, grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(s.T(), err, "failed to connect to access node") + defer grpcConn.Close() + + grpcClient := accessproto.NewAccessAPIClient(grpcConn) + + for _, receivedEventResponse := range receivedEventsResponse { + // Create a map where key is EventType and value is list of events with this EventType + receivedEventMap := make(map[flow.EventType][]flow.Event) + for _, event := range receivedEventResponse.Events { + eventType := event.Type + receivedEventMap[eventType] = append(receivedEventMap[eventType], event) + } + + for eventType, receivedEventList := range receivedEventMap { + // get events by block id and event type + response, err := MakeApiRequest(grpcClient.GetEventsForBlockIDs, grpcCtx, + &accessproto.GetEventsForBlockIDsRequest{BlockIds: [][]byte{convert.IdentifierToMessage(receivedEventResponse.BlockID)}, + Type: fmt.Sprintf("%s", eventType)}) + require.NoError(s.T(), err) + require.Equal(s.T(), 1, len(response.Results), "expect to get 1 result") + + expectedEventsResult := response.Results[0] + require.Equal(s.T(), expectedEventsResult.BlockHeight, receivedEventResponse.Height, "expect the same block height") + require.Equal(s.T(), len(expectedEventsResult.Events), len(receivedEventList), "expect the same count of events") + + for i, event := range receivedEventList { + require.Equal(s.T(), expectedEventsResult.Events[i].EventIndex, event.EventIndex, "expect the same EventIndex") + require.Equal(s.T(), convert.MessageToIdentifier(expectedEventsResult.Events[i].TransactionId), event.TransactionID, "expect the same TransactionId") + } + } + } +} + +func getWSClient(ctx context.Context, address string) (*websocket.Conn, error) { + // helper func to create WebSocket client + client, _, err := websocket.DefaultDialer.DialContext(ctx, strings.Replace(address, "http", "ws", 1), nil) + if err != nil { + return nil, err + } + return client, nil +} + +func getSubscribeEventsRequest(accessAddr string, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) string { + u, _ := url.Parse("http://" + accessAddr + "/v1/subscribe_events") + q := u.Query() + + if startBlockId != flow.ZeroID { + q.Add("start_block_id", startBlockId.String()) + } + + if startHeight != request.EmptyHeight { + q.Add("start_height", fmt.Sprintf("%d", startHeight)) + } + + if len(eventTypes) > 0 { + q.Add("event_types", strings.Join(eventTypes, ",")) + } + if len(addresses) > 0 { + q.Add("addresses", strings.Join(addresses, ",")) + } + if len(contracts) > 0 { + q.Add("contracts", strings.Join(contracts, ",")) + } + + u.RawQuery = q.Encode() + return u.String() +} From b372ed98afc3b30cc87a9bf044b7c44f8f1f5b24 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 3 Aug 2023 23:45:26 +0300 Subject: [PATCH 13/35] Updated routeUrlMap init for rest --- engine/access/rest/routes/router.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/engine/access/rest/routes/router.go b/engine/access/rest/routes/router.go index 82eb54c3a31..4032aef6165 100644 --- a/engine/access/rest/routes/router.go +++ b/engine/access/rest/routes/router.go @@ -159,6 +159,9 @@ func init() { for _, r := range Routes { routeUrlMap[r.Pattern] = r.Name } + for _, r := range WSRoutes { + routeUrlMap[r.Pattern] = r.Name + } } func URLToRoute(url string) (string, error) { From 379ec43dbb4c8b965857328bd418bb8005aa7a4e Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 4 Aug 2023 11:05:57 +0300 Subject: [PATCH 14/35] Removed unnecessary log --- integration/tests/access/rest_state_stream_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/integration/tests/access/rest_state_stream_test.go b/integration/tests/access/rest_state_stream_test.go index f4f4c6a3411..41cdf634029 100644 --- a/integration/tests/access/rest_state_stream_test.go +++ b/integration/tests/access/rest_state_stream_test.go @@ -157,7 +157,6 @@ func (s *RestStateStreamSuite) TestRestEventStreaming() { return } receivedEventsResponse = append(receivedEventsResponse, eventResponse) - s.T().Logf(fmt.Sprintf("___event %v", eventResponse)) } } }) From 8068f97c53952272505c17db1355636ce6fbc474 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 4 Aug 2023 11:25:37 +0300 Subject: [PATCH 15/35] Added more comments --- integration/tests/access/rest_state_stream_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/integration/tests/access/rest_state_stream_test.go b/integration/tests/access/rest_state_stream_test.go index 41cdf634029..c8773358e4a 100644 --- a/integration/tests/access/rest_state_stream_test.go +++ b/integration/tests/access/rest_state_stream_test.go @@ -58,8 +58,7 @@ func (s *RestStateStreamSuite) SetupTest() { // access node bridgeANConfig := testnet.NewNodeConfig( flow.RoleAccess, - testnet.WithLogLevel(zerolog.DebugLevel), - testnet.WithAdditionalFlag("--supports-observer=true"), + testnet.WithLogLevel(zerolog.InfoLevel), testnet.WithAdditionalFlag("--execution-data-sync-enabled=true"), testnet.WithAdditionalFlag(fmt.Sprintf("--execution-data-dir=%s", testnet.DefaultExecutionDataServiceDir)), testnet.WithAdditionalFlag("--execution-data-retry-delay=1s"), @@ -101,7 +100,7 @@ func (s *RestStateStreamSuite) SetupTest() { s.net.Start(s.ctx) } -// TestRestEventStreaming tests event streaming on REST +// TestRestEventStreaming tests event streaming route on REST func (s *RestStateStreamSuite) TestRestEventStreaming() { ctx, cancel := context.WithTimeout(s.ctx, 1*time.Second) defer cancel() @@ -162,6 +161,8 @@ func (s *RestStateStreamSuite) TestRestEventStreaming() { }) } +// requireEvents is a helper function that encapsulates logic for comparing received events from rest state streaming and +// events which received from grpc api func (s *RestStateStreamSuite) requireEvents(receivedEventsResponse []*state_stream.EventsResponse) { grpcCtx, grpcCancel := context.WithCancel(s.ctx) defer grpcCancel() @@ -202,6 +203,7 @@ func (s *RestStateStreamSuite) requireEvents(receivedEventsResponse []*state_str } } +// getWSClient is a helper function that creates a websocket client func getWSClient(ctx context.Context, address string) (*websocket.Conn, error) { // helper func to create WebSocket client client, _, err := websocket.DefaultDialer.DialContext(ctx, strings.Replace(address, "http", "ws", 1), nil) @@ -211,6 +213,7 @@ func getWSClient(ctx context.Context, address string) (*websocket.Conn, error) { return client, nil } +// getSubscribeEventsRequest is a helper function that creates SubscribeEventsRequest func getSubscribeEventsRequest(accessAddr string, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) string { u, _ := url.Parse("http://" + accessAddr + "/v1/subscribe_events") q := u.Query() From 7c6442c2e8b1afaf6afd3884ba359648034e083a Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 9 Aug 2023 14:28:38 +0300 Subject: [PATCH 16/35] Moved part of state_stream impl back to access/state_stream package --- .../node_builder/access_node_builder.go | 21 +++++++++---------- cmd/observer/node_builder/observer_builder.go | 2 +- .../integration_unsecure_grpc_server_test.go | 13 ++++++------ engine/access/rest/routes/router.go | 2 +- engine/access/rest/routes/subscribe_events.go | 2 +- .../rest/routes/subscribe_events_test.go | 4 ++-- engine/access/rest/routes/test_helpers.go | 12 +++++------ .../access/rest/routes/websocket_handler.go | 2 +- engine/access/rest/server.go | 2 +- engine/access/rest_api_test.go | 2 +- engine/access/rpc/engine.go | 2 +- engine/access/rpc/rate_limit_test.go | 3 ++- engine/access/secure_grpcr_test.go | 2 +- .../state_stream/backend.go | 0 .../state_stream/backend_events.go | 0 .../state_stream/backend_events_test.go | 0 .../state_stream/backend_executiondata.go | 0 .../backend_executiondata_test.go | 0 engine/access/state_stream/engine.go | 9 ++++---- .../{common => access}/state_stream/event.go | 0 .../state_stream/event_test.go | 2 +- .../{common => access}/state_stream/filter.go | 0 .../state_stream/filter_test.go | 2 +- engine/access/state_stream/handler.go | 15 +++++++------ engine/access/state_stream/mock/api.go | 19 ++++++++--------- .../state_stream/streamer.go | 0 .../state_stream/streamer_test.go | 3 +-- .../state_stream/subscription.go | 0 .../state_stream/subscription_test.go | 3 +-- .../tests/access/rest_state_stream_test.go | 2 +- .../execution_data_requester_test.go | 3 +-- .../jobs/execution_data_reader_test.go | 2 +- 32 files changed, 61 insertions(+), 68 deletions(-) rename engine/{common => access}/state_stream/backend.go (100%) rename engine/{common => access}/state_stream/backend_events.go (100%) rename engine/{common => access}/state_stream/backend_events_test.go (100%) rename engine/{common => access}/state_stream/backend_executiondata.go (100%) rename engine/{common => access}/state_stream/backend_executiondata_test.go (100%) rename engine/{common => access}/state_stream/event.go (100%) rename engine/{common => access}/state_stream/event_test.go (97%) rename engine/{common => access}/state_stream/filter.go (100%) rename engine/{common => access}/state_stream/filter_test.go (98%) rename engine/{common => access}/state_stream/streamer.go (100%) rename engine/{common => access}/state_stream/streamer_test.go (98%) rename engine/{common => access}/state_stream/subscription.go (100%) rename engine/{common => access}/state_stream/subscription_test.go (98%) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index bb0d0d88e21..39d317e9811 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -44,7 +44,6 @@ import ( "github.com/onflow/flow-go/engine/access/state_stream" followereng "github.com/onflow/flow-go/engine/common/follower" "github.com/onflow/flow-go/engine/common/requester" - cstate_stream "github.com/onflow/flow-go/engine/common/state_stream" synceng "github.com/onflow/flow-go/engine/common/synchronization" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/model/flow/filter" @@ -120,8 +119,8 @@ type AccessNodeConfig struct { apiRatelimits map[string]int apiBurstlimits map[string]int rpcConf rpc.Config - stateStreamBackend *cstate_stream.StateStreamBackend - stateStreamConf cstate_stream.Config + stateStreamBackend *state_stream.StateStreamBackend + stateStreamConf state_stream.Config stateStreamFilterConf map[string]int ExecutionNodeAddress string // deprecated HistoricalAccessRPCs []access.AccessAPIClient @@ -175,14 +174,14 @@ func DefaultAccessNodeConfig() *AccessNodeConfig { }, MaxMsgSize: grpcutils.DefaultMaxMsgSize, }, - stateStreamConf: cstate_stream.Config{ + stateStreamConf: state_stream.Config{ MaxExecutionDataMsgSize: grpcutils.DefaultMaxMsgSize, - ExecutionDataCacheSize: cstate_stream.DefaultCacheSize, - ClientSendTimeout: cstate_stream.DefaultSendTimeout, - ClientSendBufferSize: cstate_stream.DefaultSendBufferSize, - MaxGlobalStreams: cstate_stream.DefaultMaxGlobalStreams, - EventFilterConfig: cstate_stream.DefaultEventFilterConfig, - ResponseLimit: cstate_stream.DefaultResponseLimit, + ExecutionDataCacheSize: state_stream.DefaultCacheSize, + ClientSendTimeout: state_stream.DefaultSendTimeout, + ClientSendBufferSize: state_stream.DefaultSendBufferSize, + MaxGlobalStreams: state_stream.DefaultMaxGlobalStreams, + EventFilterConfig: state_stream.DefaultEventFilterConfig, + ResponseLimit: state_stream.DefaultResponseLimit, }, stateStreamBackend: nil, stateStreamFilterConf: nil, @@ -622,7 +621,7 @@ func (builder *FlowAccessNodeBuilder) BuildStateStreamPool() *FlowAccessNodeBuil } broadcaster := engine.NewBroadcaster() - builder.stateStreamBackend, err = cstate_stream.New(node.Logger, + builder.stateStreamBackend, err = state_stream.New(node.Logger, builder.stateStreamConf, node.State, node.Storage.Headers, diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index bb25be78ff7..17d0d2268ed 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -33,8 +33,8 @@ import ( "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" rpcConnection "github.com/onflow/flow-go/engine/access/rpc/connection" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/common/follower" - "github.com/onflow/flow-go/engine/common/state_stream" synceng "github.com/onflow/flow-go/engine/common/synchronization" "github.com/onflow/flow-go/engine/protocol" "github.com/onflow/flow-go/model/encodable" diff --git a/engine/access/integration_unsecure_grpc_server_test.go b/engine/access/integration_unsecure_grpc_server_test.go index 86fd1f4e985..2dbe92f156b 100644 --- a/engine/access/integration_unsecure_grpc_server_test.go +++ b/engine/access/integration_unsecure_grpc_server_test.go @@ -21,7 +21,6 @@ import ( "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" "github.com/onflow/flow-go/engine/access/state_stream" - cstatestream "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/blobs" "github.com/onflow/flow-go/module/executiondatasync/execution_data" @@ -114,7 +113,7 @@ func (suite *SameGRPCPortTestSuite) SetupTest() { suite.broadcaster = engine.NewBroadcaster() - suite.execDataHeroCache = herocache.NewBlockExecutionData(cstatestream.DefaultCacheSize, suite.log, metrics.NewNoopCollector()) + suite.execDataHeroCache = herocache.NewBlockExecutionData(state_stream.DefaultCacheSize, suite.log, metrics.NewNoopCollector()) suite.execDataCache = cache.NewExecutionDataCache(suite.eds, suite.headers, suite.seals, suite.results, suite.execDataHeroCache) accessIdentity := unittest.IdentityFixture(unittest.WithRole(flow.RoleAccess)) @@ -205,7 +204,7 @@ func (suite *SameGRPCPortTestSuite) SetupTest() { suite.secureGrpcServer, suite.unsecureGrpcServer, nil, - cstatestream.DefaultEventFilterConfig, + state_stream.DefaultEventFilterConfig, 0, ) assert.NoError(suite.T(), err) @@ -228,12 +227,12 @@ func (suite *SameGRPCPortTestSuite) SetupTest() { }, ).Maybe() - conf := cstatestream.Config{ - ClientSendTimeout: cstatestream.DefaultSendTimeout, - ClientSendBufferSize: cstatestream.DefaultSendBufferSize, + conf := state_stream.Config{ + ClientSendTimeout: state_stream.DefaultSendTimeout, + ClientSendBufferSize: state_stream.DefaultSendBufferSize, } - stateStreamBackend, err := cstatestream.New( + stateStreamBackend, err := state_stream.New( suite.log, conf, suite.state, diff --git a/engine/access/rest/routes/router.go b/engine/access/rest/routes/router.go index 4032aef6165..15acb6e9338 100644 --- a/engine/access/rest/routes/router.go +++ b/engine/access/rest/routes/router.go @@ -12,7 +12,7 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/middleware" "github.com/onflow/flow-go/engine/access/rest/models" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" ) diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index 27a7b012b9d..b9274017727 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -12,7 +12,7 @@ import ( "github.com/onflow/flow-go/engine/access/rest/models" "github.com/onflow/flow-go/engine/access/rest/request" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" ) const ( diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index e9923ebca6f..a65fac232bc 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -18,8 +18,8 @@ import ( "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/request" + "github.com/onflow/flow-go/engine/access/state_stream" mockstatestream "github.com/onflow/flow-go/engine/access/state_stream/mock" - "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) @@ -125,7 +125,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { // Create a channel to receive mock EventsResponse objects ch := make(chan interface{}) var chReadOnly <-chan interface{} - expectedEventsResponses := []*state_stream.EventsResponse{} + var expectedEventsResponses []*state_stream.EventsResponse for i, b := range s.blocks { s.T().Logf("checking block %d %v", i, b.ID()) diff --git a/engine/access/rest/routes/test_helpers.go b/engine/access/rest/routes/test_helpers.go index ed9f8d973b2..ba5cfe635ed 100644 --- a/engine/access/rest/routes/test_helpers.go +++ b/engine/access/rest/routes/test_helpers.go @@ -18,8 +18,8 @@ import ( "github.com/stretchr/testify/require" "github.com/onflow/flow-go/access/mock" + "github.com/onflow/flow-go/engine/access/state_stream" mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" - common_state_stream "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/metrics" ) @@ -87,9 +87,9 @@ func newRouter(backend *mock.API, stateStreamApi *mock_state_stream.API) (*mux.R logger := zerolog.New(&b) restCollector := metrics.NewNoopCollector() - stateStreamConfig := common_state_stream.Config{ - EventFilterConfig: common_state_stream.DefaultEventFilterConfig, - MaxGlobalStreams: common_state_stream.DefaultMaxGlobalStreams, + stateStreamConfig := state_stream.Config{ + EventFilterConfig: state_stream.DefaultEventFilterConfig, + MaxGlobalStreams: state_stream.DefaultMaxGlobalStreams, } return NewRouter(backend, @@ -107,8 +107,8 @@ func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_s return nil, err } - br := bufio.NewReaderSize(strings.NewReader(""), common_state_stream.DefaultSendBufferSize) - bw := bufio.NewWriterSize(&bytes.Buffer{}, common_state_stream.DefaultSendBufferSize) + br := bufio.NewReaderSize(strings.NewReader(""), state_stream.DefaultSendBufferSize) + bw := bufio.NewWriterSize(&bytes.Buffer{}, state_stream.DefaultSendBufferSize) resp := NewHijackResponseRecorder(bufio.NewReadWriter(br, bw)) router.ServeHTTP(resp, req) diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 095a281f68d..5dc2fd8c303 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -7,7 +7,7 @@ import ( "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/rest/request" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" ) diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index 4c23ca03d84..db14cb6c5db 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -9,7 +9,7 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/engine/access/rest/routes" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" ) diff --git a/engine/access/rest_api_test.go b/engine/access/rest_api_test.go index 053848dcdb6..31140959734 100644 --- a/engine/access/rest_api_test.go +++ b/engine/access/rest_api_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc/credentials" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/module/grpcserver" "github.com/onflow/flow-go/utils/grpcutils" @@ -28,7 +29,6 @@ import ( "github.com/onflow/flow-go/engine/access/rest/routes" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" - "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/irrecoverable" "github.com/onflow/flow-go/module/metrics" diff --git a/engine/access/rpc/engine.go b/engine/access/rpc/engine.go index e64c8780e30..b19b2881ef8 100644 --- a/engine/access/rpc/engine.go +++ b/engine/access/rpc/engine.go @@ -16,7 +16,7 @@ import ( "github.com/onflow/flow-go/consensus/hotstuff/model" "github.com/onflow/flow-go/engine/access/rest" "github.com/onflow/flow-go/engine/access/rpc/backend" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" "github.com/onflow/flow-go/module/component" diff --git a/engine/access/rpc/rate_limit_test.go b/engine/access/rpc/rate_limit_test.go index 9ab2bf1aaa2..2acf69a54b4 100644 --- a/engine/access/rpc/rate_limit_test.go +++ b/engine/access/rpc/rate_limit_test.go @@ -3,6 +3,7 @@ package rpc import ( "context" "fmt" + "io" "os" "testing" @@ -21,7 +22,7 @@ import ( accessmock "github.com/onflow/flow-go/engine/access/mock" "github.com/onflow/flow-go/engine/access/rpc/backend" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/grpcserver" "github.com/onflow/flow-go/module/irrecoverable" diff --git a/engine/access/secure_grpcr_test.go b/engine/access/secure_grpcr_test.go index ef0fbc89f26..a5570849c71 100644 --- a/engine/access/secure_grpcr_test.go +++ b/engine/access/secure_grpcr_test.go @@ -21,7 +21,7 @@ import ( accessmock "github.com/onflow/flow-go/engine/access/mock" "github.com/onflow/flow-go/engine/access/rpc" "github.com/onflow/flow-go/engine/access/rpc/backend" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/grpcserver" "github.com/onflow/flow-go/module/irrecoverable" diff --git a/engine/common/state_stream/backend.go b/engine/access/state_stream/backend.go similarity index 100% rename from engine/common/state_stream/backend.go rename to engine/access/state_stream/backend.go diff --git a/engine/common/state_stream/backend_events.go b/engine/access/state_stream/backend_events.go similarity index 100% rename from engine/common/state_stream/backend_events.go rename to engine/access/state_stream/backend_events.go diff --git a/engine/common/state_stream/backend_events_test.go b/engine/access/state_stream/backend_events_test.go similarity index 100% rename from engine/common/state_stream/backend_events_test.go rename to engine/access/state_stream/backend_events_test.go diff --git a/engine/common/state_stream/backend_executiondata.go b/engine/access/state_stream/backend_executiondata.go similarity index 100% rename from engine/common/state_stream/backend_executiondata.go rename to engine/access/state_stream/backend_executiondata.go diff --git a/engine/common/state_stream/backend_executiondata_test.go b/engine/access/state_stream/backend_executiondata_test.go similarity index 100% rename from engine/common/state_stream/backend_executiondata_test.go rename to engine/access/state_stream/backend_executiondata_test.go diff --git a/engine/access/state_stream/engine.go b/engine/access/state_stream/engine.go index 4fd15309e4c..7afe0c9c225 100644 --- a/engine/access/state_stream/engine.go +++ b/engine/access/state_stream/engine.go @@ -4,7 +4,6 @@ import ( "github.com/rs/zerolog" "github.com/onflow/flow-go/engine" - "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/component" "github.com/onflow/flow-go/module/executiondatasync/execution_data" @@ -23,8 +22,8 @@ import ( type Engine struct { *component.ComponentManager log zerolog.Logger - backend *state_stream.StateStreamBackend - config state_stream.Config + backend *StateStreamBackend + config Config chain flow.Chain handler *Handler @@ -36,12 +35,12 @@ type Engine struct { // NewEng returns a new ingress server. func NewEng( log zerolog.Logger, - config state_stream.Config, + config Config, execDataCache *cache.ExecutionDataCache, headers storage.Headers, chainID flow.ChainID, server *grpcserver.GrpcServer, - backend *state_stream.StateStreamBackend, + backend *StateStreamBackend, broadcaster *engine.Broadcaster, ) (*Engine, error) { logger := log.With().Str("engine", "state_stream_rpc").Logger() diff --git a/engine/common/state_stream/event.go b/engine/access/state_stream/event.go similarity index 100% rename from engine/common/state_stream/event.go rename to engine/access/state_stream/event.go diff --git a/engine/common/state_stream/event_test.go b/engine/access/state_stream/event_test.go similarity index 97% rename from engine/common/state_stream/event_test.go rename to engine/access/state_stream/event_test.go index 65c8629989a..3dbccd34406 100644 --- a/engine/common/state_stream/event_test.go +++ b/engine/access/state_stream/event_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" ) diff --git a/engine/common/state_stream/filter.go b/engine/access/state_stream/filter.go similarity index 100% rename from engine/common/state_stream/filter.go rename to engine/access/state_stream/filter.go diff --git a/engine/common/state_stream/filter_test.go b/engine/access/state_stream/filter_test.go similarity index 98% rename from engine/common/state_stream/filter_test.go rename to engine/access/state_stream/filter_test.go index 982687ab756..d25c272a06f 100644 --- a/engine/common/state_stream/filter_test.go +++ b/engine/access/state_stream/filter_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/assert" - state_stream "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) diff --git a/engine/access/state_stream/handler.go b/engine/access/state_stream/handler.go index 446f71adb32..badf4b5d3b5 100644 --- a/engine/access/state_stream/handler.go +++ b/engine/access/state_stream/handler.go @@ -12,21 +12,20 @@ import ( "github.com/onflow/flow-go/engine/common/rpc" "github.com/onflow/flow-go/engine/common/rpc/convert" - "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/model/flow" ) type Handler struct { - api state_stream.API + api API chain flow.Chain - eventFilterConfig state_stream.EventFilterConfig + eventFilterConfig EventFilterConfig maxStreams int32 streamCount atomic.Int32 } -func NewHandler(api state_stream.API, chain flow.Chain, conf state_stream.EventFilterConfig, maxGlobalStreams uint32) *Handler { +func NewHandler(api API, chain flow.Chain, conf EventFilterConfig, maxGlobalStreams uint32) *Handler { h := &Handler{ api: api, chain: chain, @@ -84,7 +83,7 @@ func (h *Handler) SubscribeExecutionData(request *access.SubscribeExecutionDataR return nil } - resp, ok := v.(*state_stream.ExecutionDataResponse) + resp, ok := v.(*ExecutionDataResponse) if !ok { return status.Errorf(codes.Internal, "unexpected response type: %T", v) } @@ -121,11 +120,11 @@ func (h *Handler) SubscribeEvents(request *access.SubscribeEventsRequest, stream startBlockID = blockID } - filter := state_stream.EventFilter{} + filter := EventFilter{} if request.GetFilter() != nil { var err error reqFilter := request.GetFilter() - filter, err = state_stream.NewEventFilter( + filter, err = NewEventFilter( h.eventFilterConfig, h.chain, reqFilter.GetEventType(), @@ -147,7 +146,7 @@ func (h *Handler) SubscribeEvents(request *access.SubscribeEventsRequest, stream return nil } - resp, ok := v.(*state_stream.EventsResponse) + resp, ok := v.(*EventsResponse) if !ok { return status.Errorf(codes.Internal, "unexpected response type: %T", v) } diff --git a/engine/access/state_stream/mock/api.go b/engine/access/state_stream/mock/api.go index 8ddbe1dfb86..2548022d5d1 100644 --- a/engine/access/state_stream/mock/api.go +++ b/engine/access/state_stream/mock/api.go @@ -4,8 +4,7 @@ package mock import ( context "context" - state_stream2 "github.com/onflow/flow-go/engine/common/state_stream" - + "github.com/onflow/flow-go/engine/access/state_stream" flow "github.com/onflow/flow-go/model/flow" execution_data "github.com/onflow/flow-go/module/executiondatasync/execution_data" @@ -44,15 +43,15 @@ func (_m *API) GetExecutionDataByBlockID(ctx context.Context, blockID flow.Ident } // SubscribeEvents provides a mock function with given fields: ctx, startBlockID, startHeight, filter -func (_m *API) SubscribeEvents(ctx context.Context, startBlockID flow.Identifier, startHeight uint64, filter state_stream2.EventFilter) state_stream2.Subscription { +func (_m *API) SubscribeEvents(ctx context.Context, startBlockID flow.Identifier, startHeight uint64, filter state_stream.EventFilter) state_stream.Subscription { ret := _m.Called(ctx, startBlockID, startHeight, filter) - var r0 state_stream2.Subscription - if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, uint64, state_stream2.EventFilter) state_stream2.Subscription); ok { + var r0 state_stream.Subscription + if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, uint64, state_stream.EventFilter) state_stream.Subscription); ok { r0 = rf(ctx, startBlockID, startHeight, filter) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(state_stream2.Subscription) + r0 = ret.Get(0).(state_stream.Subscription) } } @@ -60,15 +59,15 @@ func (_m *API) SubscribeEvents(ctx context.Context, startBlockID flow.Identifier } // SubscribeExecutionData provides a mock function with given fields: ctx, startBlockID, startBlockHeight -func (_m *API) SubscribeExecutionData(ctx context.Context, startBlockID flow.Identifier, startBlockHeight uint64) state_stream2.Subscription { +func (_m *API) SubscribeExecutionData(ctx context.Context, startBlockID flow.Identifier, startBlockHeight uint64) state_stream.Subscription { ret := _m.Called(ctx, startBlockID, startBlockHeight) - var r0 state_stream2.Subscription - if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, uint64) state_stream2.Subscription); ok { + var r0 state_stream.Subscription + if rf, ok := ret.Get(0).(func(context.Context, flow.Identifier, uint64) state_stream.Subscription); ok { r0 = rf(ctx, startBlockID, startBlockHeight) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(state_stream2.Subscription) + r0 = ret.Get(0).(state_stream.Subscription) } } diff --git a/engine/common/state_stream/streamer.go b/engine/access/state_stream/streamer.go similarity index 100% rename from engine/common/state_stream/streamer.go rename to engine/access/state_stream/streamer.go diff --git a/engine/common/state_stream/streamer_test.go b/engine/access/state_stream/streamer_test.go similarity index 98% rename from engine/common/state_stream/streamer_test.go rename to engine/access/state_stream/streamer_test.go index f728b39c0a3..6c80feec7ed 100644 --- a/engine/common/state_stream/streamer_test.go +++ b/engine/access/state_stream/streamer_test.go @@ -3,7 +3,6 @@ package state_stream_test import ( "context" "fmt" - "testing" "time" @@ -12,8 +11,8 @@ import ( "github.com/stretchr/testify/mock" "github.com/onflow/flow-go/engine" + "github.com/onflow/flow-go/engine/access/state_stream" streammock "github.com/onflow/flow-go/engine/access/state_stream/mock" - state_stream "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/utils/unittest" ) diff --git a/engine/common/state_stream/subscription.go b/engine/access/state_stream/subscription.go similarity index 100% rename from engine/common/state_stream/subscription.go rename to engine/access/state_stream/subscription.go diff --git a/engine/common/state_stream/subscription_test.go b/engine/access/state_stream/subscription_test.go similarity index 98% rename from engine/common/state_stream/subscription_test.go rename to engine/access/state_stream/subscription_test.go index 81b0edf013b..d5ef7296cf3 100644 --- a/engine/common/state_stream/subscription_test.go +++ b/engine/access/state_stream/subscription_test.go @@ -3,7 +3,6 @@ package state_stream_test import ( "context" "fmt" - "sync" "testing" "time" @@ -11,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/utils/unittest" ) diff --git a/integration/tests/access/rest_state_stream_test.go b/integration/tests/access/rest_state_stream_test.go index c8773358e4a..9660e84805f 100644 --- a/integration/tests/access/rest_state_stream_test.go +++ b/integration/tests/access/rest_state_stream_test.go @@ -16,8 +16,8 @@ import ( "google.golang.org/grpc/credentials/insecure" "github.com/onflow/flow-go/engine/access/rest/request" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/engine/common/rpc/convert" - "github.com/onflow/flow-go/engine/common/state_stream" "github.com/onflow/flow-go/integration/testnet" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" diff --git a/module/state_synchronization/requester/execution_data_requester_test.go b/module/state_synchronization/requester/execution_data_requester_test.go index c1ff44f0771..5ac29329094 100644 --- a/module/state_synchronization/requester/execution_data_requester_test.go +++ b/module/state_synchronization/requester/execution_data_requester_test.go @@ -3,7 +3,6 @@ package requester_test import ( "context" "fmt" - "math/rand" "sync" "testing" @@ -19,7 +18,7 @@ import ( "github.com/onflow/flow-go/consensus/hotstuff/model" "github.com/onflow/flow-go/consensus/hotstuff/notifications/pubsub" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module" "github.com/onflow/flow-go/module/blobs" diff --git a/module/state_synchronization/requester/jobs/execution_data_reader_test.go b/module/state_synchronization/requester/jobs/execution_data_reader_test.go index c5e19ae8c26..90240c83dd8 100644 --- a/module/state_synchronization/requester/jobs/execution_data_reader_test.go +++ b/module/state_synchronization/requester/jobs/execution_data_reader_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/onflow/flow-go/engine/common/state_stream" + "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/executiondatasync/execution_data" "github.com/onflow/flow-go/module/executiondatasync/execution_data/cache" From 296afffe949aaabc72c676d6dfbec599549cf170 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 9 Aug 2023 15:03:06 +0300 Subject: [PATCH 17/35] Updated rest test according to comments, removed unnecessary empty lines, linted --- .../node_builder/access_node_builder.go | 4 +-- cmd/observer/node_builder/observer_builder.go | 2 +- engine/access/rest/middleware/logging.go | 2 +- .../access/rest/request/subscribe_events.go | 1 - engine/access/rest/routes/accounts_test.go | 13 ++++------ engine/access/rest/routes/blocks_test.go | 4 +-- engine/access/rest/routes/collections_test.go | 8 +++--- engine/access/rest/routes/events_test.go | 4 +-- .../rest/routes/execution_result_test.go | 11 +++----- engine/access/rest/routes/network_test.go | 4 +-- .../rest/routes/node_version_info_test.go | 4 +-- engine/access/rest/routes/scripts_test.go | 16 +++++------- .../access/rest/routes/transactions_test.go | 26 ++++++++----------- engine/access/state_stream/backend.go | 2 +- .../state_stream/backend_events_test.go | 4 +-- .../backend_executiondata_test.go | 6 ++--- engine/access/state_stream/engine.go | 2 +- engine/access/state_stream/handler.go | 2 +- 18 files changed, 45 insertions(+), 70 deletions(-) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index 39d317e9811..7b34c37326c 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -439,7 +439,7 @@ func (builder *FlowAccessNodeBuilder) BuildConsensusFollower() *FlowAccessNodeBu return builder } -func (builder *FlowAccessNodeBuilder) BuildStateStreamPool() *FlowAccessNodeBuilder { +func (builder *FlowAccessNodeBuilder) BuildStateStreamComponentsAndModules() *FlowAccessNodeBuilder { var ds *badger.Datastore var bs network.BlobService var processedBlockHeight storage.ConsumerProgress @@ -939,7 +939,7 @@ func (builder *FlowAccessNodeBuilder) enqueueRelayNetwork() { func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { if builder.executionDataSyncEnabled { - builder.BuildStateStreamPool() + builder.BuildStateStreamComponentsAndModules() } builder. diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index 17d0d2268ed..3f1b636f1a5 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -975,7 +975,7 @@ func (builder *ObserverServiceBuilder) enqueueRPCServer() { restHandler, builder.secureGrpcServer, builder.unsecureGrpcServer, - nil, + nil, // state streaming is not supported state_stream.DefaultEventFilterConfig, 0, ) diff --git a/engine/access/rest/middleware/logging.go b/engine/access/rest/middleware/logging.go index 9c524681562..e3f7de64341 100644 --- a/engine/access/rest/middleware/logging.go +++ b/engine/access/rest/middleware/logging.go @@ -43,7 +43,7 @@ type responseWriter struct { statusCode int } -// http.Hijacker necessary for upgrading gorilla websocket connection for "subscribe_events" route. +// http.Hijacker necessary for using middleware with gorilla websocket connections. var _ http.Hijacker = (*responseWriter)(nil) func newResponseWriter(w http.ResponseWriter) *responseWriter { diff --git a/engine/access/rest/request/subscribe_events.go b/engine/access/rest/request/subscribe_events.go index ac73a1842df..1d115198fb7 100644 --- a/engine/access/rest/request/subscribe_events.go +++ b/engine/access/rest/request/subscribe_events.go @@ -62,7 +62,6 @@ func (g *SubscribeEvents) Parse(rawStartBlockID string, rawStartHeight string, r } g.EventTypes = eventTypes.Flow() - g.Addresses = rawAddresses g.Contracts = rawContracts diff --git a/engine/access/rest/routes/accounts_test.go b/engine/access/rest/routes/accounts_test.go index a1803561ed3..cbda5d9acf1 100644 --- a/engine/access/rest/routes/accounts_test.go +++ b/engine/access/rest/routes/accounts_test.go @@ -2,7 +2,6 @@ package routes import ( "fmt" - "net/http" "net/url" "strings" @@ -14,7 +13,6 @@ import ( "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/middleware" - mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) @@ -45,7 +43,6 @@ func accountURL(t *testing.T, address string, height string) string { // 5. Get invalid account. func TestAccessGetAccount(t *testing.T) { backend := &mock.API{} - stateStreamBackend := &mock_state_stream.API{} t.Run("get by address at latest sealed block", func(t *testing.T) { account := accountFixture(t) @@ -64,7 +61,7 @@ func TestAccessGetAccount(t *testing.T) { expected := expectedExpandedResponse(account) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -84,7 +81,7 @@ func TestAccessGetAccount(t *testing.T) { expected := expectedExpandedResponse(account) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -99,7 +96,7 @@ func TestAccessGetAccount(t *testing.T) { expected := expectedExpandedResponse(account) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -114,7 +111,7 @@ func TestAccessGetAccount(t *testing.T) { expected := expectedCondensedResponse(account) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -129,7 +126,7 @@ func TestAccessGetAccount(t *testing.T) { for i, test := range tests { req, _ := http.NewRequest("GET", test.url, nil) - rr, err := executeRequest(req, backend, stateStreamBackend) + rr, err := executeRequest(req, backend, nil) assert.NoError(t, err) assert.Equal(t, http.StatusBadRequest, rr.Code) diff --git a/engine/access/rest/routes/blocks_test.go b/engine/access/rest/routes/blocks_test.go index 43973958854..e60837179dc 100644 --- a/engine/access/rest/routes/blocks_test.go +++ b/engine/access/rest/routes/blocks_test.go @@ -20,7 +20,6 @@ import ( "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/middleware" - mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) @@ -144,14 +143,13 @@ func prepareTestVectors(t *testing.T, // TestGetBlocks tests local get blocks by ID and get blocks by heights API func TestAccessGetBlocks(t *testing.T) { backend := &mock.API{} - stateStreamBackend := &mock_state_stream.API{} blkCnt := 10 blockIDs, heights, blocks, executionResults := generateMocks(backend, blkCnt) testVectors := prepareTestVectors(t, blockIDs, heights, blocks, executionResults, blkCnt) for _, tv := range testVectors { - responseRec, err := executeRequest(tv.request, backend, stateStreamBackend) + responseRec, err := executeRequest(tv.request, backend, nil) assert.NoError(t, err) require.Equal(t, tv.expectedStatus, responseRec.Code, "failed test %s: incorrect response code", tv.description) actualResp := responseRec.Body.String() diff --git a/engine/access/rest/routes/collections_test.go b/engine/access/rest/routes/collections_test.go index 8378efebee7..dec00e67ef3 100644 --- a/engine/access/rest/routes/collections_test.go +++ b/engine/access/rest/routes/collections_test.go @@ -16,7 +16,6 @@ import ( mocks "github.com/stretchr/testify/mock" "github.com/onflow/flow-go/access/mock" - mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/utils/unittest" ) @@ -32,7 +31,6 @@ func getCollectionReq(id string, expandTransactions bool) *http.Request { func TestGetCollections(t *testing.T) { backend := &mock.API{} - stateStreamBackend := &mock_state_stream.API{} t.Run("get by ID", func(t *testing.T) { inputs := []flow.LightCollection{ @@ -64,7 +62,7 @@ func TestGetCollections(t *testing.T) { }`, col.ID(), col.ID(), transactionsStr) req := getCollectionReq(col.ID().String(), false) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) mocks.AssertExpectationsForObjects(t, backend) } }) @@ -89,7 +87,7 @@ func TestGetCollections(t *testing.T) { Once() req := getCollectionReq(col.ID().String(), true) - rr, err := executeRequest(req, backend, stateStreamBackend) + rr, err := executeRequest(req, backend, nil) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rr.Code) @@ -148,7 +146,7 @@ func TestGetCollections(t *testing.T) { Return(test.mockValue, test.mockErr) } req := getCollectionReq(test.id, false) - assertResponse(t, req, test.status, test.response, backend, stateStreamBackend) + assertResponse(t, req, test.status, test.response, backend, nil) } }) } diff --git a/engine/access/rest/routes/events_test.go b/engine/access/rest/routes/events_test.go index 305c330e9b7..83608331a2c 100644 --- a/engine/access/rest/routes/events_test.go +++ b/engine/access/rest/routes/events_test.go @@ -16,14 +16,12 @@ import ( "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/util" - mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) func TestGetEvents(t *testing.T) { backend := &mock.API{} - stateStreamBackend := &mock_state_stream.API{} events := generateEventsMocks(backend, 5) allBlockIDs := make([]string, len(events)) @@ -126,7 +124,7 @@ func TestGetEvents(t *testing.T) { for _, test := range testVectors { t.Run(test.description, func(t *testing.T) { - assertResponse(t, test.request, test.expectedStatus, test.expectedResponse, backend, stateStreamBackend) + assertResponse(t, test.request, test.expectedStatus, test.expectedResponse, backend, nil) }) } diff --git a/engine/access/rest/routes/execution_result_test.go b/engine/access/rest/routes/execution_result_test.go index f7553a1682e..c0037727187 100644 --- a/engine/access/rest/routes/execution_result_test.go +++ b/engine/access/rest/routes/execution_result_test.go @@ -14,7 +14,6 @@ import ( "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/util" - mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) @@ -38,7 +37,6 @@ func getResultByIDReq(id string, blockIDs []string) *http.Request { } func TestGetResultByID(t *testing.T) { - stateStreamBackend := &mock_state_stream.API{} t.Run("get by ID", func(t *testing.T) { backend := &mock.API{} @@ -51,7 +49,7 @@ func TestGetResultByID(t *testing.T) { req := getResultByIDReq(id.String(), nil) expected := executionResultExpectedStr(result) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) mocks.AssertExpectationsForObjects(t, backend) }) @@ -64,13 +62,12 @@ func TestGetResultByID(t *testing.T) { Once() req := getResultByIDReq(id.String(), nil) - assertResponse(t, req, http.StatusNotFound, `{"code":404,"message":"Flow resource not found: block not found"}`, backend, stateStreamBackend) + assertResponse(t, req, http.StatusNotFound, `{"code":404,"message":"Flow resource not found: block not found"}`, backend, nil) mocks.AssertExpectationsForObjects(t, backend) }) } func TestGetResultBlockID(t *testing.T) { - stateStreamBackend := &mock_state_stream.API{} t.Run("get by block ID", func(t *testing.T) { backend := &mock.API{} @@ -85,7 +82,7 @@ func TestGetResultBlockID(t *testing.T) { req := getResultByIDReq("", []string{blockID.String()}) expected := fmt.Sprintf(`[%s]`, executionResultExpectedStr(result)) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) mocks.AssertExpectationsForObjects(t, backend) }) @@ -98,7 +95,7 @@ func TestGetResultBlockID(t *testing.T) { Once() req := getResultByIDReq("", []string{blockID.String()}) - assertResponse(t, req, http.StatusNotFound, `{"code":404,"message":"Flow resource not found: block not found"}`, backend, stateStreamBackend) + assertResponse(t, req, http.StatusNotFound, `{"code":404,"message":"Flow resource not found: block not found"}`, backend, nil) mocks.AssertExpectationsForObjects(t, backend) }) } diff --git a/engine/access/rest/routes/network_test.go b/engine/access/rest/routes/network_test.go index 4bfba3395b1..0cb3bc2b6c8 100644 --- a/engine/access/rest/routes/network_test.go +++ b/engine/access/rest/routes/network_test.go @@ -11,7 +11,6 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/access/mock" - mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/model/flow" ) @@ -24,7 +23,6 @@ func networkURL(t *testing.T) string { func TestGetNetworkParameters(t *testing.T) { backend := &mock.API{} - stateStreamBackend := &mock_state_stream.API{} t.Run("get network parameters on mainnet", func(t *testing.T) { @@ -40,7 +38,7 @@ func TestGetNetworkParameters(t *testing.T) { expected := networkParametersExpectedStr(flow.Mainnet) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) } diff --git a/engine/access/rest/routes/node_version_info_test.go b/engine/access/rest/routes/node_version_info_test.go index 4a86265bc99..5d69131d58f 100644 --- a/engine/access/rest/routes/node_version_info_test.go +++ b/engine/access/rest/routes/node_version_info_test.go @@ -12,7 +12,6 @@ import ( "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/cmd/build" - mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/utils/unittest" ) @@ -25,7 +24,6 @@ func nodeVersionInfoURL(t *testing.T) string { func TestGetNodeVersionInfo(t *testing.T) { backend := mock.NewAPI(t) - stateStreamBackend := &mock_state_stream.API{} t.Run("get node version info", func(t *testing.T) { req := getNodeVersionInfoRequest(t) @@ -43,7 +41,7 @@ func TestGetNodeVersionInfo(t *testing.T) { expected := nodeVersionInfoExpectedStr(params) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) } diff --git a/engine/access/rest/routes/scripts_test.go b/engine/access/rest/routes/scripts_test.go index 73adb0eeefc..5e1b15ca86f 100644 --- a/engine/access/rest/routes/scripts_test.go +++ b/engine/access/rest/routes/scripts_test.go @@ -9,14 +9,12 @@ import ( "net/url" "testing" - "github.com/onflow/flow-go/engine/access/rest/util" - mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" - mocks "github.com/stretchr/testify/mock" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "github.com/onflow/flow-go/access/mock" + "github.com/onflow/flow-go/engine/access/rest/util" "github.com/onflow/flow-go/model/flow" ) @@ -47,8 +45,6 @@ func TestScripts(t *testing.T) { "arguments": []string{util.ToBase64(validArgs)}, } - stateStreamBackend := &mock_state_stream.API{} - t.Run("get by Latest height", func(t *testing.T) { backend := &mock.API{} backend.Mock. @@ -59,7 +55,7 @@ func TestScripts(t *testing.T) { assertOKResponse(t, req, fmt.Sprintf( "\"%s\"", base64.StdEncoding.EncodeToString([]byte(`hello world`)), - ), backend, stateStreamBackend) + ), backend, nil) }) t.Run("get by height", func(t *testing.T) { @@ -74,7 +70,7 @@ func TestScripts(t *testing.T) { assertOKResponse(t, req, fmt.Sprintf( "\"%s\"", base64.StdEncoding.EncodeToString([]byte(`hello world`)), - ), backend, stateStreamBackend) + ), backend, nil) }) t.Run("get by ID", func(t *testing.T) { @@ -89,7 +85,7 @@ func TestScripts(t *testing.T) { assertOKResponse(t, req, fmt.Sprintf( "\"%s\"", base64.StdEncoding.EncodeToString([]byte(`hello world`)), - ), backend, stateStreamBackend) + ), backend, nil) }) t.Run("get error", func(t *testing.T) { @@ -105,7 +101,7 @@ func TestScripts(t *testing.T) { http.StatusBadRequest, `{"code":400, "message":"Invalid Flow request: internal server error"}`, backend, - stateStreamBackend, + nil, ) }) @@ -130,7 +126,7 @@ func TestScripts(t *testing.T) { for _, test := range tests { req := scriptReq(test.id, test.height, test.body) - assertResponse(t, req, http.StatusBadRequest, test.out, backend, stateStreamBackend) + assertResponse(t, req, http.StatusBadRequest, test.out, backend, nil) } }) } diff --git a/engine/access/rest/routes/transactions_test.go b/engine/access/rest/routes/transactions_test.go index 0accd2dd72c..ca5eec50670 100644 --- a/engine/access/rest/routes/transactions_test.go +++ b/engine/access/rest/routes/transactions_test.go @@ -19,7 +19,6 @@ import ( "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/models" "github.com/onflow/flow-go/engine/access/rest/util" - mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/utils/unittest" ) @@ -71,7 +70,6 @@ func createTransactionReq(body interface{}) *http.Request { } func TestGetTransactions(t *testing.T) { - stateStreamBackend := &mock_state_stream.API{} t.Run("get by ID without results", func(t *testing.T) { backend := &mock.API{} @@ -115,7 +113,7 @@ func TestGetTransactions(t *testing.T) { }`, tx.ID(), tx.ReferenceBlockID, util.ToBase64(tx.EnvelopeSignatures[0].Signature), tx.ID(), tx.ID()) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) }) t.Run("Get by ID with results", func(t *testing.T) { @@ -184,7 +182,7 @@ func TestGetTransactions(t *testing.T) { } }`, tx.ID(), tx.ReferenceBlockID, util.ToBase64(tx.EnvelopeSignatures[0].Signature), tx.ReferenceBlockID, txr.CollectionID, tx.ID(), tx.ID(), tx.ID()) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) }) t.Run("get by ID Invalid", func(t *testing.T) { @@ -192,7 +190,7 @@ func TestGetTransactions(t *testing.T) { req := getTransactionReq("invalid", false, "", "") expected := `{"code":400, "message":"invalid ID format"}` - assertResponse(t, req, http.StatusBadRequest, expected, backend, stateStreamBackend) + assertResponse(t, req, http.StatusBadRequest, expected, backend, nil) }) t.Run("get by ID non-existing", func(t *testing.T) { @@ -205,7 +203,7 @@ func TestGetTransactions(t *testing.T) { Return(nil, status.Error(codes.NotFound, "transaction not found")) expected := `{"code":404, "message":"Flow resource not found: transaction not found"}` - assertResponse(t, req, http.StatusNotFound, expected, backend, stateStreamBackend) + assertResponse(t, req, http.StatusNotFound, expected, backend, nil) }) } @@ -246,7 +244,6 @@ func TestGetTransactionResult(t *testing.T) { } }`, bid.String(), cid.String(), id.String(), util.ToBase64(txr.Events[0].Payload), id.String()) - stateStreamBackend := &mock_state_stream.API{} t.Run("get by transaction ID", func(t *testing.T) { backend := &mock.API{} req := getTransactionResultReq(id.String(), "", "") @@ -255,7 +252,7 @@ func TestGetTransactionResult(t *testing.T) { On("GetTransactionResult", mocks.Anything, id, flow.ZeroID, flow.ZeroID). Return(txr, nil) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) }) t.Run("get by block ID", func(t *testing.T) { @@ -266,7 +263,7 @@ func TestGetTransactionResult(t *testing.T) { On("GetTransactionResult", mocks.Anything, id, bid, flow.ZeroID). Return(txr, nil) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) }) t.Run("get by collection ID", func(t *testing.T) { @@ -277,7 +274,7 @@ func TestGetTransactionResult(t *testing.T) { On("GetTransactionResult", mocks.Anything, id, flow.ZeroID, cid). Return(txr, nil) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) }) t.Run("get execution statuses", func(t *testing.T) { @@ -324,7 +321,7 @@ func TestGetTransactionResult(t *testing.T) { "_self": "/v1/transaction_results/%s" } }`, bid.String(), cid.String(), err, cases.Title(language.English).String(strings.ToLower(txResult.Status.String())), txResult.ErrorMessage, id.String()) - assertOKResponse(t, req, expectedResp, backend, stateStreamBackend) + assertOKResponse(t, req, expectedResp, backend, nil) } }) @@ -333,12 +330,11 @@ func TestGetTransactionResult(t *testing.T) { req := getTransactionResultReq("invalid", "", "") expected := `{"code":400, "message":"invalid ID format"}` - assertResponse(t, req, http.StatusBadRequest, expected, backend, stateStreamBackend) + assertResponse(t, req, http.StatusBadRequest, expected, backend, nil) }) } func TestCreateTransaction(t *testing.T) { - stateStreamBackend := &mock_state_stream.API{} t.Run("create", func(t *testing.T) { backend := &mock.API{} @@ -389,7 +385,7 @@ func TestCreateTransaction(t *testing.T) { } }`, tx.ID(), tx.ReferenceBlockID, util.ToBase64(tx.PayloadSignatures[0].Signature), util.ToBase64(tx.EnvelopeSignatures[0].Signature), tx.ID(), tx.ID()) - assertOKResponse(t, req, expected, backend, stateStreamBackend) + assertOKResponse(t, req, expected, backend, nil) }) t.Run("post invalid transaction", func(t *testing.T) { @@ -417,7 +413,7 @@ func TestCreateTransaction(t *testing.T) { testTx[test.inputField] = test.inputValue req := createTransactionReq(testTx) - assertResponse(t, req, http.StatusBadRequest, test.output, backend, stateStreamBackend) + assertResponse(t, req, http.StatusBadRequest, test.output, backend, nil) } }) } diff --git a/engine/access/state_stream/backend.go b/engine/access/state_stream/backend.go index 0593472d162..371427ea680 100644 --- a/engine/access/state_stream/backend.go +++ b/engine/access/state_stream/backend.go @@ -225,6 +225,6 @@ func (b *StateStreamBackend) getStartHeight(startBlockID flow.Identifier, startH } // SetHighestHeight sets the highest height for which execution data is available. -func (b *StateStreamBackend) SetHighestHeight(height uint64) bool { +func (b *StateStreamBackend) setHighestHeight(height uint64) bool { return b.highestHeight.Set(height) } diff --git a/engine/access/state_stream/backend_events_test.go b/engine/access/state_stream/backend_events_test.go index 69a85d55907..68ca0a789cb 100644 --- a/engine/access/state_stream/backend_events_test.go +++ b/engine/access/state_stream/backend_events_test.go @@ -108,7 +108,7 @@ func (s *BackendEventsSuite) TestSubscribeEvents() { // this simulates a subscription on a past block for i := 0; i <= test.highestBackfill; i++ { s.T().Logf("backfilling block %d", i) - s.backend.SetHighestHeight(s.blocks[i].Header.Height) + s.backend.setHighestHeight(s.blocks[i].Header.Height) } subCtx, subCancel := context.WithCancel(ctx) @@ -121,7 +121,7 @@ func (s *BackendEventsSuite) TestSubscribeEvents() { // simulate new exec data received. // exec data for all blocks with index <= highestBackfill were already received if i > test.highestBackfill { - s.backend.SetHighestHeight(b.Header.Height) + s.backend.setHighestHeight(b.Header.Height) s.broadcaster.Publish() } diff --git a/engine/access/state_stream/backend_executiondata_test.go b/engine/access/state_stream/backend_executiondata_test.go index ab8f7d2a14a..361cb64aa80 100644 --- a/engine/access/state_stream/backend_executiondata_test.go +++ b/engine/access/state_stream/backend_executiondata_test.go @@ -252,7 +252,7 @@ func (s *BackendExecutionDataSuite) TestGetExecutionDataByBlockID() { execData := s.execDataMap[block.ID()] // notify backend block is available - s.backend.SetHighestHeight(block.Header.Height) + s.backend.setHighestHeight(block.Header.Height) var err error s.Run("happy path TestGetExecutionDataByBlockID success", func() { @@ -328,7 +328,7 @@ func (s *BackendExecutionDataSuite) TestSubscribeExecutionData() { // this simulates a subscription on a past block for i := 0; i <= test.highestBackfill; i++ { s.T().Logf("backfilling block %d", i) - s.backend.SetHighestHeight(s.blocks[i].Header.Height) + s.backend.setHighestHeight(s.blocks[i].Header.Height) } subCtx, subCancel := context.WithCancel(ctx) @@ -342,7 +342,7 @@ func (s *BackendExecutionDataSuite) TestSubscribeExecutionData() { // simulate new exec data received. // exec data for all blocks with index <= highestBackfill were already received if i > test.highestBackfill { - s.backend.SetHighestHeight(b.Header.Height) + s.backend.setHighestHeight(b.Header.Height) s.broadcaster.Publish() } diff --git a/engine/access/state_stream/engine.go b/engine/access/state_stream/engine.go index 7afe0c9c225..80b728098bf 100644 --- a/engine/access/state_stream/engine.go +++ b/engine/access/state_stream/engine.go @@ -84,7 +84,7 @@ func (e *Engine) OnExecutionData(executionData *execution_data.BlockExecutionDat return } - if ok := e.backend.SetHighestHeight(header.Height); !ok { + if ok := e.backend.setHighestHeight(header.Height); !ok { // this means that the height was lower than the current highest height // OnExecutionData is guaranteed by the requester to be called in order, but may be called // multiple times for the same block. diff --git a/engine/access/state_stream/handler.go b/engine/access/state_stream/handler.go index badf4b5d3b5..df7c4dd9f6b 100644 --- a/engine/access/state_stream/handler.go +++ b/engine/access/state_stream/handler.go @@ -6,7 +6,6 @@ import ( access "github.com/onflow/flow/protobuf/go/flow/executiondata" executiondata "github.com/onflow/flow/protobuf/go/flow/executiondata" - "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -135,6 +134,7 @@ func (h *Handler) SubscribeEvents(request *access.SubscribeEventsRequest, stream return status.Errorf(codes.InvalidArgument, "invalid event filter: %v", err) } } + sub := h.api.SubscribeEvents(stream.Context(), startBlockID, request.GetStartBlockHeight(), filter) for { From efcf292359c4a9cb7b114330f09f1f3a36dfb76e Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 9 Aug 2023 15:22:05 +0300 Subject: [PATCH 18/35] Reverted back imports order --- cmd/access/node_builder/access_node_builder.go | 3 +-- cmd/observer/node_builder/observer_builder.go | 1 - engine/access/rest/routes/accounts_test.go | 1 + engine/access/rest/routes/events_test.go | 1 + engine/access/rest/routes/execution_result_test.go | 1 - engine/access/rest/routes/subscribe_events.go | 2 -- engine/access/rest/routes/transactions_test.go | 9 ++++++--- engine/access/rpc/rate_limit_test.go | 1 - engine/access/state_stream/mock/api.go | 4 +++- 9 files changed, 12 insertions(+), 11 deletions(-) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index 7b34c37326c..efd65fd0512 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "github.com/onflow/flow/protobuf/go/flow/access" "github.com/onflow/go-bitswap" "github.com/onflow/flow-go/admin/commands" @@ -88,8 +89,6 @@ import ( "github.com/onflow/flow-go/storage" bstorage "github.com/onflow/flow-go/storage/badger" "github.com/onflow/flow-go/utils/grpcutils" - - "github.com/onflow/flow/protobuf/go/flow/access" ) // AccessNodeBuilder extends cmd.NodeBuilder and declares additional functions needed to bootstrap an Access node. diff --git a/cmd/observer/node_builder/observer_builder.go b/cmd/observer/node_builder/observer_builder.go index 3f1b636f1a5..3b9a3612c3f 100644 --- a/cmd/observer/node_builder/observer_builder.go +++ b/cmd/observer/node_builder/observer_builder.go @@ -959,7 +959,6 @@ func (builder *ObserverServiceBuilder) enqueueRPCServer() { observerCollector, node.RootChainID.Chain()) if err != nil { - return nil, err } diff --git a/engine/access/rest/routes/accounts_test.go b/engine/access/rest/routes/accounts_test.go index cbda5d9acf1..0267a991d4a 100644 --- a/engine/access/rest/routes/accounts_test.go +++ b/engine/access/rest/routes/accounts_test.go @@ -173,6 +173,7 @@ func getAccountRequest(t *testing.T, account *flow.Account, height string, expan q.Add(middleware.ExpandQueryParam, fieldParam) req.URL.RawQuery = q.Encode() } + require.NoError(t, err) return req } diff --git a/engine/access/rest/routes/events_test.go b/engine/access/rest/routes/events_test.go index 83608331a2c..0aec01937af 100644 --- a/engine/access/rest/routes/events_test.go +++ b/engine/access/rest/routes/events_test.go @@ -11,6 +11,7 @@ import ( mocks "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" "google.golang.org/grpc/status" diff --git a/engine/access/rest/routes/execution_result_test.go b/engine/access/rest/routes/execution_result_test.go index c0037727187..e44804fceb8 100644 --- a/engine/access/rest/routes/execution_result_test.go +++ b/engine/access/rest/routes/execution_result_test.go @@ -37,7 +37,6 @@ func getResultByIDReq(id string, blockIDs []string) *http.Request { } func TestGetResultByID(t *testing.T) { - t.Run("get by ID", func(t *testing.T) { backend := &mock.API{} result := unittest.ExecutionResultFixture() diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index b9274017727..d8752b73acd 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -74,7 +74,6 @@ func SubscribeEvents(r *request.Request, w, req, conn, - r.Context(), api, filter, errorHandler, @@ -88,7 +87,6 @@ func writeEvents( w http.ResponseWriter, req request.SubscribeEvents, conn *websocket.Conn, - c context.Context, api state_stream.API, filter state_stream.EventFilter, errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), diff --git a/engine/access/rest/routes/transactions_test.go b/engine/access/rest/routes/transactions_test.go index ca5eec50670..671324ba29e 100644 --- a/engine/access/rest/routes/transactions_test.go +++ b/engine/access/rest/routes/transactions_test.go @@ -70,7 +70,6 @@ func createTransactionReq(body interface{}) *http.Request { } func TestGetTransactions(t *testing.T) { - t.Run("get by ID without results", func(t *testing.T) { backend := &mock.API{} tx := unittest.TransactionFixture() @@ -118,6 +117,7 @@ func TestGetTransactions(t *testing.T) { t.Run("Get by ID with results", func(t *testing.T) { backend := &mock.API{} + tx := unittest.TransactionFixture() txr := transactionResultFixture(tx) @@ -195,6 +195,7 @@ func TestGetTransactions(t *testing.T) { t.Run("get by ID non-existing", func(t *testing.T) { backend := &mock.API{} + tx := unittest.TransactionFixture() req := getTransactionReq(tx.ID().String(), false, "", "") @@ -257,6 +258,7 @@ func TestGetTransactionResult(t *testing.T) { t.Run("get by block ID", func(t *testing.T) { backend := &mock.API{} + req := getTransactionResultReq(id.String(), bid.String(), "") backend.Mock. @@ -279,6 +281,7 @@ func TestGetTransactionResult(t *testing.T) { t.Run("get execution statuses", func(t *testing.T) { backend := &mock.API{} + testVectors := map[*access.TransactionResult]string{{ Status: flow.TransactionStatusExpired, ErrorMessage: "", @@ -327,6 +330,7 @@ func TestGetTransactionResult(t *testing.T) { t.Run("get by ID Invalid", func(t *testing.T) { backend := &mock.API{} + req := getTransactionResultReq("invalid", "", "") expected := `{"code":400, "message":"invalid ID format"}` @@ -335,9 +339,9 @@ func TestGetTransactionResult(t *testing.T) { } func TestCreateTransaction(t *testing.T) { + backend := &mock.API{} t.Run("create", func(t *testing.T) { - backend := &mock.API{} tx := unittest.TransactionBodyFixture() tx.PayloadSignatures = []flow.TransactionSignature{unittest.TransactionSignatureFixture()} tx.Arguments = [][]uint8{} @@ -389,7 +393,6 @@ func TestCreateTransaction(t *testing.T) { }) t.Run("post invalid transaction", func(t *testing.T) { - backend := &mock.API{} tests := []struct { inputField string inputValue string diff --git a/engine/access/rpc/rate_limit_test.go b/engine/access/rpc/rate_limit_test.go index 2acf69a54b4..9fc97ae1c85 100644 --- a/engine/access/rpc/rate_limit_test.go +++ b/engine/access/rpc/rate_limit_test.go @@ -3,7 +3,6 @@ package rpc import ( "context" "fmt" - "io" "os" "testing" diff --git a/engine/access/state_stream/mock/api.go b/engine/access/state_stream/mock/api.go index 2548022d5d1..5b57efc917f 100644 --- a/engine/access/state_stream/mock/api.go +++ b/engine/access/state_stream/mock/api.go @@ -4,11 +4,13 @@ package mock import ( context "context" - "github.com/onflow/flow-go/engine/access/state_stream" + flow "github.com/onflow/flow-go/model/flow" execution_data "github.com/onflow/flow-go/module/executiondatasync/execution_data" mock "github.com/stretchr/testify/mock" + + state_stream "github.com/onflow/flow-go/engine/access/state_stream" ) // API is an autogenerated mock type for the API type From 31ba626434cd23aa74795c38ff1e6e3398c82b1c Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 16 Aug 2023 09:38:13 +0300 Subject: [PATCH 19/35] Refactored subscribeEvents function --- engine/access/rest/routes/subscribe_events.go | 126 ++++++------------ .../access/rest/routes/websocket_handler.go | 108 ++++++++++++--- 2 files changed, 131 insertions(+), 103 deletions(-) diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index d8752b73acd..6e4cba8f2a2 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -11,112 +11,75 @@ import ( "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/rest/models" - "github.com/onflow/flow-go/engine/access/rest/request" "github.com/onflow/flow-go/engine/access/state_stream" ) -const ( - // Time allowed to read the next pong message from the peer. - pongWait = 60 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 -) - // SubscribeEvents create websocket connection and write to it requested events. -func SubscribeEvents(r *request.Request, - w http.ResponseWriter, +func SubscribeEvents( logger zerolog.Logger, - api state_stream.API, + h SubscribeHandler, eventFilterConfig state_stream.EventFilterConfig, - maxStreams int32, streamCount *atomic.Int32, - errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), - jsonResponse func(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger)) { - req, err := r.SubscribeEventsRequest() - if err != nil { - errorHandler(w, models.NewBadRequestError(err), logger) - return - } - - logger = logger.With().Str("subscribe events", r.URL.String()).Logger() - if streamCount.Load() >= maxStreams { - err := fmt.Errorf("maximum number of streams reached") - errorHandler(w, models.NewRestError(http.StatusServiceUnavailable, "maximum number of streams reached", err), logger) - return - } + errorHandler func(logger zerolog.Logger, conn *websocket.Conn, err error)) { + logger = logger.With().Str("subscribe events", h.request.URL.String()).Logger() + defer func() { + h.conn.Close() + }() - // Upgrade the HTTP connection to a WebSocket connection - upgrader := websocket.Upgrader{} - conn, err := upgrader.Upgrade(w, r.Request, nil) + req, err := h.request.SubscribeEventsRequest() if err != nil { - errorHandler(w, models.NewRestError(http.StatusInternalServerError, "webSocket upgrade error: ", err), logger) + errorHandler(logger, h.conn, models.NewBadRequestError(err)) return } // Retrieve the filter parameters from the request, if provided filter, err := state_stream.NewEventFilter( eventFilterConfig, - r.Chain, + h.request.Chain, req.EventTypes, req.Addresses, req.Contracts, ) if err != nil { err := fmt.Errorf("event filter error") - errorHandler(w, models.NewBadRequestError(err), logger) + errorHandler(logger, h.conn, models.NewBadRequestError(err)) return } + if streamCount.Load() >= h.maxStreams { + err := fmt.Errorf("maximum number of streams reached") + errorHandler(logger, h.conn, models.NewRestError(http.StatusServiceUnavailable, "maximum number of streams reached", err)) + return + } streamCount.Add(1) + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + cancel() + }() + sub := h.api.SubscribeEvents(ctx, req.StartBlockID, req.StartHeight, filter) + // Write messages to the WebSocket connection - go writeEvents(logger, - w, - req, - conn, - api, - filter, - errorHandler, - jsonResponse, + err = writeEvents( + sub, + h, streamCount) - time.Sleep(2 * time.Second) // wait for creating child context in goroutine + if err != nil { + errorHandler(logger, h.conn, err) + return + } } func writeEvents( - log zerolog.Logger, - w http.ResponseWriter, - req request.SubscribeEvents, - conn *websocket.Conn, - api state_stream.API, - filter state_stream.EventFilter, - errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), - jsonResponse func(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger), + sub state_stream.Subscription, + h SubscribeHandler, streamCount *atomic.Int32, -) { - ctx, cancel := context.WithCancel(context.Background()) +) error { ticker := time.NewTicker(pingPeriod) - sub := api.SubscribeEvents(ctx, req.StartBlockID, req.StartHeight, filter) defer func() { ticker.Stop() streamCount.Add(-1) - conn.Close() - cancel() }() - err := conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message - if err != nil { - errorHandler(w, models.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err), log) - return - } - conn.SetPongHandler(func(string) error { - err = conn.SetReadDeadline(time.Now().Add(pongWait)) // Reset the read deadline upon receiving a pong message - if err != nil { - errorHandler(w, models.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err), log) - conn.Close() - return err - } - return nil - }) for { select { @@ -124,36 +87,25 @@ func writeEvents( if !ok { if sub.Err() != nil { err := fmt.Errorf("stream encountered an error: %v", sub.Err()) - errorHandler(w, models.NewBadRequestError(err), log) - conn.Close() - return + return models.NewBadRequestError(err) } err := fmt.Errorf("subscription channel closed, no error occurred") - errorHandler(w, models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err), log) - conn.Close() - return + return models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err) } resp, ok := v.(*state_stream.EventsResponse) if !ok { - err := fmt.Errorf("unexpected response type: %T", v) - errorHandler(w, err, log) - conn.Close() - return + return fmt.Errorf("unexpected response type: %T", v) } // Write the response to the WebSocket connection - err := conn.WriteJSON(resp) + err := h.conn.WriteJSON(resp) if err != nil { - errorHandler(w, err, log) - conn.Close() - return + return err } - jsonResponse(w, http.StatusOK, "", log) case <-ticker.C: - if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { - conn.Close() - return + if err := h.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return err } } } diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 5dc2fd8c303..52ce6a01d8b 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -1,8 +1,13 @@ package routes import ( + "errors" + "fmt" + "github.com/gorilla/websocket" + "github.com/onflow/flow-go/engine/access/rest/models" "net/http" "sync/atomic" + "time" "github.com/rs/zerolog" @@ -11,19 +16,46 @@ import ( "github.com/onflow/flow-go/model/flow" ) +const ( + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second +) + +type SubscribeHandler struct { + request *request.Request + respWriter http.ResponseWriter + conn *websocket.Conn + + api state_stream.API + maxStreams int32 +} + +func (h *SubscribeHandler) SetReadWriteDeadline() error { + err := h.conn.SetWriteDeadline(time.Now().Add(writeWait)) // Set the initial write deadline for the first ping message + if err != nil { + return models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err) + } + err = h.conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message + if err != nil { + return models.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err) + } + return nil +} + // SubscribeHandlerFunc is a function that contains endpoint handling logic for subscribes, // it fetches necessary resources and returns an error. type SubscribeHandlerFunc func( - r *request.Request, - w http.ResponseWriter, - logger zerolog.Logger, - api state_stream.API, + subscribeHandler SubscribeHandler, eventFilterConfig state_stream.EventFilterConfig, - maxStreams int32, streamCount *atomic.Int32, - errorHandler func(w http.ResponseWriter, err error, errorLogger zerolog.Logger), - jsonResponse func(w http.ResponseWriter, code int, response interface{}, errLogger zerolog.Logger), + errorHandler func(logger zerolog.Logger, conn *websocket.Conn, err error), ) // WSHandler is websocket handler implementing custom websocket handler function and allows easier handling of errors and @@ -52,8 +84,8 @@ func NewWSHandler( eventFilterConfig: eventFilterConfig, maxStreams: int32(maxGlobalStreams), streamCount: atomic.Int32{}, + HttpHandler: NewHttpHandler(logger, chain), } - handler.HttpHandler = NewHttpHandler(logger, chain) return handler } @@ -62,22 +94,66 @@ func NewWSHandler( // such as logging, error handling, request decorators func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // create a logger - errLog := h.Logger.With().Str("request_url", r.URL.String()).Logger() + logger := h.Logger.With().Str("request_url", r.URL.String()).Logger() err := h.VerifyRequest(w, r) if err != nil { return } + + // Upgrade the HTTP connection to a WebSocket connection + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.errorHandler(w, models.NewRestError(http.StatusInternalServerError, "webSocket upgrade error: ", err), logger) + return + } + decoratedRequest := request.Decorate(r, h.HttpHandler.Chain) - h.subscribeFunc(decoratedRequest, - w, - errLog, - h.api, + subscribeHandler := SubscribeHandler{ + request: decoratedRequest, + respWriter: w, + conn: conn, + api: h.api, + maxStreams: h.maxStreams, + } + + err = subscribeHandler.SetReadWriteDeadline() + if err != nil { + h.errorHandler(w, err, logger) + conn.Close() + } + + go h.subscribeFunc( + logger, + subscribeHandler, h.eventFilterConfig, - h.maxStreams, &h.streamCount, - h.errorHandler, - h.jsonResponse) + h.sendError) +} +func (h *WSHandler) sendError( + logger zerolog.Logger, + conn *websocket.Conn, + err error) { + // rest status type error should be returned with status and user message provided + var statusErr models.StatusError + var errMsg models.ModelError + if errors.As(err, &statusErr) { + errMsg = models.ModelError{ + Code: int32(statusErr.Status()), + Message: statusErr.UserMessage(), + } + } else { + errMsg = models.ModelError{ + Code: http.StatusInternalServerError, + Message: "internal server error", + } + } + + err = conn.WriteJSON(errMsg) + if err != nil { + logger.Error().Err(err).Msg(fmt.Sprintf("error sending WebSocket error: %v", err)) + } } From 1a9f924eac2a95a697ac81cd813fb31865b4bbf6 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 16 Aug 2023 10:15:37 +0300 Subject: [PATCH 20/35] Added more comments, linted --- .../node_builder/access_node_builder.go | 5 ++-- .../access/rest/routes/account_keys_test.go | 20 ++++++++-------- engine/access/rest/routes/handler.go | 3 ++- engine/access/rest/routes/subscribe_events.go | 4 ++++ .../access/rest/routes/websocket_handler.go | 23 +++++++++++-------- engine/access/rest/server.go | 1 - 6 files changed, 31 insertions(+), 25 deletions(-) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index c9b24b4a89a..265afb529c8 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -439,7 +439,7 @@ func (builder *FlowAccessNodeBuilder) BuildConsensusFollower() *FlowAccessNodeBu return builder } -func (builder *FlowAccessNodeBuilder) BuildStateStreamComponentsAndModules() *FlowAccessNodeBuilder { +func (builder *FlowAccessNodeBuilder) BuildStateStreamPipeline() *FlowAccessNodeBuilder { var ds *badger.Datastore var bs network.BlobService var processedBlockHeight storage.ConsumerProgress @@ -632,7 +632,6 @@ func (builder *FlowAccessNodeBuilder) BuildStateStreamComponentsAndModules() *Fl broadcaster, builder.executionDataConfig.InitialBlockHeight, highestAvailableHeight) - if err != nil { return nil, fmt.Errorf("could not create state stream backend: %w", err) } @@ -940,7 +939,7 @@ func (builder *FlowAccessNodeBuilder) enqueueRelayNetwork() { func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { if builder.executionDataSyncEnabled { - builder.BuildStateStreamComponentsAndModules() + builder.BuildStateStreamPipeline() } builder. diff --git a/engine/access/rest/routes/account_keys_test.go b/engine/access/rest/routes/account_keys_test.go index 241e40240e5..0675edb7ea7 100644 --- a/engine/access/rest/routes/account_keys_test.go +++ b/engine/access/rest/routes/account_keys_test.go @@ -48,7 +48,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { expected := expectedAccountKeyResponse(account) - assertOKResponse(t, req, expected, backend) + assertOKResponse(t, req, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -69,7 +69,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { expected := expectedAccountKeyResponse(account) - assertOKResponse(t, req, expected, backend) + assertOKResponse(t, req, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -97,7 +97,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { } `, statusCode, index) - assertResponse(t, req, statusCode, expected, backend) + assertResponse(t, req, statusCode, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -125,7 +125,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { } `, statusCode, index) - assertResponse(t, req, statusCode, expected, backend) + assertResponse(t, req, statusCode, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -154,7 +154,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { } `, statusCode, account.Address) - assertResponse(t, req, statusCode, expected, backend) + assertResponse(t, req, statusCode, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -183,7 +183,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { } `, statusCode, account.Address) - assertResponse(t, req, statusCode, expected, backend) + assertResponse(t, req, statusCode, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -198,8 +198,8 @@ func TestGetAccountKeyByIndex(t *testing.T) { expected := expectedAccountKeyResponse(account) - assertOKResponse(t, req, expected, backend) - mocktestify.AssertExpectationsForObjects(t, backend) + assertOKResponse(t, req, expected, backend, nil) + mocktestify.AssertExpectationsForObjects(t, backend, nil) }) t.Run("get key by address and index at missing block", func(t *testing.T) { @@ -222,7 +222,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { } `, statusCode, finalHeight) - assertResponse(t, req, statusCode, expected, backend) + assertResponse(t, req, statusCode, expected, backend, nil) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -261,7 +261,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { req, _ := http.NewRequest("GET", test.url, nil) - rr, err := executeRequest(req, backend) + rr, err := executeRequest(req, backend, nil) assert.NoError(t, err) assert.Equal(t, http.StatusBadRequest, rr.Code) diff --git a/engine/access/rest/routes/handler.go b/engine/access/rest/routes/handler.go index c78df359cab..2779fe32699 100644 --- a/engine/access/rest/routes/handler.go +++ b/engine/access/rest/routes/handler.go @@ -41,8 +41,9 @@ func NewHandler( backend: backend, apiHandlerFunc: handlerFunc, linkGenerator: generator, + HttpHandler: NewHttpHandler(logger, chain), } - handler.HttpHandler = NewHttpHandler(logger, chain) + return handler } diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index 6e4cba8f2a2..0ad56fd9aba 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -69,6 +69,10 @@ func SubscribeEvents( } } +// writeEvents use for writes events and pings to the WebSocket connection. It listens to a subscription's channel for +// events and writes them to the connection. If an error occurs or the subscription channel is closed, it handles the +// error or termination accordingly. +// The function uses a ticker to periodically send ping messages to the client to maintain the connection. func writeEvents( sub state_stream.Subscription, h SubscribeHandler, diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 52ce6a01d8b..8be446ed59c 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -3,14 +3,14 @@ package routes import ( "errors" "fmt" - "github.com/gorilla/websocket" - "github.com/onflow/flow-go/engine/access/rest/models" "net/http" "sync/atomic" "time" + "github.com/gorilla/websocket" "github.com/rs/zerolog" + "github.com/onflow/flow-go/engine/access/rest/models" "github.com/onflow/flow-go/engine/access/rest/request" "github.com/onflow/flow-go/engine/access/state_stream" "github.com/onflow/flow-go/model/flow" @@ -27,15 +27,19 @@ const ( writeWait = 10 * time.Second ) +// SubscribeHandler holds the necessary components and parameters for handling a WebSocket subscription. +// It manages the communication between the server and the WebSocket client for subscribing. type SubscribeHandler struct { - request *request.Request - respWriter http.ResponseWriter - conn *websocket.Conn - - api state_stream.API - maxStreams int32 + request *request.Request // the incoming HTTP request containing the subscription details. + respWriter http.ResponseWriter // the HTTP response writer to communicate back to the client. + conn *websocket.Conn // the established WebSocket connection for communication with the client. + api state_stream.API // the state_stream.API instance for managing event subscriptions. + maxStreams int32 // the maximum number of streams allowed. } +// SetReadWriteDeadline used to set read and write deadlines for WebSocket connections. These methods allow you to +// specify a time limit for reading from or writing to a WebSocket connection. If the operation (reading or writing) +// takes longer than the specified deadline, the connection will be closed. func (h *SubscribeHandler) SetReadWriteDeadline() error { err := h.conn.SetWriteDeadline(time.Now().Add(writeWait)) // Set the initial write deadline for the first ping message if err != nil { @@ -48,8 +52,7 @@ func (h *SubscribeHandler) SetReadWriteDeadline() error { return nil } -// SubscribeHandlerFunc is a function that contains endpoint handling logic for subscribes, -// it fetches necessary resources and returns an error. +// SubscribeHandlerFunc is a function that contains endpoint handling logic for subscribes, fetches necessary resources type SubscribeHandlerFunc func( logger zerolog.Logger, subscribeHandler SubscribeHandler, diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index db14cb6c5db..e46582e092a 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -25,7 +25,6 @@ func NewServer(serverAPI access.API, maxGlobalStreams uint32, ) (*http.Server, error) { router, err := routes.NewRouter(serverAPI, logger, chain, restCollector, stateStreamApi, eventFilterConfig, maxGlobalStreams) - if err != nil { return nil, err } From 4477f31a8c1f27cbe2f553d7909e04e287ba9e88 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 17 Aug 2023 10:03:31 +0300 Subject: [PATCH 21/35] Updated unit tests, linted, added more comments --- .../access/rest/routes/account_keys_test.go | 2 +- engine/access/rest/routes/subscribe_events.go | 9 +-- .../rest/routes/subscribe_events_test.go | 25 +++++++-- engine/access/rest/routes/test_helpers.go | 37 +++++++----- .../access/rest/routes/websocket_handler.go | 56 ++++++++++++------- 5 files changed, 82 insertions(+), 47 deletions(-) diff --git a/engine/access/rest/routes/account_keys_test.go b/engine/access/rest/routes/account_keys_test.go index 0675edb7ea7..d6a1ca25077 100644 --- a/engine/access/rest/routes/account_keys_test.go +++ b/engine/access/rest/routes/account_keys_test.go @@ -199,7 +199,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { expected := expectedAccountKeyResponse(account) assertOKResponse(t, req, expected, backend, nil) - mocktestify.AssertExpectationsForObjects(t, backend, nil) + mocktestify.AssertExpectationsForObjects(t, backend) }) t.Run("get key by address and index at missing block", func(t *testing.T) { diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index 0ad56fd9aba..68c14cbba77 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -22,9 +22,7 @@ func SubscribeEvents( streamCount *atomic.Int32, errorHandler func(logger zerolog.Logger, conn *websocket.Conn, err error)) { logger = logger.With().Str("subscribe events", h.request.URL.String()).Logger() - defer func() { - h.conn.Close() - }() + defer h.conn.Close() req, err := h.request.SubscribeEventsRequest() if err != nil { @@ -53,9 +51,8 @@ func SubscribeEvents( streamCount.Add(1) ctx, cancel := context.WithCancel(context.Background()) - defer func() { - cancel() - }() + defer cancel() + sub := h.api.SubscribeEvents(ctx, req.StartBlockID, req.StartHeight, filter) // Write messages to the WebSocket connection diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index a65fac232bc..b1ff9375345 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -165,9 +165,8 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { req, err := getSubscribeEventsRequest(s.T(), test.startBlockID, test.startHeight, test.eventTypes, test.addresses, test.contracts) assert.NoError(s.T(), err) - rr, err := executeRequest(req, backend, stateStreamBackend) + _, err = executeRequest(req, backend, stateStreamBackend) assert.NoError(s.T(), err) - assert.Equal(s.T(), http.StatusOK, rr.Code) }) } } @@ -179,7 +178,9 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil) assert.NoError(s.T(), err) - assertResponse(s.T(), req, http.StatusBadRequest, `{"code":400,"message":"can only provide either block ID or start height"}`, backend, stateStreamBackend) + respRecorder, err := executeRequest(req, backend, stateStreamBackend) + assert.NoError(s.T(), err) + requireError(s.T(), respRecorder, "can only provide either block ID or start height") }) s.Run("returns error for invalid block id", func() { @@ -202,7 +203,9 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil) assert.NoError(s.T(), err) - assertResponse(s.T(), req, http.StatusBadRequest, `{"code":400,"message":"stream encountered an error: subscription error"}`, backend, stateStreamBackend) + respRecorder, err := executeRequest(req, backend, stateStreamBackend) + assert.NoError(s.T(), err) + requireError(s.T(), respRecorder, "stream encountered an error: subscription error") }) s.Run("returns error when channel closed", func() { @@ -224,7 +227,9 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) assert.NoError(s.T(), err) - assertResponse(s.T(), req, http.StatusRequestTimeout, `{"code":408,"message":"subscription channel closed"}`, backend, stateStreamBackend) + respRecorder, err := executeRequest(req, backend, stateStreamBackend) + assert.NoError(s.T(), err) + requireError(s.T(), respRecorder, "subscription channel closed") }) s.Run("returns error for unexpected response type", func() { @@ -248,7 +253,10 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) assert.NoError(s.T(), err) - assertResponse(s.T(), req, http.StatusInternalServerError, `{"code":500,"message":"internal server error"}`, backend, stateStreamBackend) + + respRecorder, err := executeRequest(req, backend, stateStreamBackend) + assert.NoError(s.T(), err) + requireError(s.T(), respRecorder, "unexpected response type: *state_stream.ExecutionDataResponse") }) } @@ -300,3 +308,8 @@ func generateWebSocketKey() (string, error) { // Encode the bytes to base64 and return the key as a string. return base64.StdEncoding.EncodeToString(keyBytes), nil } + +func requireError(t *testing.T, recorder *HijackResponseRecorder, expected string) { + <-recorder.closed + require.Contains(t, recorder.responseBuff.String(), expected) +} diff --git a/engine/access/rest/routes/test_helpers.go b/engine/access/rest/routes/test_helpers.go index ba5cfe635ed..f198ee1cad1 100644 --- a/engine/access/rest/routes/test_helpers.go +++ b/engine/access/rest/routes/test_helpers.go @@ -41,9 +41,14 @@ const ( type fakeNetConn struct { io.Reader io.Writer + closed chan struct{} } -func (c fakeNetConn) Close() error { return nil } +// Close closes the fakeNetConn and signals its closure by closing the "closed" channel. +func (c fakeNetConn) Close() error { + close(c.closed) + return nil +} func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } @@ -65,19 +70,28 @@ func (a fakeAddr) String() string { return "str" } +// HijackResponseRecorder is a custom ResponseRecorder that implements the http.Hijacker interface +// for testing WebSocket connections and hijacking. type HijackResponseRecorder struct { *httptest.ResponseRecorder - brw *bufio.ReadWriter + closed chan struct{} + responseBuff *bytes.Buffer } +// Hijack implements the http.Hijacker interface by returning a fakeNetConn and a bufio.ReadWriter +// that simulate a hijacked connection. func (w *HijackResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, w.brw, nil + br := bufio.NewReaderSize(strings.NewReader(""), state_stream.DefaultSendBufferSize) + bw := bufio.NewWriterSize(&bytes.Buffer{}, state_stream.DefaultSendBufferSize) + w.responseBuff = bytes.NewBuffer(make([]byte, 0)) + w.closed = make(chan struct{}, 1) + + return fakeNetConn{strings.NewReader(""), w.responseBuff, w.closed}, bufio.NewReadWriter(br, bw), nil } -func NewHijackResponseRecorder(brw *bufio.ReadWriter) *HijackResponseRecorder { - responseRecorder := &HijackResponseRecorder{ - brw: brw, - } +// NewHijackResponseRecorder creates a new instance of HijackResponseRecorder. +func NewHijackResponseRecorder() *HijackResponseRecorder { + responseRecorder := &HijackResponseRecorder{} responseRecorder.ResponseRecorder = httptest.NewRecorder() return responseRecorder } @@ -101,18 +115,15 @@ func newRouter(backend *mock.API, stateStreamApi *mock_state_stream.API) (*mux.R stateStreamConfig.MaxGlobalStreams) } -func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_state_stream.API) (*httptest.ResponseRecorder, error) { +func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_state_stream.API) (*HijackResponseRecorder, error) { router, err := newRouter(backend, stateStreamApi) if err != nil { return nil, err } - br := bufio.NewReaderSize(strings.NewReader(""), state_stream.DefaultSendBufferSize) - bw := bufio.NewWriterSize(&bytes.Buffer{}, state_stream.DefaultSendBufferSize) - resp := NewHijackResponseRecorder(bufio.NewReadWriter(br, bw)) - + resp := NewHijackResponseRecorder() router.ServeHTTP(resp, req) - return resp.ResponseRecorder, nil + return resp, nil } func assertOKResponse(t *testing.T, req *http.Request, expectedRespBody string, backend *mock.API, stateStreamApi *mock_state_stream.API) { diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 8be446ed59c..ffe46d532ad 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -30,11 +30,10 @@ const ( // SubscribeHandler holds the necessary components and parameters for handling a WebSocket subscription. // It manages the communication between the server and the WebSocket client for subscribing. type SubscribeHandler struct { - request *request.Request // the incoming HTTP request containing the subscription details. - respWriter http.ResponseWriter // the HTTP response writer to communicate back to the client. - conn *websocket.Conn // the established WebSocket connection for communication with the client. - api state_stream.API // the state_stream.API instance for managing event subscriptions. - maxStreams int32 // the maximum number of streams allowed. + request *request.Request // the incoming HTTP request containing the subscription details. + conn *websocket.Conn // the established WebSocket connection for communication with the client. + api state_stream.API // the state_stream.API instance for managing event subscriptions. + maxStreams int32 // the maximum number of streams allowed. } // SetReadWriteDeadline used to set read and write deadlines for WebSocket connections. These methods allow you to @@ -112,11 +111,8 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - decoratedRequest := request.Decorate(r, h.HttpHandler.Chain) - subscribeHandler := SubscribeHandler{ - request: decoratedRequest, - respWriter: w, + request: request.Decorate(r, h.HttpHandler.Chain), conn: conn, api: h.api, maxStreams: h.maxStreams, @@ -124,7 +120,7 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { err = subscribeHandler.SetReadWriteDeadline() if err != nil { - h.errorHandler(w, err, logger) + h.wsErrorHandler(logger, conn, err) conn.Close() } @@ -133,29 +129,47 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { subscribeHandler, h.eventFilterConfig, &h.streamCount, - h.sendError) + h.wsErrorHandler) } -func (h *WSHandler) sendError( +// wsErrorHandler handles WebSocket errors by sending an appropriate close message +// to the client WebSocket connection. +// +// If the error is an instance of models.StatusError, the function extracts the +// relevant information like status code and user message to construct the WebSocket +// close code and message. If the error is not a models.StatusError, a default +// internal server error close code and the error's message are used. +// The connection is then closed using WriteControl to send a CloseMessage with the +// constructed close code and message. Any errors that occur during the closing +// process are logged using the provided logger. +func (h *WSHandler) wsErrorHandler( logger zerolog.Logger, conn *websocket.Conn, err error) { // rest status type error should be returned with status and user message provided var statusErr models.StatusError - var errMsg models.ModelError + var wsCode int + var wsMsg string + if errors.As(err, &statusErr) { - errMsg = models.ModelError{ - Code: int32(statusErr.Status()), - Message: statusErr.UserMessage(), + if statusErr.Status() == http.StatusBadRequest { + wsCode = websocket.CloseUnsupportedData } - } else { - errMsg = models.ModelError{ - Code: http.StatusInternalServerError, - Message: "internal server error", + if statusErr.Status() == http.StatusServiceUnavailable { + wsCode = websocket.CloseTryAgainLater } + if statusErr.Status() == http.StatusRequestTimeout { + wsCode = websocket.CloseGoingAway + } + wsMsg = statusErr.UserMessage() + + } else { + wsCode = websocket.CloseInternalServerErr + wsMsg = err.Error() } - err = conn.WriteJSON(errMsg) + // Close the connection with the CloseError message + err = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(wsCode, wsMsg), time.Now().Add(time.Second)) if err != nil { logger.Error().Err(err).Msg(fmt.Sprintf("error sending WebSocket error: %v", err)) } From 6322ba914b72a06e0dc727f90f59b94496e4c6df Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 17 Aug 2023 13:22:47 +0300 Subject: [PATCH 22/35] Updated tests, linted --- .../rest/routes/subscribe_events_test.go | 96 ++++++++++++++----- .../tests/access/rest_state_stream_test.go | 1 + 2 files changed, 72 insertions(+), 25 deletions(-) diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index b1ff9375345..5cc09509675 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -3,13 +3,14 @@ package routes import ( "crypto/rand" "encoding/base64" + "encoding/json" "fmt" "net/http" "net/url" + "regexp" "strings" "testing" - - "golang.org/x/exp/slices" + "time" "github.com/stretchr/testify/assert" mocks "github.com/stretchr/testify/mock" @@ -112,33 +113,51 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { s.Run(test.name, func() { stateStreamBackend := &mockstatestream.API{} backend := &mock.API{} - subscription := &mockstatestream.Subscription{} - expectedEvents := flow.EventsList{} - for _, event := range s.blockEvents[s.blocks[0].ID()] { - if slices.Contains(test.eventTypes, string(event.Type)) { - expectedEvents = append(expectedEvents, event) - } - } + filter, err := state_stream.NewEventFilter(state_stream.DefaultEventFilterConfig, chain, test.eventTypes, test.addresses, test.contracts) + assert.NoError(s.T(), err) - // Create a channel to receive mock EventsResponse objects - ch := make(chan interface{}) - var chReadOnly <-chan interface{} var expectedEventsResponses []*state_stream.EventsResponse + startBlockFound := test.startBlockID == flow.ZeroID - for i, b := range s.blocks { - s.T().Logf("checking block %d %v", i, b.ID()) + // Helper function to check if a string is present in a slice + addExpectedEvent := func(slice []string, item string) bool { + if slice == nil { + return true // Include all events when test.eventTypes is nil + } + for _, s := range slice { + if s == item { + return true + } + } + return false + } - //simulate EventsResponse - eventResponse := &state_stream.EventsResponse{ - Height: b.Header.Height, - BlockID: b.ID(), - Events: expectedEvents, + // construct expected event responses based on the provided test configuration + for _, block := range s.blocks { + if startBlockFound || block.ID() == test.startBlockID { + startBlockFound = true + if test.startHeight == request.EmptyHeight || block.Header.Height >= test.startHeight { + eventsForBlock := flow.EventsList{} + for _, event := range s.blockEvents[block.ID()] { + if addExpectedEvent(test.eventTypes, string(event.Type)) { + eventsForBlock = append(eventsForBlock, event) + } + } + eventResponse := &state_stream.EventsResponse{ + Height: block.Header.Height, + BlockID: block.ID(), + Events: eventsForBlock, + } + expectedEventsResponses = append(expectedEventsResponses, eventResponse) + } } - expectedEventsResponses = append(expectedEventsResponses, eventResponse) } + // Create a channel to receive mock EventsResponse objects + ch := make(chan interface{}) + var chReadOnly <-chan interface{} // Simulate sending a mock EventsResponse go func() { for _, eventResponse := range expectedEventsResponses { @@ -148,13 +167,9 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { }() chReadOnly = ch - subscription.Mock.On("Channel").Return(chReadOnly) subscription.Mock.On("Err").Return(nil) - filter, err := state_stream.NewEventFilter(state_stream.DefaultEventFilterConfig, chain, test.eventTypes, test.addresses, test.contracts) - assert.NoError(s.T(), err) - var startHeight uint64 if test.startHeight == request.EmptyHeight { startHeight = uint64(0) @@ -165,8 +180,9 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { req, err := getSubscribeEventsRequest(s.T(), test.startBlockID, test.startHeight, test.eventTypes, test.addresses, test.contracts) assert.NoError(s.T(), err) - _, err = executeRequest(req, backend, stateStreamBackend) + respRecorder, err := executeRequest(req, backend, stateStreamBackend) assert.NoError(s.T(), err) + requireResponse(s.T(), respRecorder, expectedEventsResponses) }) } } @@ -313,3 +329,33 @@ func requireError(t *testing.T, recorder *HijackResponseRecorder, expected strin <-recorder.closed require.Contains(t, recorder.responseBuff.String(), expected) } + +func requireResponse(t *testing.T, recorder *HijackResponseRecorder, expected []*state_stream.EventsResponse) { + time.Sleep(1 * time.Second) + // Convert the actual response from respRecorder to JSON bytes + actualJSON := recorder.responseBuff.Bytes() + // Define a regular expression pattern to match JSON objects + pattern := `\{"BlockID":".*?","Height":\d+,"Events":\[\{.*?\}\]\}` + matches := regexp.MustCompile(pattern).FindAll(actualJSON, -1) + + // Unmarshal each matched JSON into []state_stream.EventsResponse + var actual []state_stream.EventsResponse + for _, match := range matches { + var response state_stream.EventsResponse + if err := json.Unmarshal(match, &response); err == nil { + actual = append(actual, response) + } + } + + // Compare the count of expected and actual responses + assert.Equal(t, len(expected), len(actual)) + + // Compare the BlockID and Events count for each response + for i := 0; i < len(expected); i++ { + expected := expected[i] + actual := actual[i] + + assert.Equal(t, expected.BlockID, actual.BlockID) + assert.Equal(t, len(expected.Events), len(actual.Events)) + } +} diff --git a/integration/tests/access/rest_state_stream_test.go b/integration/tests/access/rest_state_stream_test.go index 9660e84805f..712e3066a51 100644 --- a/integration/tests/access/rest_state_stream_test.go +++ b/integration/tests/access/rest_state_stream_test.go @@ -153,6 +153,7 @@ func (s *RestStateStreamSuite) TestRestEventStreaming() { // Event channel closed, events received s.T().Log(" Event channel closed, events received") client.Close() + require.Equal(s.T(), len(receivedEventsResponse) > 0, "expect some events ") return } receivedEventsResponse = append(receivedEventsResponse, eventResponse) From 3cd7b32ad84dba7e7dde16229d9c08ce05ff9d4a Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 18 Aug 2023 17:41:13 +0300 Subject: [PATCH 23/35] Upgraded state streaming impl --- engine/access/rest/routes/subscribe_events.go | 107 +++------- .../rest/routes/subscribe_events_test.go | 27 --- .../access/rest/routes/websocket_handler.go | 188 +++++++++++------- 3 files changed, 146 insertions(+), 176 deletions(-) diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index 68c14cbba77..1a8ec7d0dfd 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -4,110 +4,63 @@ import ( "context" "fmt" "net/http" - "sync/atomic" - "time" - - "github.com/gorilla/websocket" - "github.com/rs/zerolog" "github.com/onflow/flow-go/engine/access/rest/models" + "github.com/onflow/flow-go/engine/access/rest/request" "github.com/onflow/flow-go/engine/access/state_stream" ) // SubscribeEvents create websocket connection and write to it requested events. func SubscribeEvents( - logger zerolog.Logger, - h SubscribeHandler, - eventFilterConfig state_stream.EventFilterConfig, - streamCount *atomic.Int32, - errorHandler func(logger zerolog.Logger, conn *websocket.Conn, err error)) { - logger = logger.With().Str("subscribe events", h.request.URL.String()).Logger() - defer h.conn.Close() - - req, err := h.request.SubscribeEventsRequest() + request *request.Request, + wsCtx *WebsocketContext) { + req, err := request.SubscribeEventsRequest() if err != nil { - errorHandler(logger, h.conn, models.NewBadRequestError(err)) + wsCtx.wsErrorHandler(models.NewBadRequestError(err)) return } // Retrieve the filter parameters from the request, if provided filter, err := state_stream.NewEventFilter( - eventFilterConfig, - h.request.Chain, + wsCtx.eventFilterConfig, + request.Chain, req.EventTypes, req.Addresses, req.Contracts, ) if err != nil { err := fmt.Errorf("event filter error") - errorHandler(logger, h.conn, models.NewBadRequestError(err)) + wsCtx.wsErrorHandler(models.NewBadRequestError(err)) return } - if streamCount.Load() >= h.maxStreams { + if wsCtx.streamCount.Load() >= wsCtx.maxStreams { err := fmt.Errorf("maximum number of streams reached") - errorHandler(logger, h.conn, models.NewRestError(http.StatusServiceUnavailable, "maximum number of streams reached", err)) + wsCtx.wsErrorHandler(models.NewRestError(http.StatusServiceUnavailable, "maximum number of streams reached", err)) return } - streamCount.Add(1) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - sub := h.api.SubscribeEvents(ctx, req.StartBlockID, req.StartHeight, filter) - - // Write messages to the WebSocket connection - err = writeEvents( - sub, - h, - streamCount) - if err != nil { - errorHandler(logger, h.conn, err) - return - } -} + wsCtx.streamCount.Add(1) -// writeEvents use for writes events and pings to the WebSocket connection. It listens to a subscription's channel for -// events and writes them to the connection. If an error occurs or the subscription channel is closed, it handles the -// error or termination accordingly. -// The function uses a ticker to periodically send ping messages to the client to maintain the connection. -func writeEvents( - sub state_stream.Subscription, - h SubscribeHandler, - streamCount *atomic.Int32, -) error { - ticker := time.NewTicker(pingPeriod) + ctx := context.Background() + sub := wsCtx.api.SubscribeEvents(ctx, req.StartBlockID, req.StartHeight, filter) - defer func() { - ticker.Stop() - streamCount.Add(-1) - }() - - for { - select { - case v, ok := <-sub.Channel(): - if !ok { - if sub.Err() != nil { - err := fmt.Errorf("stream encountered an error: %v", sub.Err()) - return models.NewBadRequestError(err) + go func() { + for { + select { + case <-ctx.Done(): + return + case event, ok := <-sub.Channel(): + if !ok { + if sub.Err() != nil { + err := fmt.Errorf("stream encountered an error: %v", sub.Err()) + wsCtx.wsErrorHandler(models.NewBadRequestError(err)) + return + } + err := fmt.Errorf("subscription channel closed, no error occurred") + wsCtx.wsErrorHandler(models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err)) + return } - err := fmt.Errorf("subscription channel closed, no error occurred") - return models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err) - } - - resp, ok := v.(*state_stream.EventsResponse) - if !ok { - return fmt.Errorf("unexpected response type: %T", v) - } - - // Write the response to the WebSocket connection - err := h.conn.WriteJSON(resp) - if err != nil { - return err - } - case <-ticker.C: - if err := h.conn.WriteMessage(websocket.PingMessage, nil); err != nil { - return err + wsCtx.send <- event } } - } + }() } diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index 5cc09509675..2e33cbda6fc 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -247,33 +247,6 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { assert.NoError(s.T(), err) requireError(s.T(), respRecorder, "subscription channel closed") }) - - s.Run("returns error for unexpected response type", func() { - stateStreamBackend := &mockstatestream.API{} - backend := &mock.API{} - subscription := &mockstatestream.Subscription{} - - ch := make(chan interface{}) - var chReadOnly <-chan interface{} - go func() { - executionDataResponse := &state_stream.ExecutionDataResponse{ - Height: s.blocks[0].Header.Height, - } - ch <- executionDataResponse - }() - chReadOnly = ch - - subscription.Mock.On("Channel").Return(chReadOnly) - subscription.Mock.On("Err").Return(nil) - stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), uint64(0), mocks.Anything).Return(subscription) - - req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) - assert.NoError(s.T(), err) - - respRecorder, err := executeRequest(req, backend, stateStreamBackend) - assert.NoError(s.T(), err) - requireError(s.T(), respRecorder, "unexpected response type: *state_stream.ExecutionDataResponse") - }) } func getSubscribeEventsRequest(t *testing.T, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) (*http.Request, error) { diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index ffe46d532ad..b8c16aac429 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "sync" "sync/atomic" "time" @@ -18,7 +19,7 @@ import ( const ( // Time allowed to read the next pong message from the peer. - pongWait = 60 * time.Second + pongWait = 10 * time.Second // Send pings to peer with this period. Must be less than pongWait. pingPeriod = (pongWait * 9) / 10 @@ -27,37 +28,87 @@ const ( writeWait = 10 * time.Second ) -// SubscribeHandler holds the necessary components and parameters for handling a WebSocket subscription. +// WebsocketContext holds the necessary components and parameters for handling a WebSocket subscription. // It manages the communication between the server and the WebSocket client for subscribing. -type SubscribeHandler struct { - request *request.Request // the incoming HTTP request containing the subscription details. - conn *websocket.Conn // the established WebSocket connection for communication with the client. - api state_stream.API // the state_stream.API instance for managing event subscriptions. - maxStreams int32 // the maximum number of streams allowed. +type WebsocketContext struct { + logger zerolog.Logger + conn *websocket.Conn // the WebSocket connection for communication with the client + api state_stream.API // the state_stream.API instance for managing event subscriptions + eventFilterConfig state_stream.EventFilterConfig // the configuration for filtering events + maxStreams int32 // the maximum number of streams allowed + streamCount *atomic.Int32 // the current number of active streams + send chan interface{} // channel for sending messages to the client + + wg sync.WaitGroup } -// SetReadWriteDeadline used to set read and write deadlines for WebSocket connections. These methods allow you to -// specify a time limit for reading from or writing to a WebSocket connection. If the operation (reading or writing) -// takes longer than the specified deadline, the connection will be closed. -func (h *SubscribeHandler) SetReadWriteDeadline() error { - err := h.conn.SetWriteDeadline(time.Now().Add(writeWait)) // Set the initial write deadline for the first ping message +// SetWebsocketConf used to set read and write deadlines for WebSocket connections and establishes a Pong handler to +// manage incoming Pong messages. These methods allow to specify a time limit for reading from or writing to a WebSocket +// connection. If the operation (reading or writing) takes longer than the specified deadline, the connection will be closed. +func (ctx *WebsocketContext) SetWebsocketConf() error { + err := ctx.conn.SetWriteDeadline(time.Now().Add(writeWait)) // Set the initial write deadline for the first ping message if err != nil { return models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err) } - err = h.conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message + err = ctx.conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message if err != nil { return models.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err) } + // Establish a Pong handler + ctx.conn.SetPongHandler(func(string) error { + err := ctx.conn.SetReadDeadline(time.Now().Add(pongWait)) + if err != nil { + return err + } + return nil + }) return nil } +// wsErrorHandler handles WebSocket errors by sending an appropriate close message +// to the client WebSocket connection. +// +// If the error is an instance of models.StatusError, the function extracts the +// relevant information like status code and user message to construct the WebSocket +// close code and message. If the error is not a models.StatusError, a default +// internal server error close code and the error's message are used. +// The connection is then closed using WriteControl to send a CloseMessage with the +// constructed close code and message. Any errors that occur during the closing +// process are logged using the provided logger. +func (ctx *WebsocketContext) wsErrorHandler(err error) { + // rest status type error should be returned with status and user message provided + var statusErr models.StatusError + var wsCode int + var wsMsg string + + if errors.As(err, &statusErr) { + if statusErr.Status() == http.StatusBadRequest { + wsCode = websocket.CloseUnsupportedData + } + if statusErr.Status() == http.StatusServiceUnavailable { + wsCode = websocket.CloseTryAgainLater + } + if statusErr.Status() == http.StatusRequestTimeout { + wsCode = websocket.CloseGoingAway + } + wsMsg = statusErr.UserMessage() + + } else { + wsCode = websocket.CloseInternalServerErr + wsMsg = err.Error() + } + + // Close the connection with the CloseError message + err = ctx.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(wsCode, wsMsg), time.Now().Add(time.Second)) + if err != nil { + ctx.logger.Error().Err(err).Msg(fmt.Sprintf("error sending WebSocket error: %v", err)) + } +} + // SubscribeHandlerFunc is a function that contains endpoint handling logic for subscribes, fetches necessary resources type SubscribeHandlerFunc func( - logger zerolog.Logger, - subscribeHandler SubscribeHandler, - eventFilterConfig state_stream.EventFilterConfig, - streamCount *atomic.Int32, - errorHandler func(logger zerolog.Logger, conn *websocket.Conn, err error), + request *request.Request, + wsCtx *WebsocketContext, ) // WSHandler is websocket handler implementing custom websocket handler function and allows easier handling of errors and @@ -69,7 +120,7 @@ type WSHandler struct { api state_stream.API eventFilterConfig state_stream.EventFilterConfig maxStreams int32 - streamCount atomic.Int32 + streamCount *atomic.Int32 } func NewWSHandler( @@ -85,7 +136,7 @@ func NewWSHandler( api: api, eventFilterConfig: eventFilterConfig, maxStreams: int32(maxGlobalStreams), - streamCount: atomic.Int32{}, + streamCount: &atomic.Int32{}, HttpHandler: NewHttpHandler(logger, chain), } @@ -97,6 +148,7 @@ func NewWSHandler( func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // create a logger logger := h.Logger.With().Str("request_url", r.URL.String()).Logger() + //logger := wsCtx.logger.With().Str("subscribe events", request.URL.String()).Logger() err := h.VerifyRequest(w, r) if err != nil { @@ -111,66 +163,58 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - subscribeHandler := SubscribeHandler{ - request: request.Decorate(r, h.HttpHandler.Chain), - conn: conn, - api: h.api, - maxStreams: h.maxStreams, + wsCtx := &WebsocketContext{ + logger: logger, + conn: conn, + api: h.api, + eventFilterConfig: h.eventFilterConfig, + maxStreams: h.maxStreams, + streamCount: h.streamCount, + send: make(chan interface{}), + wg: sync.WaitGroup{}, } - err = subscribeHandler.SetReadWriteDeadline() + err = wsCtx.SetWebsocketConf() if err != nil { - h.wsErrorHandler(logger, conn, err) + wsCtx.wsErrorHandler(err) conn.Close() } - go h.subscribeFunc( - logger, - subscribeHandler, - h.eventFilterConfig, - &h.streamCount, - h.wsErrorHandler) -} + wsCtx.wg.Add(1) + go wsCtx.writeEvents() + wsCtx.wg.Wait() -// wsErrorHandler handles WebSocket errors by sending an appropriate close message -// to the client WebSocket connection. -// -// If the error is an instance of models.StatusError, the function extracts the -// relevant information like status code and user message to construct the WebSocket -// close code and message. If the error is not a models.StatusError, a default -// internal server error close code and the error's message are used. -// The connection is then closed using WriteControl to send a CloseMessage with the -// constructed close code and message. Any errors that occur during the closing -// process are logged using the provided logger. -func (h *WSHandler) wsErrorHandler( - logger zerolog.Logger, - conn *websocket.Conn, - err error) { - // rest status type error should be returned with status and user message provided - var statusErr models.StatusError - var wsCode int - var wsMsg string + h.subscribeFunc(request.Decorate(r, h.HttpHandler.Chain), wsCtx) +} - if errors.As(err, &statusErr) { - if statusErr.Status() == http.StatusBadRequest { - wsCode = websocket.CloseUnsupportedData - } - if statusErr.Status() == http.StatusServiceUnavailable { - wsCode = websocket.CloseTryAgainLater - } - if statusErr.Status() == http.StatusRequestTimeout { - wsCode = websocket.CloseGoingAway +// writeEvents use for writes events and pings to the WebSocket connection. It listens to a subscription's channel for +// events and writes them to the connection. If an error occurs or the subscription channel is closed, it handles the +// error or termination accordingly. +// The function uses a ticker to periodically send ping messages to the client to maintain the connection. +func (ctx *WebsocketContext) writeEvents() { + ticker := time.NewTicker(pingPeriod) + + defer func() { + ticker.Stop() + ctx.streamCount.Add(-1) + ctx.conn.Close() + }() + + ctx.wg.Done() + for { + select { + case v := <-ctx.send: + // Write the response to the WebSocket connection + err := ctx.conn.WriteJSON(v) + if err != nil { + ctx.wsErrorHandler(err) + return + } + case <-ticker.C: + if err := ctx.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + ctx.wsErrorHandler(err) + return + } } - wsMsg = statusErr.UserMessage() - - } else { - wsCode = websocket.CloseInternalServerErr - wsMsg = err.Error() - } - - // Close the connection with the CloseError message - err = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(wsCode, wsMsg), time.Now().Add(time.Second)) - if err != nil { - logger.Error().Err(err).Msg(fmt.Sprintf("error sending WebSocket error: %v", err)) } } From 83f980580b0ee936badeffc253585906592ca2a0 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 18 Aug 2023 17:44:12 +0300 Subject: [PATCH 24/35] Remove unnecessary comment --- engine/access/rest/routes/websocket_handler.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index b8c16aac429..9aa2e5f6818 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -147,8 +147,7 @@ func NewWSHandler( // such as logging, error handling, request decorators func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // create a logger - logger := h.Logger.With().Str("request_url", r.URL.String()).Logger() - //logger := wsCtx.logger.With().Str("subscribe events", request.URL.String()).Logger() + logger := h.Logger.With().Str("subscribe_url", r.URL.String()).Logger() err := h.VerifyRequest(w, r) if err != nil { From 75c757bb7400a8373042008ad1f208c820cf24ed Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Tue, 22 Aug 2023 12:21:08 +0300 Subject: [PATCH 25/35] Removed unnecessary check for event types --- engine/access/rest/request/event_type.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/engine/access/rest/request/event_type.go b/engine/access/rest/request/event_type.go index b1c99978a86..b57c7fa6aff 100644 --- a/engine/access/rest/request/event_type.go +++ b/engine/access/rest/request/event_type.go @@ -25,10 +25,6 @@ func (e EventType) Flow() string { type EventTypes []EventType func (e *EventTypes) Parse(raw []string) error { - if len(raw) > MaxIDsLength { - return fmt.Errorf("at most %d event types can be requested at a time", MaxIDsLength) - } - // make a map to have only unique values as keys eventTypes := make(EventTypes, 0) uniqueTypes := make(map[string]bool) From f387b4f24cdbfb7e25290011c9beed6bf427481e Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 23 Aug 2023 15:22:42 +0300 Subject: [PATCH 26/35] Linted integration test --- integration/tests/access/rest_state_stream_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration/tests/access/rest_state_stream_test.go b/integration/tests/access/rest_state_stream_test.go index 712e3066a51..8ab7c0ea20a 100644 --- a/integration/tests/access/rest_state_stream_test.go +++ b/integration/tests/access/rest_state_stream_test.go @@ -188,7 +188,7 @@ func (s *RestStateStreamSuite) requireEvents(receivedEventsResponse []*state_str // get events by block id and event type response, err := MakeApiRequest(grpcClient.GetEventsForBlockIDs, grpcCtx, &accessproto.GetEventsForBlockIDsRequest{BlockIds: [][]byte{convert.IdentifierToMessage(receivedEventResponse.BlockID)}, - Type: fmt.Sprintf("%s", eventType)}) + Type: string(eventType)}) require.NoError(s.T(), err) require.Equal(s.T(), 1, len(response.Results), "expect to get 1 result") From df964394b6298a352d0fe05ffc4b92402a2107ae Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Mon, 4 Sep 2023 13:43:35 +0300 Subject: [PATCH 27/35] Refactored according to comments --- .../node_builder/access_node_builder.go | 8 +- engine/access/rest/request/event_type.go | 11 +- engine/access/rest/routes/subscribe_events.go | 45 +----- .../access/rest/routes/websocket_handler.go | 146 ++++++++++++------ engine/access/state_stream/backend.go | 6 +- 5 files changed, 117 insertions(+), 99 deletions(-) diff --git a/cmd/access/node_builder/access_node_builder.go b/cmd/access/node_builder/access_node_builder.go index 78745cb8a4a..70b148f0185 100644 --- a/cmd/access/node_builder/access_node_builder.go +++ b/cmd/access/node_builder/access_node_builder.go @@ -118,7 +118,6 @@ type AccessNodeConfig struct { apiRatelimits map[string]int apiBurstlimits map[string]int rpcConf rpc.Config - stateStreamBackend *state_stream.StateStreamBackend stateStreamConf state_stream.Config stateStreamFilterConf map[string]int ExecutionNodeAddress string // deprecated @@ -183,7 +182,6 @@ func DefaultAccessNodeConfig() *AccessNodeConfig { EventFilterConfig: state_stream.DefaultEventFilterConfig, ResponseLimit: state_stream.DefaultResponseLimit, }, - stateStreamBackend: nil, stateStreamFilterConf: nil, ExecutionNodeAddress: "localhost:9000", logTxTimeToFinalized: false, @@ -260,6 +258,8 @@ type FlowAccessNodeBuilder struct { secureGrpcServer *grpcserver.GrpcServer unsecureGrpcServer *grpcserver.GrpcServer stateStreamGrpcServer *grpcserver.GrpcServer + + stateStreamBackend *state_stream.StateStreamBackend } func (builder *FlowAccessNodeBuilder) buildFollowerState() *FlowAccessNodeBuilder { @@ -439,7 +439,7 @@ func (builder *FlowAccessNodeBuilder) BuildConsensusFollower() *FlowAccessNodeBu return builder } -func (builder *FlowAccessNodeBuilder) BuildStateStreamPipeline() *FlowAccessNodeBuilder { +func (builder *FlowAccessNodeBuilder) BuildExecutionSyncComponents() *FlowAccessNodeBuilder { var ds *badger.Datastore var bs network.BlobService var processedBlockHeight storage.ConsumerProgress @@ -940,7 +940,7 @@ func (builder *FlowAccessNodeBuilder) enqueueRelayNetwork() { func (builder *FlowAccessNodeBuilder) Build() (cmd.Node, error) { if builder.executionDataSyncEnabled { - builder.BuildStateStreamPipeline() + builder.BuildExecutionSyncComponents() } builder. diff --git a/engine/access/rest/request/event_type.go b/engine/access/rest/request/event_type.go index b57c7fa6aff..31d1dad784d 100644 --- a/engine/access/rest/request/event_type.go +++ b/engine/access/rest/request/event_type.go @@ -7,11 +7,11 @@ import ( type EventType string +var basicEventRe = regexp.MustCompile(`[A-Z]\.[a-f0-9]{16}\.[\w+]*\.[\w+]*`) +var flowEventRe = regexp.MustCompile(`flow\.[\w]*`) + func (e *EventType) Parse(raw string) error { - basic, _ := regexp.MatchString(`[A-Z]\.[a-f0-9]{16}\.[\w+]*\.[\w+]*`, raw) - // match core events flow.event - core, _ := regexp.MatchString(`flow\.[\w]*`, raw) - if !core && !basic { + if !basicEventRe.MatchString(raw) && !flowEventRe.MatchString(raw) { return fmt.Errorf("invalid event type format") } *e = EventType(raw) @@ -28,10 +28,11 @@ func (e *EventTypes) Parse(raw []string) error { // make a map to have only unique values as keys eventTypes := make(EventTypes, 0) uniqueTypes := make(map[string]bool) - for _, r := range raw { + for i, r := range raw { var eType EventType err := eType.Parse(r) if err != nil { + err = fmt.Errorf("%v at index %d ", i, err) return err } diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index 1a8ec7d0dfd..7ad21355911 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -3,7 +3,6 @@ package routes import ( "context" "fmt" - "net/http" "github.com/onflow/flow-go/engine/access/rest/models" "github.com/onflow/flow-go/engine/access/rest/request" @@ -13,54 +12,24 @@ import ( // SubscribeEvents create websocket connection and write to it requested events. func SubscribeEvents( request *request.Request, - wsCtx *WebsocketContext) { + ctx context.Context, + wsController *WebsocketController) (state_stream.Subscription, error) { req, err := request.SubscribeEventsRequest() if err != nil { - wsCtx.wsErrorHandler(models.NewBadRequestError(err)) - return + return nil, models.NewBadRequestError(err) } // Retrieve the filter parameters from the request, if provided filter, err := state_stream.NewEventFilter( - wsCtx.eventFilterConfig, + wsController.eventFilterConfig, request.Chain, req.EventTypes, req.Addresses, req.Contracts, ) if err != nil { - err := fmt.Errorf("event filter error") - wsCtx.wsErrorHandler(models.NewBadRequestError(err)) - return + err := fmt.Errorf("invalid event filter") + return nil, models.NewBadRequestError(err) } - if wsCtx.streamCount.Load() >= wsCtx.maxStreams { - err := fmt.Errorf("maximum number of streams reached") - wsCtx.wsErrorHandler(models.NewRestError(http.StatusServiceUnavailable, "maximum number of streams reached", err)) - return - } - wsCtx.streamCount.Add(1) - - ctx := context.Background() - sub := wsCtx.api.SubscribeEvents(ctx, req.StartBlockID, req.StartHeight, filter) - - go func() { - for { - select { - case <-ctx.Done(): - return - case event, ok := <-sub.Channel(): - if !ok { - if sub.Err() != nil { - err := fmt.Errorf("stream encountered an error: %v", sub.Err()) - wsCtx.wsErrorHandler(models.NewBadRequestError(err)) - return - } - err := fmt.Errorf("subscription channel closed, no error occurred") - wsCtx.wsErrorHandler(models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err)) - return - } - wsCtx.send <- event - } - } - }() + return wsController.api.SubscribeEvents(ctx, req.StartBlockID, req.StartHeight, filter), nil } diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 9aa2e5f6818..d8b879a96d7 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -1,10 +1,10 @@ package routes import ( + "context" "errors" "fmt" "net/http" - "sync" "sync/atomic" "time" @@ -28,9 +28,9 @@ const ( writeWait = 10 * time.Second ) -// WebsocketContext holds the necessary components and parameters for handling a WebSocket subscription. +// WebsocketController holds the necessary components and parameters for handling a WebSocket subscription. // It manages the communication between the server and the WebSocket client for subscribing. -type WebsocketContext struct { +type WebsocketController struct { logger zerolog.Logger conn *websocket.Conn // the WebSocket connection for communication with the client api state_stream.API // the state_stream.API instance for managing event subscriptions @@ -38,25 +38,23 @@ type WebsocketContext struct { maxStreams int32 // the maximum number of streams allowed streamCount *atomic.Int32 // the current number of active streams send chan interface{} // channel for sending messages to the client - - wg sync.WaitGroup } // SetWebsocketConf used to set read and write deadlines for WebSocket connections and establishes a Pong handler to // manage incoming Pong messages. These methods allow to specify a time limit for reading from or writing to a WebSocket // connection. If the operation (reading or writing) takes longer than the specified deadline, the connection will be closed. -func (ctx *WebsocketContext) SetWebsocketConf() error { - err := ctx.conn.SetWriteDeadline(time.Now().Add(writeWait)) // Set the initial write deadline for the first ping message +func (wsController *WebsocketController) SetWebsocketConf() error { + err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) // Set the initial write deadline for the first ping message if err != nil { return models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err) } - err = ctx.conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message + err = wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) // Set the initial read deadline for the first pong message if err != nil { return models.NewRestError(http.StatusInternalServerError, "Set the initial read deadline error: ", err) } // Establish a Pong handler - ctx.conn.SetPongHandler(func(string) error { - err := ctx.conn.SetReadDeadline(time.Now().Add(pongWait)) + wsController.conn.SetPongHandler(func(string) error { + err := wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) if err != nil { return err } @@ -75,7 +73,7 @@ func (ctx *WebsocketContext) SetWebsocketConf() error { // The connection is then closed using WriteControl to send a CloseMessage with the // constructed close code and message. Any errors that occur during the closing // process are logged using the provided logger. -func (ctx *WebsocketContext) wsErrorHandler(err error) { +func (wsController *WebsocketController) wsErrorHandler(err error) { // rest status type error should be returned with status and user message provided var statusErr models.StatusError var wsCode int @@ -99,17 +97,57 @@ func (ctx *WebsocketContext) wsErrorHandler(err error) { } // Close the connection with the CloseError message - err = ctx.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(wsCode, wsMsg), time.Now().Add(time.Second)) + err = wsController.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(wsCode, wsMsg), time.Now().Add(time.Second)) if err != nil { - ctx.logger.Error().Err(err).Msg(fmt.Sprintf("error sending WebSocket error: %v", err)) + wsController.logger.Error().Err(err).Msg(fmt.Sprintf("error sending WebSocket error: %v", err)) + } +} + +// writeEvents use for writes events and pings to the WebSocket connection. It listens to a subscription's channel for +// events and writes them to the connection. If an error occurs or the subscription channel is closed, it handles the +// error or termination accordingly. +// The function uses a ticker to periodically send ping messages to the client to maintain the connection. +func (wsController *WebsocketController) writeEvents() { + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + + for { + select { + case v, ok := <-wsController.send: + err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err != nil { + wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err)) + return + } + if !ok { + return + } + // Write the response to the WebSocket connection + err = wsController.conn.WriteJSON(v) + if err != nil { + wsController.wsErrorHandler(err) + return + } + case <-ticker.C: + err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err != nil { + wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err)) + return + } + if err := wsController.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + wsController.wsErrorHandler(err) + return + } + } } } // SubscribeHandlerFunc is a function that contains endpoint handling logic for subscribes, fetches necessary resources type SubscribeHandlerFunc func( request *request.Request, - wsCtx *WebsocketContext, -) + ctx context.Context, + wsController *WebsocketController, +) (state_stream.Subscription, error) // WSHandler is websocket handler implementing custom websocket handler function and allows easier handling of errors and // responses as it wraps functionality for handling error and responses outside of endpoint handling. @@ -162,7 +200,7 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - wsCtx := &WebsocketContext{ + wsController := &WebsocketController{ logger: logger, conn: conn, api: h.api, @@ -170,50 +208,58 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { maxStreams: h.maxStreams, streamCount: h.streamCount, send: make(chan interface{}), - wg: sync.WaitGroup{}, } - err = wsCtx.SetWebsocketConf() + err = wsController.SetWebsocketConf() if err != nil { - wsCtx.wsErrorHandler(err) + wsController.wsErrorHandler(err) conn.Close() } - wsCtx.wg.Add(1) - go wsCtx.writeEvents() - wsCtx.wg.Wait() + if wsController.streamCount.Load() >= wsController.maxStreams { + err := fmt.Errorf("maximum number of streams reached") + wsController.wsErrorHandler(models.NewRestError(http.StatusServiceUnavailable, err.Error(), err)) + } + wsController.streamCount.Add(1) + + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + wsController.streamCount.Add(-1) + wsController.conn.Close() + cancel() + }() + + sub, err := h.subscribeFunc(request.Decorate(r, h.HttpHandler.Chain), ctx, wsController) + if err != nil { + wsController.wsErrorHandler(err) + return + } - h.subscribeFunc(request.Decorate(r, h.HttpHandler.Chain), wsCtx) + go h.handleSubscription(sub, wsController) + wsController.writeEvents() } -// writeEvents use for writes events and pings to the WebSocket connection. It listens to a subscription's channel for -// events and writes them to the connection. If an error occurs or the subscription channel is closed, it handles the -// error or termination accordingly. -// The function uses a ticker to periodically send ping messages to the client to maintain the connection. -func (ctx *WebsocketContext) writeEvents() { - ticker := time.NewTicker(pingPeriod) +// handleSubscription is responsible for managing event subscriptions and sending events to the WebSocket connection. +// It continuously listens to the provided `state_stream.Subscription` channel, processes incoming events, +// and sends them to the WebSocket client via the `WebsocketController`. +// +// This function runs as a goroutine and handles various scenarios, including the receipt of events, +// errors in the subscription, and closure of the subscription channel. In case of an error or channel closure, +// appropriate error handling and cleanup are performed. +func (h *WSHandler) handleSubscription(sub state_stream.Subscription, wsController *WebsocketController) { + defer close(wsController.send) - defer func() { - ticker.Stop() - ctx.streamCount.Add(-1) - ctx.conn.Close() - }() + for event := range sub.Channel() { + wsController.send <- event + } - ctx.wg.Done() - for { - select { - case v := <-ctx.send: - // Write the response to the WebSocket connection - err := ctx.conn.WriteJSON(v) - if err != nil { - ctx.wsErrorHandler(err) - return - } - case <-ticker.C: - if err := ctx.conn.WriteMessage(websocket.PingMessage, nil); err != nil { - ctx.wsErrorHandler(err) - return - } - } + // The loop will keep running until there's a channel closure. + // Handle any potential errors or channel closure here + if sub.Err() != nil { + err := fmt.Errorf("stream encountered an error: %v", sub.Err()) + wsController.wsErrorHandler(models.NewBadRequestError(err)) + } else { + err := fmt.Errorf("subscription channel closed, no error occurred") + wsController.wsErrorHandler(models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err)) } } diff --git a/engine/access/state_stream/backend.go b/engine/access/state_stream/backend.go index 371427ea680..035fe45bad1 100644 --- a/engine/access/state_stream/backend.go +++ b/engine/access/state_stream/backend.go @@ -188,8 +188,10 @@ func (b *StateStreamBackend) getStartHeight(startBlockID flow.Identifier, startH // begin from the next block. // Note: we can skip the block lookup since it was already done in the constructor if startBlockID == b.rootBlockID || - // Note: when startBlockID is provided and startHeight no needed then startHeight should be 0, otherwise, an - // InvalidArgument error is returned above, so we need also check if startBlockID provided before skip it. + // Note: there is a corner case when rootBlockHeight == 0: + // since the default value of an uint64 is 0, when checking if startHeight matches the root block + // we also need to check that startBlockID is unset, otherwise we may incorrectly set the start height + // for non-matching startBlockIDs. (startHeight == b.rootBlockHeight && startBlockID == flow.ZeroID) { return b.rootBlockHeight + 1, nil } From 212d98fb23b8eebf376827e2e7f512b47eec9703 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Tue, 5 Sep 2023 15:09:31 +0300 Subject: [PATCH 28/35] Added checking connection for closing from client side --- engine/access/rest/routes/test_helpers.go | 5 ++ .../access/rest/routes/websocket_handler.go | 71 ++++++++++--------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/engine/access/rest/routes/test_helpers.go b/engine/access/rest/routes/test_helpers.go index f198ee1cad1..e3bc67e07bb 100644 --- a/engine/access/rest/routes/test_helpers.go +++ b/engine/access/rest/routes/test_helpers.go @@ -122,7 +122,12 @@ func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_s } resp := NewHijackResponseRecorder() + go func() { + time.Sleep(5 * time.Second) + //close(resp.closed) + }() router.ServeHTTP(resp, req) + //<-resp.closed return resp, nil } diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index d8b879a96d7..34ed559f6d5 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -37,7 +37,7 @@ type WebsocketController struct { eventFilterConfig state_stream.EventFilterConfig // the configuration for filtering events maxStreams int32 // the maximum number of streams allowed streamCount *atomic.Int32 // the current number of active streams - send chan interface{} // channel for sending messages to the client + read chan interface{} // channel for read close message from the client } // SetWebsocketConf used to set read and write deadlines for WebSocket connections and establishes a Pong handler to @@ -60,6 +60,21 @@ func (wsController *WebsocketController) SetWebsocketConf() error { } return nil }) + + // Start a goroutine to handle the WebSocket connection + go func() { + defer close(wsController.read) // notify websocket about closed connection + + wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) + wsController.conn.SetPongHandler(func(string) error { wsController.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + + for { + _, _, err := wsController.conn.ReadMessage() + if err != nil { + return + } + } + }() return nil } @@ -106,24 +121,36 @@ func (wsController *WebsocketController) wsErrorHandler(err error) { // writeEvents use for writes events and pings to the WebSocket connection. It listens to a subscription's channel for // events and writes them to the connection. If an error occurs or the subscription channel is closed, it handles the // error or termination accordingly. +// TODO: comment for added part // The function uses a ticker to periodically send ping messages to the client to maintain the connection. -func (wsController *WebsocketController) writeEvents() { +func (wsController *WebsocketController) writeEvents(sub state_stream.Subscription) { ticker := time.NewTicker(pingPeriod) defer ticker.Stop() for { select { - case v, ok := <-wsController.send: - err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err != nil { - wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err)) + case _, ok := <-wsController.read: + if !ok { return } + case event, ok := <-sub.Channel(): if !ok { + if sub.Err() != nil { + err := fmt.Errorf("stream encountered an error: %v", sub.Err()) + wsController.wsErrorHandler(models.NewBadRequestError(err)) + return + } + err := fmt.Errorf("subscription channel closed, no error occurred") + wsController.wsErrorHandler(models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err)) + return + } + err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err != nil { + wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err)) return } // Write the response to the WebSocket connection - err = wsController.conn.WriteJSON(v) + err = wsController.conn.WriteJSON(event) if err != nil { wsController.wsErrorHandler(err) return @@ -207,7 +234,7 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { eventFilterConfig: h.eventFilterConfig, maxStreams: h.maxStreams, streamCount: h.streamCount, - send: make(chan interface{}), + read: make(chan interface{}), } err = wsController.SetWebsocketConf() @@ -235,31 +262,5 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - go h.handleSubscription(sub, wsController) - wsController.writeEvents() -} - -// handleSubscription is responsible for managing event subscriptions and sending events to the WebSocket connection. -// It continuously listens to the provided `state_stream.Subscription` channel, processes incoming events, -// and sends them to the WebSocket client via the `WebsocketController`. -// -// This function runs as a goroutine and handles various scenarios, including the receipt of events, -// errors in the subscription, and closure of the subscription channel. In case of an error or channel closure, -// appropriate error handling and cleanup are performed. -func (h *WSHandler) handleSubscription(sub state_stream.Subscription, wsController *WebsocketController) { - defer close(wsController.send) - - for event := range sub.Channel() { - wsController.send <- event - } - - // The loop will keep running until there's a channel closure. - // Handle any potential errors or channel closure here - if sub.Err() != nil { - err := fmt.Errorf("stream encountered an error: %v", sub.Err()) - wsController.wsErrorHandler(models.NewBadRequestError(err)) - } else { - err := fmt.Errorf("subscription channel closed, no error occurred") - wsController.wsErrorHandler(models.NewRestError(http.StatusRequestTimeout, "subscription channel closed", err)) - } + wsController.writeEvents(sub) } From a4257913670aab40a34063089c3b90a6dd217b55 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Tue, 5 Sep 2023 17:34:06 +0300 Subject: [PATCH 29/35] Fixed unit test for subscribe events --- .../access/rest/routes/account_keys_test.go | 3 +- engine/access/rest/routes/accounts_test.go | 3 +- engine/access/rest/routes/blocks_test.go | 3 +- engine/access/rest/routes/collections_test.go | 3 +- .../rest/routes/subscribe_events_test.go | 19 ++++++-- engine/access/rest/routes/test_helpers.go | 37 ++++++++------- .../access/rest/routes/websocket_handler.go | 47 ++++++++++--------- 7 files changed, 66 insertions(+), 49 deletions(-) diff --git a/engine/access/rest/routes/account_keys_test.go b/engine/access/rest/routes/account_keys_test.go index d6a1ca25077..692691fa6ca 100644 --- a/engine/access/rest/routes/account_keys_test.go +++ b/engine/access/rest/routes/account_keys_test.go @@ -261,7 +261,8 @@ func TestGetAccountKeyByIndex(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { req, _ := http.NewRequest("GET", test.url, nil) - rr, err := executeRequest(req, backend, nil) + rr := NewHijackResponseRecorder() + err := executeRequest(req, backend, nil, rr) assert.NoError(t, err) assert.Equal(t, http.StatusBadRequest, rr.Code) diff --git a/engine/access/rest/routes/accounts_test.go b/engine/access/rest/routes/accounts_test.go index 0267a991d4a..482e9f6376a 100644 --- a/engine/access/rest/routes/accounts_test.go +++ b/engine/access/rest/routes/accounts_test.go @@ -126,7 +126,8 @@ func TestAccessGetAccount(t *testing.T) { for i, test := range tests { req, _ := http.NewRequest("GET", test.url, nil) - rr, err := executeRequest(req, backend, nil) + rr := NewHijackResponseRecorder() + err := executeRequest(req, backend, nil, rr) assert.NoError(t, err) assert.Equal(t, http.StatusBadRequest, rr.Code) diff --git a/engine/access/rest/routes/blocks_test.go b/engine/access/rest/routes/blocks_test.go index e60837179dc..5ef3e851ce2 100644 --- a/engine/access/rest/routes/blocks_test.go +++ b/engine/access/rest/routes/blocks_test.go @@ -149,7 +149,8 @@ func TestAccessGetBlocks(t *testing.T) { testVectors := prepareTestVectors(t, blockIDs, heights, blocks, executionResults, blkCnt) for _, tv := range testVectors { - responseRec, err := executeRequest(tv.request, backend, nil) + responseRec := NewHijackResponseRecorder() + err := executeRequest(tv.request, backend, nil, responseRec) assert.NoError(t, err) require.Equal(t, tv.expectedStatus, responseRec.Code, "failed test %s: incorrect response code", tv.description) actualResp := responseRec.Body.String() diff --git a/engine/access/rest/routes/collections_test.go b/engine/access/rest/routes/collections_test.go index dec00e67ef3..3ff17198386 100644 --- a/engine/access/rest/routes/collections_test.go +++ b/engine/access/rest/routes/collections_test.go @@ -87,7 +87,8 @@ func TestGetCollections(t *testing.T) { Once() req := getCollectionReq(col.ID().String(), true) - rr, err := executeRequest(req, backend, nil) + rr := NewHijackResponseRecorder() + err := executeRequest(req, backend, nil, rr) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rr.Code) diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index 2e33cbda6fc..e1e5932999a 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -180,7 +180,13 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { req, err := getSubscribeEventsRequest(s.T(), test.startBlockID, test.startHeight, test.eventTypes, test.addresses, test.contracts) assert.NoError(s.T(), err) - respRecorder, err := executeRequest(req, backend, stateStreamBackend) + respRecorder := NewHijackResponseRecorder() + // closing the connection after 5 seconds + go func() { + time.Sleep(5 * time.Second) + close(respRecorder.closed) + }() + err = executeRequest(req, backend, stateStreamBackend, respRecorder) assert.NoError(s.T(), err) requireResponse(s.T(), respRecorder, expectedEventsResponses) }) @@ -194,7 +200,8 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil) assert.NoError(s.T(), err) - respRecorder, err := executeRequest(req, backend, stateStreamBackend) + respRecorder := NewHijackResponseRecorder() + err = executeRequest(req, backend, stateStreamBackend, respRecorder) assert.NoError(s.T(), err) requireError(s.T(), respRecorder, "can only provide either block ID or start height") }) @@ -219,7 +226,8 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil) assert.NoError(s.T(), err) - respRecorder, err := executeRequest(req, backend, stateStreamBackend) + respRecorder := NewHijackResponseRecorder() + err = executeRequest(req, backend, stateStreamBackend, respRecorder) assert.NoError(s.T(), err) requireError(s.T(), respRecorder, "stream encountered an error: subscription error") }) @@ -243,7 +251,8 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) assert.NoError(s.T(), err) - respRecorder, err := executeRequest(req, backend, stateStreamBackend) + respRecorder := NewHijackResponseRecorder() + err = executeRequest(req, backend, stateStreamBackend, respRecorder) assert.NoError(s.T(), err) requireError(s.T(), respRecorder, "subscription channel closed") }) @@ -304,7 +313,7 @@ func requireError(t *testing.T, recorder *HijackResponseRecorder, expected strin } func requireResponse(t *testing.T, recorder *HijackResponseRecorder, expected []*state_stream.EventsResponse) { - time.Sleep(1 * time.Second) + <-recorder.closed // Convert the actual response from respRecorder to JSON bytes actualJSON := recorder.responseBuff.Bytes() // Define a regular expression pattern to match JSON objects diff --git a/engine/access/rest/routes/test_helpers.go b/engine/access/rest/routes/test_helpers.go index e3bc67e07bb..4526aba9a25 100644 --- a/engine/access/rest/routes/test_helpers.go +++ b/engine/access/rest/routes/test_helpers.go @@ -39,14 +39,17 @@ const ( ) type fakeNetConn struct { - io.Reader io.Writer closed chan struct{} } // Close closes the fakeNetConn and signals its closure by closing the "closed" channel. func (c fakeNetConn) Close() error { - close(c.closed) + select { + case <-c.closed: + default: + close(c.closed) + } return nil } func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } @@ -54,6 +57,10 @@ func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } +func (c fakeNetConn) Read(p []byte) (n int, err error) { + <-c.closed + return 0, fmt.Errorf("closed") +} type fakeAddr int @@ -86,14 +93,14 @@ func (w *HijackResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { w.responseBuff = bytes.NewBuffer(make([]byte, 0)) w.closed = make(chan struct{}, 1) - return fakeNetConn{strings.NewReader(""), w.responseBuff, w.closed}, bufio.NewReadWriter(br, bw), nil + return fakeNetConn{w.responseBuff, w.closed}, bufio.NewReadWriter(br, bw), nil } // NewHijackResponseRecorder creates a new instance of HijackResponseRecorder. func NewHijackResponseRecorder() *HijackResponseRecorder { - responseRecorder := &HijackResponseRecorder{} - responseRecorder.ResponseRecorder = httptest.NewRecorder() - return responseRecorder + return &HijackResponseRecorder{ + ResponseRecorder: httptest.NewRecorder(), + } } func newRouter(backend *mock.API, stateStreamApi *mock_state_stream.API) (*mux.Router, error) { @@ -115,20 +122,13 @@ func newRouter(backend *mock.API, stateStreamApi *mock_state_stream.API) (*mux.R stateStreamConfig.MaxGlobalStreams) } -func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_state_stream.API) (*HijackResponseRecorder, error) { +func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_state_stream.API, responseRecorder *HijackResponseRecorder) error { router, err := newRouter(backend, stateStreamApi) if err != nil { - return nil, err + return err } - - resp := NewHijackResponseRecorder() - go func() { - time.Sleep(5 * time.Second) - //close(resp.closed) - }() - router.ServeHTTP(resp, req) - //<-resp.closed - return resp, nil + router.ServeHTTP(responseRecorder, req) + return nil } func assertOKResponse(t *testing.T, req *http.Request, expectedRespBody string, backend *mock.API, stateStreamApi *mock_state_stream.API) { @@ -136,7 +136,8 @@ func assertOKResponse(t *testing.T, req *http.Request, expectedRespBody string, } func assertResponse(t *testing.T, req *http.Request, status int, expectedRespBody string, backend *mock.API, stateStreamApi *mock_state_stream.API) { - rr, err := executeRequest(req, backend, stateStreamApi) + rr := NewHijackResponseRecorder() + err := executeRequest(req, backend, stateStreamApi, rr) assert.NoError(t, err) actualResponseBody := rr.Body.String() require.JSONEq(t, diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 34ed559f6d5..8eb677b77bd 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -37,7 +37,7 @@ type WebsocketController struct { eventFilterConfig state_stream.EventFilterConfig // the configuration for filtering events maxStreams int32 // the maximum number of streams allowed streamCount *atomic.Int32 // the current number of active streams - read chan interface{} // channel for read close message from the client + readChannel chan interface{} // channel which notify closing connection by the client } // SetWebsocketConf used to set read and write deadlines for WebSocket connections and establishes a Pong handler to @@ -60,21 +60,6 @@ func (wsController *WebsocketController) SetWebsocketConf() error { } return nil }) - - // Start a goroutine to handle the WebSocket connection - go func() { - defer close(wsController.read) // notify websocket about closed connection - - wsController.conn.SetReadDeadline(time.Now().Add(pongWait)) - wsController.conn.SetPongHandler(func(string) error { wsController.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - - for { - _, _, err := wsController.conn.ReadMessage() - if err != nil { - return - } - } - }() return nil } @@ -118,10 +103,9 @@ func (wsController *WebsocketController) wsErrorHandler(err error) { } } -// writeEvents use for writes events and pings to the WebSocket connection. It listens to a subscription's channel for -// events and writes them to the connection. If an error occurs or the subscription channel is closed, it handles the -// error or termination accordingly. -// TODO: comment for added part +// writeEvents use for writes events and pings to the WebSocket connection for a given subscription. +// It listens to the subscription's channel for events and writes them to the WebSocket connection. +// If an error occurs or the subscription channel is closed, it handles the error or termination accordingly. // The function uses a ticker to periodically send ping messages to the client to maintain the connection. func (wsController *WebsocketController) writeEvents(sub state_stream.Subscription) { ticker := time.NewTicker(pingPeriod) @@ -129,7 +113,7 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti for { select { - case _, ok := <-wsController.read: + case _, ok := <-wsController.readChannel: if !ok { return } @@ -169,6 +153,24 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti } } +// read function handles WebSocket messages from the client. +// It continuously reads messages from the WebSocket connection and closes +// the associated read channel when the connection is closed. +// +// This method should be called after establishing the WebSocket connection +// to handle incoming messages asynchronously. +func (wsController *WebsocketController) read() { + // Start a goroutine to handle the WebSocket connection + defer close(wsController.readChannel) // notify websocket about closed connection + + for { + _, _, err := wsController.conn.ReadMessage() + if err != nil { + return + } + } +} + // SubscribeHandlerFunc is a function that contains endpoint handling logic for subscribes, fetches necessary resources type SubscribeHandlerFunc func( request *request.Request, @@ -234,7 +236,7 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { eventFilterConfig: h.eventFilterConfig, maxStreams: h.maxStreams, streamCount: h.streamCount, - read: make(chan interface{}), + readChannel: make(chan interface{}), } err = wsController.SetWebsocketConf() @@ -262,5 +264,6 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + go wsController.read() wsController.writeEvents(sub) } From 23cad55bd696bbe01c788c77892dafb50ea7cea1 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Tue, 12 Sep 2023 17:00:39 +0300 Subject: [PATCH 30/35] Updated according to comments --- engine/access/rest/request/event_type.go | 3 +- engine/access/rest/routes/subscribe_events.go | 2 +- .../rest/routes/subscribe_events_test.go | 18 ++------- .../access/rest/routes/websocket_handler.go | 38 ++++++++++++------- .../tests/access/rest_state_stream_test.go | 29 ++++++-------- 5 files changed, 41 insertions(+), 49 deletions(-) diff --git a/engine/access/rest/request/event_type.go b/engine/access/rest/request/event_type.go index 31d1dad784d..c3f425d81c8 100644 --- a/engine/access/rest/request/event_type.go +++ b/engine/access/rest/request/event_type.go @@ -32,8 +32,7 @@ func (e *EventTypes) Parse(raw []string) error { var eType EventType err := eType.Parse(r) if err != nil { - err = fmt.Errorf("%v at index %d ", i, err) - return err + return fmt.Errorf("error at index %d: %w", i, err) } if !uniqueTypes[eType.Flow()] { diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index 7ad21355911..e9d29c92c3e 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -11,8 +11,8 @@ import ( // SubscribeEvents create websocket connection and write to it requested events. func SubscribeEvents( - request *request.Request, ctx context.Context, + request *request.Request, wsController *WebsocketController) (state_stream.Subscription, error) { req, err := request.SubscribeEventsRequest() if err != nil { diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index e1e5932999a..7ac7e6404b1 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -12,6 +12,8 @@ import ( "testing" "time" + "golang.org/x/exp/slices" + "github.com/stretchr/testify/assert" mocks "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -121,19 +123,6 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { var expectedEventsResponses []*state_stream.EventsResponse startBlockFound := test.startBlockID == flow.ZeroID - // Helper function to check if a string is present in a slice - addExpectedEvent := func(slice []string, item string) bool { - if slice == nil { - return true // Include all events when test.eventTypes is nil - } - for _, s := range slice { - if s == item { - return true - } - } - return false - } - // construct expected event responses based on the provided test configuration for _, block := range s.blocks { if startBlockFound || block.ID() == test.startBlockID { @@ -141,7 +130,8 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { if test.startHeight == request.EmptyHeight || block.Header.Height >= test.startHeight { eventsForBlock := flow.EventsList{} for _, event := range s.blockEvents[block.ID()] { - if addExpectedEvent(test.eventTypes, string(event.Type)) { + if slices.Contains(test.eventTypes, string(event.Type)) || + len(test.eventTypes) == 0 { //Include all events eventsForBlock = append(eventsForBlock, event) } } diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 8eb677b77bd..9c4a400f0ea 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -37,7 +37,7 @@ type WebsocketController struct { eventFilterConfig state_stream.EventFilterConfig // the configuration for filtering events maxStreams int32 // the maximum number of streams allowed streamCount *atomic.Int32 // the current number of active streams - readChannel chan interface{} // channel which notify closing connection by the client + readChannel chan struct{} // channel which notify closing connection by the client } // SetWebsocketConf used to set read and write deadlines for WebSocket connections and establishes a Pong handler to @@ -130,7 +130,7 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti } err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err != nil { - wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err)) + wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline error: ", err)) return } // Write the response to the WebSocket connection @@ -142,7 +142,7 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti case <-ticker.C: err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err != nil { - wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "Set the initial write deadline error: ", err)) + wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline error: ", err)) return } if err := wsController.conn.WriteMessage(websocket.PingMessage, nil); err != nil { @@ -155,7 +155,8 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti // read function handles WebSocket messages from the client. // It continuously reads messages from the WebSocket connection and closes -// the associated read channel when the connection is closed. +// the associated read channel when the connection is closed by client or when an +// any additional message is received from the client. // // This method should be called after establishing the WebSocket connection // to handle incoming messages asynchronously. @@ -164,17 +165,25 @@ func (wsController *WebsocketController) read() { defer close(wsController.readChannel) // notify websocket about closed connection for { - _, _, err := wsController.conn.ReadMessage() + // reads messages from the WebSocket connection when the connection is closed by client or when an + // 1) when the connection is closed by client + // 2) when an any additional message is received from the client + _, msg, err := wsController.conn.ReadMessage() if err != nil { return } + + // Check the message from the client, if is any just close the connection + if len(msg) > 0 { + return + } } } // SubscribeHandlerFunc is a function that contains endpoint handling logic for subscribes, fetches necessary resources type SubscribeHandlerFunc func( - request *request.Request, ctx context.Context, + request *request.Request, wsController *WebsocketController, ) (state_stream.Subscription, error) @@ -228,6 +237,7 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.errorHandler(w, models.NewRestError(http.StatusInternalServerError, "webSocket upgrade error: ", err), logger) return } + defer conn.Close() wsController := &WebsocketController{ logger: logger, @@ -236,29 +246,29 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { eventFilterConfig: h.eventFilterConfig, maxStreams: h.maxStreams, streamCount: h.streamCount, - readChannel: make(chan interface{}), + readChannel: make(chan struct{}), } err = wsController.SetWebsocketConf() if err != nil { wsController.wsErrorHandler(err) - conn.Close() + return } if wsController.streamCount.Load() >= wsController.maxStreams { err := fmt.Errorf("maximum number of streams reached") wsController.wsErrorHandler(models.NewRestError(http.StatusServiceUnavailable, err.Error(), err)) + return } wsController.streamCount.Add(1) + defer wsController.streamCount.Add(-1) + // cancelling the context passed into the `subscribeFunc` to ensure when the client disconnect it's time the shutdown + // gorountines setup by the backend are cleaned up if the client disconnects first. ctx, cancel := context.WithCancel(context.Background()) - defer func() { - wsController.streamCount.Add(-1) - wsController.conn.Close() - cancel() - }() + defer cancel() - sub, err := h.subscribeFunc(request.Decorate(r, h.HttpHandler.Chain), ctx, wsController) + sub, err := h.subscribeFunc(ctx, request.Decorate(r, h.HttpHandler.Chain), wsController) if err != nil { wsController.wsErrorHandler(err) return diff --git a/integration/tests/access/rest_state_stream_test.go b/integration/tests/access/rest_state_stream_test.go index 8ab7c0ea20a..dbcac0115c5 100644 --- a/integration/tests/access/rest_state_stream_test.go +++ b/integration/tests/access/rest_state_stream_test.go @@ -114,16 +114,18 @@ func (s *RestStateStreamSuite) TestRestEventStreaming() { client, err := getWSClient(ctx, url) require.NoError(t, err) + var receivedEventsResponse []*state_stream.EventsResponse - eventChan := make(chan *state_stream.EventsResponse) - // Start the timeout goroutine - timeoutChan := make(chan struct{}) go func() { - time.Sleep(10 * time.Second) // Sleep for 10 seconds - close(timeoutChan) // Signal the timeout + time.Sleep(10 * time.Second) + // close connection after 10 seconds + client.Close() + // check events + s.requireEvents(receivedEventsResponse) }() + eventChan := make(chan *state_stream.EventsResponse) go func() { for { resp := &state_stream.EventsResponse{} @@ -139,21 +141,12 @@ func (s *RestStateStreamSuite) TestRestEventStreaming() { } }() - // Wait for events or timeout + // collect received events during 10 seconds for { select { - case <-timeoutChan: - // Handle the timeout and close the client connection - client.Close() - s.T().Log("Client connection closed") - s.requireEvents(receivedEventsResponse) - return case eventResponse, ok := <-eventChan: + // Event channel closed if !ok { - // Event channel closed, events received - s.T().Log(" Event channel closed, events received") - client.Close() - require.Equal(s.T(), len(receivedEventsResponse) > 0, "expect some events ") return } receivedEventsResponse = append(receivedEventsResponse, eventResponse) @@ -197,8 +190,8 @@ func (s *RestStateStreamSuite) requireEvents(receivedEventsResponse []*state_str require.Equal(s.T(), len(expectedEventsResult.Events), len(receivedEventList), "expect the same count of events") for i, event := range receivedEventList { - require.Equal(s.T(), expectedEventsResult.Events[i].EventIndex, event.EventIndex, "expect the same EventIndex") - require.Equal(s.T(), convert.MessageToIdentifier(expectedEventsResult.Events[i].TransactionId), event.TransactionID, "expect the same TransactionId") + require.Equal(s.T(), expectedEventsResult.Events[i].EventIndex, event.EventIndex, "expect the same event index") + require.Equal(s.T(), convert.MessageToIdentifier(expectedEventsResult.Events[i].TransactionId), event.TransactionID, "expect the same transaction id") } } } From f6325b642b7581b71784edc2161ada5aa7fe2eeb Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Tue, 12 Sep 2023 18:26:55 +0300 Subject: [PATCH 31/35] Updated error according to comment --- engine/access/rest/routes/websocket_handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 9c4a400f0ea..832f3d1a747 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -121,7 +121,7 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti if !ok { if sub.Err() != nil { err := fmt.Errorf("stream encountered an error: %v", sub.Err()) - wsController.wsErrorHandler(models.NewBadRequestError(err)) + wsController.wsErrorHandler(err) return } err := fmt.Errorf("subscription channel closed, no error occurred") From c630e4a216cadf1d1a87100701b78e176be56fae Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Wed, 13 Sep 2023 19:25:03 +0300 Subject: [PATCH 32/35] Added RouterBuilder and updated rest unit tests. Added fixes according to last comments. Linted --- .../access/rest/routes/account_keys_test.go | 21 +++---- engine/access/rest/routes/accounts_test.go | 12 ++-- engine/access/rest/routes/blocks_test.go | 10 +--- engine/access/rest/routes/collections_test.go | 10 ++-- engine/access/rest/routes/events_test.go | 2 +- .../rest/routes/execution_result_test.go | 8 +-- engine/access/rest/routes/network_test.go | 2 +- .../rest/routes/node_version_info_test.go | 2 +- engine/access/rest/routes/router.go | 58 ++++++++++++------- engine/access/rest/routes/scripts_test.go | 9 ++- .../rest/routes/subscribe_events_test.go | 17 ++---- engine/access/rest/routes/test_helpers.go | 58 +++++++++---------- .../access/rest/routes/transactions_test.go | 22 +++---- .../access/rest/routes/websocket_handler.go | 22 ++++--- engine/access/rest/server.go | 8 +-- .../tests/access/rest_state_stream_test.go | 16 ++--- 16 files changed, 135 insertions(+), 142 deletions(-) diff --git a/engine/access/rest/routes/account_keys_test.go b/engine/access/rest/routes/account_keys_test.go index 692691fa6ca..9b6d4ba34c7 100644 --- a/engine/access/rest/routes/account_keys_test.go +++ b/engine/access/rest/routes/account_keys_test.go @@ -48,7 +48,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { expected := expectedAccountKeyResponse(account) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -69,7 +69,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { expected := expectedAccountKeyResponse(account) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -97,7 +97,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { } `, statusCode, index) - assertResponse(t, req, statusCode, expected, backend, nil) + assertResponse(t, req, statusCode, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -125,7 +125,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { } `, statusCode, index) - assertResponse(t, req, statusCode, expected, backend, nil) + assertResponse(t, req, statusCode, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -154,7 +154,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { } `, statusCode, account.Address) - assertResponse(t, req, statusCode, expected, backend, nil) + assertResponse(t, req, statusCode, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -183,7 +183,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { } `, statusCode, account.Address) - assertResponse(t, req, statusCode, expected, backend, nil) + assertResponse(t, req, statusCode, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -198,7 +198,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { expected := expectedAccountKeyResponse(account) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -222,7 +222,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { } `, statusCode, finalHeight) - assertResponse(t, req, statusCode, expected, backend, nil) + assertResponse(t, req, statusCode, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -261,10 +261,7 @@ func TestGetAccountKeyByIndex(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { req, _ := http.NewRequest("GET", test.url, nil) - rr := NewHijackResponseRecorder() - err := executeRequest(req, backend, nil, rr) - assert.NoError(t, err) - + rr := executeRequest(req, backend) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.JSONEq(t, test.out, rr.Body.String()) }) diff --git a/engine/access/rest/routes/accounts_test.go b/engine/access/rest/routes/accounts_test.go index 482e9f6376a..feb9e77eeae 100644 --- a/engine/access/rest/routes/accounts_test.go +++ b/engine/access/rest/routes/accounts_test.go @@ -61,7 +61,7 @@ func TestAccessGetAccount(t *testing.T) { expected := expectedExpandedResponse(account) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -81,7 +81,7 @@ func TestAccessGetAccount(t *testing.T) { expected := expectedExpandedResponse(account) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -96,7 +96,7 @@ func TestAccessGetAccount(t *testing.T) { expected := expectedExpandedResponse(account) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -111,7 +111,7 @@ func TestAccessGetAccount(t *testing.T) { expected := expectedCondensedResponse(account) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) @@ -126,9 +126,7 @@ func TestAccessGetAccount(t *testing.T) { for i, test := range tests { req, _ := http.NewRequest("GET", test.url, nil) - rr := NewHijackResponseRecorder() - err := executeRequest(req, backend, nil, rr) - assert.NoError(t, err) + rr := executeRequest(req, backend) assert.Equal(t, http.StatusBadRequest, rr.Code) assert.JSONEq(t, test.out, rr.Body.String(), fmt.Sprintf("test #%d failed: %v", i, test)) diff --git a/engine/access/rest/routes/blocks_test.go b/engine/access/rest/routes/blocks_test.go index 5ef3e851ce2..7facf06d423 100644 --- a/engine/access/rest/routes/blocks_test.go +++ b/engine/access/rest/routes/blocks_test.go @@ -8,8 +8,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/onflow/flow-go/engine/access/rest/request" "github.com/onflow/flow-go/engine/access/rest/util" @@ -149,11 +147,9 @@ func TestAccessGetBlocks(t *testing.T) { testVectors := prepareTestVectors(t, blockIDs, heights, blocks, executionResults, blkCnt) for _, tv := range testVectors { - responseRec := NewHijackResponseRecorder() - err := executeRequest(tv.request, backend, nil, responseRec) - assert.NoError(t, err) - require.Equal(t, tv.expectedStatus, responseRec.Code, "failed test %s: incorrect response code", tv.description) - actualResp := responseRec.Body.String() + rr := executeRequest(tv.request, backend) + require.Equal(t, tv.expectedStatus, rr.Code, "failed test %s: incorrect response code", tv.description) + actualResp := rr.Body.String() require.JSONEq(t, tv.expectedResponse, actualResp, "Failed: %s: incorrect response body", tv.description) } } diff --git a/engine/access/rest/routes/collections_test.go b/engine/access/rest/routes/collections_test.go index 3ff17198386..d0c5684f4ed 100644 --- a/engine/access/rest/routes/collections_test.go +++ b/engine/access/rest/routes/collections_test.go @@ -62,7 +62,7 @@ func TestGetCollections(t *testing.T) { }`, col.ID(), col.ID(), transactionsStr) req := getCollectionReq(col.ID().String(), false) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocks.AssertExpectationsForObjects(t, backend) } }) @@ -87,16 +87,14 @@ func TestGetCollections(t *testing.T) { Once() req := getCollectionReq(col.ID().String(), true) - rr := NewHijackResponseRecorder() - err := executeRequest(req, backend, nil, rr) - assert.NoError(t, err) + rr := executeRequest(req, backend) assert.Equal(t, http.StatusOK, rr.Code) // really hacky but we can't build whole response since it's really complex // so we just make sure the transactions are included and have defined values // anyhow we already test transaction responses in transaction tests var res map[string]interface{} - err = json.Unmarshal(rr.Body.Bytes(), &res) + err := json.Unmarshal(rr.Body.Bytes(), &res) assert.NoError(t, err) resTx := res["transactions"].([]interface{}) for i, r := range resTx { @@ -147,7 +145,7 @@ func TestGetCollections(t *testing.T) { Return(test.mockValue, test.mockErr) } req := getCollectionReq(test.id, false) - assertResponse(t, req, test.status, test.response, backend, nil) + assertResponse(t, req, test.status, test.response, backend) } }) } diff --git a/engine/access/rest/routes/events_test.go b/engine/access/rest/routes/events_test.go index 0aec01937af..c4bd95f4d34 100644 --- a/engine/access/rest/routes/events_test.go +++ b/engine/access/rest/routes/events_test.go @@ -125,7 +125,7 @@ func TestGetEvents(t *testing.T) { for _, test := range testVectors { t.Run(test.description, func(t *testing.T) { - assertResponse(t, test.request, test.expectedStatus, test.expectedResponse, backend, nil) + assertResponse(t, test.request, test.expectedStatus, test.expectedResponse, backend) }) } diff --git a/engine/access/rest/routes/execution_result_test.go b/engine/access/rest/routes/execution_result_test.go index e44804fceb8..ba74974af1a 100644 --- a/engine/access/rest/routes/execution_result_test.go +++ b/engine/access/rest/routes/execution_result_test.go @@ -48,7 +48,7 @@ func TestGetResultByID(t *testing.T) { req := getResultByIDReq(id.String(), nil) expected := executionResultExpectedStr(result) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocks.AssertExpectationsForObjects(t, backend) }) @@ -61,7 +61,7 @@ func TestGetResultByID(t *testing.T) { Once() req := getResultByIDReq(id.String(), nil) - assertResponse(t, req, http.StatusNotFound, `{"code":404,"message":"Flow resource not found: block not found"}`, backend, nil) + assertResponse(t, req, http.StatusNotFound, `{"code":404,"message":"Flow resource not found: block not found"}`, backend) mocks.AssertExpectationsForObjects(t, backend) }) } @@ -81,7 +81,7 @@ func TestGetResultBlockID(t *testing.T) { req := getResultByIDReq("", []string{blockID.String()}) expected := fmt.Sprintf(`[%s]`, executionResultExpectedStr(result)) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocks.AssertExpectationsForObjects(t, backend) }) @@ -94,7 +94,7 @@ func TestGetResultBlockID(t *testing.T) { Once() req := getResultByIDReq("", []string{blockID.String()}) - assertResponse(t, req, http.StatusNotFound, `{"code":404,"message":"Flow resource not found: block not found"}`, backend, nil) + assertResponse(t, req, http.StatusNotFound, `{"code":404,"message":"Flow resource not found: block not found"}`, backend) mocks.AssertExpectationsForObjects(t, backend) }) } diff --git a/engine/access/rest/routes/network_test.go b/engine/access/rest/routes/network_test.go index 0cb3bc2b6c8..00d0ca03944 100644 --- a/engine/access/rest/routes/network_test.go +++ b/engine/access/rest/routes/network_test.go @@ -38,7 +38,7 @@ func TestGetNetworkParameters(t *testing.T) { expected := networkParametersExpectedStr(flow.Mainnet) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) } diff --git a/engine/access/rest/routes/node_version_info_test.go b/engine/access/rest/routes/node_version_info_test.go index 5d69131d58f..179d339f94f 100644 --- a/engine/access/rest/routes/node_version_info_test.go +++ b/engine/access/rest/routes/node_version_info_test.go @@ -41,7 +41,7 @@ func TestGetNodeVersionInfo(t *testing.T) { expected := nodeVersionInfoExpectedStr(params) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) mocktestify.AssertExpectationsForObjects(t, backend) }) } diff --git a/engine/access/rest/routes/router.go b/engine/access/rest/routes/router.go index a14bfabee70..e7928cd8de3 100644 --- a/engine/access/rest/routes/router.go +++ b/engine/access/rest/routes/router.go @@ -17,13 +17,15 @@ import ( "github.com/onflow/flow-go/module" ) -func NewRouter(backend access.API, +type RouterBuilder struct { + logger zerolog.Logger + router *mux.Router + v1SubRouter *mux.Router +} + +func NewRouterBuilder( logger zerolog.Logger, - chain flow.Chain, - restCollector module.RestMetrics, - stateStreamApi state_stream.API, - eventFilterConfig state_stream.EventFilterConfig, - maxGlobalStreams uint32) (*mux.Router, error) { + restCollector module.RestMetrics) *RouterBuilder { router := mux.NewRouter().StrictSlash(true) v1SubRouter := router.PathPrefix("/v1").Subrouter() @@ -33,30 +35,46 @@ func NewRouter(backend access.API, v1SubRouter.Use(middleware.QuerySelect()) v1SubRouter.Use(middleware.MetricsMiddleware(restCollector)) - linkGenerator := models.NewLinkGeneratorImpl(v1SubRouter) + return &RouterBuilder{ + logger: logger, + router: router, + v1SubRouter: v1SubRouter, + } +} +func (b *RouterBuilder) AddRestRoutes(backend access.API, chain flow.Chain) *RouterBuilder { + linkGenerator := models.NewLinkGeneratorImpl(b.v1SubRouter) for _, r := range Routes { - h := NewHandler(logger, backend, r.Handler, linkGenerator, chain) - v1SubRouter. + h := NewHandler(b.logger, backend, r.Handler, linkGenerator, chain) + b.v1SubRouter. Methods(r.Method). Path(r.Pattern). Name(r.Name). Handler(h) } + return b +} - // Note: we add subscribe routes only if stateStreamApi is available - if stateStreamApi != nil { - for _, r := range WSRoutes { - h := NewWSHandler(logger, r.Handler, chain, stateStreamApi, eventFilterConfig, maxGlobalStreams) - v1SubRouter. - Methods(r.Method). - Path(r.Pattern). - Name(r.Name). - Handler(h) - } +func (b *RouterBuilder) AddWsRoutes( + chain flow.Chain, + stateStreamApi state_stream.API, + eventFilterConfig state_stream.EventFilterConfig, + maxGlobalStreams uint32) *RouterBuilder { + + for _, r := range WSRoutes { + h := NewWSHandler(b.logger, r.Handler, chain, stateStreamApi, eventFilterConfig, maxGlobalStreams) + b.v1SubRouter. + Methods(r.Method). + Path(r.Pattern). + Name(r.Name). + Handler(h) } - return router, nil + return b +} + +func (b *RouterBuilder) Build() *mux.Router { + return b.router } type route struct { diff --git a/engine/access/rest/routes/scripts_test.go b/engine/access/rest/routes/scripts_test.go index 5e1b15ca86f..3fc48689a86 100644 --- a/engine/access/rest/routes/scripts_test.go +++ b/engine/access/rest/routes/scripts_test.go @@ -55,7 +55,7 @@ func TestScripts(t *testing.T) { assertOKResponse(t, req, fmt.Sprintf( "\"%s\"", base64.StdEncoding.EncodeToString([]byte(`hello world`)), - ), backend, nil) + ), backend) }) t.Run("get by height", func(t *testing.T) { @@ -70,7 +70,7 @@ func TestScripts(t *testing.T) { assertOKResponse(t, req, fmt.Sprintf( "\"%s\"", base64.StdEncoding.EncodeToString([]byte(`hello world`)), - ), backend, nil) + ), backend) }) t.Run("get by ID", func(t *testing.T) { @@ -85,7 +85,7 @@ func TestScripts(t *testing.T) { assertOKResponse(t, req, fmt.Sprintf( "\"%s\"", base64.StdEncoding.EncodeToString([]byte(`hello world`)), - ), backend, nil) + ), backend) }) t.Run("get error", func(t *testing.T) { @@ -101,7 +101,6 @@ func TestScripts(t *testing.T) { http.StatusBadRequest, `{"code":400, "message":"Invalid Flow request: internal server error"}`, backend, - nil, ) }) @@ -126,7 +125,7 @@ func TestScripts(t *testing.T) { for _, test := range tests { req := scriptReq(test.id, test.height, test.body) - assertResponse(t, req, http.StatusBadRequest, test.out, backend, nil) + assertResponse(t, req, http.StatusBadRequest, test.out, backend) } }) } diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index 7ac7e6404b1..4b058b44efd 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -19,7 +19,6 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/rest/request" "github.com/onflow/flow-go/engine/access/state_stream" mockstatestream "github.com/onflow/flow-go/engine/access/state_stream/mock" @@ -113,8 +112,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { for _, test := range tests { s.Run(test.name, func() { - stateStreamBackend := &mockstatestream.API{} - backend := &mock.API{} + stateStreamBackend := mockstatestream.NewAPI(s.T()) subscription := &mockstatestream.Subscription{} filter, err := state_stream.NewEventFilter(state_stream.DefaultEventFilterConfig, chain, test.eventTypes, test.addresses, test.contracts) @@ -176,7 +174,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { time.Sleep(5 * time.Second) close(respRecorder.closed) }() - err = executeRequest(req, backend, stateStreamBackend, respRecorder) + executeWsRequest(req, stateStreamBackend, respRecorder) assert.NoError(s.T(), err) requireResponse(s.T(), respRecorder, expectedEventsResponses) }) @@ -186,20 +184,16 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { s.Run("returns error for block id and height", func() { stateStreamBackend := &mockstatestream.API{} - backend := &mock.API{} - req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil) assert.NoError(s.T(), err) respRecorder := NewHijackResponseRecorder() - err = executeRequest(req, backend, stateStreamBackend, respRecorder) + executeWsRequest(req, stateStreamBackend, respRecorder) assert.NoError(s.T(), err) requireError(s.T(), respRecorder, "can only provide either block ID or start height") }) s.Run("returns error for invalid block id", func() { stateStreamBackend := &mockstatestream.API{} - backend := &mock.API{} - invalidBlock := unittest.BlockFixture() subscription := &mockstatestream.Subscription{} @@ -217,14 +211,13 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil) assert.NoError(s.T(), err) respRecorder := NewHijackResponseRecorder() - err = executeRequest(req, backend, stateStreamBackend, respRecorder) + executeWsRequest(req, stateStreamBackend, respRecorder) assert.NoError(s.T(), err) requireError(s.T(), respRecorder, "stream encountered an error: subscription error") }) s.Run("returns error when channel closed", func() { stateStreamBackend := &mockstatestream.API{} - backend := &mock.API{} subscription := &mockstatestream.Subscription{} ch := make(chan interface{}) @@ -242,7 +235,7 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) assert.NoError(s.T(), err) respRecorder := NewHijackResponseRecorder() - err = executeRequest(req, backend, stateStreamBackend, respRecorder) + executeWsRequest(req, stateStreamBackend, respRecorder) assert.NoError(s.T(), err) requireError(s.T(), respRecorder, "subscription channel closed") }) diff --git a/engine/access/rest/routes/test_helpers.go b/engine/access/rest/routes/test_helpers.go index 4526aba9a25..be2113ddd28 100644 --- a/engine/access/rest/routes/test_helpers.go +++ b/engine/access/rest/routes/test_helpers.go @@ -12,16 +12,14 @@ import ( "testing" "time" - "github.com/gorilla/mux" - "github.com/rs/zerolog" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/onflow/flow-go/access" "github.com/onflow/flow-go/access/mock" "github.com/onflow/flow-go/engine/access/state_stream" - mock_state_stream "github.com/onflow/flow-go/engine/access/state_stream/mock" "github.com/onflow/flow-go/model/flow" "github.com/onflow/flow-go/module/metrics" + "github.com/onflow/flow-go/utils/unittest" ) const ( @@ -38,11 +36,14 @@ const ( contractsQueryParams = "contracts" ) +// fakeNetConn implements a mocked ws connection that can be injected in testing logic. type fakeNetConn struct { io.Writer closed chan struct{} } +var _ net.Conn = (*fakeNetConn)(nil) + // Close closes the fakeNetConn and signals its closure by closing the "closed" channel. func (c fakeNetConn) Close() error { select { @@ -103,42 +104,37 @@ func NewHijackResponseRecorder() *HijackResponseRecorder { } } -func newRouter(backend *mock.API, stateStreamApi *mock_state_stream.API) (*mux.Router, error) { - var b bytes.Buffer - logger := zerolog.New(&b) - restCollector := metrics.NewNoopCollector() +func executeRequest(req *http.Request, backend access.API) *httptest.ResponseRecorder { + router := NewRouterBuilder( + unittest.Logger(), + metrics.NewNoopCollector(), + ).AddRestRoutes( + backend, + flow.Testnet.Chain(), + ).Build() - stateStreamConfig := state_stream.Config{ - EventFilterConfig: state_stream.DefaultEventFilterConfig, - MaxGlobalStreams: state_stream.DefaultMaxGlobalStreams, - } + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + return rr +} - return NewRouter(backend, - logger, +func executeWsRequest(req *http.Request, stateStreamApi state_stream.API, responseRecorder *HijackResponseRecorder) { + restCollector := metrics.NewNoopCollector() + router := NewRouterBuilder(unittest.Logger(), restCollector).AddWsRoutes( flow.Testnet.Chain(), - restCollector, stateStreamApi, - stateStreamConfig.EventFilterConfig, - stateStreamConfig.MaxGlobalStreams) -} - -func executeRequest(req *http.Request, backend *mock.API, stateStreamApi *mock_state_stream.API, responseRecorder *HijackResponseRecorder) error { - router, err := newRouter(backend, stateStreamApi) - if err != nil { - return err - } + state_stream.DefaultEventFilterConfig, + state_stream.DefaultMaxGlobalStreams, + ).Build() router.ServeHTTP(responseRecorder, req) - return nil } -func assertOKResponse(t *testing.T, req *http.Request, expectedRespBody string, backend *mock.API, stateStreamApi *mock_state_stream.API) { - assertResponse(t, req, http.StatusOK, expectedRespBody, backend, stateStreamApi) +func assertOKResponse(t *testing.T, req *http.Request, expectedRespBody string, backend *mock.API) { + assertResponse(t, req, http.StatusOK, expectedRespBody, backend) } -func assertResponse(t *testing.T, req *http.Request, status int, expectedRespBody string, backend *mock.API, stateStreamApi *mock_state_stream.API) { - rr := NewHijackResponseRecorder() - err := executeRequest(req, backend, stateStreamApi, rr) - assert.NoError(t, err) +func assertResponse(t *testing.T, req *http.Request, status int, expectedRespBody string, backend *mock.API) { + rr := executeRequest(req, backend) actualResponseBody := rr.Body.String() require.JSONEq(t, expectedRespBody, diff --git a/engine/access/rest/routes/transactions_test.go b/engine/access/rest/routes/transactions_test.go index 671324ba29e..c19dd096b95 100644 --- a/engine/access/rest/routes/transactions_test.go +++ b/engine/access/rest/routes/transactions_test.go @@ -112,7 +112,7 @@ func TestGetTransactions(t *testing.T) { }`, tx.ID(), tx.ReferenceBlockID, util.ToBase64(tx.EnvelopeSignatures[0].Signature), tx.ID(), tx.ID()) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) }) t.Run("Get by ID with results", func(t *testing.T) { @@ -182,7 +182,7 @@ func TestGetTransactions(t *testing.T) { } }`, tx.ID(), tx.ReferenceBlockID, util.ToBase64(tx.EnvelopeSignatures[0].Signature), tx.ReferenceBlockID, txr.CollectionID, tx.ID(), tx.ID(), tx.ID()) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) }) t.Run("get by ID Invalid", func(t *testing.T) { @@ -190,7 +190,7 @@ func TestGetTransactions(t *testing.T) { req := getTransactionReq("invalid", false, "", "") expected := `{"code":400, "message":"invalid ID format"}` - assertResponse(t, req, http.StatusBadRequest, expected, backend, nil) + assertResponse(t, req, http.StatusBadRequest, expected, backend) }) t.Run("get by ID non-existing", func(t *testing.T) { @@ -204,7 +204,7 @@ func TestGetTransactions(t *testing.T) { Return(nil, status.Error(codes.NotFound, "transaction not found")) expected := `{"code":404, "message":"Flow resource not found: transaction not found"}` - assertResponse(t, req, http.StatusNotFound, expected, backend, nil) + assertResponse(t, req, http.StatusNotFound, expected, backend) }) } @@ -253,7 +253,7 @@ func TestGetTransactionResult(t *testing.T) { On("GetTransactionResult", mocks.Anything, id, flow.ZeroID, flow.ZeroID). Return(txr, nil) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) }) t.Run("get by block ID", func(t *testing.T) { @@ -265,7 +265,7 @@ func TestGetTransactionResult(t *testing.T) { On("GetTransactionResult", mocks.Anything, id, bid, flow.ZeroID). Return(txr, nil) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) }) t.Run("get by collection ID", func(t *testing.T) { @@ -276,7 +276,7 @@ func TestGetTransactionResult(t *testing.T) { On("GetTransactionResult", mocks.Anything, id, flow.ZeroID, cid). Return(txr, nil) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) }) t.Run("get execution statuses", func(t *testing.T) { @@ -324,7 +324,7 @@ func TestGetTransactionResult(t *testing.T) { "_self": "/v1/transaction_results/%s" } }`, bid.String(), cid.String(), err, cases.Title(language.English).String(strings.ToLower(txResult.Status.String())), txResult.ErrorMessage, id.String()) - assertOKResponse(t, req, expectedResp, backend, nil) + assertOKResponse(t, req, expectedResp, backend) } }) @@ -334,7 +334,7 @@ func TestGetTransactionResult(t *testing.T) { req := getTransactionResultReq("invalid", "", "") expected := `{"code":400, "message":"invalid ID format"}` - assertResponse(t, req, http.StatusBadRequest, expected, backend, nil) + assertResponse(t, req, http.StatusBadRequest, expected, backend) }) } @@ -389,7 +389,7 @@ func TestCreateTransaction(t *testing.T) { } }`, tx.ID(), tx.ReferenceBlockID, util.ToBase64(tx.PayloadSignatures[0].Signature), util.ToBase64(tx.EnvelopeSignatures[0].Signature), tx.ID(), tx.ID()) - assertOKResponse(t, req, expected, backend, nil) + assertOKResponse(t, req, expected, backend) }) t.Run("post invalid transaction", func(t *testing.T) { @@ -416,7 +416,7 @@ func TestCreateTransaction(t *testing.T) { testTx[test.inputField] = test.inputValue req := createTransactionReq(testTx) - assertResponse(t, req, http.StatusBadRequest, test.output, backend, nil) + assertResponse(t, req, http.StatusBadRequest, test.output, backend) } }) } diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 832f3d1a747..2e21b479c76 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -5,11 +5,11 @@ import ( "errors" "fmt" "net/http" - "sync/atomic" "time" "github.com/gorilla/websocket" "github.com/rs/zerolog" + "go.uber.org/atomic" "github.com/onflow/flow-go/engine/access/rest/models" "github.com/onflow/flow-go/engine/access/rest/request" @@ -36,7 +36,7 @@ type WebsocketController struct { api state_stream.API // the state_stream.API instance for managing event subscriptions eventFilterConfig state_stream.EventFilterConfig // the configuration for filtering events maxStreams int32 // the maximum number of streams allowed - streamCount *atomic.Int32 // the current number of active streams + activeStreamCount *atomic.Int32 // the current number of active streams readChannel chan struct{} // channel which notify closing connection by the client } @@ -103,7 +103,7 @@ func (wsController *WebsocketController) wsErrorHandler(err error) { } } -// writeEvents use for writes events and pings to the WebSocket connection for a given subscription. +// writeEvents is used for writing events and pings to the WebSocket connection for a given subscription. // It listens to the subscription's channel for events and writes them to the WebSocket connection. // If an error occurs or the subscription channel is closed, it handles the error or termination accordingly. // The function uses a ticker to periodically send ping messages to the client to maintain the connection. @@ -114,6 +114,8 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti for { select { case _, ok := <-wsController.readChannel: + // we use `readChannel` as indicator of client's status, when `readChannel` closes it means that client + // connection has been terminated and we need to stop this goroutine to avoid memory leak. if !ok { return } @@ -196,9 +198,11 @@ type WSHandler struct { api state_stream.API eventFilterConfig state_stream.EventFilterConfig maxStreams int32 - streamCount *atomic.Int32 + activeStreamCount *atomic.Int32 } +var _ http.Handler = (*WSHandler)(nil) + func NewWSHandler( logger zerolog.Logger, subscribeFunc SubscribeHandlerFunc, @@ -212,7 +216,7 @@ func NewWSHandler( api: api, eventFilterConfig: eventFilterConfig, maxStreams: int32(maxGlobalStreams), - streamCount: &atomic.Int32{}, + activeStreamCount: atomic.NewInt32(0), HttpHandler: NewHttpHandler(logger, chain), } @@ -245,7 +249,7 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { api: h.api, eventFilterConfig: h.eventFilterConfig, maxStreams: h.maxStreams, - streamCount: h.streamCount, + activeStreamCount: h.activeStreamCount, readChannel: make(chan struct{}), } @@ -255,13 +259,13 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if wsController.streamCount.Load() >= wsController.maxStreams { + if wsController.activeStreamCount.Load() >= wsController.maxStreams { err := fmt.Errorf("maximum number of streams reached") wsController.wsErrorHandler(models.NewRestError(http.StatusServiceUnavailable, err.Error(), err)) return } - wsController.streamCount.Add(1) - defer wsController.streamCount.Add(-1) + wsController.activeStreamCount.Add(1) + defer wsController.activeStreamCount.Add(-1) // cancelling the context passed into the `subscribeFunc` to ensure when the client disconnect it's time the shutdown // gorountines setup by the backend are cleaned up if the client disconnects first. diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index 20613cf018f..eb9c7ed239a 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -42,9 +42,9 @@ func NewServer(serverAPI access.API, eventFilterConfig state_stream.EventFilterConfig, maxGlobalStreams uint32, ) (*http.Server, error) { - router, err := routes.NewRouter(serverAPI, logger, chain, restCollector, stateStreamApi, eventFilterConfig, maxGlobalStreams) - if err != nil { - return nil, err + builder := routes.NewRouterBuilder(logger, restCollector).AddRestRoutes(serverAPI, chain) + if stateStreamApi != nil { + builder.AddWsRoutes(chain, stateStreamApi, eventFilterConfig, maxGlobalStreams) } c := cors.New(cors.Options{ @@ -58,7 +58,7 @@ func NewServer(serverAPI access.API, }) return &http.Server{ - Handler: c.Handler(router), + Handler: c.Handler(builder.Build()), Addr: config.ListenAddress, WriteTimeout: config.WriteTimeout, ReadTimeout: config.ReadTimeout, diff --git a/integration/tests/access/rest_state_stream_test.go b/integration/tests/access/rest_state_stream_test.go index dbcac0115c5..a2ef3441e3c 100644 --- a/integration/tests/access/rest_state_stream_test.go +++ b/integration/tests/access/rest_state_stream_test.go @@ -121,8 +121,6 @@ func (s *RestStateStreamSuite) TestRestEventStreaming() { time.Sleep(10 * time.Second) // close connection after 10 seconds client.Close() - // check events - s.requireEvents(receivedEventsResponse) }() eventChan := make(chan *state_stream.EventsResponse) @@ -142,16 +140,12 @@ func (s *RestStateStreamSuite) TestRestEventStreaming() { }() // collect received events during 10 seconds - for { - select { - case eventResponse, ok := <-eventChan: - // Event channel closed - if !ok { - return - } - receivedEventsResponse = append(receivedEventsResponse, eventResponse) - } + for eventResponse := range eventChan { + receivedEventsResponse = append(receivedEventsResponse, eventResponse) } + + // check events + s.requireEvents(receivedEventsResponse) }) } From 47dc938b1ba0102dabaa0abb54ac2c09e803f366 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Thu, 14 Sep 2023 16:45:07 +0300 Subject: [PATCH 33/35] Updated according to commits --- engine/access/rest/routes/router.go | 4 +- engine/access/rest/routes/subscribe_events.go | 2 - .../rest/routes/subscribe_events_test.go | 76 +++++++++++-------- engine/access/rest/routes/test_helpers.go | 16 ++-- .../access/rest/routes/websocket_handler.go | 39 ++++++---- engine/access/rest/server.go | 2 +- .../tests/access/rest_state_stream_test.go | 9 ++- 7 files changed, 85 insertions(+), 63 deletions(-) diff --git a/engine/access/rest/routes/router.go b/engine/access/rest/routes/router.go index e7928cd8de3..8df7bdd5bb3 100644 --- a/engine/access/rest/routes/router.go +++ b/engine/access/rest/routes/router.go @@ -56,13 +56,13 @@ func (b *RouterBuilder) AddRestRoutes(backend access.API, chain flow.Chain) *Rou } func (b *RouterBuilder) AddWsRoutes( - chain flow.Chain, stateStreamApi state_stream.API, + chain flow.Chain, eventFilterConfig state_stream.EventFilterConfig, maxGlobalStreams uint32) *RouterBuilder { for _, r := range WSRoutes { - h := NewWSHandler(b.logger, r.Handler, chain, stateStreamApi, eventFilterConfig, maxGlobalStreams) + h := NewWSHandler(b.logger, r.Handler, stateStreamApi, chain, eventFilterConfig, maxGlobalStreams) b.v1SubRouter. Methods(r.Method). Path(r.Pattern). diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index e9d29c92c3e..d092993b6f9 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -2,7 +2,6 @@ package routes import ( "context" - "fmt" "github.com/onflow/flow-go/engine/access/rest/models" "github.com/onflow/flow-go/engine/access/rest/request" @@ -27,7 +26,6 @@ func SubscribeEvents( req.Contracts, ) if err != nil { - err := fmt.Errorf("invalid event filter") return nil, models.NewBadRequestError(err) } diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index 4b058b44efd..fc38e8bbd7a 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -14,7 +14,6 @@ import ( "golang.org/x/exp/slices" - "github.com/stretchr/testify/assert" mocks "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -113,10 +112,10 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { for _, test := range tests { s.Run(test.name, func() { stateStreamBackend := mockstatestream.NewAPI(s.T()) - subscription := &mockstatestream.Subscription{} + subscription := mockstatestream.NewSubscription(s.T()) filter, err := state_stream.NewEventFilter(state_stream.DefaultEventFilterConfig, chain, test.eventTypes, test.addresses, test.contracts) - assert.NoError(s.T(), err) + require.NoError(s.T(), err) var expectedEventsResponses []*state_stream.EventsResponse startBlockFound := test.startBlockID == flow.ZeroID @@ -156,7 +155,6 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { chReadOnly = ch subscription.Mock.On("Channel").Return(chReadOnly) - subscription.Mock.On("Err").Return(nil) var startHeight uint64 if test.startHeight == request.EmptyHeight { @@ -167,15 +165,14 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, test.startBlockID, startHeight, filter).Return(subscription) req, err := getSubscribeEventsRequest(s.T(), test.startBlockID, test.startHeight, test.eventTypes, test.addresses, test.contracts) - assert.NoError(s.T(), err) - respRecorder := NewHijackResponseRecorder() - // closing the connection after 5 seconds + require.NoError(s.T(), err) + respRecorder := newTestHijackResponseRecorder() + // closing the connection after 1 second go func() { - time.Sleep(5 * time.Second) + time.Sleep(1 * time.Second) close(respRecorder.closed) }() executeWsRequest(req, stateStreamBackend, respRecorder) - assert.NoError(s.T(), err) requireResponse(s.T(), respRecorder, expectedEventsResponses) }) } @@ -183,19 +180,18 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { s.Run("returns error for block id and height", func() { - stateStreamBackend := &mockstatestream.API{} + stateStreamBackend := mockstatestream.NewAPI(s.T()) req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil) - assert.NoError(s.T(), err) - respRecorder := NewHijackResponseRecorder() + require.NoError(s.T(), err) + respRecorder := newTestHijackResponseRecorder() executeWsRequest(req, stateStreamBackend, respRecorder) - assert.NoError(s.T(), err) requireError(s.T(), respRecorder, "can only provide either block ID or start height") }) s.Run("returns error for invalid block id", func() { - stateStreamBackend := &mockstatestream.API{} + stateStreamBackend := mockstatestream.NewAPI(s.T()) invalidBlock := unittest.BlockFixture() - subscription := &mockstatestream.Subscription{} + subscription := mockstatestream.NewSubscription(s.T()) ch := make(chan interface{}) var chReadOnly <-chan interface{} @@ -209,16 +205,24 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, invalidBlock.ID(), uint64(0), mocks.Anything).Return(subscription) req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil) - assert.NoError(s.T(), err) - respRecorder := NewHijackResponseRecorder() + require.NoError(s.T(), err) + respRecorder := newTestHijackResponseRecorder() executeWsRequest(req, stateStreamBackend, respRecorder) - assert.NoError(s.T(), err) requireError(s.T(), respRecorder, "stream encountered an error: subscription error") }) + s.Run("returns error for invalid event filter", func() { + stateStreamBackend := mockstatestream.NewAPI(s.T()) + req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, []string{"foo"}, nil, nil) + require.NoError(s.T(), err) + respRecorder := newTestHijackResponseRecorder() + executeWsRequest(req, stateStreamBackend, respRecorder) + requireError(s.T(), respRecorder, "invalid event type format") + }) + s.Run("returns error when channel closed", func() { - stateStreamBackend := &mockstatestream.API{} - subscription := &mockstatestream.Subscription{} + stateStreamBackend := mockstatestream.NewAPI(s.T()) + subscription := mockstatestream.NewSubscription(s.T()) ch := make(chan interface{}) var chReadOnly <-chan interface{} @@ -233,10 +237,9 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), uint64(0), mocks.Anything).Return(subscription) req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) - assert.NoError(s.T(), err) - respRecorder := NewHijackResponseRecorder() + require.NoError(s.T(), err) + respRecorder := newTestHijackResponseRecorder() executeWsRequest(req, stateStreamBackend, respRecorder) - assert.NoError(s.T(), err) requireError(s.T(), respRecorder, "subscription channel closed") }) } @@ -290,12 +293,12 @@ func generateWebSocketKey() (string, error) { return base64.StdEncoding.EncodeToString(keyBytes), nil } -func requireError(t *testing.T, recorder *HijackResponseRecorder, expected string) { +func requireError(t *testing.T, recorder *testHijackResponseRecorder, expected string) { <-recorder.closed require.Contains(t, recorder.responseBuff.String(), expected) } -func requireResponse(t *testing.T, recorder *HijackResponseRecorder, expected []*state_stream.EventsResponse) { +func requireResponse(t *testing.T, recorder *testHijackResponseRecorder, expected []*state_stream.EventsResponse) { <-recorder.closed // Convert the actual response from respRecorder to JSON bytes actualJSON := recorder.responseBuff.Bytes() @@ -313,14 +316,23 @@ func requireResponse(t *testing.T, recorder *HijackResponseRecorder, expected [] } // Compare the count of expected and actual responses - assert.Equal(t, len(expected), len(actual)) + require.Equal(t, len(expected), len(actual)) // Compare the BlockID and Events count for each response - for i := 0; i < len(expected); i++ { - expected := expected[i] - actual := actual[i] - - assert.Equal(t, expected.BlockID, actual.BlockID) - assert.Equal(t, len(expected.Events), len(actual.Events)) + for responseIndex := range expected { + expectedEventsResponse := expected[responseIndex] + actualEventsResponse := actual[responseIndex] + + require.Equal(t, expectedEventsResponse.BlockID, actualEventsResponse.BlockID) + require.Equal(t, len(expectedEventsResponse.Events), len(actualEventsResponse.Events)) + + for eventIndex, expectedEvent := range expectedEventsResponse.Events { + actualEvent := actualEventsResponse.Events[eventIndex] + require.Equal(t, expectedEvent.Type, actualEvent.Type) + require.Equal(t, expectedEvent.TransactionID, actualEvent.TransactionID) + require.Equal(t, expectedEvent.TransactionIndex, actualEvent.TransactionIndex) + require.Equal(t, expectedEvent.EventIndex, actualEvent.EventIndex) + require.Equal(t, expectedEvent.Payload, actualEvent.Payload) + } } } diff --git a/engine/access/rest/routes/test_helpers.go b/engine/access/rest/routes/test_helpers.go index be2113ddd28..3ad13996d74 100644 --- a/engine/access/rest/routes/test_helpers.go +++ b/engine/access/rest/routes/test_helpers.go @@ -78,9 +78,9 @@ func (a fakeAddr) String() string { return "str" } -// HijackResponseRecorder is a custom ResponseRecorder that implements the http.Hijacker interface +// testHijackResponseRecorder is a custom ResponseRecorder that implements the http.Hijacker interface // for testing WebSocket connections and hijacking. -type HijackResponseRecorder struct { +type testHijackResponseRecorder struct { *httptest.ResponseRecorder closed chan struct{} responseBuff *bytes.Buffer @@ -88,7 +88,7 @@ type HijackResponseRecorder struct { // Hijack implements the http.Hijacker interface by returning a fakeNetConn and a bufio.ReadWriter // that simulate a hijacked connection. -func (w *HijackResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (w *testHijackResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { br := bufio.NewReaderSize(strings.NewReader(""), state_stream.DefaultSendBufferSize) bw := bufio.NewWriterSize(&bytes.Buffer{}, state_stream.DefaultSendBufferSize) w.responseBuff = bytes.NewBuffer(make([]byte, 0)) @@ -97,9 +97,9 @@ func (w *HijackResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return fakeNetConn{w.responseBuff, w.closed}, bufio.NewReadWriter(br, bw), nil } -// NewHijackResponseRecorder creates a new instance of HijackResponseRecorder. -func NewHijackResponseRecorder() *HijackResponseRecorder { - return &HijackResponseRecorder{ +// newTestHijackResponseRecorder creates a new instance of testHijackResponseRecorder. +func newTestHijackResponseRecorder() *testHijackResponseRecorder { + return &testHijackResponseRecorder{ ResponseRecorder: httptest.NewRecorder(), } } @@ -118,11 +118,11 @@ func executeRequest(req *http.Request, backend access.API) *httptest.ResponseRec return rr } -func executeWsRequest(req *http.Request, stateStreamApi state_stream.API, responseRecorder *HijackResponseRecorder) { +func executeWsRequest(req *http.Request, stateStreamApi state_stream.API, responseRecorder *testHijackResponseRecorder) { restCollector := metrics.NewNoopCollector() router := NewRouterBuilder(unittest.Logger(), restCollector).AddWsRoutes( - flow.Testnet.Chain(), stateStreamApi, + flow.Testnet.Chain(), state_stream.DefaultEventFilterConfig, state_stream.DefaultMaxGlobalStreams, ).Build() diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 2e21b479c76..813f4da1955 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -37,7 +37,7 @@ type WebsocketController struct { eventFilterConfig state_stream.EventFilterConfig // the configuration for filtering events maxStreams int32 // the maximum number of streams allowed activeStreamCount *atomic.Int32 // the current number of active streams - readChannel chan struct{} // channel which notify closing connection by the client + readChannel chan error // channel which notify closing connection by the client and provide errors to the client } // SetWebsocketConf used to set read and write deadlines for WebSocket connections and establishes a Pong handler to @@ -113,12 +113,16 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti for { select { - case _, ok := <-wsController.readChannel: - // we use `readChannel` as indicator of client's status, when `readChannel` closes it means that client + case err, ok := <-wsController.readChannel: + // we use `readChannel` + // 1) as indicator of client's status, when `readChannel` closes it means that client // connection has been terminated and we need to stop this goroutine to avoid memory leak. - if !ok { - return + // 2) as error receiver for any errors that occur during the reading process + if ok { + wsController.wsErrorHandler(err) } + return + case event, ok := <-sub.Channel(): if !ok { if sub.Err() != nil { @@ -132,7 +136,7 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti } err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err != nil { - wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline error: ", err)) + wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline: ", err)) return } // Write the response to the WebSocket connection @@ -144,7 +148,7 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti case <-ticker.C: err := wsController.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err != nil { - wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline error: ", err)) + wsController.wsErrorHandler(models.NewRestError(http.StatusInternalServerError, "failed to set the initial write deadline: ", err)) return } if err := wsController.conn.WriteMessage(websocket.PingMessage, nil); err != nil { @@ -167,16 +171,22 @@ func (wsController *WebsocketController) read() { defer close(wsController.readChannel) // notify websocket about closed connection for { - // reads messages from the WebSocket connection when the connection is closed by client or when an - // 1) when the connection is closed by client - // 2) when an any additional message is received from the client + // reads messages from the WebSocket connection when + // 1) the connection is closed by client + // 2) a message is received from the client _, msg, err := wsController.conn.ReadMessage() if err != nil { + if _, ok := err.(*websocket.CloseError); !ok { + wsController.readChannel <- err + } return } // Check the message from the client, if is any just close the connection if len(msg) > 0 { + err := fmt.Errorf("the client sent an unexpected message, connection closed") + wsController.logger.Debug().Err(err) + wsController.readChannel <- err return } } @@ -206,8 +216,8 @@ var _ http.Handler = (*WSHandler)(nil) func NewWSHandler( logger zerolog.Logger, subscribeFunc SubscribeHandlerFunc, - chain flow.Chain, api state_stream.API, + chain flow.Chain, eventFilterConfig state_stream.EventFilterConfig, maxGlobalStreams uint32, ) *WSHandler { @@ -231,6 +241,7 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { err := h.VerifyRequest(w, r) if err != nil { + // VerifyRequest sets the response error before returning return } @@ -250,7 +261,7 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { eventFilterConfig: h.eventFilterConfig, maxStreams: h.maxStreams, activeStreamCount: h.activeStreamCount, - readChannel: make(chan struct{}), + readChannel: make(chan error), } err = wsController.SetWebsocketConf() @@ -267,8 +278,8 @@ func (h *WSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { wsController.activeStreamCount.Add(1) defer wsController.activeStreamCount.Add(-1) - // cancelling the context passed into the `subscribeFunc` to ensure when the client disconnect it's time the shutdown - // gorountines setup by the backend are cleaned up if the client disconnects first. + // cancelling the context passed into the `subscribeFunc` to ensure that when the client disconnects, + // gorountines setup by the backend are cleaned up. ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/engine/access/rest/server.go b/engine/access/rest/server.go index eb9c7ed239a..daa4983a517 100644 --- a/engine/access/rest/server.go +++ b/engine/access/rest/server.go @@ -44,7 +44,7 @@ func NewServer(serverAPI access.API, ) (*http.Server, error) { builder := routes.NewRouterBuilder(logger, restCollector).AddRestRoutes(serverAPI, chain) if stateStreamApi != nil { - builder.AddWsRoutes(chain, stateStreamApi, eventFilterConfig, maxGlobalStreams) + builder.AddWsRoutes(stateStreamApi, chain, eventFilterConfig, maxGlobalStreams) } c := cors.New(cors.Options{ diff --git a/integration/tests/access/rest_state_stream_test.go b/integration/tests/access/rest_state_stream_test.go index a2ef3441e3c..a89ce38f603 100644 --- a/integration/tests/access/rest_state_stream_test.go +++ b/integration/tests/access/rest_state_stream_test.go @@ -102,9 +102,6 @@ func (s *RestStateStreamSuite) SetupTest() { // TestRestEventStreaming tests event streaming route on REST func (s *RestStateStreamSuite) TestRestEventStreaming() { - ctx, cancel := context.WithTimeout(s.ctx, 1*time.Second) - defer cancel() - restAddr := s.net.ContainerByName(testnet.PrimaryAN).Addr(testnet.RESTPort) s.T().Run("subscribe events", func(t *testing.T) { @@ -112,7 +109,7 @@ func (s *RestStateStreamSuite) TestRestEventStreaming() { startHeight := uint64(0) url := getSubscribeEventsRequest(restAddr, startBlockId, startHeight, nil, nil, nil) - client, err := getWSClient(ctx, url) + client, err := getWSClient(s.ctx, url) require.NoError(t, err) var receivedEventsResponse []*state_stream.EventsResponse @@ -131,6 +128,7 @@ func (s *RestStateStreamSuite) TestRestEventStreaming() { if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { s.T().Logf("unexpected close error: %v", err) + require.NoError(s.T(), err) } close(eventChan) // Close the event channel when the client connection is closed return @@ -152,6 +150,9 @@ func (s *RestStateStreamSuite) TestRestEventStreaming() { // requireEvents is a helper function that encapsulates logic for comparing received events from rest state streaming and // events which received from grpc api func (s *RestStateStreamSuite) requireEvents(receivedEventsResponse []*state_stream.EventsResponse) { + // make sure there are received events + require.GreaterOrEqual(s.T(), len(receivedEventsResponse), 1, "expect received events") + grpcCtx, grpcCancel := context.WithCancel(s.ctx) defer grpcCancel() From b20cb94af61213e15d26ddaa966c14d51738bf31 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 15 Sep 2023 01:02:05 +0300 Subject: [PATCH 34/35] Updated according to last comments --- engine/access/rest/routes/router.go | 4 ++ engine/access/rest/routes/subscribe_events.go | 3 +- .../rest/routes/subscribe_events_test.go | 44 ++++++++++++++++--- engine/access/rest/routes/test_helpers.go | 2 + 4 files changed, 47 insertions(+), 6 deletions(-) diff --git a/engine/access/rest/routes/router.go b/engine/access/rest/routes/router.go index 8df7bdd5bb3..bf10d451115 100644 --- a/engine/access/rest/routes/router.go +++ b/engine/access/rest/routes/router.go @@ -17,12 +17,14 @@ import ( "github.com/onflow/flow-go/module" ) +// RouterBuilder is a utility for building HTTP routers with common middleware and routes. type RouterBuilder struct { logger zerolog.Logger router *mux.Router v1SubRouter *mux.Router } +// NewRouterBuilder creates a new RouterBuilder instance with common middleware and a v1 sub-router. func NewRouterBuilder( logger zerolog.Logger, restCollector module.RestMetrics) *RouterBuilder { @@ -42,6 +44,7 @@ func NewRouterBuilder( } } +// AddRestRoutes adds rest routes to the router. func (b *RouterBuilder) AddRestRoutes(backend access.API, chain flow.Chain) *RouterBuilder { linkGenerator := models.NewLinkGeneratorImpl(b.v1SubRouter) for _, r := range Routes { @@ -55,6 +58,7 @@ func (b *RouterBuilder) AddRestRoutes(backend access.API, chain flow.Chain) *Rou return b } +// AddWsRoutes adds WebSocket routes to the router. func (b *RouterBuilder) AddWsRoutes( stateStreamApi state_stream.API, chain flow.Chain, diff --git a/engine/access/rest/routes/subscribe_events.go b/engine/access/rest/routes/subscribe_events.go index d092993b6f9..fb275a68df9 100644 --- a/engine/access/rest/routes/subscribe_events.go +++ b/engine/access/rest/routes/subscribe_events.go @@ -12,7 +12,8 @@ import ( func SubscribeEvents( ctx context.Context, request *request.Request, - wsController *WebsocketController) (state_stream.Subscription, error) { + wsController *WebsocketController, +) (state_stream.Subscription, error) { req, err := request.SubscribeEventsRequest() if err != nil { return nil, models.NewBadRequestError(err) diff --git a/engine/access/rest/routes/subscribe_events_test.go b/engine/access/rest/routes/subscribe_events_test.go index fc38e8bbd7a..c2e071cd8a4 100644 --- a/engine/access/rest/routes/subscribe_events_test.go +++ b/engine/access/rest/routes/subscribe_events_test.go @@ -76,6 +76,20 @@ func (s *SubscribeEventsSuite) SetupTest() { } } +// TestSubscribeEvents is a happy cases tests for the SubscribeEvents functionality. +// This test function covers various scenarios for subscribing to events via WebSocket. +// +// It tests scenarios: +// - Subscribing to events from the root height. +// - Subscribing to events from a specific start height. +// - Subscribing to events from a specific start block ID. +// +// Every scenario covers the following aspects: +// - Subscribing to all events. +// - Subscribing to events of a specific type (some events). +// +// For each scenario, this test function creates WebSocket requests, simulates WebSocket responses with mock data, +// and validates that the received WebSocket response matches the expected EventsResponses. func (s *SubscribeEventsSuite) TestSubscribeEvents() { testVectors := []testType{ { @@ -114,7 +128,12 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { stateStreamBackend := mockstatestream.NewAPI(s.T()) subscription := mockstatestream.NewSubscription(s.T()) - filter, err := state_stream.NewEventFilter(state_stream.DefaultEventFilterConfig, chain, test.eventTypes, test.addresses, test.contracts) + filter, err := state_stream.NewEventFilter( + state_stream.DefaultEventFilterConfig, + chain, + test.eventTypes, + test.addresses, + test.contracts) require.NoError(s.T(), err) var expectedEventsResponses []*state_stream.EventsResponse @@ -162,7 +181,9 @@ func (s *SubscribeEventsSuite) TestSubscribeEvents() { } else { startHeight = test.startHeight } - stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, test.startBlockID, startHeight, filter).Return(subscription) + stateStreamBackend.Mock. + On("SubscribeEvents", mocks.Anything, test.startBlockID, startHeight, filter). + Return(subscription) req, err := getSubscribeEventsRequest(s.T(), test.startBlockID, test.startHeight, test.eventTypes, test.addresses, test.contracts) require.NoError(s.T(), err) @@ -202,7 +223,9 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { subscription.Mock.On("Channel").Return(chReadOnly) subscription.Mock.On("Err").Return(fmt.Errorf("subscription error")) - stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, invalidBlock.ID(), uint64(0), mocks.Anything).Return(subscription) + stateStreamBackend.Mock. + On("SubscribeEvents", mocks.Anything, invalidBlock.ID(), uint64(0), mocks.Anything). + Return(subscription) req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil) require.NoError(s.T(), err) @@ -234,7 +257,9 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { subscription.Mock.On("Channel").Return(chReadOnly) subscription.Mock.On("Err").Return(nil) - stateStreamBackend.Mock.On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), uint64(0), mocks.Anything).Return(subscription) + stateStreamBackend.Mock. + On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), uint64(0), mocks.Anything). + Return(subscription) req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil) require.NoError(s.T(), err) @@ -244,7 +269,13 @@ func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { }) } -func getSubscribeEventsRequest(t *testing.T, startBlockId flow.Identifier, startHeight uint64, eventTypes []string, addresses []string, contracts []string) (*http.Request, error) { +func getSubscribeEventsRequest(t *testing.T, + startBlockId flow.Identifier, + startHeight uint64, + eventTypes []string, + addresses []string, + contracts []string, +) (*http.Request, error) { u, _ := url.Parse("/v1/subscribe_events") q := u.Query() @@ -298,6 +329,9 @@ func requireError(t *testing.T, recorder *testHijackResponseRecorder, expected s require.Contains(t, recorder.responseBuff.String(), expected) } +// requireResponse validates that the response received from WebSocket communication matches the expected EventsResponses. +// This function compares the BlockID, Events count, and individual event properties for each expected and actual +// EventsResponse. It ensures that the response received from WebSocket matches the expected structure and content. func requireResponse(t *testing.T, recorder *testHijackResponseRecorder, expected []*state_stream.EventsResponse) { <-recorder.closed // Convert the actual response from respRecorder to JSON bytes diff --git a/engine/access/rest/routes/test_helpers.go b/engine/access/rest/routes/test_helpers.go index 3ad13996d74..6b967e45066 100644 --- a/engine/access/rest/routes/test_helpers.go +++ b/engine/access/rest/routes/test_helpers.go @@ -86,6 +86,8 @@ type testHijackResponseRecorder struct { responseBuff *bytes.Buffer } +var _ http.Hijacker = (*testHijackResponseRecorder)(nil) + // Hijack implements the http.Hijacker interface by returning a fakeNetConn and a bufio.ReadWriter // that simulate a hijacked connection. func (w *testHijackResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { From 303c02ff141ee840578785a7ff45d7a3a5d18926 Mon Sep 17 00:00:00 2001 From: UlyanaAndrukhiv Date: Fri, 15 Sep 2023 07:32:52 +0300 Subject: [PATCH 35/35] Added small fixes according to last comments --- engine/access/rest/routes/router.go | 2 +- engine/access/rest/routes/websocket_handler.go | 8 ++++---- integration/tests/access/rest_state_stream_test.go | 9 +++++++++ 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/engine/access/rest/routes/router.go b/engine/access/rest/routes/router.go index bf10d451115..c055cbf2ae3 100644 --- a/engine/access/rest/routes/router.go +++ b/engine/access/rest/routes/router.go @@ -66,7 +66,7 @@ func (b *RouterBuilder) AddWsRoutes( maxGlobalStreams uint32) *RouterBuilder { for _, r := range WSRoutes { - h := NewWSHandler(b.logger, r.Handler, stateStreamApi, chain, eventFilterConfig, maxGlobalStreams) + h := NewWSHandler(b.logger, stateStreamApi, r.Handler, chain, eventFilterConfig, maxGlobalStreams) b.v1SubRouter. Methods(r.Method). Path(r.Pattern). diff --git a/engine/access/rest/routes/websocket_handler.go b/engine/access/rest/routes/websocket_handler.go index 813f4da1955..5e09953e287 100644 --- a/engine/access/rest/routes/websocket_handler.go +++ b/engine/access/rest/routes/websocket_handler.go @@ -113,12 +113,12 @@ func (wsController *WebsocketController) writeEvents(sub state_stream.Subscripti for { select { - case err, ok := <-wsController.readChannel: + case err := <-wsController.readChannel: // we use `readChannel` // 1) as indicator of client's status, when `readChannel` closes it means that client // connection has been terminated and we need to stop this goroutine to avoid memory leak. // 2) as error receiver for any errors that occur during the reading process - if ok { + if err != nil { wsController.wsErrorHandler(err) } return @@ -185,7 +185,7 @@ func (wsController *WebsocketController) read() { // Check the message from the client, if is any just close the connection if len(msg) > 0 { err := fmt.Errorf("the client sent an unexpected message, connection closed") - wsController.logger.Debug().Err(err) + wsController.logger.Debug().Msg(err.Error()) wsController.readChannel <- err return } @@ -215,8 +215,8 @@ var _ http.Handler = (*WSHandler)(nil) func NewWSHandler( logger zerolog.Logger, - subscribeFunc SubscribeHandlerFunc, api state_stream.API, + subscribeFunc SubscribeHandlerFunc, chain flow.Chain, eventFilterConfig state_stream.EventFilterConfig, maxGlobalStreams uint32, diff --git a/integration/tests/access/rest_state_stream_test.go b/integration/tests/access/rest_state_stream_test.go index a89ce38f603..ae286b42a36 100644 --- a/integration/tests/access/rest_state_stream_test.go +++ b/integration/tests/access/rest_state_stream_test.go @@ -164,6 +164,8 @@ func (s *RestStateStreamSuite) requireEvents(receivedEventsResponse []*state_str grpcClient := accessproto.NewAccessAPIClient(grpcConn) + // Variable to keep track of non-empty event response count + nonEmptyResponseCount := 0 for _, receivedEventResponse := range receivedEventsResponse { // Create a map where key is EventType and value is list of events with this EventType receivedEventMap := make(map[flow.EventType][]flow.Event) @@ -188,8 +190,15 @@ func (s *RestStateStreamSuite) requireEvents(receivedEventsResponse []*state_str require.Equal(s.T(), expectedEventsResult.Events[i].EventIndex, event.EventIndex, "expect the same event index") require.Equal(s.T(), convert.MessageToIdentifier(expectedEventsResult.Events[i].TransactionId), event.TransactionID, "expect the same transaction id") } + + // Check if the current response has non-empty events + if len(receivedEventResponse.Events) > 0 { + nonEmptyResponseCount++ + } } } + // Ensure that at least one response had non-empty events + require.GreaterOrEqual(s.T(), nonEmptyResponseCount, 1, "expect at least one response with non-empty events") } // getWSClient is a helper function that creates a websocket client