From a326931d51b100c202cf86b319ea438130b4e210 Mon Sep 17 00:00:00 2001 From: affan Date: Fri, 6 Sep 2024 09:50:23 -0400 Subject: [PATCH] Check affiliate tiers are strictly increasing before updating tiers --- protocol/x/affiliates/keeper/keeper.go | 21 +++++-- protocol/x/affiliates/keeper/keeper_test.go | 69 +++++++++++++++++---- 2 files changed, 73 insertions(+), 17 deletions(-) diff --git a/protocol/x/affiliates/keeper/keeper.go b/protocol/x/affiliates/keeper/keeper.go index 24559c325a1..c00f8472f63 100644 --- a/protocol/x/affiliates/keeper/keeper.go +++ b/protocol/x/affiliates/keeper/keeper.go @@ -218,9 +218,22 @@ func (k Keeper) GetTierForAffiliate( // UpdateAffiliateTiers updates the affiliate tiers. // Used primarily through governance. -func (k Keeper) UpdateAffiliateTiers(ctx sdk.Context, affiliateTiers types.AffiliateTiers) { +func (k Keeper) UpdateAffiliateTiers(ctx sdk.Context, affiliateTiers types.AffiliateTiers) error { store := ctx.KVStore(k.storeKey) - // TODO(OTE-779): Check strictly increasing volume and - // staking requirements hold in UpdateAffiliateTiers - store.Set([]byte(types.AffiliateTiersKey), k.cdc.MustMarshal(&affiliateTiers)) + affiliateTiersBytes, err := k.cdc.Marshal(&affiliateTiers) + if err != nil { + return errorsmod.Wrapf(types.ErrInvalidAffiliateTiers, + "error marshalling affiliate tiers: %s", err) + } + tiers := affiliateTiers.GetTiers() + // start at 1, since 0 is the default tier. + for i := 1; i < len(tiers); i++ { + if tiers[i].ReqReferredVolumeQuoteQuantums <= tiers[i-1].ReqReferredVolumeQuoteQuantums || + tiers[i].ReqStakedWholeCoins <= tiers[i-1].ReqStakedWholeCoins { + return errorsmod.Wrapf(types.ErrInvalidAffiliateTiers, + "tiers values must be strictly increasing") + } + } + store.Set([]byte(types.AffiliateTiersKey), affiliateTiersBytes) + return nil } diff --git a/protocol/x/affiliates/keeper/keeper_test.go b/protocol/x/affiliates/keeper/keeper_test.go index 421f62cf55d..61d792982a1 100644 --- a/protocol/x/affiliates/keeper/keeper_test.go +++ b/protocol/x/affiliates/keeper/keeper_test.go @@ -145,10 +145,11 @@ func TestGetTakerFeeShareViaReferredVolume(t *testing.T) { k := tApp.App.AffiliatesKeeper // Set up affiliate tiers affiliateTiers := types.DefaultAffiliateTiers - k.UpdateAffiliateTiers(ctx, affiliateTiers) + err := k.UpdateAffiliateTiers(ctx, affiliateTiers) + require.NoError(t, err) stakingKeeper := tApp.App.StakingKeeper - err := stakingKeeper.SetDelegation(ctx, + err = stakingKeeper.SetDelegation(ctx, stakingtypes.NewDelegation(constants.AliceAccAddress.String(), constants.AliceValAddress.String(), math.LegacyNewDecFromBigInt( new(big.Int).Mul( @@ -194,13 +195,14 @@ func TestGetTakerFeeShareViaStakedAmount(t *testing.T) { ctx = ctx.WithBlockTime(time.Now()) // Set up affiliate tiers affiliateTiers := types.DefaultAffiliateTiers - k.UpdateAffiliateTiers(ctx, affiliateTiers) + err := k.UpdateAffiliateTiers(ctx, affiliateTiers) + require.NoError(t, err) // Register affiliate and referee affiliate := constants.AliceAccAddress.String() referee := constants.BobAccAddress.String() stakingKeeper := tApp.App.StakingKeeper - err := stakingKeeper.SetDelegation(ctx, + err = stakingKeeper.SetDelegation(ctx, stakingtypes.NewDelegation(constants.AliceAccAddress.String(), constants.AliceValAddress.String(), math.LegacyNewDecFromBigInt( new(big.Int).Mul( @@ -248,12 +250,13 @@ func TestGetTierForAffiliate_VolumeAndStake(t *testing.T) { k := tApp.App.AffiliatesKeeper affiliateTiers := types.DefaultAffiliateTiers - k.UpdateAffiliateTiers(ctx, affiliateTiers) + err := k.UpdateAffiliateTiers(ctx, affiliateTiers) + require.NoError(t, err) affiliate := constants.AliceAccAddress.String() referee := constants.BobAccAddress.String() stakingKeeper := tApp.App.StakingKeeper - err := stakingKeeper.SetDelegation(ctx, + err = stakingKeeper.SetDelegation(ctx, stakingtypes.NewDelegation(constants.AliceAccAddress.String(), constants.AliceValAddress.String(), math.LegacyNewDecFromBigInt( new(big.Int).Mul( @@ -296,12 +299,52 @@ func TestUpdateAffiliateTiers(t *testing.T) { ctx := tApp.InitChain() k := tApp.App.AffiliatesKeeper - // Set up valid affiliate tiers - validTiers := types.DefaultAffiliateTiers - k.UpdateAffiliateTiers(ctx, validTiers) + tests := []struct { + name string + affiliateTiers types.AffiliateTiers + expectedError error + }{ + { + name: "Valid tiers", + affiliateTiers: types.DefaultAffiliateTiers, + expectedError: nil, + }, + { + name: "Invalid tiers - decreasing volume requirement", + affiliateTiers: types.AffiliateTiers{ + Tiers: []types.AffiliateTiers_Tier{ + {ReqReferredVolumeQuoteQuantums: 1000, ReqStakedWholeCoins: 100, TakerFeeSharePpm: 100}, + {ReqReferredVolumeQuoteQuantums: 500, ReqStakedWholeCoins: 200, TakerFeeSharePpm: 200}, + }, + }, + expectedError: types.ErrInvalidAffiliateTiers, + }, + { + name: "Invalid tiers - decreasing staking requirement", + affiliateTiers: types.AffiliateTiers{ + Tiers: []types.AffiliateTiers_Tier{ + {ReqReferredVolumeQuoteQuantums: 1000, ReqStakedWholeCoins: 200, TakerFeeSharePpm: 100}, + {ReqReferredVolumeQuoteQuantums: 2000, ReqStakedWholeCoins: 100, TakerFeeSharePpm: 200}, + }, + }, + expectedError: types.ErrInvalidAffiliateTiers, + }, + } - // Retrieve and validate updated tiers - updatedTiers, err := k.GetAllAffiliateTiers(ctx) - require.NoError(t, err) - require.Equal(t, validTiers, updatedTiers) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := k.UpdateAffiliateTiers(ctx, tc.affiliateTiers) + + if tc.expectedError != nil { + require.ErrorIs(t, err, tc.expectedError) + } else { + require.NoError(t, err) + + // Retrieve and validate updated tiers + updatedTiers, err := k.GetAllAffiliateTiers(ctx) + require.NoError(t, err) + require.Equal(t, tc.affiliateTiers, updatedTiers) + } + }) + } }