From 018af98f42fe2b6c21313d9c8daafe84efa3f6eb Mon Sep 17 00:00:00 2001 From: Tanmay Date: Wed, 24 Jul 2024 21:38:57 -0400 Subject: [PATCH 01/14] make solana migration optional --- e2e/e2etests/test_migrate_tss.go | 3 ++- x/crosschain/keeper/msg_server_update_tss.go | 6 +++++- x/crosschain/types/expected_keepers.go | 1 + x/observer/keeper/chain_params.go | 19 +++++++++++++++++++ 4 files changed, 27 insertions(+), 2 deletions(-) diff --git a/e2e/e2etests/test_migrate_tss.go b/e2e/e2etests/test_migrate_tss.go index c72b876f0f..9912e56599 100644 --- a/e2e/e2etests/test_migrate_tss.go +++ b/e2e/e2etests/test_migrate_tss.go @@ -113,9 +113,10 @@ func TestMigrateTSS(r *runner.E2ERunner, _ []string) { r.ZetaTxServer.MustGetAccountAddressFromName(utils.AdminPolicyName), allTss.TssList[1].TssPubkey, ) - _, err = r.ZetaTxServer.BroadcastTx(utils.AdminPolicyName, msgUpdateTss) + res, err := r.ZetaTxServer.BroadcastTx(utils.AdminPolicyName, msgUpdateTss) require.NoError(r, err) + r.Logger.Print("Brodacast tx : ", res.TxHash) // Wait for atleast one block for the TSS to be updated time.Sleep(8 * time.Second) diff --git a/x/crosschain/keeper/msg_server_update_tss.go b/x/crosschain/keeper/msg_server_update_tss.go index 78c3b833d7..5f3e5c8cc4 100644 --- a/x/crosschain/keeper/msg_server_update_tss.go +++ b/x/crosschain/keeper/msg_server_update_tss.go @@ -5,6 +5,7 @@ import ( errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/zeta-chain/zetacore/pkg/chains" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" "github.com/zeta-chain/zetacore/x/crosschain/types" @@ -37,8 +38,11 @@ func (k msgServer) UpdateTssAddress( } tssMigrators := k.zetaObserverKeeper.GetAllTssFundMigrators(ctx) + + supportedChains := append(k.zetaObserverKeeper.GetSupportedChainsByConsensus(ctx, chains.Consensus_ethereum), + k.zetaObserverKeeper.GetSupportedChainsByConsensus(ctx, chains.Consensus_bitcoin)...) // Each connected chain should have its own tss migrator - if len(k.zetaObserverKeeper.GetSupportedForeignChains(ctx)) != len(tssMigrators) { + if len(supportedChains) != len(tssMigrators) { return nil, errorsmod.Wrap( types.ErrUnableToUpdateTss, "cannot update tss address incorrect number of migrations have been created and completed", diff --git a/x/crosschain/types/expected_keepers.go b/x/crosschain/types/expected_keepers.go index b66ab1a12c..57ccd28856 100644 --- a/x/crosschain/types/expected_keepers.go +++ b/x/crosschain/types/expected_keepers.go @@ -104,6 +104,7 @@ type ObserverKeeper interface { GetSupportedChainFromChainID(ctx sdk.Context, chainID int64) (chains.Chain, bool) GetSupportedChains(ctx sdk.Context) []chains.Chain GetSupportedForeignChains(ctx sdk.Context) []chains.Chain + GetSupportedChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain } type FungibleKeeper interface { diff --git a/x/observer/keeper/chain_params.go b/x/observer/keeper/chain_params.go index 9277010c3c..1ba7160b86 100644 --- a/x/observer/keeper/chain_params.go +++ b/x/observer/keeper/chain_params.go @@ -76,6 +76,25 @@ func (k Keeper) GetSupportedChains(ctx sdk.Context) []chains.Chain { return c } +// GetSupportedChains returns the list of supported chains +func (k Keeper) GetSupportedChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain { + cpl, found := k.GetChainParamsList(ctx) + if !found { + return []chains.Chain{} + } + + var c []chains.Chain + for _, cp := range cpl.ChainParams { + if cp.IsSupported { + chain, found := chains.GetChainFromChainID(cp.ChainId, k.GetAuthorityKeeper().GetAdditionalChainList(ctx)) + if found && chain.GetConsensus() == consensus { + c = append(c, chain) + } + } + } + return c +} + // GetSupportedForeignChains returns the list of supported foreign chains func (k Keeper) GetSupportedForeignChains(ctx sdk.Context) []chains.Chain { allChains := k.GetSupportedChains(ctx) From 748331265c16af7992d00587e254937fccc14846 Mon Sep 17 00:00:00 2001 From: Tanmay Date: Thu, 25 Jul 2024 12:39:18 -0400 Subject: [PATCH 02/14] add unit tests --- changelog.md | 1 + e2e/e2etests/test_migrate_tss.go | 3 +- testutil/keeper/mocks/crosschain/observer.go | 20 ++++++ x/crosschain/keeper/msg_server_update_tss.go | 9 ++- .../keeper/msg_server_update_tss_test.go | 22 +++---- x/crosschain/types/expected_keepers.go | 2 +- x/observer/keeper/chain_params.go | 22 +++---- x/observer/keeper/chain_params_test.go | 66 +++++++++++++++++++ 8 files changed, 114 insertions(+), 31 deletions(-) diff --git a/changelog.md b/changelog.md index 29ef9db9b6..92e72f9deb 100644 --- a/changelog.md +++ b/changelog.md @@ -65,6 +65,7 @@ * [2428](https://github.com/zeta-chain/node/pull/2428) - propagate context across codebase & refactor zetacore client * [2464](https://github.com/zeta-chain/node/pull/2464) - move common voting logic to voting.go and add new function VoteOnBallot * [2515](https://github.com/zeta-chain/node/pull/2515) - replace chainName by chainID for ChainNonces indexing +* [2556](https://github.com/zeta-chain/node/pull/2556) - refactor migrator length check to use consensus type ### Tests diff --git a/e2e/e2etests/test_migrate_tss.go b/e2e/e2etests/test_migrate_tss.go index 9912e56599..c72b876f0f 100644 --- a/e2e/e2etests/test_migrate_tss.go +++ b/e2e/e2etests/test_migrate_tss.go @@ -113,10 +113,9 @@ func TestMigrateTSS(r *runner.E2ERunner, _ []string) { r.ZetaTxServer.MustGetAccountAddressFromName(utils.AdminPolicyName), allTss.TssList[1].TssPubkey, ) - res, err := r.ZetaTxServer.BroadcastTx(utils.AdminPolicyName, msgUpdateTss) + _, err = r.ZetaTxServer.BroadcastTx(utils.AdminPolicyName, msgUpdateTss) require.NoError(r, err) - r.Logger.Print("Brodacast tx : ", res.TxHash) // Wait for atleast one block for the TSS to be updated time.Sleep(8 * time.Second) diff --git a/testutil/keeper/mocks/crosschain/observer.go b/testutil/keeper/mocks/crosschain/observer.go index d71673848f..63789912ce 100644 --- a/testutil/keeper/mocks/crosschain/observer.go +++ b/testutil/keeper/mocks/crosschain/observer.go @@ -624,6 +624,26 @@ func (_m *CrosschainObserverKeeper) GetSupportedForeignChains(ctx types.Context) return r0 } +// GetSupportedForeignChainsByConsensus provides a mock function with given fields: ctx, consensus +func (_m *CrosschainObserverKeeper) GetSupportedForeignChainsByConsensus(ctx types.Context, consensus chains.Consensus) []chains.Chain { + ret := _m.Called(ctx, consensus) + + if len(ret) == 0 { + panic("no return value specified for GetSupportedForeignChainsByConsensus") + } + + var r0 []chains.Chain + if rf, ok := ret.Get(0).(func(types.Context, chains.Consensus) []chains.Chain); ok { + r0 = rf(ctx, consensus) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]chains.Chain) + } + } + + return r0 +} + // GetTSS provides a mock function with given fields: ctx func (_m *CrosschainObserverKeeper) GetTSS(ctx types.Context) (observertypes.TSS, bool) { ret := _m.Called(ctx) diff --git a/x/crosschain/keeper/msg_server_update_tss.go b/x/crosschain/keeper/msg_server_update_tss.go index 5f3e5c8cc4..73fdab6b38 100644 --- a/x/crosschain/keeper/msg_server_update_tss.go +++ b/x/crosschain/keeper/msg_server_update_tss.go @@ -39,10 +39,8 @@ func (k msgServer) UpdateTssAddress( tssMigrators := k.zetaObserverKeeper.GetAllTssFundMigrators(ctx) - supportedChains := append(k.zetaObserverKeeper.GetSupportedChainsByConsensus(ctx, chains.Consensus_ethereum), - k.zetaObserverKeeper.GetSupportedChainsByConsensus(ctx, chains.Consensus_bitcoin)...) // Each connected chain should have its own tss migrator - if len(supportedChains) != len(tssMigrators) { + if len(k.GetChainsSupportingMigration(ctx)) != len(tssMigrators) { return nil, errorsmod.Wrap( types.ErrUnableToUpdateTss, "cannot update tss address incorrect number of migrations have been created and completed", @@ -74,3 +72,8 @@ func (k msgServer) UpdateTssAddress( return &types.MsgUpdateTssAddressResponse{}, nil } + +func (k *Keeper) GetChainsSupportingMigration(ctx sdk.Context) []chains.Chain { + return append(k.zetaObserverKeeper.GetSupportedForeignChainsByConsensus(ctx, chains.Consensus_ethereum), + k.zetaObserverKeeper.GetSupportedForeignChainsByConsensus(ctx, chains.Consensus_bitcoin)...) +} diff --git a/x/crosschain/keeper/msg_server_update_tss_test.go b/x/crosschain/keeper/msg_server_update_tss_test.go index 761ff617c8..352b3873fa 100644 --- a/x/crosschain/keeper/msg_server_update_tss_test.go +++ b/x/crosschain/keeper/msg_server_update_tss_test.go @@ -65,7 +65,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSSHistory(ctx, tssNew) k.GetObserverKeeper().SetTSS(ctx, tssOld) - for _, chain := range k.GetObserverKeeper().GetSupportedForeignChains(ctx) { + for _, chain := range k.GetChainsSupportingMigration(ctx) { index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -78,7 +78,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetObserverKeeper().GetSupportedForeignChains(ctx)), + len(k.GetChainsSupportingMigration(ctx)), ) msg := crosschaintypes.MsgUpdateTssAddress{ @@ -109,7 +109,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSS(ctx, tssOld) - for _, chain := range k.GetObserverKeeper().GetSupportedChains(ctx) { + for _, chain := range k.GetChainsSupportingMigration(ctx) { index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -122,7 +122,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetObserverKeeper().GetSupportedChains(ctx)), + len(k.GetChainsSupportingMigration(ctx)), ) msg := crosschaintypes.MsgUpdateTssAddress{ @@ -139,7 +139,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetObserverKeeper().GetSupportedChains(ctx)), + len(k.GetChainsSupportingMigration(ctx)), ) }) @@ -156,7 +156,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSS(ctx, tssOld) - for _, chain := range k.GetObserverKeeper().GetSupportedChains(ctx) { + for _, chain := range k.GetChainsSupportingMigration(ctx) { index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -169,7 +169,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetObserverKeeper().GetSupportedChains(ctx)), + len(k.GetChainsSupportingMigration(ctx)), ) msg := crosschaintypes.MsgUpdateTssAddress{ @@ -186,7 +186,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetObserverKeeper().GetSupportedChains(ctx)), + len(k.GetChainsSupportingMigration(ctx)), ) }) @@ -207,7 +207,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { setSupportedChain(ctx, zk, getValidEthChainIDWithIndex(t, 0), getValidEthChainIDWithIndex(t, 1)) // set a single migrator while there are 2 supported chains - chain := k.GetObserverKeeper().GetSupportedChains(ctx)[0] + chain := k.GetChainsSupportingMigration(ctx)[0] index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -254,7 +254,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSS(ctx, tssOld) setSupportedChain(ctx, zk, getValidEthChainIDWithIndex(t, 0), getValidEthChainIDWithIndex(t, 1)) - for _, chain := range k.GetObserverKeeper().GetSupportedChains(ctx) { + for _, chain := range k.GetChainsSupportingMigration(ctx) { index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -301,7 +301,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSS(ctx, tssOld) setSupportedChain(ctx, zk, getValidEthChainIDWithIndex(t, 0), getValidEthChainIDWithIndex(t, 1)) - for _, chain := range k.GetObserverKeeper().GetSupportedChains(ctx) { + for _, chain := range k.GetChainsSupportingMigration(ctx) { index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, diff --git a/x/crosschain/types/expected_keepers.go b/x/crosschain/types/expected_keepers.go index 57ccd28856..b159f84148 100644 --- a/x/crosschain/types/expected_keepers.go +++ b/x/crosschain/types/expected_keepers.go @@ -104,7 +104,7 @@ type ObserverKeeper interface { GetSupportedChainFromChainID(ctx sdk.Context, chainID int64) (chains.Chain, bool) GetSupportedChains(ctx sdk.Context) []chains.Chain GetSupportedForeignChains(ctx sdk.Context) []chains.Chain - GetSupportedChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain + GetSupportedForeignChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain } type FungibleKeeper interface { diff --git a/x/observer/keeper/chain_params.go b/x/observer/keeper/chain_params.go index 1ba7160b86..1c2cf4e7c4 100644 --- a/x/observer/keeper/chain_params.go +++ b/x/observer/keeper/chain_params.go @@ -76,23 +76,17 @@ func (k Keeper) GetSupportedChains(ctx sdk.Context) []chains.Chain { return c } -// GetSupportedChains returns the list of supported chains -func (k Keeper) GetSupportedChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain { - cpl, found := k.GetChainParamsList(ctx) - if !found { - return []chains.Chain{} - } +// GetSupportedChainsByConsensus returns the list of supported chains by consensus +func (k Keeper) GetSupportedForeignChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain { + allChains := k.GetSupportedChains(ctx) - var c []chains.Chain - for _, cp := range cpl.ChainParams { - if cp.IsSupported { - chain, found := chains.GetChainFromChainID(cp.ChainId, k.GetAuthorityKeeper().GetAdditionalChainList(ctx)) - if found && chain.GetConsensus() == consensus { - c = append(c, chain) - } + foreignChains := make([]chains.Chain, 0) + for _, chain := range allChains { + if !chain.IsZetaChain() && chain.GetConsensus() == consensus { + foreignChains = append(foreignChains, chain) } } - return c + return foreignChains } // GetSupportedForeignChains returns the list of supported foreign chains diff --git a/x/observer/keeper/chain_params_test.go b/x/observer/keeper/chain_params_test.go index 733fafc0b8..84f1946a3f 100644 --- a/x/observer/keeper/chain_params_test.go +++ b/x/observer/keeper/chain_params_test.go @@ -110,3 +110,69 @@ func TestKeeper_GetSupportedChains(t *testing.T) { require.EqualValues(t, supported4.ChainId, supportedChains[3].ChainId) }) } + +func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { + t.Run("return empty list if not chans are supported", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + require.Empty(t, k.GetSupportedForeignChainsByConsensus(ctx, chains.Consensus_ethereum)) + }) + + t.Run("return list of supported chains for ethereum consensus", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + chainList := chains.ExternalChainList([]chains.Chain{}) + var chainParamsList types.ChainParamsList + for _, chain := range chainList { + chainParamsList.ChainParams = append(chainParamsList.ChainParams, sample.ChainParamsSupported(chain.ChainId)) + } + k.SetChainParamsList(ctx, chainParamsList) + consensus := chains.Consensus_ethereum + + supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) + require.NotEmpty(t, supportedChainsList) + + require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) + }) + + t.Run("return list of supported chains for bitcoin consensus", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + chainList := chains.ExternalChainList([]chains.Chain{}) + var chainParamsList types.ChainParamsList + for _, chain := range chainList { + chainParamsList.ChainParams = append(chainParamsList.ChainParams, sample.ChainParamsSupported(chain.ChainId)) + } + k.SetChainParamsList(ctx, chainParamsList) + consensus := chains.Consensus_bitcoin + + supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) + require.NotEmpty(t, supportedChainsList) + require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) + }) + + t.Run("return list of supported chains for solana consensus", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + chainList := chains.ExternalChainList([]chains.Chain{}) + var chainParamsList types.ChainParamsList + for _, chain := range chainList { + chainParamsList.ChainParams = append(chainParamsList.ChainParams, sample.ChainParamsSupported(chain.ChainId)) + } + k.SetChainParamsList(ctx, chainParamsList) + consensus := chains.Consensus_solana_consensus + + supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) + require.NotEmpty(t, supportedChainsList) + require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) + }) + +} + +func getForeignChains(consensus chains.Consensus) []chains.Chain { + evmChains := chains.ChainListByConsensus(consensus, []chains.Chain{}) + foreignEvmChains := make([]chains.Chain, 0) + + for _, chain := range evmChains { + if !chain.IsZetaChain() { + foreignEvmChains = append(foreignEvmChains, chain) + } + } + return foreignEvmChains +} From fa77cf5446783116330ae2674a80094974954b8b Mon Sep 17 00:00:00 2001 From: Tanmay Date: Thu, 25 Jul 2024 12:44:09 -0400 Subject: [PATCH 03/14] get all foreign chains unit test --- x/observer/keeper/chain_params_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/x/observer/keeper/chain_params_test.go b/x/observer/keeper/chain_params_test.go index 84f1946a3f..395d1848c0 100644 --- a/x/observer/keeper/chain_params_test.go +++ b/x/observer/keeper/chain_params_test.go @@ -162,7 +162,32 @@ func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { require.NotEmpty(t, supportedChainsList) require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) }) +} + +func TestKeeper_GetSupportedForeignChains(t *testing.T) { + t.Run("return empty list if not chans are supported", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + require.Empty(t, k.GetSupportedForeignChains(ctx)) + }) + + t.Run("return list of supported chains", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + chainList := chains.ExternalChainList([]chains.Chain{}) + var chainParamsList types.ChainParamsList + for _, chain := range chainList { + chainParamsList.ChainParams = append(chainParamsList.ChainParams, sample.ChainParamsSupported(chain.ChainId)) + } + k.SetChainParamsList(ctx, chainParamsList) + + supportedChainsList := k.GetSupportedForeignChains(ctx) + require.NotEmpty(t, supportedChainsList) + + require.ElementsMatch(t, getAllForeignChains(), supportedChainsList) + }) +} +func getAllForeignChains() []chains.Chain { + return chains.ExternalChainList([]chains.Chain{}) } func getForeignChains(consensus chains.Consensus) []chains.Chain { From 53fbdd29e996fe483ea63004de5bd35f6d838b32 Mon Sep 17 00:00:00 2001 From: Tanmay Date: Thu, 25 Jul 2024 13:29:05 -0400 Subject: [PATCH 04/14] add unit test for chains supporting migration --- x/crosschain/keeper/msg_server_update_tss.go | 1 + .../keeper/msg_server_update_tss_test.go | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/x/crosschain/keeper/msg_server_update_tss.go b/x/crosschain/keeper/msg_server_update_tss.go index 73fdab6b38..e7e14e13b9 100644 --- a/x/crosschain/keeper/msg_server_update_tss.go +++ b/x/crosschain/keeper/msg_server_update_tss.go @@ -73,6 +73,7 @@ func (k msgServer) UpdateTssAddress( return &types.MsgUpdateTssAddressResponse{}, nil } +// GetChainsSupportingMigration returns the chains that support migration. func (k *Keeper) GetChainsSupportingMigration(ctx sdk.Context) []chains.Chain { return append(k.zetaObserverKeeper.GetSupportedForeignChainsByConsensus(ctx, chains.Consensus_ethereum), k.zetaObserverKeeper.GetSupportedForeignChainsByConsensus(ctx, chains.Consensus_bitcoin)...) diff --git a/x/crosschain/keeper/msg_server_update_tss_test.go b/x/crosschain/keeper/msg_server_update_tss_test.go index 352b3873fa..30cd475842 100644 --- a/x/crosschain/keeper/msg_server_update_tss_test.go +++ b/x/crosschain/keeper/msg_server_update_tss_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/chains" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" @@ -329,3 +330,22 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal(t, len(k.GetObserverKeeper().GetSupportedChains(ctx)), len(migrators)) }) } + +func TestKeeper_GetChainsSupportingMigration(t *testing.T) { + t.Run("should return supported chains", func(t *testing.T) { + k, ctx, _, zk := keepertest.CrosschainKeeperWithMocks(t, keepertest.CrosschainMockOptions{}) + chainList := chains.ExternalChainList([]chains.Chain{}) + var chainParamsList types.ChainParamsList + for _, chain := range chainList { + chainParamsList.ChainParams = append(chainParamsList.ChainParams, sample.ChainParamsSupported(chain.ChainId)) + } + zk.ObserverKeeper.SetChainParamsList(ctx, chainParamsList) + + chainsSupportingMigration := k.GetChainsSupportingMigration(ctx) + for _, chain := range chainsSupportingMigration { + require.NotEqual(t, chain.Consensus, chains.Consensus_solana_consensus) + require.NotEqual(t, chain.Consensus, chains.Consensus_op_stack) + require.NotEqual(t, chain.Consensus, chains.Consensus_tendermint) + } + }) +} From 56cb61dc67488c44e8b5b52716d844f507f50373 Mon Sep 17 00:00:00 2001 From: Tanmay Date: Thu, 25 Jul 2024 14:02:17 -0400 Subject: [PATCH 05/14] add unit test for chains supporting migration --- x/crosschain/keeper/msg_server_update_tss.go | 2 +- .../keeper/msg_server_update_tss_test.go | 3 ++- x/observer/keeper/chain_params_test.go | 20 +++++++++++++++---- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/x/crosschain/keeper/msg_server_update_tss.go b/x/crosschain/keeper/msg_server_update_tss.go index e7e14e13b9..e088412e9d 100644 --- a/x/crosschain/keeper/msg_server_update_tss.go +++ b/x/crosschain/keeper/msg_server_update_tss.go @@ -5,8 +5,8 @@ import ( errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/zeta-chain/zetacore/pkg/chains" + "github.com/zeta-chain/zetacore/pkg/chains" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" "github.com/zeta-chain/zetacore/x/crosschain/types" ) diff --git a/x/crosschain/keeper/msg_server_update_tss_test.go b/x/crosschain/keeper/msg_server_update_tss_test.go index 30cd475842..bc747e7389 100644 --- a/x/crosschain/keeper/msg_server_update_tss_test.go +++ b/x/crosschain/keeper/msg_server_update_tss_test.go @@ -332,7 +332,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { } func TestKeeper_GetChainsSupportingMigration(t *testing.T) { - t.Run("should return supported chains", func(t *testing.T) { + t.Run("should return only ethereum and bitcoin chains", func(t *testing.T) { k, ctx, _, zk := keepertest.CrosschainKeeperWithMocks(t, keepertest.CrosschainMockOptions{}) chainList := chains.ExternalChainList([]chains.Chain{}) var chainParamsList types.ChainParamsList @@ -346,6 +346,7 @@ func TestKeeper_GetChainsSupportingMigration(t *testing.T) { require.NotEqual(t, chain.Consensus, chains.Consensus_solana_consensus) require.NotEqual(t, chain.Consensus, chains.Consensus_op_stack) require.NotEqual(t, chain.Consensus, chains.Consensus_tendermint) + require.Equal(t, chain.IsExternal, true) } }) } diff --git a/x/observer/keeper/chain_params_test.go b/x/observer/keeper/chain_params_test.go index 395d1848c0..f4de631043 100644 --- a/x/observer/keeper/chain_params_test.go +++ b/x/observer/keeper/chain_params_test.go @@ -122,7 +122,10 @@ func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { chainList := chains.ExternalChainList([]chains.Chain{}) var chainParamsList types.ChainParamsList for _, chain := range chainList { - chainParamsList.ChainParams = append(chainParamsList.ChainParams, sample.ChainParamsSupported(chain.ChainId)) + chainParamsList.ChainParams = append( + chainParamsList.ChainParams, + sample.ChainParamsSupported(chain.ChainId), + ) } k.SetChainParamsList(ctx, chainParamsList) consensus := chains.Consensus_ethereum @@ -138,7 +141,10 @@ func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { chainList := chains.ExternalChainList([]chains.Chain{}) var chainParamsList types.ChainParamsList for _, chain := range chainList { - chainParamsList.ChainParams = append(chainParamsList.ChainParams, sample.ChainParamsSupported(chain.ChainId)) + chainParamsList.ChainParams = append( + chainParamsList.ChainParams, + sample.ChainParamsSupported(chain.ChainId), + ) } k.SetChainParamsList(ctx, chainParamsList) consensus := chains.Consensus_bitcoin @@ -153,7 +159,10 @@ func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { chainList := chains.ExternalChainList([]chains.Chain{}) var chainParamsList types.ChainParamsList for _, chain := range chainList { - chainParamsList.ChainParams = append(chainParamsList.ChainParams, sample.ChainParamsSupported(chain.ChainId)) + chainParamsList.ChainParams = append( + chainParamsList.ChainParams, + sample.ChainParamsSupported(chain.ChainId), + ) } k.SetChainParamsList(ctx, chainParamsList) consensus := chains.Consensus_solana_consensus @@ -175,7 +184,10 @@ func TestKeeper_GetSupportedForeignChains(t *testing.T) { chainList := chains.ExternalChainList([]chains.Chain{}) var chainParamsList types.ChainParamsList for _, chain := range chainList { - chainParamsList.ChainParams = append(chainParamsList.ChainParams, sample.ChainParamsSupported(chain.ChainId)) + chainParamsList.ChainParams = append( + chainParamsList.ChainParams, + sample.ChainParamsSupported(chain.ChainId), + ) } k.SetChainParamsList(ctx, chainParamsList) From e32c2a87997922c99ad8a9e7709df90d10050b69 Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 26 Jul 2024 10:55:13 -0400 Subject: [PATCH 06/14] refactor to add a filter chain function --- pkg/chains/chain_filters.go | 19 ++ pkg/chains/chains.go | 10 + testutil/keeper/mocks/crosschain/observer.go | 27 +++ x/crosschain/keeper/msg_server_update_tss.go | 25 +- .../keeper/msg_server_update_tss_test.go | 24 +- x/crosschain/types/expected_keepers.go | 1 + x/observer/keeper/chain_params.go | 19 ++ x/observer/keeper/chain_params_test.go | 213 +++++++++++++++--- 8 files changed, 295 insertions(+), 43 deletions(-) create mode 100644 pkg/chains/chain_filters.go diff --git a/pkg/chains/chain_filters.go b/pkg/chains/chain_filters.go new file mode 100644 index 0000000000..654868b224 --- /dev/null +++ b/pkg/chains/chain_filters.go @@ -0,0 +1,19 @@ +package chains + +type ChainFilter func(c Chain) bool + +func FilterExternalChains(c Chain) bool { + return c.IsExternal +} + +func FilterGatewayObserver(c Chain) bool { + return c.CctxGateway == CCTXGateway_observers +} + +func FilterConsensusEthereum(c Chain) bool { + return c.Consensus == Consensus_ethereum +} + +func FilterConsensusBitcoin(c Chain) bool { return c.Consensus == Consensus_bitcoin } + +func FilterConsensusSolana(c Chain) bool { return c.Consensus == Consensus_solana_consensus } diff --git a/pkg/chains/chains.go b/pkg/chains/chains.go index 2dcd0377ea..5dbc2ea6a1 100644 --- a/pkg/chains/chains.go +++ b/pkg/chains/chains.go @@ -390,6 +390,16 @@ func ChainListByConsensus(consensus Consensus, additionalChains []Chain) []Chain return chainList } +func ChainListByGateway(gateway CCTXGateway, additionalChains []Chain) []Chain { + var chainList []Chain + for _, chain := range CombineDefaultChainsList(additionalChains) { + if chain.CctxGateway == gateway { + chainList = append(chainList, chain) + } + } + return chainList +} + // ChainListForHeaderSupport returns a list of chains that support headers func ChainListForHeaderSupport(additionalChains []Chain) []Chain { var chainList []Chain diff --git a/testutil/keeper/mocks/crosschain/observer.go b/testutil/keeper/mocks/crosschain/observer.go index 63789912ce..189603fca0 100644 --- a/testutil/keeper/mocks/crosschain/observer.go +++ b/testutil/keeper/mocks/crosschain/observer.go @@ -109,6 +109,33 @@ func (_m *CrosschainObserverKeeper) CheckIfTssPubkeyHasBeenGenerated(ctx types.C return r0, r1 } +// FilterChains provides a mock function with given fields: ctx, filters +func (_m *CrosschainObserverKeeper) FilterChains(ctx types.Context, filters ...chains.ChainFilter) []chains.Chain { + _va := make([]interface{}, len(filters)) + for _i := range filters { + _va[_i] = filters[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for FilterChains") + } + + var r0 []chains.Chain + if rf, ok := ret.Get(0).(func(types.Context, ...chains.ChainFilter) []chains.Chain); ok { + r0 = rf(ctx, filters...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]chains.Chain) + } + } + + return r0 +} + // FindBallot provides a mock function with given fields: ctx, index, chain, observationType func (_m *CrosschainObserverKeeper) FindBallot(ctx types.Context, index string, chain chains.Chain, observationType observertypes.ObservationType) (observertypes.Ballot, bool, error) { ret := _m.Called(ctx, index, chain, observationType) diff --git a/x/crosschain/keeper/msg_server_update_tss.go b/x/crosschain/keeper/msg_server_update_tss.go index e088412e9d..b53eeea885 100644 --- a/x/crosschain/keeper/msg_server_update_tss.go +++ b/x/crosschain/keeper/msg_server_update_tss.go @@ -40,7 +40,7 @@ func (k msgServer) UpdateTssAddress( tssMigrators := k.zetaObserverKeeper.GetAllTssFundMigrators(ctx) // Each connected chain should have its own tss migrator - if len(k.GetChainsSupportingMigration(ctx)) != len(tssMigrators) { + if len(k.GetChainsSupportingTSSMigration(ctx)) != len(tssMigrators) { return nil, errorsmod.Wrap( types.ErrUnableToUpdateTss, "cannot update tss address incorrect number of migrations have been created and completed", @@ -73,8 +73,23 @@ func (k msgServer) UpdateTssAddress( return &types.MsgUpdateTssAddressResponse{}, nil } -// GetChainsSupportingMigration returns the chains that support migration. -func (k *Keeper) GetChainsSupportingMigration(ctx sdk.Context) []chains.Chain { - return append(k.zetaObserverKeeper.GetSupportedForeignChainsByConsensus(ctx, chains.Consensus_ethereum), - k.zetaObserverKeeper.GetSupportedForeignChainsByConsensus(ctx, chains.Consensus_bitcoin)...) +// GetChainsSupportingTSSMigration returns the chains that support tss migration. +// Chains that support tss migration are chains that have the following properties: +// 1. External chains +// 2. Gateway observer +// 3. Consensus is bitcoin or ethereum (Other consensus types are not supported) +func (k *Keeper) GetChainsSupportingTSSMigration(ctx sdk.Context) []chains.Chain { + evmChainsForTSSMigration := k.zetaObserverKeeper.FilterChains(ctx, []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusBitcoin, + }...) + + bitcoinChainsForTSSMigration := k.zetaObserverKeeper.FilterChains(ctx, []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusEthereum, + }...) + + return append(evmChainsForTSSMigration, bitcoinChainsForTSSMigration...) } diff --git a/x/crosschain/keeper/msg_server_update_tss_test.go b/x/crosschain/keeper/msg_server_update_tss_test.go index bc747e7389..faadffbc5f 100644 --- a/x/crosschain/keeper/msg_server_update_tss_test.go +++ b/x/crosschain/keeper/msg_server_update_tss_test.go @@ -66,7 +66,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSSHistory(ctx, tssNew) k.GetObserverKeeper().SetTSS(ctx, tssOld) - for _, chain := range k.GetChainsSupportingMigration(ctx) { + for _, chain := range k.GetChainsSupportingTSSMigration(ctx) { index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -79,7 +79,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetChainsSupportingMigration(ctx)), + len(k.GetChainsSupportingTSSMigration(ctx)), ) msg := crosschaintypes.MsgUpdateTssAddress{ @@ -110,7 +110,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSS(ctx, tssOld) - for _, chain := range k.GetChainsSupportingMigration(ctx) { + for _, chain := range k.GetChainsSupportingTSSMigration(ctx) { index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -123,7 +123,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetChainsSupportingMigration(ctx)), + len(k.GetChainsSupportingTSSMigration(ctx)), ) msg := crosschaintypes.MsgUpdateTssAddress{ @@ -140,7 +140,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetChainsSupportingMigration(ctx)), + len(k.GetChainsSupportingTSSMigration(ctx)), ) }) @@ -157,7 +157,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSS(ctx, tssOld) - for _, chain := range k.GetChainsSupportingMigration(ctx) { + for _, chain := range k.GetChainsSupportingTSSMigration(ctx) { index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -170,7 +170,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetChainsSupportingMigration(ctx)), + len(k.GetChainsSupportingTSSMigration(ctx)), ) msg := crosschaintypes.MsgUpdateTssAddress{ @@ -187,7 +187,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetChainsSupportingMigration(ctx)), + len(k.GetChainsSupportingTSSMigration(ctx)), ) }) @@ -208,7 +208,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { setSupportedChain(ctx, zk, getValidEthChainIDWithIndex(t, 0), getValidEthChainIDWithIndex(t, 1)) // set a single migrator while there are 2 supported chains - chain := k.GetChainsSupportingMigration(ctx)[0] + chain := k.GetChainsSupportingTSSMigration(ctx)[0] index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -255,7 +255,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSS(ctx, tssOld) setSupportedChain(ctx, zk, getValidEthChainIDWithIndex(t, 0), getValidEthChainIDWithIndex(t, 1)) - for _, chain := range k.GetChainsSupportingMigration(ctx) { + for _, chain := range k.GetChainsSupportingTSSMigration(ctx) { index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -302,7 +302,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSS(ctx, tssOld) setSupportedChain(ctx, zk, getValidEthChainIDWithIndex(t, 0), getValidEthChainIDWithIndex(t, 1)) - for _, chain := range k.GetChainsSupportingMigration(ctx) { + for _, chain := range k.GetChainsSupportingTSSMigration(ctx) { index := chain.ChainName.String() + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -341,7 +341,7 @@ func TestKeeper_GetChainsSupportingMigration(t *testing.T) { } zk.ObserverKeeper.SetChainParamsList(ctx, chainParamsList) - chainsSupportingMigration := k.GetChainsSupportingMigration(ctx) + chainsSupportingMigration := k.GetChainsSupportingTSSMigration(ctx) for _, chain := range chainsSupportingMigration { require.NotEqual(t, chain.Consensus, chains.Consensus_solana_consensus) require.NotEqual(t, chain.Consensus, chains.Consensus_op_stack) diff --git a/x/crosschain/types/expected_keepers.go b/x/crosschain/types/expected_keepers.go index b159f84148..3720c6bc8b 100644 --- a/x/crosschain/types/expected_keepers.go +++ b/x/crosschain/types/expected_keepers.go @@ -105,6 +105,7 @@ type ObserverKeeper interface { GetSupportedChains(ctx sdk.Context) []chains.Chain GetSupportedForeignChains(ctx sdk.Context) []chains.Chain GetSupportedForeignChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain + FilterChains(ctx sdk.Context, filters ...chains.ChainFilter) []chains.Chain } type FungibleKeeper interface { diff --git a/x/observer/keeper/chain_params.go b/x/observer/keeper/chain_params.go index 1c2cf4e7c4..a3578b50e5 100644 --- a/x/observer/keeper/chain_params.go +++ b/x/observer/keeper/chain_params.go @@ -76,6 +76,25 @@ func (k Keeper) GetSupportedChains(ctx sdk.Context) []chains.Chain { return c } +func (k Keeper) FilterChains(ctx sdk.Context, filters ...chains.ChainFilter) []chains.Chain { + // Retrieve all supported chains + supportedChains := k.GetSupportedChains(ctx) + + // Apply each filter to the list of supported chains + for _, filter := range filters { + var filteredChains []chains.Chain + for _, chain := range supportedChains { + if filter(chain) { + filteredChains = append(filteredChains, chain) + } + } + supportedChains = filteredChains + } + + // Return the filtered list of chains + return supportedChains +} + // GetSupportedChainsByConsensus returns the list of supported chains by consensus func (k Keeper) GetSupportedForeignChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain { allChains := k.GetSupportedChains(ctx) diff --git a/x/observer/keeper/chain_params_test.go b/x/observer/keeper/chain_params_test.go index f4de631043..a9a46abb60 100644 --- a/x/observer/keeper/chain_params_test.go +++ b/x/observer/keeper/chain_params_test.go @@ -111,14 +111,27 @@ func TestKeeper_GetSupportedChains(t *testing.T) { }) } -func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { - t.Run("return empty list if not chans are supported", func(t *testing.T) { +func TestKeeper_FilterChains(t *testing.T) { + t.Run("Filter external chains", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) - require.Empty(t, k.GetSupportedForeignChainsByConsensus(ctx, chains.Consensus_ethereum)) + + chainList := chains.ExternalChainList([]chains.Chain{}) + var chainParamsList types.ChainParamsList + for _, chain := range chainList { + chainParamsList.ChainParams = append( + chainParamsList.ChainParams, + sample.ChainParamsSupported(chain.ChainId), + ) + } + k.SetChainParamsList(ctx, chainParamsList) + + filteredChains := k.FilterChains(ctx, chains.FilterExternalChains) + require.ElementsMatch(t, chains.ExternalChainList([]chains.Chain{}), filteredChains) }) - t.Run("return list of supported chains for ethereum consensus", func(t *testing.T) { + t.Run("Filter gateway observer chains", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) + chainList := chains.ExternalChainList([]chains.Chain{}) var chainParamsList types.ChainParamsList for _, chain := range chainList { @@ -128,16 +141,31 @@ func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { ) } k.SetChainParamsList(ctx, chainParamsList) - consensus := chains.Consensus_ethereum - supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) - require.NotEmpty(t, supportedChainsList) + filteredChains := k.FilterChains(ctx, chains.FilterGatewayObserver) + require.ElementsMatch(t, chains.ChainListByGateway(chains.CCTXGateway_observers, []chains.Chain{}), filteredChains) + }) + + t.Run("Filter consensus ethereum chains", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) - require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) + chainList := chains.ExternalChainList([]chains.Chain{}) + var chainParamsList types.ChainParamsList + for _, chain := range chainList { + chainParamsList.ChainParams = append( + chainParamsList.ChainParams, + sample.ChainParamsSupported(chain.ChainId), + ) + } + k.SetChainParamsList(ctx, chainParamsList) + + filteredChains := k.FilterChains(ctx, chains.FilterConsensusEthereum) + require.ElementsMatch(t, chains.ChainListByConsensus(chains.Consensus_ethereum, []chains.Chain{}), filteredChains) }) - t.Run("return list of supported chains for bitcoin consensus", func(t *testing.T) { + t.Run("Filter consensus bitcoin chains", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) + chainList := chains.ExternalChainList([]chains.Chain{}) var chainParamsList types.ChainParamsList for _, chain := range chainList { @@ -147,15 +175,14 @@ func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { ) } k.SetChainParamsList(ctx, chainParamsList) - consensus := chains.Consensus_bitcoin - supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) - require.NotEmpty(t, supportedChainsList) - require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) + filteredChains := k.FilterChains(ctx, chains.FilterConsensusBitcoin) + require.ElementsMatch(t, chains.ChainListByConsensus(chains.Consensus_bitcoin, []chains.Chain{}), filteredChains) }) - t.Run("return list of supported chains for solana consensus", func(t *testing.T) { + t.Run("Filter consensus solana chains", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) + chainList := chains.ExternalChainList([]chains.Chain{}) var chainParamsList types.ChainParamsList for _, chain := range chainList { @@ -165,22 +192,38 @@ func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { ) } k.SetChainParamsList(ctx, chainParamsList) - consensus := chains.Consensus_solana_consensus - supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) - require.NotEmpty(t, supportedChainsList) - require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) + filteredChains := k.FilterChains(ctx, chains.FilterConsensusSolana) + require.ElementsMatch(t, chains.ChainListByConsensus(chains.Consensus_solana_consensus, []chains.Chain{}), filteredChains) }) -} -func TestKeeper_GetSupportedForeignChains(t *testing.T) { - t.Run("return empty list if not chans are supported", func(t *testing.T) { + t.Run("Apply multiple filters external chains with gateway observer", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) - require.Empty(t, k.GetSupportedForeignChains(ctx)) + + chainList := chains.ExternalChainList([]chains.Chain{}) + var chainParamsList types.ChainParamsList + for _, chain := range chainList { + chainParamsList.ChainParams = append( + chainParamsList.ChainParams, + sample.ChainParamsSupported(chain.ChainId), + ) + } + k.SetChainParamsList(ctx, chainParamsList) + + filteredChains := k.FilterChains(ctx, chains.FilterExternalChains, chains.FilterGatewayObserver) + externalChains := chains.ExternalChainList([]chains.Chain{}) + var gatewayObserverChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers { + gatewayObserverChains = append(gatewayObserverChains, chain) + } + } + require.ElementsMatch(t, gatewayObserverChains, filteredChains) }) - t.Run("return list of supported chains", func(t *testing.T) { + t.Run("Apply multiple filters external chains with gateway observer and consensus ethereum and bitcoin", func(t *testing.T) { k, ctx, _, _ := keepertest.ObserverKeeper(t) + chainList := chains.ExternalChainList([]chains.Chain{}) var chainParamsList types.ChainParamsList for _, chain := range chainList { @@ -191,13 +234,131 @@ func TestKeeper_GetSupportedForeignChains(t *testing.T) { } k.SetChainParamsList(ctx, chainParamsList) - supportedChainsList := k.GetSupportedForeignChains(ctx) - require.NotEmpty(t, supportedChainsList) + filteredChainsEVM := k.FilterChains(ctx, chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusEthereum) + filteredChainsBitcoin := k.FilterChains(ctx, chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusBitcoin) + externalChains := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers && (chain.Consensus == chains.Consensus_ethereum || chain.Consensus == chains.Consensus_bitcoin) { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + require.ElementsMatch(t, filterMultipleChains, append(filteredChainsEVM, filteredChainsBitcoin...)) + }) + + t.Run("Apply multiple filters external chains with gateway observer and consensus ethereum and bitcoin in different order", func(t *testing.T) { + k, ctx, _, _ := keepertest.ObserverKeeper(t) + + chainList := chains.ExternalChainList([]chains.Chain{}) + var chainParamsList types.ChainParamsList + for _, chain := range chainList { + chainParamsList.ChainParams = append( + chainParamsList.ChainParams, + sample.ChainParamsSupported(chain.ChainId), + ) + } + k.SetChainParamsList(ctx, chainParamsList) - require.ElementsMatch(t, getAllForeignChains(), supportedChainsList) + filteredChainsEVM := k.FilterChains(ctx, chains.FilterGatewayObserver, chains.FilterConsensusEthereum, chains.FilterExternalChains) + filteredChainsBitcoin := k.FilterChains(ctx, chains.FilterExternalChains, chains.FilterConsensusBitcoin, chains.FilterGatewayObserver) + externalChains := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers && (chain.Consensus == chains.Consensus_ethereum || chain.Consensus == chains.Consensus_bitcoin) { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + require.ElementsMatch(t, filterMultipleChains, append(filteredChainsEVM, filteredChainsBitcoin...)) }) } +//func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { +// t.Run("return empty list if not chans are supported", func(t *testing.T) { +// k, ctx, _, _ := keepertest.ObserverKeeper(t) +// require.Empty(t, k.GetSupportedForeignChainsByConsensus(ctx, chains.Consensus_ethereum)) +// }) +// +// t.Run("return list of supported chains for ethereum consensus", func(t *testing.T) { +// k, ctx, _, _ := keepertest.ObserverKeeper(t) +// chainList := chains.ExternalChainList([]chains.Chain{}) +// var chainParamsList types.ChainParamsList +// for _, chain := range chainList { +// chainParamsList.ChainParams = append( +// chainParamsList.ChainParams, +// sample.ChainParamsSupported(chain.ChainId), +// ) +// } +// k.SetChainParamsList(ctx, chainParamsList) +// consensus := chains.Consensus_ethereum +// +// supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) +// require.NotEmpty(t, supportedChainsList) +// +// require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) +// }) +// +// t.Run("return list of supported chains for bitcoin consensus", func(t *testing.T) { +// k, ctx, _, _ := keepertest.ObserverKeeper(t) +// chainList := chains.ExternalChainList([]chains.Chain{}) +// var chainParamsList types.ChainParamsList +// for _, chain := range chainList { +// chainParamsList.ChainParams = append( +// chainParamsList.ChainParams, +// sample.ChainParamsSupported(chain.ChainId), +// ) +// } +// k.SetChainParamsList(ctx, chainParamsList) +// consensus := chains.Consensus_bitcoin +// +// supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) +// require.NotEmpty(t, supportedChainsList) +// require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) +// }) +// +// t.Run("return list of supported chains for solana consensus", func(t *testing.T) { +// k, ctx, _, _ := keepertest.ObserverKeeper(t) +// chainList := chains.ExternalChainList([]chains.Chain{}) +// var chainParamsList types.ChainParamsList +// for _, chain := range chainList { +// chainParamsList.ChainParams = append( +// chainParamsList.ChainParams, +// sample.ChainParamsSupported(chain.ChainId), +// ) +// } +// k.SetChainParamsList(ctx, chainParamsList) +// consensus := chains.Consensus_solana_consensus +// +// supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) +// require.NotEmpty(t, supportedChainsList) +// require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) +// }) +//} +// +//func TestKeeper_GetSupportedForeignChains(t *testing.T) { +// t.Run("return empty list if not chans are supported", func(t *testing.T) { +// k, ctx, _, _ := keepertest.ObserverKeeper(t) +// require.Empty(t, k.GetSupportedForeignChains(ctx)) +// }) +// +// t.Run("return list of supported chains", func(t *testing.T) { +// k, ctx, _, _ := keepertest.ObserverKeeper(t) +// chainList := chains.ExternalChainList([]chains.Chain{}) +// var chainParamsList types.ChainParamsList +// for _, chain := range chainList { +// chainParamsList.ChainParams = append( +// chainParamsList.ChainParams, +// sample.ChainParamsSupported(chain.ChainId), +// ) +// } +// k.SetChainParamsList(ctx, chainParamsList) +// +// supportedChainsList := k.GetSupportedForeignChains(ctx) +// require.NotEmpty(t, supportedChainsList) +// +// require.ElementsMatch(t, getAllForeignChains(), supportedChainsList) +// }) +//} + func getAllForeignChains() []chains.Chain { return chains.ExternalChainList([]chains.Chain{}) } From 66367e7fe12e2a22939c9679fe96d03b716acb95 Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 26 Jul 2024 11:18:53 -0400 Subject: [PATCH 07/14] move filter to pkg --- pkg/chains/chain_filters.go | 17 +++++++++++++++++ .../keeper/grpc_query_cctx_rate_limit.go | 2 +- x/crosschain/keeper/msg_server_update_tss.go | 5 +++-- x/crosschain/types/expected_keepers.go | 1 - x/observer/keeper/chain_params.go | 19 ------------------- 5 files changed, 21 insertions(+), 23 deletions(-) diff --git a/pkg/chains/chain_filters.go b/pkg/chains/chain_filters.go index 654868b224..192261248c 100644 --- a/pkg/chains/chain_filters.go +++ b/pkg/chains/chain_filters.go @@ -17,3 +17,20 @@ func FilterConsensusEthereum(c Chain) bool { func FilterConsensusBitcoin(c Chain) bool { return c.Consensus == Consensus_bitcoin } func FilterConsensusSolana(c Chain) bool { return c.Consensus == Consensus_solana_consensus } + +func FilterChains(chainList []Chain, filters ...ChainFilter) []Chain { + + // Apply each filter to the list of supported chains + for _, filter := range filters { + var filteredChains []Chain + for _, chain := range chainList { + if filter(chain) { + filteredChains = append(filteredChains, chain) + } + } + chainList = filteredChains + } + + // Return the filtered list of chains + return chainList +} diff --git a/x/crosschain/keeper/grpc_query_cctx_rate_limit.go b/x/crosschain/keeper/grpc_query_cctx_rate_limit.go index f2406c7a13..cf07ddabef 100644 --- a/x/crosschain/keeper/grpc_query_cctx_rate_limit.go +++ b/x/crosschain/keeper/grpc_query_cctx_rate_limit.go @@ -74,7 +74,7 @@ func (k Keeper) RateLimiterInput( } // get foreign chains and conversion rates of foreign coins - chains := k.zetaObserverKeeper.GetSupportedForeignChains(ctx) + chains := k.zetaObserverKeeper.FilterChains(ctx, chains.FilterExternalChains) _, assetRates, found := k.GetRateLimiterAssetRateList(ctx) if !found { return nil, status.Error(codes.Internal, "asset rates not found") diff --git a/x/crosschain/keeper/msg_server_update_tss.go b/x/crosschain/keeper/msg_server_update_tss.go index b53eeea885..6c770797fd 100644 --- a/x/crosschain/keeper/msg_server_update_tss.go +++ b/x/crosschain/keeper/msg_server_update_tss.go @@ -79,13 +79,14 @@ func (k msgServer) UpdateTssAddress( // 2. Gateway observer // 3. Consensus is bitcoin or ethereum (Other consensus types are not supported) func (k *Keeper) GetChainsSupportingTSSMigration(ctx sdk.Context) []chains.Chain { - evmChainsForTSSMigration := k.zetaObserverKeeper.FilterChains(ctx, []chains.ChainFilter{ + supportedChains := k.zetaObserverKeeper.GetSupportedChains(ctx) + evmChainsForTSSMigration := chains.FilterChains(supportedChains, []chains.ChainFilter{ chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusBitcoin, }...) - bitcoinChainsForTSSMigration := k.zetaObserverKeeper.FilterChains(ctx, []chains.ChainFilter{ + bitcoinChainsForTSSMigration := chains.FilterChains(supportedChains, []chains.ChainFilter{ chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusEthereum, diff --git a/x/crosschain/types/expected_keepers.go b/x/crosschain/types/expected_keepers.go index 3720c6bc8b..a256d4b2d9 100644 --- a/x/crosschain/types/expected_keepers.go +++ b/x/crosschain/types/expected_keepers.go @@ -104,7 +104,6 @@ type ObserverKeeper interface { GetSupportedChainFromChainID(ctx sdk.Context, chainID int64) (chains.Chain, bool) GetSupportedChains(ctx sdk.Context) []chains.Chain GetSupportedForeignChains(ctx sdk.Context) []chains.Chain - GetSupportedForeignChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain FilterChains(ctx sdk.Context, filters ...chains.ChainFilter) []chains.Chain } diff --git a/x/observer/keeper/chain_params.go b/x/observer/keeper/chain_params.go index a3578b50e5..1c2cf4e7c4 100644 --- a/x/observer/keeper/chain_params.go +++ b/x/observer/keeper/chain_params.go @@ -76,25 +76,6 @@ func (k Keeper) GetSupportedChains(ctx sdk.Context) []chains.Chain { return c } -func (k Keeper) FilterChains(ctx sdk.Context, filters ...chains.ChainFilter) []chains.Chain { - // Retrieve all supported chains - supportedChains := k.GetSupportedChains(ctx) - - // Apply each filter to the list of supported chains - for _, filter := range filters { - var filteredChains []chains.Chain - for _, chain := range supportedChains { - if filter(chain) { - filteredChains = append(filteredChains, chain) - } - } - supportedChains = filteredChains - } - - // Return the filtered list of chains - return supportedChains -} - // GetSupportedChainsByConsensus returns the list of supported chains by consensus func (k Keeper) GetSupportedForeignChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain { allChains := k.GetSupportedChains(ctx) From 12e267dd27e177561f8db9fdb5bb5dc1130fc4cf Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 26 Jul 2024 12:01:17 -0400 Subject: [PATCH 08/14] add more unit tests --- pkg/chains/chain_filters.go | 25 +- pkg/chains/chain_filters_test.go | 187 +++++++++++++ testutil/keeper/mocks/crosschain/observer.go | 67 ----- .../keeper/grpc_query_cctx_rate_limit.go | 9 +- x/crosschain/keeper/msg_server_update_tss.go | 23 +- .../keeper/msg_server_update_tss_test.go | 5 +- x/crosschain/types/expected_keepers.go | 2 - x/observer/keeper/chain_params.go | 26 -- x/observer/keeper/chain_params_test.go | 264 ------------------ 9 files changed, 231 insertions(+), 377 deletions(-) create mode 100644 pkg/chains/chain_filters_test.go diff --git a/pkg/chains/chain_filters.go b/pkg/chains/chain_filters.go index 192261248c..c46a9ec475 100644 --- a/pkg/chains/chain_filters.go +++ b/pkg/chains/chain_filters.go @@ -1,25 +1,31 @@ package chains +// ChainFilter is a function that filters chains based on some criteria type ChainFilter func(c Chain) bool +// FilterExternalChains filters chains that are external func FilterExternalChains(c Chain) bool { return c.IsExternal } +// FilterGatewayObserver filters chains that are gateway observers func FilterGatewayObserver(c Chain) bool { return c.CctxGateway == CCTXGateway_observers } +// FilterConsensusEthereum filters chains that have the ethereum consensus func FilterConsensusEthereum(c Chain) bool { return c.Consensus == Consensus_ethereum } +// FilterConsensusBitcoin filters chains that have the bitcoin consensus func FilterConsensusBitcoin(c Chain) bool { return c.Consensus == Consensus_bitcoin } +// FilterConsensusSolana filters chains that have the solana consensus func FilterConsensusSolana(c Chain) bool { return c.Consensus == Consensus_solana_consensus } +// FilterChains applies a list of filters to a list of chains func FilterChains(chainList []Chain, filters ...ChainFilter) []Chain { - // Apply each filter to the list of supported chains for _, filter := range filters { var filteredChains []Chain @@ -34,3 +40,20 @@ func FilterChains(chainList []Chain, filters ...ChainFilter) []Chain { // Return the filtered list of chains return chainList } + +func CombineFilterChains(chainLists ...[]Chain) []Chain { + chainMap := make(map[Chain]bool) + var combinedChains []Chain + + // Add chains from each slice to remove duplicates + for _, chains := range chainLists { + for _, chain := range chains { + if !chainMap[chain] { + chainMap[chain] = true + combinedChains = append(combinedChains, chain) + } + } + } + + return combinedChains +} diff --git a/pkg/chains/chain_filters_test.go b/pkg/chains/chain_filters_test.go new file mode 100644 index 0000000000..b296f68b87 --- /dev/null +++ b/pkg/chains/chain_filters_test.go @@ -0,0 +1,187 @@ +package chains_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/chains" +) + +func TestFilterChains(t *testing.T) { + tt := []struct { + name string + filters []chains.ChainFilter + expected func() []chains.Chain + }{ + { + name: "Filter external chains", + filters: []chains.ChainFilter{chains.FilterExternalChains}, + expected: func() []chains.Chain { + return chains.ExternalChainList([]chains.Chain{}) + }, + }, + { + name: "Filter gateway observer chains", + filters: []chains.ChainFilter{chains.FilterGatewayObserver}, + expected: func() []chains.Chain { + return chains.ChainListByGateway(chains.CCTXGateway_observers, []chains.Chain{}) + }, + }, + { + name: "Filter consensus ethereum chains", + filters: []chains.ChainFilter{chains.FilterConsensusEthereum}, + expected: func() []chains.Chain { + return chains.ChainListByConsensus(chains.Consensus_ethereum, []chains.Chain{}) + }, + }, + { + name: "Filter consensus bitcoin chains", + filters: []chains.ChainFilter{chains.FilterConsensusBitcoin}, + expected: func() []chains.Chain { + return chains.ChainListByConsensus(chains.Consensus_bitcoin, []chains.Chain{}) + }, + }, + { + name: "Filter consensus solana chains", + filters: []chains.ChainFilter{chains.FilterConsensusSolana}, + expected: func() []chains.Chain { + return chains.ChainListByConsensus(chains.Consensus_solana_consensus, []chains.Chain{}) + }, + }, + { + name: "Apply multiple filters external chains and gateway observer", + filters: []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver}, + expected: func() []chains.Chain { + externalChains := chains.ExternalChainList([]chains.Chain{}) + var gatewayObserverChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers { + gatewayObserverChains = append(gatewayObserverChains, chain) + } + } + return gatewayObserverChains + }, + }, + { + name: "Apply multiple filters external chains with gateway observer and consensus ethereum", + filters: []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusEthereum, + }, + expected: func() []chains.Chain { + externalChains := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers && + chain.Consensus == chains.Consensus_ethereum { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + return filterMultipleChains + }, + }, + { + name: "Apply multiple filters external chains with gateway observer and consensus bitcoin", + filters: []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusBitcoin, + }, + expected: func() []chains.Chain { + externalChains := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers && + chain.Consensus == chains.Consensus_bitcoin { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + return filterMultipleChains + }, + }, + { + name: "Test multiple filters in random order", + filters: []chains.ChainFilter{ + chains.FilterGatewayObserver, + chains.FilterConsensusEthereum, + chains.FilterExternalChains, + }, + expected: func() []chains.Chain { + externalChains := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers && + chain.Consensus == chains.Consensus_ethereum { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + return filterMultipleChains + }, + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + chainList := chains.ExternalChainList([]chains.Chain{}) + filteredChains := chains.FilterChains(chainList, tc.filters...) + require.ElementsMatch(t, tc.expected(), filteredChains) + }) + } +} + +func TestCombineFilterChains(t *testing.T) { + tt := []struct { + name string + chainLists func() [][]chains.Chain + expected func() []chains.Chain + }{ + { + name: "test support TSS migration filter", + chainLists: func() [][]chains.Chain { + return [][]chains.Chain{ + chains.FilterChains(chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusEthereum}...), + chains.FilterChains(chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusBitcoin}...), + } + }, + expected: func() []chains.Chain { + chainList := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range chainList { + if chain.CctxGateway == chains.CCTXGateway_observers && + (chain.Consensus == chains.Consensus_ethereum || chain.Consensus == chains.Consensus_bitcoin) { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + return filterMultipleChains + }, + }, + { + name: "test support TSS migration filter with solana", + chainLists: func() [][]chains.Chain { + return [][]chains.Chain{ + chains.FilterChains(chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusEthereum}...), + chains.FilterChains(chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusBitcoin}...), + chains.FilterChains(chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusSolana}...), + } + }, + expected: func() []chains.Chain { + chainList := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range chainList { + if chain.CctxGateway == chains.CCTXGateway_observers && + (chain.Consensus == chains.Consensus_ethereum || chain.Consensus == chains.Consensus_bitcoin || chain.Consensus == chains.Consensus_solana_consensus) { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + return filterMultipleChains + }, + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + chainLists := tc.chainLists() + combinedChains := chains.CombineFilterChains(chainLists...) + require.ElementsMatch(t, tc.expected(), combinedChains) + }) + } +} diff --git a/testutil/keeper/mocks/crosschain/observer.go b/testutil/keeper/mocks/crosschain/observer.go index 189603fca0..c90c15c3a6 100644 --- a/testutil/keeper/mocks/crosschain/observer.go +++ b/testutil/keeper/mocks/crosschain/observer.go @@ -109,33 +109,6 @@ func (_m *CrosschainObserverKeeper) CheckIfTssPubkeyHasBeenGenerated(ctx types.C return r0, r1 } -// FilterChains provides a mock function with given fields: ctx, filters -func (_m *CrosschainObserverKeeper) FilterChains(ctx types.Context, filters ...chains.ChainFilter) []chains.Chain { - _va := make([]interface{}, len(filters)) - for _i := range filters { - _va[_i] = filters[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for FilterChains") - } - - var r0 []chains.Chain - if rf, ok := ret.Get(0).(func(types.Context, ...chains.ChainFilter) []chains.Chain); ok { - r0 = rf(ctx, filters...) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]chains.Chain) - } - } - - return r0 -} - // FindBallot provides a mock function with given fields: ctx, index, chain, observationType func (_m *CrosschainObserverKeeper) FindBallot(ctx types.Context, index string, chain chains.Chain, observationType observertypes.ObservationType) (observertypes.Ballot, bool, error) { ret := _m.Called(ctx, index, chain, observationType) @@ -631,46 +604,6 @@ func (_m *CrosschainObserverKeeper) GetSupportedChains(ctx types.Context) []chai return r0 } -// GetSupportedForeignChains provides a mock function with given fields: ctx -func (_m *CrosschainObserverKeeper) GetSupportedForeignChains(ctx types.Context) []chains.Chain { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for GetSupportedForeignChains") - } - - var r0 []chains.Chain - if rf, ok := ret.Get(0).(func(types.Context) []chains.Chain); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]chains.Chain) - } - } - - return r0 -} - -// GetSupportedForeignChainsByConsensus provides a mock function with given fields: ctx, consensus -func (_m *CrosschainObserverKeeper) GetSupportedForeignChainsByConsensus(ctx types.Context, consensus chains.Consensus) []chains.Chain { - ret := _m.Called(ctx, consensus) - - if len(ret) == 0 { - panic("no return value specified for GetSupportedForeignChainsByConsensus") - } - - var r0 []chains.Chain - if rf, ok := ret.Get(0).(func(types.Context, chains.Consensus) []chains.Chain); ok { - r0 = rf(ctx, consensus) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]chains.Chain) - } - } - - return r0 -} - // GetTSS provides a mock function with given fields: ctx func (_m *CrosschainObserverKeeper) GetTSS(ctx types.Context) (observertypes.TSS, bool) { ret := _m.Called(ctx) diff --git a/x/crosschain/keeper/grpc_query_cctx_rate_limit.go b/x/crosschain/keeper/grpc_query_cctx_rate_limit.go index cf07ddabef..ce4af73732 100644 --- a/x/crosschain/keeper/grpc_query_cctx_rate_limit.go +++ b/x/crosschain/keeper/grpc_query_cctx_rate_limit.go @@ -74,7 +74,8 @@ func (k Keeper) RateLimiterInput( } // get foreign chains and conversion rates of foreign coins - chains := k.zetaObserverKeeper.FilterChains(ctx, chains.FilterExternalChains) + externalSupportedChains := chains.FilterChains(k.GetObserverKeeper().GetSupportedChains(ctx), chains.FilterExternalChains) + _, assetRates, found := k.GetRateLimiterAssetRateList(ctx) if !found { return nil, status.Error(codes.Internal, "asset rates not found") @@ -84,7 +85,7 @@ func (k Keeper) RateLimiterInput( // query pending nonces of each foreign chain and get the lowest height of the pending cctxs lowestPendingCctxHeight := int64(0) pendingNoncesMap := make(map[int64]observertypes.PendingNonces) - for _, chain := range chains { + for _, chain := range externalSupportedChains { pendingNonces, found := k.GetObserverKeeper().GetPendingNonces(ctx, tss.TssPubkey, chain.ChainId) if !found { return nil, status.Error(codes.Internal, "pending nonces not found") @@ -113,7 +114,7 @@ func (k Keeper) RateLimiterInput( cctxsPending := make([]*types.CrossChainTx, 0) // query backwards for pending cctxs of each foreign chain - for _, chain := range chains { + for _, chain := range externalSupportedChains { // we should at least query 1000 prior to find any pending cctx that we might have missed // this logic is needed because a confirmation of higher nonce will automatically update the p.NonceLow // therefore might mask some lower nonce cctx that is still pending. @@ -205,7 +206,7 @@ func (k Keeper) ListPendingCctxWithinRateLimit( totalPending := uint64(0) totalWithdrawInAzeta := sdkmath.NewInt(0) cctxs := make([]*types.CrossChainTx, 0) - foreignChains := k.zetaObserverKeeper.GetSupportedForeignChains(ctx) + foreignChains := chains.FilterChains(k.zetaObserverKeeper.GetSupportedChains(ctx), chains.FilterExternalChains) // check rate limit flags to decide if we should apply rate limit applyLimit := true diff --git a/x/crosschain/keeper/msg_server_update_tss.go b/x/crosschain/keeper/msg_server_update_tss.go index 6c770797fd..63da5d65eb 100644 --- a/x/crosschain/keeper/msg_server_update_tss.go +++ b/x/crosschain/keeper/msg_server_update_tss.go @@ -80,17 +80,16 @@ func (k msgServer) UpdateTssAddress( // 3. Consensus is bitcoin or ethereum (Other consensus types are not supported) func (k *Keeper) GetChainsSupportingTSSMigration(ctx sdk.Context) []chains.Chain { supportedChains := k.zetaObserverKeeper.GetSupportedChains(ctx) - evmChainsForTSSMigration := chains.FilterChains(supportedChains, []chains.ChainFilter{ - chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusBitcoin, + return chains.CombineFilterChains([][]chains.Chain{ + chains.FilterChains(supportedChains, []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusEthereum, + }...), + chains.FilterChains(supportedChains, []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusBitcoin, + }...), }...) - - bitcoinChainsForTSSMigration := chains.FilterChains(supportedChains, []chains.ChainFilter{ - chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusEthereum, - }...) - - return append(evmChainsForTSSMigration, bitcoinChainsForTSSMigration...) } diff --git a/x/crosschain/keeper/msg_server_update_tss_test.go b/x/crosschain/keeper/msg_server_update_tss_test.go index faadffbc5f..7334e5ade8 100644 --- a/x/crosschain/keeper/msg_server_update_tss_test.go +++ b/x/crosschain/keeper/msg_server_update_tss_test.go @@ -337,7 +337,10 @@ func TestKeeper_GetChainsSupportingMigration(t *testing.T) { chainList := chains.ExternalChainList([]chains.Chain{}) var chainParamsList types.ChainParamsList for _, chain := range chainList { - chainParamsList.ChainParams = append(chainParamsList.ChainParams, sample.ChainParamsSupported(chain.ChainId)) + chainParamsList.ChainParams = append( + chainParamsList.ChainParams, + sample.ChainParamsSupported(chain.ChainId), + ) } zk.ObserverKeeper.SetChainParamsList(ctx, chainParamsList) diff --git a/x/crosschain/types/expected_keepers.go b/x/crosschain/types/expected_keepers.go index a256d4b2d9..1422538b2f 100644 --- a/x/crosschain/types/expected_keepers.go +++ b/x/crosschain/types/expected_keepers.go @@ -103,8 +103,6 @@ type ObserverKeeper interface { ) (bool, bool, observertypes.Ballot, string, error) GetSupportedChainFromChainID(ctx sdk.Context, chainID int64) (chains.Chain, bool) GetSupportedChains(ctx sdk.Context) []chains.Chain - GetSupportedForeignChains(ctx sdk.Context) []chains.Chain - FilterChains(ctx sdk.Context, filters ...chains.ChainFilter) []chains.Chain } type FungibleKeeper interface { diff --git a/x/observer/keeper/chain_params.go b/x/observer/keeper/chain_params.go index 1c2cf4e7c4..9dc862dd98 100644 --- a/x/observer/keeper/chain_params.go +++ b/x/observer/keeper/chain_params.go @@ -75,29 +75,3 @@ func (k Keeper) GetSupportedChains(ctx sdk.Context) []chains.Chain { } return c } - -// GetSupportedChainsByConsensus returns the list of supported chains by consensus -func (k Keeper) GetSupportedForeignChainsByConsensus(ctx sdk.Context, consensus chains.Consensus) []chains.Chain { - allChains := k.GetSupportedChains(ctx) - - foreignChains := make([]chains.Chain, 0) - for _, chain := range allChains { - if !chain.IsZetaChain() && chain.GetConsensus() == consensus { - foreignChains = append(foreignChains, chain) - } - } - return foreignChains -} - -// GetSupportedForeignChains returns the list of supported foreign chains -func (k Keeper) GetSupportedForeignChains(ctx sdk.Context) []chains.Chain { - allChains := k.GetSupportedChains(ctx) - - foreignChains := make([]chains.Chain, 0) - for _, chain := range allChains { - if !chain.IsZetaChain() { - foreignChains = append(foreignChains, chain) - } - } - return foreignChains -} diff --git a/x/observer/keeper/chain_params_test.go b/x/observer/keeper/chain_params_test.go index a9a46abb60..733fafc0b8 100644 --- a/x/observer/keeper/chain_params_test.go +++ b/x/observer/keeper/chain_params_test.go @@ -110,267 +110,3 @@ func TestKeeper_GetSupportedChains(t *testing.T) { require.EqualValues(t, supported4.ChainId, supportedChains[3].ChainId) }) } - -func TestKeeper_FilterChains(t *testing.T) { - t.Run("Filter external chains", func(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - - chainList := chains.ExternalChainList([]chains.Chain{}) - var chainParamsList types.ChainParamsList - for _, chain := range chainList { - chainParamsList.ChainParams = append( - chainParamsList.ChainParams, - sample.ChainParamsSupported(chain.ChainId), - ) - } - k.SetChainParamsList(ctx, chainParamsList) - - filteredChains := k.FilterChains(ctx, chains.FilterExternalChains) - require.ElementsMatch(t, chains.ExternalChainList([]chains.Chain{}), filteredChains) - }) - - t.Run("Filter gateway observer chains", func(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - - chainList := chains.ExternalChainList([]chains.Chain{}) - var chainParamsList types.ChainParamsList - for _, chain := range chainList { - chainParamsList.ChainParams = append( - chainParamsList.ChainParams, - sample.ChainParamsSupported(chain.ChainId), - ) - } - k.SetChainParamsList(ctx, chainParamsList) - - filteredChains := k.FilterChains(ctx, chains.FilterGatewayObserver) - require.ElementsMatch(t, chains.ChainListByGateway(chains.CCTXGateway_observers, []chains.Chain{}), filteredChains) - }) - - t.Run("Filter consensus ethereum chains", func(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - - chainList := chains.ExternalChainList([]chains.Chain{}) - var chainParamsList types.ChainParamsList - for _, chain := range chainList { - chainParamsList.ChainParams = append( - chainParamsList.ChainParams, - sample.ChainParamsSupported(chain.ChainId), - ) - } - k.SetChainParamsList(ctx, chainParamsList) - - filteredChains := k.FilterChains(ctx, chains.FilterConsensusEthereum) - require.ElementsMatch(t, chains.ChainListByConsensus(chains.Consensus_ethereum, []chains.Chain{}), filteredChains) - }) - - t.Run("Filter consensus bitcoin chains", func(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - - chainList := chains.ExternalChainList([]chains.Chain{}) - var chainParamsList types.ChainParamsList - for _, chain := range chainList { - chainParamsList.ChainParams = append( - chainParamsList.ChainParams, - sample.ChainParamsSupported(chain.ChainId), - ) - } - k.SetChainParamsList(ctx, chainParamsList) - - filteredChains := k.FilterChains(ctx, chains.FilterConsensusBitcoin) - require.ElementsMatch(t, chains.ChainListByConsensus(chains.Consensus_bitcoin, []chains.Chain{}), filteredChains) - }) - - t.Run("Filter consensus solana chains", func(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - - chainList := chains.ExternalChainList([]chains.Chain{}) - var chainParamsList types.ChainParamsList - for _, chain := range chainList { - chainParamsList.ChainParams = append( - chainParamsList.ChainParams, - sample.ChainParamsSupported(chain.ChainId), - ) - } - k.SetChainParamsList(ctx, chainParamsList) - - filteredChains := k.FilterChains(ctx, chains.FilterConsensusSolana) - require.ElementsMatch(t, chains.ChainListByConsensus(chains.Consensus_solana_consensus, []chains.Chain{}), filteredChains) - }) - - t.Run("Apply multiple filters external chains with gateway observer", func(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - - chainList := chains.ExternalChainList([]chains.Chain{}) - var chainParamsList types.ChainParamsList - for _, chain := range chainList { - chainParamsList.ChainParams = append( - chainParamsList.ChainParams, - sample.ChainParamsSupported(chain.ChainId), - ) - } - k.SetChainParamsList(ctx, chainParamsList) - - filteredChains := k.FilterChains(ctx, chains.FilterExternalChains, chains.FilterGatewayObserver) - externalChains := chains.ExternalChainList([]chains.Chain{}) - var gatewayObserverChains []chains.Chain - for _, chain := range externalChains { - if chain.CctxGateway == chains.CCTXGateway_observers { - gatewayObserverChains = append(gatewayObserverChains, chain) - } - } - require.ElementsMatch(t, gatewayObserverChains, filteredChains) - }) - - t.Run("Apply multiple filters external chains with gateway observer and consensus ethereum and bitcoin", func(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - - chainList := chains.ExternalChainList([]chains.Chain{}) - var chainParamsList types.ChainParamsList - for _, chain := range chainList { - chainParamsList.ChainParams = append( - chainParamsList.ChainParams, - sample.ChainParamsSupported(chain.ChainId), - ) - } - k.SetChainParamsList(ctx, chainParamsList) - - filteredChainsEVM := k.FilterChains(ctx, chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusEthereum) - filteredChainsBitcoin := k.FilterChains(ctx, chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusBitcoin) - externalChains := chains.ExternalChainList([]chains.Chain{}) - var filterMultipleChains []chains.Chain - for _, chain := range externalChains { - if chain.CctxGateway == chains.CCTXGateway_observers && (chain.Consensus == chains.Consensus_ethereum || chain.Consensus == chains.Consensus_bitcoin) { - filterMultipleChains = append(filterMultipleChains, chain) - } - } - require.ElementsMatch(t, filterMultipleChains, append(filteredChainsEVM, filteredChainsBitcoin...)) - }) - - t.Run("Apply multiple filters external chains with gateway observer and consensus ethereum and bitcoin in different order", func(t *testing.T) { - k, ctx, _, _ := keepertest.ObserverKeeper(t) - - chainList := chains.ExternalChainList([]chains.Chain{}) - var chainParamsList types.ChainParamsList - for _, chain := range chainList { - chainParamsList.ChainParams = append( - chainParamsList.ChainParams, - sample.ChainParamsSupported(chain.ChainId), - ) - } - k.SetChainParamsList(ctx, chainParamsList) - - filteredChainsEVM := k.FilterChains(ctx, chains.FilterGatewayObserver, chains.FilterConsensusEthereum, chains.FilterExternalChains) - filteredChainsBitcoin := k.FilterChains(ctx, chains.FilterExternalChains, chains.FilterConsensusBitcoin, chains.FilterGatewayObserver) - externalChains := chains.ExternalChainList([]chains.Chain{}) - var filterMultipleChains []chains.Chain - for _, chain := range externalChains { - if chain.CctxGateway == chains.CCTXGateway_observers && (chain.Consensus == chains.Consensus_ethereum || chain.Consensus == chains.Consensus_bitcoin) { - filterMultipleChains = append(filterMultipleChains, chain) - } - } - require.ElementsMatch(t, filterMultipleChains, append(filteredChainsEVM, filteredChainsBitcoin...)) - }) -} - -//func TestKeeper_GetSupportedForeignChainsByConsensus(t *testing.T) { -// t.Run("return empty list if not chans are supported", func(t *testing.T) { -// k, ctx, _, _ := keepertest.ObserverKeeper(t) -// require.Empty(t, k.GetSupportedForeignChainsByConsensus(ctx, chains.Consensus_ethereum)) -// }) -// -// t.Run("return list of supported chains for ethereum consensus", func(t *testing.T) { -// k, ctx, _, _ := keepertest.ObserverKeeper(t) -// chainList := chains.ExternalChainList([]chains.Chain{}) -// var chainParamsList types.ChainParamsList -// for _, chain := range chainList { -// chainParamsList.ChainParams = append( -// chainParamsList.ChainParams, -// sample.ChainParamsSupported(chain.ChainId), -// ) -// } -// k.SetChainParamsList(ctx, chainParamsList) -// consensus := chains.Consensus_ethereum -// -// supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) -// require.NotEmpty(t, supportedChainsList) -// -// require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) -// }) -// -// t.Run("return list of supported chains for bitcoin consensus", func(t *testing.T) { -// k, ctx, _, _ := keepertest.ObserverKeeper(t) -// chainList := chains.ExternalChainList([]chains.Chain{}) -// var chainParamsList types.ChainParamsList -// for _, chain := range chainList { -// chainParamsList.ChainParams = append( -// chainParamsList.ChainParams, -// sample.ChainParamsSupported(chain.ChainId), -// ) -// } -// k.SetChainParamsList(ctx, chainParamsList) -// consensus := chains.Consensus_bitcoin -// -// supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) -// require.NotEmpty(t, supportedChainsList) -// require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) -// }) -// -// t.Run("return list of supported chains for solana consensus", func(t *testing.T) { -// k, ctx, _, _ := keepertest.ObserverKeeper(t) -// chainList := chains.ExternalChainList([]chains.Chain{}) -// var chainParamsList types.ChainParamsList -// for _, chain := range chainList { -// chainParamsList.ChainParams = append( -// chainParamsList.ChainParams, -// sample.ChainParamsSupported(chain.ChainId), -// ) -// } -// k.SetChainParamsList(ctx, chainParamsList) -// consensus := chains.Consensus_solana_consensus -// -// supportedChainsList := k.GetSupportedForeignChainsByConsensus(ctx, consensus) -// require.NotEmpty(t, supportedChainsList) -// require.ElementsMatch(t, getForeignChains(consensus), supportedChainsList) -// }) -//} -// -//func TestKeeper_GetSupportedForeignChains(t *testing.T) { -// t.Run("return empty list if not chans are supported", func(t *testing.T) { -// k, ctx, _, _ := keepertest.ObserverKeeper(t) -// require.Empty(t, k.GetSupportedForeignChains(ctx)) -// }) -// -// t.Run("return list of supported chains", func(t *testing.T) { -// k, ctx, _, _ := keepertest.ObserverKeeper(t) -// chainList := chains.ExternalChainList([]chains.Chain{}) -// var chainParamsList types.ChainParamsList -// for _, chain := range chainList { -// chainParamsList.ChainParams = append( -// chainParamsList.ChainParams, -// sample.ChainParamsSupported(chain.ChainId), -// ) -// } -// k.SetChainParamsList(ctx, chainParamsList) -// -// supportedChainsList := k.GetSupportedForeignChains(ctx) -// require.NotEmpty(t, supportedChainsList) -// -// require.ElementsMatch(t, getAllForeignChains(), supportedChainsList) -// }) -//} - -func getAllForeignChains() []chains.Chain { - return chains.ExternalChainList([]chains.Chain{}) -} - -func getForeignChains(consensus chains.Consensus) []chains.Chain { - evmChains := chains.ChainListByConsensus(consensus, []chains.Chain{}) - foreignEvmChains := make([]chains.Chain, 0) - - for _, chain := range evmChains { - if !chain.IsZetaChain() { - foreignEvmChains = append(foreignEvmChains, chain) - } - } - return foreignEvmChains -} From a7e184fb8ac8dfd242c327f4f7bd8784fcb6190f Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 26 Jul 2024 12:04:45 -0400 Subject: [PATCH 09/14] format code --- pkg/chains/chain_filters_test.go | 40 ++++++++++++++++--- .../keeper/grpc_query_cctx_rate_limit.go | 5 ++- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/pkg/chains/chain_filters_test.go b/pkg/chains/chain_filters_test.go index b296f68b87..dcb96ad948 100644 --- a/pkg/chains/chain_filters_test.go +++ b/pkg/chains/chain_filters_test.go @@ -139,8 +139,20 @@ func TestCombineFilterChains(t *testing.T) { name: "test support TSS migration filter", chainLists: func() [][]chains.Chain { return [][]chains.Chain{ - chains.FilterChains(chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusEthereum}...), - chains.FilterChains(chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusBitcoin}...), + chains.FilterChains( + chains.ExternalChainList([]chains.Chain{}), + []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusEthereum, + }...), + chains.FilterChains( + chains.ExternalChainList([]chains.Chain{}), + []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusBitcoin, + }...), } }, expected: func() []chains.Chain { @@ -159,9 +171,27 @@ func TestCombineFilterChains(t *testing.T) { name: "test support TSS migration filter with solana", chainLists: func() [][]chains.Chain { return [][]chains.Chain{ - chains.FilterChains(chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusEthereum}...), - chains.FilterChains(chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusBitcoin}...), - chains.FilterChains(chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver, chains.FilterConsensusSolana}...), + chains.FilterChains( + chains.ExternalChainList([]chains.Chain{}), + []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusEthereum, + }...), + chains.FilterChains( + chains.ExternalChainList([]chains.Chain{}), + []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusBitcoin, + }...), + chains.FilterChains( + chains.ExternalChainList([]chains.Chain{}), + []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterGatewayObserver, + chains.FilterConsensusSolana, + }...), } }, expected: func() []chains.Chain { diff --git a/x/crosschain/keeper/grpc_query_cctx_rate_limit.go b/x/crosschain/keeper/grpc_query_cctx_rate_limit.go index ce4af73732..fa97806ee2 100644 --- a/x/crosschain/keeper/grpc_query_cctx_rate_limit.go +++ b/x/crosschain/keeper/grpc_query_cctx_rate_limit.go @@ -74,7 +74,10 @@ func (k Keeper) RateLimiterInput( } // get foreign chains and conversion rates of foreign coins - externalSupportedChains := chains.FilterChains(k.GetObserverKeeper().GetSupportedChains(ctx), chains.FilterExternalChains) + externalSupportedChains := chains.FilterChains( + k.GetObserverKeeper().GetSupportedChains(ctx), + chains.FilterExternalChains, + ) _, assetRates, found := k.GetRateLimiterAssetRateList(ctx) if !found { From 7f8193686580038556efc65a6dde057e6d40c7fb Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 26 Jul 2024 12:15:38 -0400 Subject: [PATCH 10/14] add unit test for ChainListByGateway --- pkg/chains/chain_filters.go | 1 + pkg/chains/chains_test.go | 48 +++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/pkg/chains/chain_filters.go b/pkg/chains/chain_filters.go index c46a9ec475..d99c81d62b 100644 --- a/pkg/chains/chain_filters.go +++ b/pkg/chains/chain_filters.go @@ -41,6 +41,7 @@ func FilterChains(chainList []Chain, filters ...ChainFilter) []Chain { return chainList } +// CombineFilterChains combines multiple lists of chains into a single list func CombineFilterChains(chainLists ...[]Chain) []Chain { chainMap := make(map[Chain]bool) var combinedChains []Chain diff --git a/pkg/chains/chains_test.go b/pkg/chains/chains_test.go index fe947b11be..01c8ef5abf 100644 --- a/pkg/chains/chains_test.go +++ b/pkg/chains/chains_test.go @@ -151,6 +151,54 @@ func TestDefaultChainList(t *testing.T) { }, chains.DefaultChainsList()) } +func TestChainListByGateway(t *testing.T) { + listTests := []struct { + name string + gateway chains.CCTXGateway + expected []chains.Chain + }{ + { + "observers", + chains.CCTXGateway_observers, + []chains.Chain{ + chains.BitcoinMainnet, + chains.BscMainnet, + chains.Ethereum, + chains.BitcoinTestnet, + chains.Mumbai, + chains.Amoy, + chains.BscTestnet, + chains.Goerli, + chains.Sepolia, + chains.BitcoinRegtest, + chains.GoerliLocalnet, + chains.Polygon, + chains.OptimismMainnet, + chains.OptimismSepolia, + chains.BaseMainnet, + chains.BaseSepolia, + chains.SolanaMainnet, + chains.SolanaDevnet, + chains.SolanaLocalnet, + }, + }, + { + "zevm", + chains.CCTXGateway_zevm, + []chains.Chain{ + chains.ZetaChainMainnet, + chains.ZetaChainTestnet, + }, + }, + } + + for _, lt := range listTests { + t.Run(lt.name, func(t *testing.T) { + require.ElementsMatch(t, lt.expected, chains.ChainListByGateway(lt.gateway, []chains.Chain{})) + }) + } +} + func TestExternalChainList(t *testing.T) { require.ElementsMatch(t, []chains.Chain{ chains.BitcoinMainnet, From 1d3fe50323fc3eed64cb11ff7e65fa4b87a1c3c4 Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 26 Jul 2024 12:21:39 -0400 Subject: [PATCH 11/14] rename tests --- x/crosschain/keeper/msg_server_update_tss_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x/crosschain/keeper/msg_server_update_tss_test.go b/x/crosschain/keeper/msg_server_update_tss_test.go index 7334e5ade8..75ba70f331 100644 --- a/x/crosschain/keeper/msg_server_update_tss_test.go +++ b/x/crosschain/keeper/msg_server_update_tss_test.go @@ -331,7 +331,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { }) } -func TestKeeper_GetChainsSupportingMigration(t *testing.T) { +func TestKeeper_GetChainsSupportingTSSMigration(t *testing.T) { t.Run("should return only ethereum and bitcoin chains", func(t *testing.T) { k, ctx, _, zk := keepertest.CrosschainKeeperWithMocks(t, keepertest.CrosschainMockOptions{}) chainList := chains.ExternalChainList([]chains.Chain{}) From 54d32d9cb2f4cf9184950b01bc1fb38bc421fde8 Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 26 Jul 2024 13:14:46 -0400 Subject: [PATCH 12/14] add test case with repeated filters --- pkg/chains/chain_filters_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pkg/chains/chain_filters_test.go b/pkg/chains/chain_filters_test.go index dcb96ad948..7d5e980708 100644 --- a/pkg/chains/chain_filters_test.go +++ b/pkg/chains/chain_filters_test.go @@ -100,6 +100,18 @@ func TestFilterChains(t *testing.T) { return filterMultipleChains }, }, + { + name: "test three same filters", + filters: []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterExternalChains, + chains.FilterExternalChains, + }, + expected: func() []chains.Chain { + externalChains := chains.ExternalChainList([]chains.Chain{}) + return externalChains + }, + }, { name: "Test multiple filters in random order", filters: []chains.ChainFilter{ @@ -125,6 +137,7 @@ func TestFilterChains(t *testing.T) { chainList := chains.ExternalChainList([]chains.Chain{}) filteredChains := chains.FilterChains(chainList, tc.filters...) require.ElementsMatch(t, tc.expected(), filteredChains) + require.Len(t, filteredChains, len(tc.expected())) }) } } From 88a47ea2b1105249ae793de57c72b69fcd5dabe7 Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 26 Jul 2024 14:09:14 -0400 Subject: [PATCH 13/14] add FilterByConsensus and FilterByGateway functions --- pkg/chains/chain_filters.go | 18 +++------ pkg/chains/chain_filters_test.go | 42 ++++++++++---------- x/crosschain/keeper/msg_server_update_tss.go | 8 ++-- 3 files changed, 31 insertions(+), 37 deletions(-) diff --git a/pkg/chains/chain_filters.go b/pkg/chains/chain_filters.go index d99c81d62b..3235fb98b1 100644 --- a/pkg/chains/chain_filters.go +++ b/pkg/chains/chain_filters.go @@ -8,22 +8,16 @@ func FilterExternalChains(c Chain) bool { return c.IsExternal } -// FilterGatewayObserver filters chains that are gateway observers -func FilterGatewayObserver(c Chain) bool { - return c.CctxGateway == CCTXGateway_observers +// FilterByGateway filters chains by gateway +func FilterByGateway(gw CCTXGateway) ChainFilter { + return func(chain Chain) bool { return chain.CctxGateway == gw } } -// FilterConsensusEthereum filters chains that have the ethereum consensus -func FilterConsensusEthereum(c Chain) bool { - return c.Consensus == Consensus_ethereum +// FilterByConsensus filters chains by consensus type +func FilterByConsensus(cs Consensus) ChainFilter { + return func(chain Chain) bool { return chain.Consensus == cs } } -// FilterConsensusBitcoin filters chains that have the bitcoin consensus -func FilterConsensusBitcoin(c Chain) bool { return c.Consensus == Consensus_bitcoin } - -// FilterConsensusSolana filters chains that have the solana consensus -func FilterConsensusSolana(c Chain) bool { return c.Consensus == Consensus_solana_consensus } - // FilterChains applies a list of filters to a list of chains func FilterChains(chainList []Chain, filters ...ChainFilter) []Chain { // Apply each filter to the list of supported chains diff --git a/pkg/chains/chain_filters_test.go b/pkg/chains/chain_filters_test.go index 7d5e980708..b3d555cee7 100644 --- a/pkg/chains/chain_filters_test.go +++ b/pkg/chains/chain_filters_test.go @@ -22,35 +22,35 @@ func TestFilterChains(t *testing.T) { }, { name: "Filter gateway observer chains", - filters: []chains.ChainFilter{chains.FilterGatewayObserver}, + filters: []chains.ChainFilter{chains.FilterByGateway(chains.CCTXGateway_observers)}, expected: func() []chains.Chain { return chains.ChainListByGateway(chains.CCTXGateway_observers, []chains.Chain{}) }, }, { name: "Filter consensus ethereum chains", - filters: []chains.ChainFilter{chains.FilterConsensusEthereum}, + filters: []chains.ChainFilter{chains.FilterByConsensus(chains.Consensus_ethereum)}, expected: func() []chains.Chain { return chains.ChainListByConsensus(chains.Consensus_ethereum, []chains.Chain{}) }, }, { name: "Filter consensus bitcoin chains", - filters: []chains.ChainFilter{chains.FilterConsensusBitcoin}, + filters: []chains.ChainFilter{chains.FilterByConsensus(chains.Consensus_bitcoin)}, expected: func() []chains.Chain { return chains.ChainListByConsensus(chains.Consensus_bitcoin, []chains.Chain{}) }, }, { name: "Filter consensus solana chains", - filters: []chains.ChainFilter{chains.FilterConsensusSolana}, + filters: []chains.ChainFilter{chains.FilterByConsensus(chains.Consensus_solana_consensus)}, expected: func() []chains.Chain { return chains.ChainListByConsensus(chains.Consensus_solana_consensus, []chains.Chain{}) }, }, { name: "Apply multiple filters external chains and gateway observer", - filters: []chains.ChainFilter{chains.FilterExternalChains, chains.FilterGatewayObserver}, + filters: []chains.ChainFilter{chains.FilterExternalChains, chains.FilterByGateway(chains.CCTXGateway_observers)}, expected: func() []chains.Chain { externalChains := chains.ExternalChainList([]chains.Chain{}) var gatewayObserverChains []chains.Chain @@ -66,8 +66,8 @@ func TestFilterChains(t *testing.T) { name: "Apply multiple filters external chains with gateway observer and consensus ethereum", filters: []chains.ChainFilter{ chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusEthereum, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_ethereum), }, expected: func() []chains.Chain { externalChains := chains.ExternalChainList([]chains.Chain{}) @@ -85,8 +85,8 @@ func TestFilterChains(t *testing.T) { name: "Apply multiple filters external chains with gateway observer and consensus bitcoin", filters: []chains.ChainFilter{ chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusBitcoin, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_bitcoin), }, expected: func() []chains.Chain { externalChains := chains.ExternalChainList([]chains.Chain{}) @@ -115,8 +115,8 @@ func TestFilterChains(t *testing.T) { { name: "Test multiple filters in random order", filters: []chains.ChainFilter{ - chains.FilterGatewayObserver, - chains.FilterConsensusEthereum, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_ethereum), chains.FilterExternalChains, }, expected: func() []chains.Chain { @@ -156,15 +156,15 @@ func TestCombineFilterChains(t *testing.T) { chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{ chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusEthereum, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_ethereum), }...), chains.FilterChains( chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{ chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusBitcoin, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_bitcoin), }...), } }, @@ -188,22 +188,22 @@ func TestCombineFilterChains(t *testing.T) { chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{ chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusEthereum, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_ethereum), }...), chains.FilterChains( chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{ chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusBitcoin, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_bitcoin), }...), chains.FilterChains( chains.ExternalChainList([]chains.Chain{}), []chains.ChainFilter{ chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusSolana, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_solana_consensus), }...), } }, diff --git a/x/crosschain/keeper/msg_server_update_tss.go b/x/crosschain/keeper/msg_server_update_tss.go index 63da5d65eb..f68c7e39d6 100644 --- a/x/crosschain/keeper/msg_server_update_tss.go +++ b/x/crosschain/keeper/msg_server_update_tss.go @@ -83,13 +83,13 @@ func (k *Keeper) GetChainsSupportingTSSMigration(ctx sdk.Context) []chains.Chain return chains.CombineFilterChains([][]chains.Chain{ chains.FilterChains(supportedChains, []chains.ChainFilter{ chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusEthereum, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_ethereum), }...), chains.FilterChains(supportedChains, []chains.ChainFilter{ chains.FilterExternalChains, - chains.FilterGatewayObserver, - chains.FilterConsensusBitcoin, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_bitcoin), }...), }...) } From 72f617fb12f6a88561d226457fba810ccdc05aa7 Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 26 Jul 2024 14:38:59 -0400 Subject: [PATCH 14/14] format code 1 --- pkg/chains/chain_filters_test.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pkg/chains/chain_filters_test.go b/pkg/chains/chain_filters_test.go index b3d555cee7..619d555882 100644 --- a/pkg/chains/chain_filters_test.go +++ b/pkg/chains/chain_filters_test.go @@ -49,8 +49,11 @@ func TestFilterChains(t *testing.T) { }, }, { - name: "Apply multiple filters external chains and gateway observer", - filters: []chains.ChainFilter{chains.FilterExternalChains, chains.FilterByGateway(chains.CCTXGateway_observers)}, + name: "Apply multiple filters external chains and gateway observer", + filters: []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + }, expected: func() []chains.Chain { externalChains := chains.ExternalChainList([]chains.Chain{}) var gatewayObserverChains []chains.Chain