Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce extensible payloads #1667

Merged
merged 3 commits into from
Jan 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ require (
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/golang-lru v0.5.4
github.com/mr-tron/base58 v1.1.2
github.com/nspcc-dev/dbft v0.0.0-20201221101812-e13a1a1c3cb2
github.com/nspcc-dev/dbft v0.0.0-20210122071512-d9a728094f0d
github.com/nspcc-dev/rfc6979 v0.2.0
github.com/pierrec/lz4 v2.5.2+incompatible
github.com/prometheus/client_golang v1.2.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a h1:ajvxgEe9qY4vvoSm
github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a/go.mod h1:/YFK+XOxxg0Bfm6P92lY5eDSLYfp06XOdL8KAVgXjVk=
github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1 h1:yEx9WznS+rjE0jl0dLujCxuZSIb+UTjF+005TJu/nNI=
github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1/go.mod h1:O0qtn62prQSqizzoagHmuuKoz8QMkU3SzBoKdEvm3aQ=
github.com/nspcc-dev/dbft v0.0.0-20201221101812-e13a1a1c3cb2 h1:vbPjd6xbX8w61abcNfzUvSI7WT0QeS9fHWp1Mocv9N0=
github.com/nspcc-dev/dbft v0.0.0-20201221101812-e13a1a1c3cb2/go.mod h1:I5D0W3tu3epdt2RMCTxS//HDr4S+OHRqajouQTOAHI8=
github.com/nspcc-dev/dbft v0.0.0-20210122071512-d9a728094f0d h1:uUaRysqa/9VtHETVARUlteqfbXAgwxR2nvUc4DzK4pI=
github.com/nspcc-dev/dbft v0.0.0-20210122071512-d9a728094f0d/go.mod h1:I5D0W3tu3epdt2RMCTxS//HDr4S+OHRqajouQTOAHI8=
github.com/nspcc-dev/neo-go v0.73.1-pre.0.20200303142215-f5a1b928ce09/go.mod h1:pPYwPZ2ks+uMnlRLUyXOpLieaDQSEaf4NM3zHVbRjmg=
github.com/nspcc-dev/neofs-crypto v0.2.0 h1:ftN+59WqxSWz/RCgXYOfhmltOOqU+udsNQSvN6wkFck=
github.com/nspcc-dev/neofs-crypto v0.2.0/go.mod h1:F/96fUzPM3wR+UGsPi3faVNmFlA9KAEAUQR7dMxZmNA=
Expand Down
2 changes: 1 addition & 1 deletion pkg/consensus/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ func getDifferentPayloads(t *testing.T, n int) (payloads []Payload) {
var sign [signatureSize]byte
random.Fill(sign[:])

payloads[i].message = &message{}
payloads[i].SetValidatorIndex(uint16(i))
payloads[i].SetType(payload.MessageType(commitType))
payloads[i].payload = &commit{
signature: sign,
}
payloads[i].encodeData()
}

return
Expand Down
97 changes: 72 additions & 25 deletions pkg/consensus/consensus.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/io"
npayload "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/emit"
Expand All @@ -39,6 +40,9 @@ const defaultTimePerBlock = 15 * time.Second
// Number of nanoseconds in millisecond.
const nsInMs = 1000000

// Category is message category for extensible payloads.
const Category = "Consensus"

// Service represents consensus instance.
type Service interface {
// Start initializes dBFT and starts event loop for consensus service.
Expand All @@ -48,11 +52,11 @@ type Service interface {
Shutdown()

// OnPayload is a callback to notify Service about new received payload.
OnPayload(p *Payload)
OnPayload(p *npayload.Extensible)
// OnTransaction is a callback to notify Service about new received transaction.
OnTransaction(tx *transaction.Transaction)
// GetPayload returns Payload with specified hash if it is present in the local cache.
GetPayload(h util.Uint256) *Payload
GetPayload(h util.Uint256) *npayload.Extensible
}

type service struct {
Expand Down Expand Up @@ -94,7 +98,7 @@ type Config struct {
Logger *zap.Logger
// Broadcast is a callback which is called to notify server
// about new consensus payload to sent.
Broadcast func(p *Payload)
Broadcast func(p *npayload.Extensible)
// Chain is a core.Blockchainer instance.
Chain blockchainer.Blockchainer
// RequestTx is a callback to which will be called
Expand Down Expand Up @@ -204,15 +208,33 @@ var (
// NewPayload creates new consensus payload for the provided network.
func NewPayload(m netmode.Magic, stateRootEnabled bool) *Payload {
return &Payload{
network: m,
message: &message{
Extensible: npayload.Extensible{
Network: m,
Category: Category,
},
message: message{
stateRootEnabled: stateRootEnabled,
},
}
}

func (s *service) newPayload() payload.ConsensusPayload {
return NewPayload(s.network, s.stateRootEnabled)
func (s *service) newPayload(c *dbft.Context, t payload.MessageType, msg interface{}) payload.ConsensusPayload {
cp := NewPayload(s.network, s.stateRootEnabled)
cp.SetHeight(c.BlockIndex)
cp.SetValidatorIndex(uint16(c.MyIndex))
cp.SetViewNumber(c.ViewNumber)
cp.SetType(t)
if pr, ok := msg.(*prepareRequest); ok {
pr.SetPrevHash(s.dbft.PrevHash)
pr.SetVersion(s.dbft.Version)
}
cp.SetPayload(msg)

cp.Extensible.ValidBlockStart = 0
cp.Extensible.ValidBlockEnd = c.BlockIndex
cp.Extensible.Sender = c.Validators[c.MyIndex].(*publicKey).GetScriptHash()

return cp
}

func (s *service) newPrepareRequest() payload.PrepareRequest {
Expand Down Expand Up @@ -257,7 +279,7 @@ events:
s.dbft.OnTimeout(hv)
case msg := <-s.messages:
fields := []zap.Field{
zap.Uint8("from", msg.validatorIndex),
zap.Uint8("from", msg.message.ValidatorIndex),
zap.Stringer("type", msg.Type()),
}

Expand Down Expand Up @@ -312,14 +334,13 @@ func (s *service) handleChainBlock(b *coreb.Block) {

func (s *service) validatePayload(p *Payload) bool {
validators := s.getValidators()
if int(p.validatorIndex) >= len(validators) {
if int(p.message.ValidatorIndex) >= len(validators) {
return false
}

pub := validators[p.validatorIndex]
pub := validators[p.message.ValidatorIndex]
h := pub.(*publicKey).GetScriptHash()

return s.Chain.VerifyWitness(h, p, &p.Witness, payloadGasLimit) == nil
return p.Sender == h
}

func (s *service) getKeyPair(pubs []crypto.PublicKey) (int, crypto.PrivateKey, crypto.PublicKey) {
Expand All @@ -346,14 +367,27 @@ func (s *service) getKeyPair(pubs []crypto.PublicKey) (int, crypto.PrivateKey, c
return -1, nil, nil
}

func (s *service) payloadFromExtensible(ep *npayload.Extensible) *Payload {
return &Payload{
Extensible: *ep,
message: message{
stateRootEnabled: s.stateRootEnabled,
},
}
}

// OnPayload handles Payload receive.
func (s *service) OnPayload(cp *Payload) {
func (s *service) OnPayload(cp *npayload.Extensible) {
log := s.log.With(zap.Stringer("hash", cp.Hash()))
if s.cache.Has(cp.Hash()) {
log.Debug("payload is already in cache")
return
} else if !s.validatePayload(cp) {
log.Debug("can't validate payload")
}

p := s.payloadFromExtensible(cp)
p.decodeData()
if !s.validatePayload(p) {
log.Info("can't validate payload")
return
}

Expand All @@ -366,14 +400,14 @@ func (s *service) OnPayload(cp *Payload) {
}

// decode payload data into message
if cp.message.payload == nil {
if err := cp.decodeData(); err != nil {
log.Debug("can't decode payload data")
if p.message.payload == nil {
if err := p.decodeData(); err != nil {
log.Info("can't decode payload data")
return
}
}

s.messages <- *cp
s.messages <- *p
}

func (s *service) OnTransaction(tx *transaction.Transaction) {
Expand All @@ -383,13 +417,13 @@ func (s *service) OnTransaction(tx *transaction.Transaction) {
}

// GetPayload returns payload stored in cache.
func (s *service) GetPayload(h util.Uint256) *Payload {
func (s *service) GetPayload(h util.Uint256) *npayload.Extensible {
p := s.cache.Get(h)
if p == nil {
return (*Payload)(nil)
return (*npayload.Extensible)(nil)
}

cp := *p.(*Payload)
cp := *p.(*npayload.Extensible)

return &cp
}
Expand All @@ -399,8 +433,9 @@ func (s *service) broadcast(p payload.ConsensusPayload) {
s.log.Warn("can't sign consensus payload", zap.Error(err))
}

s.cache.Add(p)
s.Config.Broadcast(p.(*Payload))
ep := &p.(*Payload).Extensible
s.cache.Add(ep)
s.Config.Broadcast(ep)
}

func (s *service) getTx(h util.Uint256) block.Transaction {
Expand Down Expand Up @@ -479,14 +514,26 @@ func (s *service) verifyBlock(b block.Block) bool {
return true
}

var (
errInvalidPrevHash = errors.New("invalid PrevHash")
errInvalidVersion = errors.New("invalid Version")
errInvalidStateRoot = errors.New("state root mismatch")
)

func (s *service) verifyRequest(p payload.ConsensusPayload) error {
req := p.GetPrepareRequest().(*prepareRequest)
if req.prevHash != s.dbft.PrevHash {
return errInvalidPrevHash
}
if req.version != s.dbft.Version {
return errInvalidVersion
}
if s.stateRootEnabled {
sr, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1)
if err != nil {
return err
} else if sr.Root != req.stateRoot {
return fmt.Errorf("state root mismatch: %s != %s", sr.Root, req.stateRoot)
return fmt.Errorf("%w: %s != %s", errInvalidStateRoot, sr.Root, req.stateRoot)
}
}
// Save lastProposal for getVerified().
Expand Down
68 changes: 47 additions & 21 deletions pkg/consensus/consensus_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package consensus

import (
"errors"
"testing"
"time"

"github.com/nspcc-dev/dbft/block"
"github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/dbft/timer"
"github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/internal/testchain"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
Expand All @@ -19,6 +21,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/io"
npayload "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
"github.com/nspcc-dev/neo-go/pkg/util"
Expand Down Expand Up @@ -180,11 +183,10 @@ func TestService_GetVerified(t *testing.T) {
// Everyone sends a message.
for i := 0; i < 4; i++ {
p := new(Payload)
p.message = &message{}
// One PrepareRequest and three ChangeViews.
if i == 1 {
p.SetType(payload.PrepareRequestType)
p.SetPayload(&prepareRequest{transactionHashes: hashes})
p.SetPayload(&prepareRequest{prevHash: srv.Chain.CurrentBlockHash(), transactionHashes: hashes})
} else {
p.SetType(payload.ChangeViewType)
p.SetPayload(&changeView{newViewNumber: 1, timestamp: uint64(time.Now().UnixNano() / nsInMs)})
Expand Down Expand Up @@ -224,8 +226,7 @@ func TestService_ValidatePayload(t *testing.T) {
srv := newTestService(t)
priv, _ := getTestValidator(1)
p := new(Payload)
p.message = &message{}

p.Sender = priv.GetScriptHash()
p.SetPayload(&prepareRequest{})

t.Run("invalid validator index", func(t *testing.T) {
Expand All @@ -243,8 +244,16 @@ func TestService_ValidatePayload(t *testing.T) {
require.False(t, srv.validatePayload(p))
})

t.Run("invalid sender", func(t *testing.T) {
p.SetValidatorIndex(1)
p.Sender = util.Uint160{}
require.NoError(t, p.Sign(priv))
require.False(t, srv.validatePayload(p))
})

t.Run("normal case", func(t *testing.T) {
p.SetValidatorIndex(1)
p.Sender = priv.GetScriptHash()
require.NoError(t, p.Sign(priv))
require.True(t, srv.validatePayload(p))
})
Expand Down Expand Up @@ -295,22 +304,35 @@ func TestService_PrepareRequest(t *testing.T) {

priv, _ := getTestValidator(1)
p := new(Payload)
p.message = &message{}
p.SetValidatorIndex(1)

p.SetPayload(&prepareRequest{})
require.NoError(t, p.Sign(priv))
require.Error(t, srv.verifyRequest(p), "invalid stateroot setting")
prevHash := srv.Chain.CurrentBlockHash()

p.SetPayload(&prepareRequest{stateRootEnabled: true})
require.NoError(t, p.Sign(priv))
require.Error(t, srv.verifyRequest(p), "invalid state root")
checkRequest := func(t *testing.T, expectedErr error, req *prepareRequest) {
p.SetPayload(req)
require.NoError(t, p.Sign(priv))
err := srv.verifyRequest(p)
if expectedErr == nil {
require.NoError(t, err)
return
}
require.True(t, errors.Is(err, expectedErr), "got: %v", err)
}

checkRequest(t, errInvalidVersion, &prepareRequest{version: 0xFF, prevHash: prevHash})
checkRequest(t, errInvalidPrevHash, &prepareRequest{prevHash: random.Uint256()})
checkRequest(t, errInvalidStateRoot, &prepareRequest{
stateRootEnabled: true,
prevHash: prevHash,
})

sr, err := srv.Chain.GetStateRoot(srv.dbft.BlockIndex - 1)
require.NoError(t, err)
p.SetPayload(&prepareRequest{stateRootEnabled: true, stateRoot: sr.Root})
require.NoError(t, p.Sign(priv))
require.NoError(t, srv.verifyRequest(p))
checkRequest(t, nil, &prepareRequest{
stateRootEnabled: true,
prevHash: prevHash,
stateRoot: sr.Root,
})
}

func TestService_OnPayload(t *testing.T) {
Expand All @@ -322,22 +344,26 @@ func TestService_OnPayload(t *testing.T) {

priv, _ := getTestValidator(1)
p := new(Payload)
p.message = &message{}
p.SetValidatorIndex(1)
p.SetPayload(&prepareRequest{})
p.encodeData()

// payload is not signed
srv.OnPayload(p)
// sender is invalid
srv.OnPayload(&p.Extensible)
shouldNotReceive(t, srv.messages)
require.Nil(t, srv.GetPayload(p.Hash()))

p = new(Payload)
p.SetValidatorIndex(1)
p.Sender = priv.GetScriptHash()
p.SetPayload(&prepareRequest{})
require.NoError(t, p.Sign(priv))
srv.OnPayload(p)
srv.OnPayload(&p.Extensible)
shouldReceive(t, srv.messages)
require.Equal(t, p, srv.GetPayload(p.Hash()))
require.Equal(t, &p.Extensible, srv.GetPayload(p.Hash()))

// payload has already been received
srv.OnPayload(p)
srv.OnPayload(&p.Extensible)
shouldNotReceive(t, srv.messages)
srv.Chain.Close()
}
Expand Down Expand Up @@ -453,7 +479,7 @@ func newTestService(t *testing.T) *service {
func newTestServiceWithChain(t *testing.T, bc *core.Blockchain) *service {
srv, err := NewService(Config{
Logger: zaptest.NewLogger(t),
Broadcast: func(*Payload) {},
Broadcast: func(*npayload.Extensible) {},
Chain: bc,
RequestTx: func(...util.Uint256) {},
TimePerBlock: time.Duration(bc.GetConfig().SecondsPerBlock) * time.Second,
Expand Down
Loading