diff --git a/protocol/x/subaccounts/keeper/subaccount.go b/protocol/x/subaccounts/keeper/subaccount.go index a67c89d9e6a..505605f2120 100644 --- a/protocol/x/subaccounts/keeper/subaccount.go +++ b/protocol/x/subaccounts/keeper/subaccount.go @@ -432,32 +432,6 @@ func (k Keeper) CanUpdateSubaccounts( return success, successPerUpdate, err } -func checkPositionUpdatable( - ctx sdk.Context, - pk types.ProductKeeper, - p types.PositionSize, -) ( - err error, -) { - updatable, err := pk.IsPositionUpdatable( - ctx, - p.GetId(), - ) - if err != nil { - return err - } - - if !updatable { - return errorsmod.Wrapf( - types.ErrProductPositionNotUpdatable, - "type: %v, id: %d", - p.GetProductType(), - p.GetId(), - ) - } - return nil -} - // internalCanUpdateSubaccounts will validate all `updates` to the relevant subaccounts and compute // if any of the updates led to an isolated perpetual position being opened or closed. // The `updates` do not have to contain `Subaccounts` with unique `SubaccountIds`. @@ -587,18 +561,32 @@ func (k Keeper) internalCanUpdateSubaccounts( for i, u := range settledUpdates { // Check all updated perps are updatable. for _, perpUpdate := range u.PerpetualUpdates { - err := checkPositionUpdatable(ctx, k.perpetualsKeeper, perpUpdate) + updatable, err := k.perpetualsKeeper.IsPositionUpdatable(ctx, perpUpdate.GetId()) if err != nil { return false, nil, err } + if !updatable { + return false, nil, errorsmod.Wrapf( + types.ErrProductPositionNotUpdatable, + "type: PerpetualPosition, id: %d", + perpUpdate.GetId(), + ) + } } // Check all updated assets are updatable. for _, assetUpdate := range u.AssetUpdates { - err := checkPositionUpdatable(ctx, k.assetsKeeper, assetUpdate) + updatable, err := k.assetsKeeper.IsPositionUpdatable(ctx, assetUpdate.GetId()) if err != nil { return false, nil, err } + if !updatable { + return false, nil, errorsmod.Wrapf( + types.ErrProductPositionNotUpdatable, + "type: AssetPosition, id: %d", + assetUpdate.GetId(), + ) + } } // Get the new collateralization and margin requirements with the update applied. diff --git a/protocol/x/subaccounts/lib/updates.go b/protocol/x/subaccounts/lib/updates.go index 5185bea237e..e7afcab07ea 100644 --- a/protocol/x/subaccounts/lib/updates.go +++ b/protocol/x/subaccounts/lib/updates.go @@ -1,11 +1,9 @@ package lib import ( - "fmt" "math/big" "sort" - errorsmod "cosmossdk.io/errors" "github.com/dydxprotocol/v4-chain/protocol/dtypes" "github.com/dydxprotocol/v4-chain/protocol/lib" "github.com/dydxprotocol/v4-chain/protocol/lib/margin" @@ -161,19 +159,17 @@ func ApplyUpdatesToPositions[ P types.PositionSize, U types.PositionSize, ](positions []P, updates []U) ([]types.PositionSize, error) { - var result []types.PositionSize = make([]types.PositionSize, 0, len(positions)+len(updates)) + updateIds := lib.MapSlice(updates, func(u U) uint32 { return u.GetId() }) + if lib.ContainsDuplicates(updateIds) { + return nil, types.ErrNonUniqueUpdatesPosition + } + var result []types.PositionSize = make([]types.PositionSize, 0, len(positions)+len(updates)) updateMap := make(map[uint32]types.PositionSize, len(updates)) updateIndexMap := make(map[uint32]int, len(updates)) for i, update := range updates { // Check for non-unique updates (two updates to the same position). id := update.GetId() - _, exists := updateMap[id] - if exists { - errMsg := fmt.Sprintf("Multiple updates exist for position %v", update.GetId()) - return nil, errorsmod.Wrap(types.ErrNonUniqueUpdatesPosition, errMsg) - } - updateMap[id] = update updateIndexMap[id] = i result = append(result, update) diff --git a/protocol/x/subaccounts/types/position_size.go b/protocol/x/subaccounts/types/position_size.go index 39688ae2eb6..80ef9d38b4a 100644 --- a/protocol/x/subaccounts/types/position_size.go +++ b/protocol/x/subaccounts/types/position_size.go @@ -8,12 +8,6 @@ import ( "github.com/dydxprotocol/v4-chain/protocol/dtypes" ) -const ( - AssetProductType = "asset" - PerpetualProductType = "perpetual" - UnknownProductTYpe = "unknown" -) - // PositionSize is an interface for expressing the size of a position type PositionSize interface { // Returns true if and only if the position size is positive. @@ -21,7 +15,6 @@ type PositionSize interface { // Returns the signed position size in big.Int. GetBigQuantums() *big.Int GetId() uint32 - GetProductType() string } type PositionUpdate struct { @@ -74,10 +67,6 @@ func (m *AssetPosition) GetIsLong() bool { return m.GetBigQuantums().Sign() > 0 } -func (m *AssetPosition) GetProductType() string { - return AssetProductType -} - func (m *PerpetualPosition) GetId() uint32 { return m.GetPerpetualId() } @@ -110,10 +99,6 @@ func (m *PerpetualPosition) GetIsLong() bool { return m.GetBigQuantums().Sign() > 0 } -func (m *PerpetualPosition) GetProductType() string { - return PerpetualProductType -} - func (au AssetUpdate) GetIsLong() bool { return au.GetBigQuantums().Sign() > 0 } @@ -126,10 +111,6 @@ func (au AssetUpdate) GetId() uint32 { return au.AssetId } -func (au AssetUpdate) GetProductType() string { - return AssetProductType -} - func (pu PerpetualUpdate) GetBigQuantums() *big.Int { return pu.BigQuantumsDelta } @@ -142,10 +123,6 @@ func (pu PerpetualUpdate) GetIsLong() bool { return pu.GetBigQuantums().Sign() > 0 } -func (pu PerpetualUpdate) GetProductType() string { - return PerpetualProductType -} - func (pu PositionUpdate) GetId() uint32 { return pu.Id } @@ -161,7 +138,3 @@ func (pu PositionUpdate) SetBigQuantums(bigQuantums *big.Int) { func (pu PositionUpdate) GetBigQuantums() *big.Int { return pu.BigQuantums } -func (pu PositionUpdate) GetProductType() string { - // PositionUpdate is generic and doesn't have a product type. - return UnknownProductTYpe -}