diff --git a/protocol/x/clob/types/errors.go b/protocol/x/clob/types/errors.go index 779d95fbf6..792f0a9451 100644 --- a/protocol/x/clob/types/errors.go +++ b/protocol/x/clob/types/errors.go @@ -209,7 +209,7 @@ var ( ErrInvalidBatchCancel = errorsmod.Register( ModuleName, 45, - "Invalid Batch Cancel", + "Invalid batch cancel message", ) // Liquidations errors. diff --git a/protocol/x/clob/types/message_batch_cancel.go b/protocol/x/clob/types/message_batch_cancel.go index 39d266ef23..dacd5b3cac 100644 --- a/protocol/x/clob/types/message_batch_cancel.go +++ b/protocol/x/clob/types/message_batch_cancel.go @@ -26,9 +26,22 @@ func (msg *MsgBatchCancel) ValidateBasic() (err error) { } cancelBatches := msg.GetShortTermCancels() + if len(cancelBatches) == 0 { + return errorsmod.Wrapf( + ErrInvalidBatchCancel, + "Batch cancel cannot have zero orders specified.", + ) + } totalNumberCancels := 0 for _, cancelBatch := range cancelBatches { - totalNumberCancels += len(cancelBatch.GetClientIds()) + numClientIds := len(cancelBatch.GetClientIds()) + if numClientIds == 0 { + return errorsmod.Wrapf( + ErrInvalidBatchCancel, + "Order Batch cannot have zero client ids.", + ) + } + totalNumberCancels += numClientIds seenClientIds := map[uint32]struct{}{} for _, clientId := range cancelBatch.GetClientIds() { if _, seen := seenClientIds[clientId]; seen { diff --git a/protocol/x/clob/types/message_batch_cancel_test.go b/protocol/x/clob/types/message_batch_cancel_test.go index acf3afe404..504ee7c992 100644 --- a/protocol/x/clob/types/message_batch_cancel_test.go +++ b/protocol/x/clob/types/message_batch_cancel_test.go @@ -53,11 +53,11 @@ func TestMsgBatchCancel_ValidateBasic(t *testing.T) { []types.OrderBatch{ { ClobPairId: 0, - ClientIds: oneOverMax[:52], + ClientIds: oneOverMax[:types.MaxMsgBatchCancelBatchSize/2+2], }, { ClobPairId: 1, - ClientIds: oneOverMax[:52], + ClientIds: oneOverMax[:types.MaxMsgBatchCancelBatchSize/2+2], }, }, 10, @@ -70,11 +70,11 @@ func TestMsgBatchCancel_ValidateBasic(t *testing.T) { []types.OrderBatch{ { ClobPairId: 0, - ClientIds: oneOverMax[:50], + ClientIds: oneOverMax[:types.MaxMsgBatchCancelBatchSize/2], }, { ClobPairId: 1, - ClientIds: oneOverMax[:50], + ClientIds: oneOverMax[:types.MaxMsgBatchCancelBatchSize/2], }, }, 10, @@ -109,6 +109,27 @@ func TestMsgBatchCancel_ValidateBasic(t *testing.T) { ), err: types.ErrInvalidBatchCancel, }, + "zero batches in cancel batch": { + msg: *types.NewMsgBatchCancel( + constants.Alice_Num0, + []types.OrderBatch{}, + 10, + ), + err: types.ErrInvalidBatchCancel, + }, + "zero client ids in cancel batch": { + msg: *types.NewMsgBatchCancel( + constants.Alice_Num0, + []types.OrderBatch{ + { + ClobPairId: 0, + ClientIds: []uint32{}, + }, + }, + 10, + ), + err: types.ErrInvalidBatchCancel, + }, } for name, tc := range tests { t.Run(name, func(t *testing.T) {