Skip to content

Commit

Permalink
chore: add wrapped msg server for x/staking that queues or calls the …
Browse files Browse the repository at this point in the history
…unwrap msg delegate
  • Loading branch information
RafilxTenfen committed Sep 20, 2024
1 parent dec7fa8 commit 7619a51
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 30 deletions.
6 changes: 4 additions & 2 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ import (
paramstypes "github.com/cosmos/cosmos-sdk/x/params/types"
"github.com/cosmos/cosmos-sdk/x/slashing"
slashingtypes "github.com/cosmos/cosmos-sdk/x/slashing/types"
"github.com/cosmos/cosmos-sdk/x/staking"

stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
"github.com/cosmos/gogoproto/proto"
"github.com/cosmos/ibc-go/modules/capability"
Expand All @@ -96,6 +96,8 @@ import (
"github.com/babylonlabs-io/babylon/app/upgrades"
bbn "github.com/babylonlabs-io/babylon/types"

stakingw "github.com/babylonlabs-io/babylon/x/staking"

appkeepers "github.com/babylonlabs-io/babylon/app/keepers"
appparams "github.com/babylonlabs-io/babylon/app/params"
"github.com/babylonlabs-io/babylon/client/docs"
Expand Down Expand Up @@ -292,7 +294,7 @@ func NewBabylonApp(
mint.NewAppModule(appCodec, app.MintKeeper, app.AccountKeeper, nil, app.GetSubspace(minttypes.ModuleName)),
slashing.NewAppModule(appCodec, app.SlashingKeeper, app.AccountKeeper, app.BankKeeper, app.StakingKeeper, app.GetSubspace(slashingtypes.ModuleName), app.interfaceRegistry),
distr.NewAppModule(appCodec, app.DistrKeeper, app.AccountKeeper, app.BankKeeper, app.StakingKeeper, app.GetSubspace(distrtypes.ModuleName)),
staking.NewAppModule(appCodec, app.StakingKeeper, app.AccountKeeper, app.BankKeeper, app.GetSubspace(stakingtypes.ModuleName)),
stakingw.NewAppModule(appCodec, app.StakingKeeper, app.AccountKeeper, app.BankKeeper, app.GetSubspace(stakingtypes.ModuleName), &app.EpochingKeeper),
upgrade.NewAppModule(app.UpgradeKeeper, app.AccountKeeper.AddressCodec()),
evidence.NewAppModule(app.EvidenceKeeper),
params.NewAppModule(app.ParamsKeeper),
Expand Down
10 changes: 5 additions & 5 deletions x/checkpointing/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

type msgServer struct {
k Keeper
Keeper
}

// NewMsgServerImpl returns an implementation of the MsgServer interface
Expand All @@ -25,11 +25,11 @@ var _ types.MsgServer = msgServer{}
// WrappedCreateValidator registers validator's BLS public key
// and forwards corresponding MsgCreateValidator message to
// the epoching module
func (m msgServer) WrappedCreateValidator(goCtx context.Context, msg *types.MsgWrappedCreateValidator) (*types.MsgWrappedCreateValidatorResponse, error) {
func (k Keeper) WrappedCreateValidator(goCtx context.Context, msg *types.MsgWrappedCreateValidator) (*types.MsgWrappedCreateValidatorResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// stateless checks on the inside `MsgCreateValidator` msg
if err := m.k.epochingKeeper.CheckMsgCreateValidator(ctx, msg.MsgCreateValidator); err != nil {
if err := k.epochingKeeper.CheckMsgCreateValidator(ctx, msg.MsgCreateValidator); err != nil {
return nil, err
}

Expand All @@ -39,7 +39,7 @@ func (m msgServer) WrappedCreateValidator(goCtx context.Context, msg *types.MsgW
}

// store BLS public key
err = m.k.CreateRegistration(ctx, *msg.Key.Pubkey, valAddr)
err = k.CreateRegistration(ctx, *msg.Key.Pubkey, valAddr)
if err != nil {
return nil, err
}
Expand All @@ -49,7 +49,7 @@ func (m msgServer) WrappedCreateValidator(goCtx context.Context, msg *types.MsgW
Msg: &epochingtypes.QueuedMessage_MsgCreateValidator{MsgCreateValidator: msg.MsgCreateValidator},
}

m.k.epochingKeeper.EnqueueMsg(ctx, queueMsg)
k.epochingKeeper.EnqueueMsg(ctx, queueMsg)

return &types.MsgWrappedCreateValidatorResponse{}, err
}
3 changes: 1 addition & 2 deletions x/epoching/keeper/drop_validator_msg_decorator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ func NewDropValidatorMsgDecorator(ek Keeper) *DropValidatorMsgDecorator {
// AnteHandle performs an AnteHandler check that rejects all non-wrapped validator-related messages.
// It will reject the following types of messages:
// - MsgCreateValidator
// - MsgDelegate
// - MsgUndelegate
// - MsgBeginRedelegate
// - MsgCancelUnbondingDelegation
Expand All @@ -43,7 +42,7 @@ func (qmd DropValidatorMsgDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simu
// IsValidatorRelatedMsg checks if the given message is of non-wrapped type, which should be rejected
func (qmd DropValidatorMsgDecorator) IsValidatorRelatedMsg(msg sdk.Msg) bool {
switch msg.(type) {
case *stakingtypes.MsgCreateValidator, *stakingtypes.MsgDelegate, *stakingtypes.MsgUndelegate, *stakingtypes.MsgBeginRedelegate, *stakingtypes.MsgCancelUnbondingDelegation:
case *stakingtypes.MsgCreateValidator, *stakingtypes.MsgUndelegate, *stakingtypes.MsgBeginRedelegate, *stakingtypes.MsgCancelUnbondingDelegation:
return true
default:
return false
Expand Down
1 change: 0 additions & 1 deletion x/epoching/keeper/drop_validator_msg_decorator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ func TestDropValidatorMsgDecorator(t *testing.T) {
}{
// wrapped message types that should be rejected
{&stakingtypes.MsgCreateValidator{}, true},
{&stakingtypes.MsgDelegate{}, true},
{&stakingtypes.MsgUndelegate{}, true},
{&stakingtypes.MsgBeginRedelegate{}, true},
{&stakingtypes.MsgCancelUnbondingDelegation{}, true},
Expand Down
6 changes: 5 additions & 1 deletion x/epoching/keeper/epoch_msg_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,13 @@ func (k Keeper) HandleQueuedMsg(ctx context.Context, msg *types.QueuedMessage) (
// get the handler function from router
handler := k.router.Handler(unwrappedMsgWithType)

// tells to the msg server to use the unwrap handler.
sdkCtx := sdk.UnwrapSDKContext(ctx)
sdkCtx = sdkCtx.WithValue(types.CtxKeyUnwrapMsgServer, true)

// Create a new Context based off of the existing Context with a MultiStore branch
// in case message processing fails. At this point, the MultiStore is a branch of a branch.
handlerCtx, msCache := cacheTxContext(sdk.UnwrapSDKContext(ctx), msg.TxId, msg.MsgId, msg.BlockHeight)
handlerCtx, msCache := cacheTxContext(sdkCtx, msg.TxId, msg.MsgId, msg.BlockHeight)

// handle the unwrapped message
result, err := handler(handlerCtx, unwrappedMsgWithType)
Expand Down
38 changes: 19 additions & 19 deletions x/epoching/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func NewMsgServerImpl(keeper Keeper) types.MsgServer {
var _ types.MsgServer = msgServer{}

// WrappedDelegate handles the MsgWrappedDelegate request
func (ms msgServer) WrappedDelegate(goCtx context.Context, msg *types.MsgWrappedDelegate) (*types.MsgWrappedDelegateResponse, error) {
func (k Keeper) WrappedDelegate(goCtx context.Context, msg *types.MsgWrappedDelegate) (*types.MsgWrappedDelegateResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)
if msg.Msg == nil {
return nil, types.ErrNoWrappedMsg
Expand All @@ -36,13 +36,13 @@ func (ms msgServer) WrappedDelegate(goCtx context.Context, msg *types.MsgWrapped
if valErr != nil {
return nil, valErr
}
if _, err := ms.stk.GetValidator(ctx, valAddr); err != nil {
if _, err := k.stk.GetValidator(ctx, valAddr); err != nil {
return nil, err
}
if _, err := sdk.AccAddressFromBech32(msg.Msg.DelegatorAddress); err != nil {
return nil, err
}
bondDenom, err := ms.stk.BondDenom(ctx)
bondDenom, err := k.stk.BondDenom(ctx)
if err != nil {
return nil, err
}
Expand All @@ -64,15 +64,15 @@ func (ms msgServer) WrappedDelegate(goCtx context.Context, msg *types.MsgWrapped
return nil, err
}

ms.EnqueueMsg(ctx, queuedMsg)
k.EnqueueMsg(ctx, queuedMsg)

err = ctx.EventManager().EmitTypedEvents(
&types.EventWrappedDelegate{
DelegatorAddress: msg.Msg.DelegatorAddress,
ValidatorAddress: msg.Msg.ValidatorAddress,
Amount: msg.Msg.Amount.Amount.Uint64(),
Denom: msg.Msg.Amount.GetDenom(),
EpochBoundary: ms.GetEpoch(ctx).GetLastBlockHeight(),
EpochBoundary: k.GetEpoch(ctx).GetLastBlockHeight(),
},
)
if err != nil {
Expand All @@ -83,7 +83,7 @@ func (ms msgServer) WrappedDelegate(goCtx context.Context, msg *types.MsgWrapped
}

// WrappedUndelegate handles the MsgWrappedUndelegate request
func (ms msgServer) WrappedUndelegate(goCtx context.Context, msg *types.MsgWrappedUndelegate) (*types.MsgWrappedUndelegateResponse, error) {
func (k Keeper) WrappedUndelegate(goCtx context.Context, msg *types.MsgWrappedUndelegate) (*types.MsgWrappedUndelegateResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)
if msg.Msg == nil {
return nil, types.ErrNoWrappedMsg
Expand All @@ -98,10 +98,10 @@ func (ms msgServer) WrappedUndelegate(goCtx context.Context, msg *types.MsgWrapp
if err != nil {
return nil, err
}
if _, err := ms.stk.ValidateUnbondAmount(ctx, delegatorAddress, valAddr, msg.Msg.Amount.Amount); err != nil {
if _, err := k.stk.ValidateUnbondAmount(ctx, delegatorAddress, valAddr, msg.Msg.Amount.Amount); err != nil {
return nil, err
}
bondDenom, err := ms.stk.BondDenom(ctx)
bondDenom, err := k.stk.BondDenom(ctx)
if err != nil {
return nil, err
}
Expand All @@ -123,15 +123,15 @@ func (ms msgServer) WrappedUndelegate(goCtx context.Context, msg *types.MsgWrapp
return nil, err
}

ms.EnqueueMsg(ctx, queuedMsg)
k.EnqueueMsg(ctx, queuedMsg)

err = ctx.EventManager().EmitTypedEvents(
&types.EventWrappedUndelegate{
DelegatorAddress: msg.Msg.DelegatorAddress,
ValidatorAddress: msg.Msg.ValidatorAddress,
Amount: msg.Msg.Amount.Amount.Uint64(),
Denom: msg.Msg.Amount.GetDenom(),
EpochBoundary: ms.GetEpoch(ctx).GetLastBlockHeight(),
EpochBoundary: k.GetEpoch(ctx).GetLastBlockHeight(),
},
)
if err != nil {
Expand All @@ -142,7 +142,7 @@ func (ms msgServer) WrappedUndelegate(goCtx context.Context, msg *types.MsgWrapp
}

// WrappedBeginRedelegate handles the MsgWrappedBeginRedelegate request
func (ms msgServer) WrappedBeginRedelegate(goCtx context.Context, msg *types.MsgWrappedBeginRedelegate) (*types.MsgWrappedBeginRedelegateResponse, error) {
func (k Keeper) WrappedBeginRedelegate(goCtx context.Context, msg *types.MsgWrappedBeginRedelegate) (*types.MsgWrappedBeginRedelegateResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)
if msg.Msg == nil {
return nil, types.ErrNoWrappedMsg
Expand All @@ -157,10 +157,10 @@ func (ms msgServer) WrappedBeginRedelegate(goCtx context.Context, msg *types.Msg
if err != nil {
return nil, err
}
if _, err := ms.stk.ValidateUnbondAmount(ctx, delegatorAddress, valSrcAddr, msg.Msg.Amount.Amount); err != nil {
if _, err := k.stk.ValidateUnbondAmount(ctx, delegatorAddress, valSrcAddr, msg.Msg.Amount.Amount); err != nil {
return nil, err
}
bondDenom, err := ms.stk.BondDenom(ctx)
bondDenom, err := k.stk.BondDenom(ctx)
if err != nil {
return nil, err
}
Expand All @@ -185,15 +185,15 @@ func (ms msgServer) WrappedBeginRedelegate(goCtx context.Context, msg *types.Msg
return nil, err
}

ms.EnqueueMsg(ctx, queuedMsg)
k.EnqueueMsg(ctx, queuedMsg)
err = ctx.EventManager().EmitTypedEvents(
&types.EventWrappedBeginRedelegate{
DelegatorAddress: msg.Msg.DelegatorAddress,
SourceValidatorAddress: msg.Msg.ValidatorSrcAddress,
DestinationValidatorAddress: msg.Msg.ValidatorDstAddress,
Amount: msg.Msg.Amount.Amount.Uint64(),
Denom: msg.Msg.Amount.GetDenom(),
EpochBoundary: ms.GetEpoch(ctx).GetLastBlockHeight(),
EpochBoundary: k.GetEpoch(ctx).GetLastBlockHeight(),
},
)
if err != nil {
Expand All @@ -204,7 +204,7 @@ func (ms msgServer) WrappedBeginRedelegate(goCtx context.Context, msg *types.Msg
}

// WrappedCancelUnbondingDelegation handles the MsgWrappedCancelUnbondingDelegation request
func (ms msgServer) WrappedCancelUnbondingDelegation(goCtx context.Context, msg *types.MsgWrappedCancelUnbondingDelegation) (*types.MsgWrappedCancelUnbondingDelegationResponse, error) {
func (k Keeper) WrappedCancelUnbondingDelegation(goCtx context.Context, msg *types.MsgWrappedCancelUnbondingDelegation) (*types.MsgWrappedCancelUnbondingDelegationResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)
if msg.Msg == nil {
return nil, types.ErrNoWrappedMsg
Expand Down Expand Up @@ -233,7 +233,7 @@ func (ms msgServer) WrappedCancelUnbondingDelegation(goCtx context.Context, msg
)
}

bondDenom, err := ms.stk.BondDenom(ctx)
bondDenom, err := k.stk.BondDenom(ctx)
if err != nil {
return nil, err
}
Expand All @@ -254,14 +254,14 @@ func (ms msgServer) WrappedCancelUnbondingDelegation(goCtx context.Context, msg
return nil, err
}

ms.EnqueueMsg(ctx, queuedMsg)
k.EnqueueMsg(ctx, queuedMsg)
err = ctx.EventManager().EmitTypedEvents(
&types.EventWrappedCancelUnbondingDelegation{
DelegatorAddress: msg.Msg.DelegatorAddress,
ValidatorAddress: msg.Msg.ValidatorAddress,
Amount: msg.Msg.Amount.Amount.Uint64(),
CreationHeight: msg.Msg.CreationHeight,
EpochBoundary: ms.GetEpoch(ctx).GetLastBlockHeight(),
EpochBoundary: k.GetEpoch(ctx).GetLastBlockHeight(),
},
)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions x/epoching/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ const (

// MemStoreKey defines the in-memory store key
MemStoreKey = "mem_epoching"

// CtxKeyUnwrapMsgServer defines to the context that it should use the unwraped msg handler (native from cosmos-sdk).
CtxKeyUnwrapMsgServer = "unwrap"
)

var (
Expand Down
41 changes: 41 additions & 0 deletions x/staking/keeper/msg_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package keeper

import (
"context"

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/staking/keeper"
"github.com/cosmos/cosmos-sdk/x/staking/types"

epochingkeeper "github.com/babylonlabs-io/babylon/x/epoching/keeper"
epochingtypes "github.com/babylonlabs-io/babylon/x/epoching/types"
)

type msgServer struct {
types.MsgServer

epochK *epochingkeeper.Keeper
}

// NewMsgServerImpl returns an implementation of the staking MsgServer interface
// for the provided Keeper.
func NewMsgServerImpl(k *keeper.Keeper, epochK *epochingkeeper.Keeper) types.MsgServer {
return &msgServer{
MsgServer: keeper.NewMsgServerImpl(k),
epochK: epochK,
}
}

// Delegate defines a method for performing a delegation of coins from a delegator to a validator
func (ms msgServer) Delegate(goCtx context.Context, msg *types.MsgDelegate) (*types.MsgDelegateResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)
if ctx.Value(epochingtypes.CtxKeyUnwrapMsgServer).(bool) {
return ms.MsgServer.Delegate(goCtx, msg)
}

_, err := ms.epochK.WrappedDelegate(ctx, epochingtypes.NewMsgWrappedDelegate(msg))
if err != nil {
return nil, err
}
return &types.MsgDelegateResponse{}, nil
}
64 changes: 64 additions & 0 deletions x/staking/module.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package staking

import (
"fmt"

"github.com/CosmWasm/wasmd/x/wasm/exported"
"github.com/cosmos/cosmos-sdk/codec"
"github.com/cosmos/cosmos-sdk/types/module"
stkapp "github.com/cosmos/cosmos-sdk/x/staking"
"github.com/cosmos/cosmos-sdk/x/staking/keeper"
"github.com/cosmos/cosmos-sdk/x/staking/types"

epochingkeeper "github.com/babylonlabs-io/babylon/x/epoching/keeper"
wkeeper "github.com/babylonlabs-io/babylon/x/staking/keeper"
)

type AppModule struct {
stkapp.AppModule

k *keeper.Keeper
// legacySubspace is used solely for migration of x/params managed parameters
legacySubspace exported.Subspace

// Wrapped staking forking needed to queue msgs
epochK *epochingkeeper.Keeper
}

// NewAppModule creates a new AppModule object
func NewAppModule(
cdc codec.Codec,
k *keeper.Keeper,
ak types.AccountKeeper,
bk types.BankKeeper,
ls exported.Subspace,
epochK *epochingkeeper.Keeper,
) AppModule {
return AppModule{
AppModule: stkapp.NewAppModule(cdc, k, ak, bk, ls),
k: k,
legacySubspace: ls,
epochK: epochK,
}
}

// RegisterServices registers module services.
func (am AppModule) RegisterServices(cfg module.Configurator) {
types.RegisterMsgServer(cfg.MsgServer(), wkeeper.NewMsgServerImpl(am.k, am.epochK))
querier := keeper.Querier{Keeper: am.k}
types.RegisterQueryServer(cfg.QueryServer(), querier)

m := keeper.NewMigrator(am.k, am.legacySubspace)
if err := cfg.RegisterMigration(types.ModuleName, 1, m.Migrate1to2); err != nil {
panic(fmt.Sprintf("failed to migrate x/%s from version 1 to 2: %v", types.ModuleName, err))
}
if err := cfg.RegisterMigration(types.ModuleName, 2, m.Migrate2to3); err != nil {
panic(fmt.Sprintf("failed to migrate x/%s from version 2 to 3: %v", types.ModuleName, err))
}
if err := cfg.RegisterMigration(types.ModuleName, 3, m.Migrate3to4); err != nil {
panic(fmt.Sprintf("failed to migrate x/%s from version 3 to 4: %v", types.ModuleName, err))
}
if err := cfg.RegisterMigration(types.ModuleName, 4, m.Migrate4to5); err != nil {
panic(fmt.Sprintf("failed to migrate x/%s from version 4 to 5: %v", types.ModuleName, err))
}
}

0 comments on commit 7619a51

Please sign in to comment.