Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: migrator length check to use consensus type #2556

Merged
merged 15 commits into from
Jul 26, 2024
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
* [2515](https://github.com/zeta-chain/node/pull/2515) - replace chainName by chainID for ChainNonces indexing
* [2541](https://github.com/zeta-chain/node/pull/2541) - deprecate ChainName field in Chain object
* [2542](https://github.com/zeta-chain/node/pull/2542) - adjust permissions to be more restrictive
* [2556](https://github.com/zeta-chain/node/pull/2556) - refactor migrator length check to use consensus type

### Tests

Expand Down
54 changes: 54 additions & 0 deletions pkg/chains/chain_filters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package chains

// ChainFilter is a function that filters chains based on some criteria
type ChainFilter func(c Chain) bool

// FilterExternalChains filters chains that are external
func FilterExternalChains(c Chain) bool {
return c.IsExternal
}

// FilterByGateway filters chains by gateway
func FilterByGateway(gw CCTXGateway) ChainFilter {
return func(chain Chain) bool { return chain.CctxGateway == gw }
}

// FilterByConsensus filters chains by consensus type
func FilterByConsensus(cs Consensus) ChainFilter {
return func(chain Chain) bool { return chain.Consensus == cs }
}

// FilterChains applies a list of filters to a list of chains
func FilterChains(chainList []Chain, filters ...ChainFilter) []Chain {
// Apply each filter to the list of supported chains
for _, filter := range filters {
var filteredChains []Chain
for _, chain := range chainList {
if filter(chain) {
filteredChains = append(filteredChains, chain)
}
}
chainList = filteredChains
}

// Return the filtered list of chains
return chainList
}

// CombineFilterChains combines multiple lists of chains into a single list
func CombineFilterChains(chainLists ...[]Chain) []Chain {
chainMap := make(map[Chain]bool)
var combinedChains []Chain

// Add chains from each slice to remove duplicates
for _, chains := range chainLists {
for _, chain := range chains {
if !chainMap[chain] {
chainMap[chain] = true
combinedChains = append(combinedChains, chain)
}
}
}

return combinedChains
}
233 changes: 233 additions & 0 deletions pkg/chains/chain_filters_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
package chains_test

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/zeta-chain/zetacore/pkg/chains"
)

func TestFilterChains(t *testing.T) {
tt := []struct {
name string
filters []chains.ChainFilter
expected func() []chains.Chain
}{
{
name: "Filter external chains",
filters: []chains.ChainFilter{chains.FilterExternalChains},
expected: func() []chains.Chain {
return chains.ExternalChainList([]chains.Chain{})
},
},
{
name: "Filter gateway observer chains",
filters: []chains.ChainFilter{chains.FilterByGateway(chains.CCTXGateway_observers)},
expected: func() []chains.Chain {
return chains.ChainListByGateway(chains.CCTXGateway_observers, []chains.Chain{})
},
},
{
name: "Filter consensus ethereum chains",
filters: []chains.ChainFilter{chains.FilterByConsensus(chains.Consensus_ethereum)},
expected: func() []chains.Chain {
return chains.ChainListByConsensus(chains.Consensus_ethereum, []chains.Chain{})
},
},
{
name: "Filter consensus bitcoin chains",
filters: []chains.ChainFilter{chains.FilterByConsensus(chains.Consensus_bitcoin)},
expected: func() []chains.Chain {
return chains.ChainListByConsensus(chains.Consensus_bitcoin, []chains.Chain{})
},
},
{
name: "Filter consensus solana chains",
filters: []chains.ChainFilter{chains.FilterByConsensus(chains.Consensus_solana_consensus)},
expected: func() []chains.Chain {
return chains.ChainListByConsensus(chains.Consensus_solana_consensus, []chains.Chain{})
},
},
{
name: "Apply multiple filters external chains and gateway observer",
filters: []chains.ChainFilter{
chains.FilterExternalChains,
chains.FilterByGateway(chains.CCTXGateway_observers),
},
expected: func() []chains.Chain {
externalChains := chains.ExternalChainList([]chains.Chain{})
var gatewayObserverChains []chains.Chain
for _, chain := range externalChains {
if chain.CctxGateway == chains.CCTXGateway_observers {
gatewayObserverChains = append(gatewayObserverChains, chain)
}
}
return gatewayObserverChains
},
},
{
name: "Apply multiple filters external chains with gateway observer and consensus ethereum",
filters: []chains.ChainFilter{
chains.FilterExternalChains,
chains.FilterByGateway(chains.CCTXGateway_observers),
chains.FilterByConsensus(chains.Consensus_ethereum),
},
expected: func() []chains.Chain {
externalChains := chains.ExternalChainList([]chains.Chain{})
var filterMultipleChains []chains.Chain
for _, chain := range externalChains {
if chain.CctxGateway == chains.CCTXGateway_observers &&
chain.Consensus == chains.Consensus_ethereum {
filterMultipleChains = append(filterMultipleChains, chain)
}
}
return filterMultipleChains
},
},
{
name: "Apply multiple filters external chains with gateway observer and consensus bitcoin",
filters: []chains.ChainFilter{
chains.FilterExternalChains,
chains.FilterByGateway(chains.CCTXGateway_observers),
chains.FilterByConsensus(chains.Consensus_bitcoin),
},
expected: func() []chains.Chain {
externalChains := chains.ExternalChainList([]chains.Chain{})
var filterMultipleChains []chains.Chain
for _, chain := range externalChains {
if chain.CctxGateway == chains.CCTXGateway_observers &&
chain.Consensus == chains.Consensus_bitcoin {
filterMultipleChains = append(filterMultipleChains, chain)
}
}
return filterMultipleChains
},
},
{
name: "test three same filters",
filters: []chains.ChainFilter{
chains.FilterExternalChains,
chains.FilterExternalChains,
chains.FilterExternalChains,
},
expected: func() []chains.Chain {
externalChains := chains.ExternalChainList([]chains.Chain{})
return externalChains
},
},
{
name: "Test multiple filters in random order",
filters: []chains.ChainFilter{
chains.FilterByGateway(chains.CCTXGateway_observers),
chains.FilterByConsensus(chains.Consensus_ethereum),
chains.FilterExternalChains,
},
expected: func() []chains.Chain {
externalChains := chains.ExternalChainList([]chains.Chain{})
var filterMultipleChains []chains.Chain
for _, chain := range externalChains {
if chain.CctxGateway == chains.CCTXGateway_observers &&
chain.Consensus == chains.Consensus_ethereum {
filterMultipleChains = append(filterMultipleChains, chain)
}
}
return filterMultipleChains
},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
chainList := chains.ExternalChainList([]chains.Chain{})
filteredChains := chains.FilterChains(chainList, tc.filters...)
require.ElementsMatch(t, tc.expected(), filteredChains)
require.Len(t, filteredChains, len(tc.expected()))
})
}
}

func TestCombineFilterChains(t *testing.T) {
tt := []struct {
name string
chainLists func() [][]chains.Chain
expected func() []chains.Chain
}{
{
name: "test support TSS migration filter",
chainLists: func() [][]chains.Chain {
return [][]chains.Chain{
chains.FilterChains(
chains.ExternalChainList([]chains.Chain{}),
[]chains.ChainFilter{
chains.FilterExternalChains,
chains.FilterByGateway(chains.CCTXGateway_observers),
chains.FilterByConsensus(chains.Consensus_ethereum),
}...),
chains.FilterChains(
chains.ExternalChainList([]chains.Chain{}),
[]chains.ChainFilter{
chains.FilterExternalChains,
chains.FilterByGateway(chains.CCTXGateway_observers),
chains.FilterByConsensus(chains.Consensus_bitcoin),
}...),
}
},
expected: func() []chains.Chain {
chainList := chains.ExternalChainList([]chains.Chain{})
var filterMultipleChains []chains.Chain
for _, chain := range chainList {
if chain.CctxGateway == chains.CCTXGateway_observers &&
(chain.Consensus == chains.Consensus_ethereum || chain.Consensus == chains.Consensus_bitcoin) {
filterMultipleChains = append(filterMultipleChains, chain)
}
}
return filterMultipleChains
},
},
{
name: "test support TSS migration filter with solana",
chainLists: func() [][]chains.Chain {
return [][]chains.Chain{
chains.FilterChains(
chains.ExternalChainList([]chains.Chain{}),
[]chains.ChainFilter{
chains.FilterExternalChains,
chains.FilterByGateway(chains.CCTXGateway_observers),
chains.FilterByConsensus(chains.Consensus_ethereum),
}...),
chains.FilterChains(
chains.ExternalChainList([]chains.Chain{}),
[]chains.ChainFilter{
chains.FilterExternalChains,
chains.FilterByGateway(chains.CCTXGateway_observers),
chains.FilterByConsensus(chains.Consensus_bitcoin),
}...),
chains.FilterChains(
chains.ExternalChainList([]chains.Chain{}),
[]chains.ChainFilter{
chains.FilterExternalChains,
chains.FilterByGateway(chains.CCTXGateway_observers),
chains.FilterByConsensus(chains.Consensus_solana_consensus),
}...),
}
},
expected: func() []chains.Chain {
chainList := chains.ExternalChainList([]chains.Chain{})
var filterMultipleChains []chains.Chain
for _, chain := range chainList {
if chain.CctxGateway == chains.CCTXGateway_observers &&
(chain.Consensus == chains.Consensus_ethereum || chain.Consensus == chains.Consensus_bitcoin || chain.Consensus == chains.Consensus_solana_consensus) {
filterMultipleChains = append(filterMultipleChains, chain)
}
}
return filterMultipleChains
},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
chainLists := tc.chainLists()
combinedChains := chains.CombineFilterChains(chainLists...)
require.ElementsMatch(t, tc.expected(), combinedChains)
})
}
}
10 changes: 10 additions & 0 deletions pkg/chains/chains.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,16 @@ func ChainListByConsensus(consensus Consensus, additionalChains []Chain) []Chain
return chainList
}

func ChainListByGateway(gateway CCTXGateway, additionalChains []Chain) []Chain {
var chainList []Chain
for _, chain := range CombineDefaultChainsList(additionalChains) {
if chain.CctxGateway == gateway {
chainList = append(chainList, chain)
}
}
return chainList
}

// ChainListForHeaderSupport returns a list of chains that support headers
func ChainListForHeaderSupport(additionalChains []Chain) []Chain {
var chainList []Chain
Expand Down
50 changes: 50 additions & 0 deletions pkg/chains/chains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,56 @@ func TestDefaultChainList(t *testing.T) {
}, chains.DefaultChainsList())
}

func TestChainListByGateway(t *testing.T) {
listTests := []struct {
name string
gateway chains.CCTXGateway
expected []chains.Chain
}{
{
"observers",
chains.CCTXGateway_observers,
[]chains.Chain{
chains.BitcoinMainnet,
chains.BscMainnet,
chains.Ethereum,
chains.BitcoinTestnet,
chains.Mumbai,
chains.Amoy,
chains.BscTestnet,
chains.Goerli,
chains.Sepolia,
chains.BitcoinRegtest,
chains.GoerliLocalnet,
chains.Polygon,
chains.OptimismMainnet,
chains.OptimismSepolia,
chains.BaseMainnet,
chains.BaseSepolia,
chains.SolanaMainnet,
chains.SolanaDevnet,
chains.SolanaLocalnet,
},
},
{
"zevm",
chains.CCTXGateway_zevm,
[]chains.Chain{
chains.ZetaChainMainnet,
chains.ZetaChainTestnet,
chains.ZetaChainDevnet,
chains.ZetaChainPrivnet,
},
},
}

for _, lt := range listTests {
t.Run(lt.name, func(t *testing.T) {
require.ElementsMatch(t, lt.expected, chains.ChainListByGateway(lt.gateway, []chains.Chain{}))
})
}
}

func TestExternalChainList(t *testing.T) {
require.ElementsMatch(t, []chains.Chain{
chains.BitcoinMainnet,
Expand Down
20 changes: 0 additions & 20 deletions testutil/keeper/mocks/crosschain/observer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading