Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CT-1321] subscribe to market prices streaming services #2592

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion protocol/streaming/full_node_streaming_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/dydxprotocol/v4-chain/protocol/lib"
pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types"
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"

"cosmossdk.io/log"
Expand Down Expand Up @@ -52,6 +53,8 @@ type FullNodeStreamingManagerImpl struct {
clobPairIdToSubscriptionIdMapping map[uint32][]uint32
// map from subaccount id to subscription ids.
subaccountIdToSubscriptionIdMapping map[satypes.SubaccountId][]uint32
// map from market id to subscription ids.
marketIdToSubscriptionIdMapping map[uint32][]uint32

maxUpdatesInCache uint32
maxSubscriptionChannelSize uint32
Expand Down Expand Up @@ -79,6 +82,9 @@ type OrderbookSubscription struct {
// Subaccount ids to subscribe to.
subaccountIds []satypes.SubaccountId

// market ids to subscribe to.
marketIds []uint32

// Stream
messageSender types.OutgoingMessageSender

Expand Down Expand Up @@ -114,6 +120,7 @@ func NewFullNodeStreamingManager(
streamUpdateSubscriptionCache: make([][]uint32, 0),
clobPairIdToSubscriptionIdMapping: make(map[uint32][]uint32),
subaccountIdToSubscriptionIdMapping: make(map[satypes.SubaccountId][]uint32),
marketIdToSubscriptionIdMapping: make(map[uint32][]uint32),

maxUpdatesInCache: maxUpdatesInCache,
maxSubscriptionChannelSize: maxSubscriptionChannelSize,
Expand Down Expand Up @@ -184,6 +191,7 @@ func (sm *FullNodeStreamingManagerImpl) getNextAvailableSubscriptionId() uint32
func (sm *FullNodeStreamingManagerImpl) Subscribe(
clobPairIds []uint32,
subaccountIds []*satypes.SubaccountId,
marketIds []uint32,
messageSender types.OutgoingMessageSender,
) (
err error,
Expand All @@ -206,6 +214,7 @@ func (sm *FullNodeStreamingManagerImpl) Subscribe(
initialized: &atomic.Bool{}, // False by default.
clobPairIds: clobPairIds,
subaccountIds: sIds,
marketIds: marketIds,
messageSender: messageSender,
updatesChannel: make(chan []clobtypes.StreamUpdate, sm.maxSubscriptionChannelSize),
}
Expand All @@ -231,6 +240,17 @@ func (sm *FullNodeStreamingManagerImpl) Subscribe(
subscription.subscriptionId,
)
}
for _, marketId := range marketIds {
// if subaccountId exists in the map, append the subscription id to the slice
// otherwise, create a new slice with the subscription id
if _, ok := sm.marketIdToSubscriptionIdMapping[marketId]; !ok {
sm.marketIdToSubscriptionIdMapping[marketId] = []uint32{}
}
sm.marketIdToSubscriptionIdMapping[marketId] = append(
sm.marketIdToSubscriptionIdMapping[marketId],
subscription.subscriptionId,
)
}

sm.logger.Info(
fmt.Sprintf(
Expand Down Expand Up @@ -325,6 +345,21 @@ func (sm *FullNodeStreamingManagerImpl) removeSubscription(
}
}

// Iterate over the marketIdToSubscriptionIdMapping to remove the subscriptionIdToRemove
for marketId, subscriptionIds := range sm.marketIdToSubscriptionIdMapping {
for i, id := range subscriptionIds {
if id == subscriptionIdToRemove {
// Remove the subscription ID from the slice
sm.marketIdToSubscriptionIdMapping[marketId] = append(subscriptionIds[:i], subscriptionIds[i+1:]...)
break
}
}
// If the list is empty after removal, delete the key from the map
if len(sm.marketIdToSubscriptionIdMapping[marketId]) == 0 {
delete(sm.marketIdToSubscriptionIdMapping, marketId)
}
}

sm.logger.Info(
fmt.Sprintf("Removed streaming subscription id %+v", subscriptionIdToRemove),
)
Expand Down Expand Up @@ -372,6 +407,24 @@ func toSubaccountStreamUpdates(
return streamUpdates
}

func toPriceStreamUpdates(
priceUpdates []*pricestypes.StreamPriceUpdate,
blockHeight uint32,
execMode sdk.ExecMode,
) []clobtypes.StreamUpdate {
streamUpdates := make([]clobtypes.StreamUpdate, 0)
for _, update := range priceUpdates {
streamUpdates = append(streamUpdates, clobtypes.StreamUpdate{
UpdateMessage: &clobtypes.StreamUpdate_PriceUpdate{
PriceUpdate: update,
},
BlockHeight: blockHeight,
ExecMode: uint32(execMode),
})
}
return streamUpdates
}

func (sm *FullNodeStreamingManagerImpl) sendStreamUpdates(
subscriptionId uint32,
streamUpdates []clobtypes.StreamUpdate,
Expand Down Expand Up @@ -466,6 +519,7 @@ func (sm *FullNodeStreamingManagerImpl) GetStagedFinalizeBlockEvents(
func (sm *FullNodeStreamingManagerImpl) SendCombinedSnapshot(
offchainUpdates *clobtypes.OffchainUpdates,
saUpdates []*satypes.StreamSubaccountUpdate,
priceUpdates []*pricestypes.StreamPriceUpdate,
subscriptionId uint32,
blockHeight uint32,
execMode sdk.ExecMode,
Expand All @@ -479,6 +533,7 @@ func (sm *FullNodeStreamingManagerImpl) SendCombinedSnapshot(
var streamUpdates []clobtypes.StreamUpdate
streamUpdates = append(streamUpdates, toOrderbookStreamUpdate(offchainUpdates, blockHeight, execMode)...)
streamUpdates = append(streamUpdates, toSubaccountStreamUpdates(saUpdates, blockHeight, execMode)...)
streamUpdates = append(streamUpdates, toPriceStreamUpdates(priceUpdates, blockHeight, execMode)...)
sm.sendStreamUpdates(subscriptionId, streamUpdates)
}

Expand Down Expand Up @@ -863,6 +918,30 @@ func (sm *FullNodeStreamingManagerImpl) GetSubaccountSnapshotsForInitStreams(
return ret
}

func (sm *FullNodeStreamingManagerImpl) GetPriceSnapshotsForInitStreams(
getPriceSnapshot func(marketId uint32) *pricestypes.StreamPriceUpdate,
) map[uint32]*pricestypes.StreamPriceUpdate {
sm.Lock()
defer sm.Unlock()

ret := make(map[uint32]*pricestypes.StreamPriceUpdate)
for _, subscription := range sm.orderbookSubscriptions {
// If the subscription has been initialized, no need to grab the price snapshot.
if alreadyInitialized := subscription.initialized.Load(); alreadyInitialized {
continue
}

for _, marketId := range subscription.marketIds {
if _, exists := ret[marketId]; exists {
continue
}

ret[marketId] = getPriceSnapshot(marketId)
}
}
return ret
}

// cacheStreamUpdatesByClobPairWithLock adds stream updates to cache,
// and store corresponding clob pair Ids.
// This method requires the lock and assumes that the lock has already been
Expand Down Expand Up @@ -1003,6 +1082,7 @@ func (sm *FullNodeStreamingManagerImpl) getStagedEventsFromFinalizeBlock(
func (sm *FullNodeStreamingManagerImpl) InitializeNewStreams(
getOrderbookSnapshot func(clobPairId clobtypes.ClobPairId) *clobtypes.OffchainUpdates,
subaccountSnapshots map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate,
pricesSnapshots map[uint32]*pricestypes.StreamPriceUpdate,
blockHeight uint32,
execMode sdk.ExecMode,
) {
Expand Down Expand Up @@ -1038,7 +1118,28 @@ func (sm *FullNodeStreamingManagerImpl) InitializeNewStreams(
}
}

sm.SendCombinedSnapshot(allUpdates, saUpdates, subscriptionId, blockHeight, execMode)
priceUpdates := []*pricestypes.StreamPriceUpdate{}
for _, marketId := range subscription.marketIds {
if priceUpdate, ok := pricesSnapshots[marketId]; ok {
priceUpdates = append(priceUpdates, priceUpdate)
} else {
sm.logger.Error(
fmt.Sprintf(
"Price update not found for market id %v. This should not happen.",
marketId,
),
)
}
}

Comment on lines +1124 to +1134
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Handle missing price updates more gracefully

When a price update is not found for a market ID, the code logs an error stating "This should not happen." Consider handling this scenario more gracefully by implementing appropriate error handling or fallback logic to ensure robustness.

sm.SendCombinedSnapshot(
allUpdates,
saUpdates,
priceUpdates,
subscriptionId,
blockHeight,
execMode,
)

if sm.snapshotBlockInterval != 0 {
subscription.nextSnapshotBlock = blockHeight + sm.snapshotBlockInterval
Expand Down
9 changes: 9 additions & 0 deletions protocol/streaming/noop_streaming_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/dydxprotocol/v4-chain/protocol/streaming/types"
clobtypes "github.com/dydxprotocol/v4-chain/protocol/x/clob/types"
pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types"
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"
)

Expand All @@ -22,6 +23,7 @@ func (sm *NoopGrpcStreamingManager) Enabled() bool {
func (sm *NoopGrpcStreamingManager) Subscribe(
_ []uint32,
_ []*satypes.SubaccountId,
_ []uint32,
_ types.OutgoingMessageSender,
) (
err error,
Expand Down Expand Up @@ -58,9 +60,16 @@ func (sm *NoopGrpcStreamingManager) GetSubaccountSnapshotsForInitStreams(
return nil
}

func (sm *NoopGrpcStreamingManager) GetPriceSnapshotsForInitStreams(
_ func(_ uint32) *pricestypes.StreamPriceUpdate,
) map[uint32]*pricestypes.StreamPriceUpdate {
return nil
}

func (sm *NoopGrpcStreamingManager) InitializeNewStreams(
getOrderbookSnapshot func(clobPairId clobtypes.ClobPairId) *clobtypes.OffchainUpdates,
subaccountSnapshots map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate,
priceSnapshots map[uint32]*pricestypes.StreamPriceUpdate,
blockHeight uint32,
execMode sdk.ExecMode,
) {
Expand Down
6 changes: 6 additions & 0 deletions protocol/streaming/types/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package types
import (
sdk "github.com/cosmos/cosmos-sdk/types"
clobtypes "github.com/dydxprotocol/v4-chain/protocol/x/clob/types"
pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types"
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types"
)

Expand All @@ -14,6 +15,7 @@ type FullNodeStreamingManager interface {
Subscribe(
clobPairIds []uint32,
subaccountIds []*satypes.SubaccountId,
marketIds []uint32,
srv OutgoingMessageSender,
) (
err error,
Expand All @@ -23,12 +25,16 @@ type FullNodeStreamingManager interface {
InitializeNewStreams(
getOrderbookSnapshot func(clobPairId clobtypes.ClobPairId) *clobtypes.OffchainUpdates,
subaccountSnapshots map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate,
priceSnapshots map[uint32]*pricestypes.StreamPriceUpdate,
blockHeight uint32,
execMode sdk.ExecMode,
)
GetSubaccountSnapshotsForInitStreams(
getSubaccountSnapshot func(subaccountId satypes.SubaccountId) *satypes.StreamSubaccountUpdate,
) map[satypes.SubaccountId]*satypes.StreamSubaccountUpdate
GetPriceSnapshotsForInitStreams(
getPriceSnapshot func(marketId uint32) *pricestypes.StreamPriceUpdate,
) map[uint32]*pricestypes.StreamPriceUpdate
SendOrderbookUpdates(
offchainUpdates *clobtypes.OffchainUpdates,
ctx sdk.Context,
Expand Down
40 changes: 29 additions & 11 deletions protocol/streaming/ws/websocket_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ import (
"github.com/gorilla/websocket"
)

const (
CLOB_PAIR_IDS_QUERY_PARAM = "clobPairIds"
MARKET_IDS_QUERY_PARAM = "marketIds"
)

var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
Expand Down Expand Up @@ -61,7 +66,7 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) {
conn.SetReadLimit(10 * 1024 * 1024)

// Parse clobPairIds from query parameters
clobPairIds, err := parseClobPairIds(r)
clobPairIds, err := parseUint32(r, CLOB_PAIR_IDS_QUERY_PARAM)
if err != nil {
ws.logger.Error(
"Error parsing clobPairIds",
Expand All @@ -70,6 +75,18 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

// Parse marketIds from query parameters
marketIds, err := parseUint32(r, MARKET_IDS_QUERY_PARAM)
if err != nil {
ws.logger.Error(
"Error parsing marketIds",
"err", err,
)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

// Parse subaccountIds from query parameters
subaccountIds, err := parseSubaccountIds(r)
if err != nil {
Expand All @@ -93,6 +110,7 @@ func (ws *WebsocketServer) Handler(w http.ResponseWriter, r *http.Request) {
err = ws.streamingManager.Subscribe(
clobPairIds,
subaccountIds,
marketIds,
websocketMessageSender,
)
if err != nil {
Expand Down Expand Up @@ -136,26 +154,26 @@ func parseSubaccountIds(r *http.Request) ([]*satypes.SubaccountId, error) {
return subaccountIds, nil
}

// parseClobPairIds is a helper function to parse the clobPairIds from the query parameters.
func parseClobPairIds(r *http.Request) ([]uint32, error) {
clobPairIdsParam := r.URL.Query().Get("clobPairIds")
if clobPairIdsParam == "" {
// parseUint32 is a helper function to parse the uint32 from the query parameters.
func parseUint32(r *http.Request, queryParam string) ([]uint32, error) {
param := r.URL.Query().Get(queryParam)
if param == "" {
return []uint32{}, nil
}
idStrs := strings.Split(clobPairIdsParam, ",")
clobPairIds := make([]uint32, 0)
idStrs := strings.Split(param, ",")
ids := make([]uint32, 0)
for _, idStr := range idStrs {
id, err := strconv.Atoi(idStr)
if err != nil {
return nil, fmt.Errorf("invalid clobPairId: %s", idStr)
return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr)
}
if id < 0 || id > math.MaxInt32 {
return nil, fmt.Errorf("invalid clob pair id: %s", idStr)
return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr)
}
clobPairIds = append(clobPairIds, uint32(id))
ids = append(ids, uint32(id))
}

return clobPairIds, nil
return ids, nil
Comment on lines +157 to +176
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider additional input validation in parseUint32.

While the function handles basic parsing and range validation, consider adding the following safeguards:

  1. Duplicate ID validation
  2. Maximum input length check
  3. Empty string element validation after split
  4. Explicit overflow handling
 func parseUint32(r *http.Request, queryParam string) ([]uint32, error) {
 	param := r.URL.Query().Get(queryParam)
 	if param == "" {
 		return []uint32{}, nil
 	}
+	// Add maximum length check
+	if len(param) > 1000 {
+		return nil, fmt.Errorf("%s parameter too long: max 1000 characters", queryParam)
+	}
+
 	idStrs := strings.Split(param, ",")
+	// Add maximum number of IDs check
+	if len(idStrs) > 100 {
+		return nil, fmt.Errorf("too many %s values: max 100", queryParam)
+	}
+
 	ids := make([]uint32, 0)
+	seen := make(map[uint32]bool)
 	for _, idStr := range idStrs {
+		// Check for empty elements
+		if idStr == "" {
+			return nil, fmt.Errorf("empty value in %s list", queryParam)
+		}
+
 		id, err := strconv.Atoi(idStr)
 		if err != nil {
 			return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr)
 		}
 		if id < 0 || id > math.MaxInt32 {
 			return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr)
 		}
-		ids = append(ids, uint32(id))
+		
+		// Check for duplicates
+		uid := uint32(id)
+		if seen[uid] {
+			return nil, fmt.Errorf("duplicate %s: %d", queryParam, id)
+		}
+		seen[uid] = true
+		ids = append(ids, uid)
 	}
 	return ids, nil
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// parseUint32 is a helper function to parse the uint32 from the query parameters.
func parseUint32(r *http.Request, queryParam string) ([]uint32, error) {
param := r.URL.Query().Get(queryParam)
if param == "" {
return []uint32{}, nil
}
idStrs := strings.Split(clobPairIdsParam, ",")
clobPairIds := make([]uint32, 0)
idStrs := strings.Split(param, ",")
ids := make([]uint32, 0)
for _, idStr := range idStrs {
id, err := strconv.Atoi(idStr)
if err != nil {
return nil, fmt.Errorf("invalid clobPairId: %s", idStr)
return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr)
}
if id < 0 || id > math.MaxInt32 {
return nil, fmt.Errorf("invalid clob pair id: %s", idStr)
return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr)
}
clobPairIds = append(clobPairIds, uint32(id))
ids = append(ids, uint32(id))
}
return clobPairIds, nil
return ids, nil
// parseUint32 is a helper function to parse the uint32 from the query parameters.
func parseUint32(r *http.Request, queryParam string) ([]uint32, error) {
param := r.URL.Query().Get(queryParam)
if param == "" {
return []uint32{}, nil
}
// Add maximum length check
if len(param) > 1000 {
return nil, fmt.Errorf("%s parameter too long: max 1000 characters", queryParam)
}
idStrs := strings.Split(param, ",")
// Add maximum number of IDs check
if len(idStrs) > 100 {
return nil, fmt.Errorf("too many %s values: max 100", queryParam)
}
ids := make([]uint32, 0)
seen := make(map[uint32]bool)
for _, idStr := range idStrs {
// Check for empty elements
if idStr == "" {
return nil, fmt.Errorf("empty value in %s list", queryParam)
}
id, err := strconv.Atoi(idStr)
if err != nil {
return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr)
}
if id < 0 || id > math.MaxInt32 {
return nil, fmt.Errorf("invalid %s: %s", queryParam, idStr)
}
// Check for duplicates
uid := uint32(id)
if seen[uid] {
return nil, fmt.Errorf("duplicate %s: %d", queryParam, id)
}
seen[uid] = true
ids = append(ids, uid)
}
return ids, nil
}

}

// Start the websocket server in a separate goroutine.
Expand Down
1 change: 1 addition & 0 deletions protocol/x/clob/keeper/grpc_stream_orderbook.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ func (k Keeper) StreamOrderbookUpdates(
err := k.GetFullNodeStreamingManager().Subscribe(
req.GetClobPairId(),
req.GetSubaccountIds(),
req.GetMarketIds(),
stream,
)
if err != nil {
Expand Down
Loading
Loading