Skip to content

Commit

Permalink
change return type to err in verifypredicate
Browse files Browse the repository at this point in the history
  • Loading branch information
ceyonur committed Dec 5, 2023
1 parent 6ff9578 commit b1ff9db
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 50 deletions.
2 changes: 1 addition & 1 deletion core/predicate_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func CheckPredicates(rules params.Rules, predicateContext *precompileconfig.Pred
predicaterContract := rules.Predicaters[address]
bitset := set.NewBits()
for i, predicate := range predicates {
if !predicaterContract.VerifyPredicate(predicateContext, predicate) {
if err := predicaterContract.VerifyPredicate(predicateContext, predicate); err != nil {
bitset.Add(i)
}
}
Expand Down
13 changes: 7 additions & 6 deletions core/predicate_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func TestCheckPredicate(t *testing.T) {
predicater := precompileconfig.NewMockPredicater(gomock.NewController(t))
arg := common.Hash{1}
predicater.EXPECT().PredicateGas(arg[:]).Return(uint64(0), nil).Times(2)
predicater.EXPECT().VerifyPredicate(gomock.Any(), arg[:]).Return(true)
predicater.EXPECT().VerifyPredicate(gomock.Any(), arg[:]).Return(nil)
return map[common.Address]precompileconfig.Predicater{
addr1: predicater,
}
Expand Down Expand Up @@ -187,7 +187,7 @@ func TestCheckPredicate(t *testing.T) {
predicater := precompileconfig.NewMockPredicater(gomock.NewController(t))
arg := common.Hash{1}
predicater.EXPECT().PredicateGas(arg[:]).Return(uint64(0), nil).Times(2)
predicater.EXPECT().VerifyPredicate(gomock.Any(), arg[:]).Return(true)
predicater.EXPECT().VerifyPredicate(gomock.Any(), arg[:]).Return(nil)
return map[common.Address]precompileconfig.Predicater{
addr1: predicater,
addr2: predicater,
Expand All @@ -214,11 +214,11 @@ func TestCheckPredicate(t *testing.T) {
predicate1 := precompileconfig.NewMockPredicater(ctrl)
arg1 := common.Hash{1}
predicate1.EXPECT().PredicateGas(arg1[:]).Return(uint64(0), nil).Times(2)
predicate1.EXPECT().VerifyPredicate(gomock.Any(), arg1[:]).Return(true)
predicate1.EXPECT().VerifyPredicate(gomock.Any(), arg1[:]).Return(nil)
predicate2 := precompileconfig.NewMockPredicater(ctrl)
arg2 := common.Hash{2}
predicate2.EXPECT().PredicateGas(arg2[:]).Return(uint64(0), nil).Times(2)
predicate2.EXPECT().VerifyPredicate(gomock.Any(), arg2[:]).Return(false)
predicate2.EXPECT().VerifyPredicate(gomock.Any(), arg2[:]).Return(testErr)
return map[common.Address]precompileconfig.Predicater{
addr1: predicate1,
addr2: predicate2,
Expand Down Expand Up @@ -323,6 +323,7 @@ func TestCheckPredicate(t *testing.T) {
}

func TestCheckPredicatesOutput(t *testing.T) {
testErr := errors.New("test error")
addr1 := common.HexToAddress("0xaa")
addr2 := common.HexToAddress("0xbb")
validHash := common.Hash{1}
Expand Down Expand Up @@ -431,10 +432,10 @@ func TestCheckPredicatesOutput(t *testing.T) {
var predicateHash common.Hash
if tuple.isValidPredicate {
predicateHash = validHash
predicater.EXPECT().VerifyPredicate(gomock.Any(), validHash[:]).Return(true)
predicater.EXPECT().VerifyPredicate(gomock.Any(), validHash[:]).Return(nil)
} else {
predicateHash = invalidHash
predicater.EXPECT().VerifyPredicate(gomock.Any(), invalidHash[:]).Return(false)
predicater.EXPECT().VerifyPredicate(gomock.Any(), invalidHash[:]).Return(testErr)
}
txAccessList = append(txAccessList, types.AccessTuple{
Address: tuple.address,
Expand Down
2 changes: 1 addition & 1 deletion precompile/precompileconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type PredicateContext struct {
// rely on this. Designed for use only by precompiles that ship with subnet-evm.
type Predicater interface {
PredicateGas(predicateBytes []byte) (uint64, error)
VerifyPredicate(predicateContext *PredicateContext, predicateBytes []byte) bool
VerifyPredicate(predicateContext *PredicateContext, predicateBytes []byte) error
}

// SharedMemoryWriter defines an interface to allow a precompile's Accepter to write operations
Expand Down
4 changes: 2 additions & 2 deletions precompile/precompileconfig/mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions precompile/testutils/test_predicate.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type PredicateTest struct {
PredicateBytes []byte
Gas uint64
GasErr error
ExpectedRes bool
ExpectedErr error
}

func (test PredicateTest) Run(t testing.TB) {
Expand All @@ -37,7 +37,7 @@ func (test PredicateTest) Run(t testing.TB) {
require.Equal(test.Gas, predicateGas)

predicateRes := predicate.VerifyPredicate(test.PredicateContext, test.PredicateBytes)
require.Equal(test.ExpectedRes, predicateRes)
require.ErrorIs(predicateRes, test.ExpectedErr)
}

func RunPredicateTests(t *testing.T, predicateTests map[string]PredicateTest) {
Expand Down
60 changes: 29 additions & 31 deletions x/warp/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ var (
errOverflowSignersGasCost = errors.New("overflow calculating warp signers gas cost")
errInvalidPredicateBytes = errors.New("cannot unpack predicate bytes")
errInvalidWarpMsg = errors.New("cannot unpack warp message")
errCannotParseWarpMsg = errors.New("cannot parse warp message")
errInvalidWarpMsgPayload = errors.New("cannot unpack warp message payload")
errInvalidAddressedPayload = errors.New("cannot unpack addressed payload")
errInvalidBlockHashPayload = errors.New("cannot unpack block hash payload")
errCannotGetNumSigners = errors.New("cannot fetch num signers from warp message")
errWarpCannotBeActivated = errors.New("warp cannot be activated before DUpgrade")
errFailedVerification = errors.New("cannot verify warp signature")
)

// Config implements the precompileconfig.Config interface and
Expand Down Expand Up @@ -124,32 +126,6 @@ func (c *Config) Accept(acceptCtx *precompileconfig.AcceptContext, blockHash com
return nil
}

// verifyWarpMessage checks that [warpMsg] can be parsed as an addressed payload and verifies the Warp Message Signature
// within [predicateContext].
func (c *Config) verifyWarpMessage(predicateContext *precompileconfig.PredicateContext, warpMsg *warp.Message) bool {
// Use default quorum numerator unless config specifies a non-default option
quorumNumerator := params.WarpDefaultQuorumNumerator
if c.QuorumNumerator != 0 {
quorumNumerator = c.QuorumNumerator
}

log.Debug("verifying warp message", "warpMsg", warpMsg, "quorumNum", quorumNumerator, "quorumDenom", params.WarpQuorumDenominator)
if err := warpMsg.Signature.Verify(
context.Background(),
&warpMsg.UnsignedMessage,
predicateContext.SnowCtx.NetworkID,
warpValidators.NewState(predicateContext.SnowCtx), // Wrap validators.State on the chain snow context to special case the Primary Network
predicateContext.ProposerVMBlockCtx.PChainHeight,
quorumNumerator,
params.WarpQuorumDenominator,
); err != nil {
log.Debug("failed to verify warp signature", "msgID", warpMsg.ID(), "err", err)
return false
}

return true
}

// PredicateGas returns the amount of gas necessary to verify the predicate
// PredicateGas charges for:
// 1. Base cost of the message
Expand Down Expand Up @@ -199,16 +175,38 @@ func (c *Config) PredicateGas(predicateBytes []byte) (uint64, error) {
}

// VerifyPredicate returns whether the predicate described by [predicateBytes] passes verification.
func (c *Config) VerifyPredicate(predicateContext *precompileconfig.PredicateContext, predicateBytes []byte) bool {
func (c *Config) VerifyPredicate(predicateContext *precompileconfig.PredicateContext, predicateBytes []byte) error {
unpackedPredicateBytes, err := predicate.UnpackPredicate(predicateBytes)
if err != nil {
return false
return fmt.Errorf("%w: %w", errInvalidPredicateBytes, err)
}

// Note: PredicateGas should be called before VerifyPredicate, so we should never reach an error case here.
warpMessage, err := warp.ParseMessage(unpackedPredicateBytes)
warpMsg, err := warp.ParseMessage(unpackedPredicateBytes)
if err != nil {
return false
return fmt.Errorf("%w: %w", errCannotParseWarpMsg, err)
}
return c.verifyWarpMessage(predicateContext, warpMessage)

quorumNumerator := params.WarpDefaultQuorumNumerator
if c.QuorumNumerator != 0 {
quorumNumerator = c.QuorumNumerator
}

log.Debug("verifying warp message", "warpMsg", warpMsg, "quorumNum", quorumNumerator, "quorumDenom", params.WarpQuorumDenominator)
err = warpMsg.Signature.Verify(
context.Background(),
&warpMsg.UnsignedMessage,
predicateContext.SnowCtx.NetworkID,
warpValidators.NewState(predicateContext.SnowCtx), // Wrap validators.State on the chain snow context to special case the Primary Network
predicateContext.ProposerVMBlockCtx.PChainHeight,
quorumNumerator,
params.WarpQuorumDenominator,
)

if err != nil {
log.Debug("failed to verify warp signature", "msgID", warpMsg.ID(), "err", err)
return fmt.Errorf("%w: %w", errFailedVerification, err)
}

return nil
}
27 changes: 20 additions & 7 deletions x/warp/predicate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func createValidPredicateTest(snowCtx *snow.Context, numKeys uint64, predicateBy
PredicateBytes: predicateBytes,
Gas: GasCostPerSignatureVerification + uint64(len(predicateBytes))*GasCostPerWarpMessageBytes + numKeys*GasCostPerWarpSigner,
GasErr: nil,
ExpectedRes: true,
ExpectedErr: nil,
}
}

Expand Down Expand Up @@ -289,7 +289,7 @@ func TestWarpMessageFromPrimaryNetwork(t *testing.T) {
PredicateBytes: predicateBytes,
Gas: GasCostPerSignatureVerification + uint64(len(predicateBytes))*GasCostPerWarpMessageBytes + uint64(numKeys)*GasCostPerWarpSigner,
GasErr: nil,
ExpectedRes: true,
ExpectedErr: nil,
}

test.Run(t)
Expand Down Expand Up @@ -466,7 +466,12 @@ func TestWarpSignatureWeightsDefaultQuorumNumerator(t *testing.T) {
} {
predicateBytes := createPredicate(numSigners)
// The predicate is valid iff the number of signers is >= the required numerator and does not exceed the denominator.
isValid := numSigners >= int(params.WarpDefaultQuorumNumerator) && numSigners <= int(params.WarpQuorumDenominator)
var expectedErr error
if numSigners >= int(params.WarpDefaultQuorumNumerator) && numSigners <= int(params.WarpQuorumDenominator) {
expectedErr = nil
} else {
expectedErr = errFailedVerification
}

tests[fmt.Sprintf("default quorum %d signature(s)", numSigners)] = testutils.PredicateTest{
Config: NewDefaultConfig(subnetEVMUtils.NewUint64(0)),
Expand All @@ -479,7 +484,7 @@ func TestWarpSignatureWeightsDefaultQuorumNumerator(t *testing.T) {
PredicateBytes: predicateBytes,
Gas: GasCostPerSignatureVerification + uint64(len(predicateBytes))*GasCostPerWarpMessageBytes + uint64(numSigners)*GasCostPerWarpSigner,
GasErr: nil,
ExpectedRes: isValid,
ExpectedErr: expectedErr,
}
}
testutils.RunPredicateTests(t, tests)
Expand Down Expand Up @@ -514,13 +519,16 @@ func TestWarpMultiplePredicates(t *testing.T) {
var (
predicate []byte
expectedGas uint64
expectedErr error
)
if valid {
predicate = validPredicateBytes
expectedGas = GasCostPerSignatureVerification + uint64(len(validPredicateBytes))*GasCostPerWarpMessageBytes + uint64(numSigners)*GasCostPerWarpSigner
expectedErr = nil
} else {
expectedGas = GasCostPerSignatureVerification + uint64(len(invalidPredicateBytes))*GasCostPerWarpMessageBytes + uint64(1)*GasCostPerWarpSigner
predicate = invalidPredicateBytes
expectedErr = errFailedVerification
}

tests[fmt.Sprintf("multiple predicates %v", validMessageIndices)] = testutils.PredicateTest{
Expand All @@ -534,7 +542,7 @@ func TestWarpMultiplePredicates(t *testing.T) {
PredicateBytes: predicate,
Gas: expectedGas,
GasErr: nil,
ExpectedRes: valid,
ExpectedErr: expectedErr,
}
}
}
Expand All @@ -559,7 +567,12 @@ func TestWarpSignatureWeightsNonDefaultQuorumNumerator(t *testing.T) {
for _, numSigners := range []int{nonDefaultQuorumNumerator, nonDefaultQuorumNumerator + 1, 99, 100, 101} {
predicateBytes := createPredicate(numSigners)
// The predicate is valid iff the number of signers is >= the required numerator and does not exceed the denominator.
isValid := numSigners >= nonDefaultQuorumNumerator && numSigners <= int(params.WarpQuorumDenominator)
var expectedErr error
if numSigners >= nonDefaultQuorumNumerator && numSigners <= int(params.WarpQuorumDenominator) {
expectedErr = nil
} else {
expectedErr = errFailedVerification
}

name := fmt.Sprintf("non-default quorum %d signature(s)", numSigners)
tests[name] = testutils.PredicateTest{
Expand All @@ -573,7 +586,7 @@ func TestWarpSignatureWeightsNonDefaultQuorumNumerator(t *testing.T) {
PredicateBytes: predicateBytes,
Gas: GasCostPerSignatureVerification + uint64(len(predicateBytes))*GasCostPerWarpMessageBytes + uint64(numSigners)*GasCostPerWarpSigner,
GasErr: nil,
ExpectedRes: isValid,
ExpectedErr: expectedErr,
}
}

Expand Down

0 comments on commit b1ff9db

Please sign in to comment.