diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go
index 7af4f7d992..7f27fe57c3 100644
--- a/accounts/abi/bind/backends/simulated.go
+++ b/accounts/abi/bind/backends/simulated.go
@@ -39,8 +39,10 @@ import (
"github.com/ethereum/go-ethereum/eth/filters"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
)
// This nil assignment ensures compile time that SimulatedBackend implements bind.ContractBackend.
@@ -541,3 +543,15 @@ func (fb *filterBackend) BloomStatus() (uint64, uint64) { return 4096, 0 }
func (fb *filterBackend) ServiceFilter(ctx context.Context, ms *bloombits.MatcherSession) {
panic("not supported")
}
+
+func (fb *filterBackend) AccountExtraDataStateGetterByNumber(context.Context, rpc.BlockNumber) (vm.AccountExtraDataStateGetter, error) {
+ panic("not supported")
+}
+
+func (fb *filterBackend) IsAuthorized(ctx context.Context, authToken *proto.PreAuthenticatedAuthenticationToken, attributes ...*multitenancy.ContractSecurityAttribute) (bool, error) {
+ panic("not supported")
+}
+
+func (fb *filterBackend) SupportsMultitenancy(context.Context) (*proto.PreAuthenticatedAuthenticationToken, bool) {
+ panic("not supported")
+}
diff --git a/cmd/geth/main.go b/cmd/geth/main.go
index efd002af53..fe8fd1d2d9 100644
--- a/cmd/geth/main.go
+++ b/cmd/geth/main.go
@@ -42,6 +42,7 @@ import (
"github.com/ethereum/go-ethereum/les"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/permission"
"github.com/ethereum/go-ethereum/plugin"
@@ -169,6 +170,7 @@ var (
utils.PluginPublicKeyFlag,
utils.AllowedFutureBlockTimeFlag,
utils.EVMCallTimeOutFlag,
+ utils.MultitenancyFlag,
// End-Quorum
}
@@ -345,6 +347,8 @@ func geth(ctx *cli.Context) error {
// startNode boots up the system node and all registered protocols, after which
// it unlocks any requested accounts, and starts the RPC/IPC interfaces and the
// miner.
+// Quorum
+// - Enrich eth/les service with ContractAuthorizationProvider for multitenancy support if prequisites are met
func startNode(ctx *cli.Context, stack *node.Node) {
log.DoEmitCheckpoints = ctx.GlobalBool(utils.EmitCheckpointsFlag.Name)
debug.Memsize.Add("node", stack)
@@ -383,6 +387,11 @@ func startNode(ctx *cli.Context, stack *node.Node) {
}
ethClient := ethclient.NewClient(rpcClient)
+ var ethService *eth.Ethereum
+ if err := stack.Service(ðService); err != nil {
+ utils.Fatalf("Failed to retrieve ethereum service: %v", err)
+ }
+ setContractAuthzProviderFunc := ethService.SetContractAuthorizationProvider
// Set contract backend for ethereum service if local node
// is serving LES requests.
if ctx.GlobalInt(utils.LightLegacyServFlag.Name) > 0 || ctx.GlobalInt(utils.LightServeFlag.Name) > 0 {
@@ -400,6 +409,17 @@ func startNode(ctx *cli.Context, stack *node.Node) {
utils.Fatalf("Failed to retrieve light ethereum service: %v", err)
}
lesService.SetContractBackend(ethClient)
+ setContractAuthzProviderFunc = lesService.SetContractAuthorizationManager
+ }
+
+ // Set ContractAuthorizationProvider if multitenancy flag is on AND plugin security is configured
+ if ctx.GlobalBool(utils.MultitenancyFlag.Name) {
+ if stack.PluginManager().IsEnabled(plugin.SecurityPluginInterfaceName) {
+ log.Info("Node supports multitenancy")
+ setContractAuthzProviderFunc(&multitenancy.DefaultContractAuthorizationProvider{})
+ } else {
+ utils.Fatalf("multitenancy requires RPC Security Plugin to be configured")
+ }
}
go func() {
diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go
index 004dd37b12..04af6b1b57 100644
--- a/cmd/geth/usage.go
+++ b/cmd/geth/usage.go
@@ -281,6 +281,7 @@ var AppHelpFlagGroups = []flagGroup{
utils.PluginLocalVerifyFlag,
utils.PluginPublicKeyFlag,
utils.AllowedFutureBlockTimeFlag,
+ utils.MultitenancyFlag,
},
},
{
diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index de672838ad..8b9d21f711 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -875,6 +875,11 @@ var (
Usage: "Default minimum difference between two consecutive block's timestamps in seconds",
Value: eth.DefaultConfig.Istanbul.BlockPeriod,
}
+ // Multitenancy setting
+ MultitenancyFlag = cli.BoolFlag{
+ Name: "multitenancy",
+ Usage: "Enable multitenancy support for this node. This requires RPC Security Plugin to also be configured.",
+ }
)
// MakeDataDir retrieves the currently requested data directory, terminating
@@ -1553,7 +1558,7 @@ func setRaft(ctx *cli.Context, cfg *eth.Config) {
func setQuorumConfig(ctx *cli.Context, cfg *eth.Config) {
cfg.EVMCallTimeOut = time.Duration(ctx.GlobalInt(EVMCallTimeOutFlag.Name)) * time.Second
-
+ cfg.EnableMultitenancy = ctx.GlobalBool(MultitenancyFlag.Name)
setIstanbul(ctx, cfg)
setRaft(ctx, cfg)
}
diff --git a/common/slice.go b/common/slice.go
new file mode 100644
index 0000000000..a75f7c6485
--- /dev/null
+++ b/common/slice.go
@@ -0,0 +1,54 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package common
+
+// ContainsAll returns true if all elements in the target are in the source,
+// false otherwise.
+func ContainsAll(source, target []string) bool {
+ mark := make(map[string]bool, len(source))
+ for _, str := range source {
+ mark[str] = true
+ }
+ for _, str := range target {
+ if _, found := mark[str]; !found {
+ return false
+ }
+ }
+ return true
+}
+
+// ContainsAll returns true if all elements in the target are NOT in the source,
+// false otherwise.
+func NotContainsAll(source, target []string) bool {
+ return !ContainsAll(source, target)
+}
+
+// AppendSkipDuplicates appends source with elements with a condition
+// that those elemments must NOT already exist in the source
+func AppendSkipDuplicates(slice []string, elems ...string) (result []string) {
+ mark := make(map[string]bool, len(slice))
+ for _, val := range slice {
+ mark[val] = true
+ }
+ result = slice
+ for _, val := range elems {
+ if _, ok := mark[val]; !ok {
+ result = append(result, val)
+ }
+ }
+ return result
+}
diff --git a/common/slice_test.go b/common/slice_test.go
new file mode 100644
index 0000000000..88a63e7819
--- /dev/null
+++ b/common/slice_test.go
@@ -0,0 +1,93 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package common
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestContainsAll_whenTypical(t *testing.T) {
+ source := []string{"1", "2"}
+ target := []string{"1", "2"}
+
+ assert.True(t, ContainsAll(source, target))
+}
+
+func TestContainsAll_whenNot(t *testing.T) {
+ source := []string{"1", "2"}
+ target := []string{"3", "4"}
+
+ assert.False(t, ContainsAll(source, target))
+}
+
+func TestContainsAll_whenTargetIsSubset(t *testing.T) {
+ source := []string{"1", "2"}
+ target := []string{"1"}
+
+ assert.True(t, ContainsAll(source, target))
+}
+
+func TestContainsAll_whenTargetIsSuperSet(t *testing.T) {
+ source := []string{"2"}
+ target := []string{"1", "2"}
+
+ assert.False(t, ContainsAll(source, target))
+}
+
+func TestContainsAll_whenSourceIsEmpty(t *testing.T) {
+ var source []string
+ target := []string{"1", "2"}
+
+ assert.False(t, ContainsAll(source, target))
+}
+
+func TestContainsAll_whenSourceIsNil(t *testing.T) {
+ target := []string{"1", "2"}
+
+ assert.False(t, ContainsAll(nil, target))
+}
+
+func TestContainsAll_whenTargetIsEmpty(t *testing.T) {
+ source := []string{"1", "2"}
+
+ assert.True(t, ContainsAll(source, []string{}))
+}
+
+func TestContainsAll_whenTargetIsNil(t *testing.T) {
+ source := []string{"1", "2"}
+
+ assert.True(t, ContainsAll(source, nil))
+}
+
+func TestAppendSkipDuplicates_whenTypical(t *testing.T) {
+ source := []string{"1", "2"}
+ additional := []string{"1", "3"}
+
+ assert.Equal(t, []string{"1", "2", "3"}, AppendSkipDuplicates(source, additional...))
+}
+
+func TestAppendSkipDuplicates_whenSourceIsNil(t *testing.T) {
+ additional := []string{"1", "3"}
+
+ assert.Equal(t, []string{"1", "3"}, AppendSkipDuplicates(nil, additional...))
+}
+
+func TestAppendSkipDuplicates_whenElementIsNil(t *testing.T) {
+ assert.Equal(t, []string{"1", "3"}, AppendSkipDuplicates([]string{"1", "3"}, nil...))
+}
diff --git a/common/types.go b/common/types.go
index 6d9ec725fe..5ccf04534c 100644
--- a/common/types.go
+++ b/common/types.go
@@ -21,6 +21,7 @@ import (
"encoding/base64"
"encoding/hex"
"encoding/json"
+ "errors"
"fmt"
"math/big"
"math/rand"
@@ -42,6 +43,9 @@ const (
)
var (
+ ErrNotPrivateContract = errors.New("the provided address is not a private contract")
+ ErrNoAccountExtraData = errors.New("no account extra data found")
+
hashT = reflect.TypeOf(Hash{})
addressT = reflect.TypeOf(Address{})
)
diff --git a/core/blockchain.go b/core/blockchain.go
index 6306347d46..0d72ee3fa1 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -18,6 +18,7 @@
package core
import (
+ "context"
"errors"
"fmt"
"io"
@@ -27,6 +28,8 @@ import (
"sync/atomic"
"time"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
+
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/common/mclock"
@@ -180,6 +183,7 @@ type BlockChain struct {
setPrivateState func([]*types.Log, *state.StateDB) // Function to check extension and set private state
privateStateCache state.Database // Private state database to reuse between imports (contains state cache)
+ isMultitenant bool // if this blockchain supports multitenancy
}
// function pointer for updating private state
@@ -314,6 +318,15 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
return bc, nil
}
+func NewMultitenantBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config, shouldPreserve func(block *types.Block) bool) (*BlockChain, error) {
+ bc, err := NewBlockChain(db, cacheConfig, chainConfig, engine, vmConfig, shouldPreserve)
+ if err != nil {
+ return nil, err
+ }
+ bc.isMultitenant = true
+ return bc, err
+}
+
func (bc *BlockChain) getProcInterrupt() bool {
return atomic.LoadInt32(&bc.procInterrupt) == 1
}
@@ -2366,3 +2379,7 @@ func (bc *BlockChain) SubscribeLogsEvent(ch chan<- []*types.Log) event.Subscript
func (bc *BlockChain) SubscribeBlockProcessingEvent(ch chan<- bool) event.Subscription {
return bc.scope.Track(bc.blockProcFeed.Subscribe(ch))
}
+
+func (bc *BlockChain) SupportsMultitenancy(context.Context) (*proto.PreAuthenticatedAuthenticationToken, bool) {
+ return nil, bc.isMultitenant
+}
diff --git a/core/evm.go b/core/evm.go
index b654bbd479..ded12c447b 100644
--- a/core/evm.go
+++ b/core/evm.go
@@ -17,17 +17,21 @@
package core
import (
+ "context"
"math/big"
+ "reflect"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
+ "github.com/ethereum/go-ethereum/multitenancy"
)
// ChainContext supports retrieving headers and consensus parameters from the
// current blockchain to be used during transaction processing.
type ChainContext interface {
+ multitenancy.ContextAware
// Engine retrieves the chain's consensus engine.
Engine() consensus.Engine
@@ -44,6 +48,12 @@ func NewEVMContext(msg Message, header *types.Header, chain ChainContext, author
} else {
beneficiary = *author
}
+ supportsMultitenancy := false
+ // mainly to overcome lost of test cases which pass ChainContext as nil value
+ // nil interface requires this check to make sure we don't get nil pointer reference error
+ if chain != nil && !reflect.ValueOf(chain).IsNil() {
+ _, supportsMultitenancy = chain.SupportsMultitenancy(nil)
+ }
return vm.Context{
CanTransfer: CanTransfer,
Transfer: Transfer,
@@ -55,7 +65,23 @@ func NewEVMContext(msg Message, header *types.Header, chain ChainContext, author
Difficulty: new(big.Int).Set(header.Difficulty),
GasLimit: header.GasLimit,
GasPrice: new(big.Int).Set(msg.GasPrice()),
+
+ SupportsMultitenancy: supportsMultitenancy,
+ }
+}
+
+// Quorum
+//
+// This EVM context is meant for simulation when doing multitenancy check.
+// It enriches the given EVM context with multitenancy-specific references
+func NewMultitenancyAwareEVMContext(ctx context.Context, evmCtx vm.Context) vm.Context {
+ if f, ok := ctx.Value(multitenancy.CtxKeyAuthorizeCreateFunc).(multitenancy.AuthorizeCreateFunc); ok {
+ evmCtx.AuthorizeCreateFunc = f
+ }
+ if f, ok := ctx.Value(multitenancy.CtxKeyAuthorizeMessageCallFunc).(multitenancy.AuthorizeMessageCallFunc); ok {
+ evmCtx.AuthorizeMessageCallFunc = f
}
+ return evmCtx
}
// GetHashFn returns a GetHashFunc which retrieves header hashes by number
diff --git a/core/rawdb/database_quorum.go b/core/rawdb/database_quorum.go
index bf51489bfe..4d5f169839 100644
--- a/core/rawdb/database_quorum.go
+++ b/core/rawdb/database_quorum.go
@@ -27,10 +27,13 @@ import (
)
var (
- privateRootPrefix = []byte("P")
- privateBloomPrefix = []byte("Pb")
- quorumEIP155ActivatedPrefix = []byte("quorum155active")
- privateRootToPrivacyMetadataRootPrefix = []byte("PSR2PMDR")
+ privateRootPrefix = []byte("P")
+ privateBloomPrefix = []byte("Pb")
+ quorumEIP155ActivatedPrefix = []byte("quorum155active")
+ // Quorum
+ // we introduce a generic approach to store extra data for an account. PrivacyMetadata is wrapped.
+ // However, this value is kept as-is to support backward compatibility
+ stateRootToExtraDataRootPrefix = []byte("PSR2PMDR")
// emptyRoot is the known root hash of an empty trie. Duplicate from `trie/trie.go#emptyRoot`
emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
)
@@ -52,8 +55,8 @@ func GetPrivateStateRoot(db ethdb.Database, blockRoot common.Hash) common.Hash {
return common.BytesToHash(root)
}
-func GetPrivacyMetadataStateRootForPrivateStateRoot(db ethdb.KeyValueReader, privateStateRoot common.Hash) common.Hash {
- root, _ := db.Get(append(privateRootToPrivacyMetadataRootPrefix, privateStateRoot[:]...))
+func GetAccountExtraDataRoot(db ethdb.KeyValueReader, stateRoot common.Hash) common.Hash {
+ root, _ := db.Get(append(stateRootToExtraDataRootPrefix, stateRoot[:]...))
return common.BytesToHash(root)
}
@@ -61,8 +64,10 @@ func WritePrivateStateRoot(db ethdb.Database, blockRoot, root common.Hash) error
return db.Put(append(privateRootPrefix, blockRoot[:]...), root[:])
}
-func WritePrivacyMetadataStateRootForPrivateStateRoot(db ethdb.KeyValueWriter, privateStateRoot, privacyMetadataRoot common.Hash) error {
- return db.Put(append(privateRootToPrivacyMetadataRootPrefix, privateStateRoot[:]...), privacyMetadataRoot[:])
+// WriteRootHashMapping stores the mapping between root hash of state trie and
+// root hash of state.AccountExtraData trie to persistent storage
+func WriteRootHashMapping(db ethdb.KeyValueWriter, stateRoot, extraDataRoot common.Hash) error {
+ return db.Put(append(stateRootToExtraDataRootPrefix, stateRoot[:]...), extraDataRoot[:])
}
// WritePrivateBlockBloom creates a bloom filter for the given receipts and saves it to the database
@@ -81,28 +86,39 @@ func GetPrivateBlockBloom(db ethdb.Database, number uint64) (bloom types.Bloom)
return bloom
}
-type PrivacyMetadataLinker interface {
- PrivacyMetadataRootForPrivateStateRoot(privateStateRoot common.Hash) common.Hash
- LinkPrivacyMetadataRootToPrivateStateRoot(privateStateRoot, privacyMetadataRoot common.Hash) error
+// AccountExtraDataLinker maintains mapping between root hash of the state trie
+// and root hash of state.AccountExtraData trie
+type AccountExtraDataLinker interface {
+ // GetAccountExtraDataRoot returns the root hash of the state.AccountExtraData trie from
+ // the given root hash of the state trie.
+ //
+ // It returns an empty hash if not found.
+ GetAccountExtraDataRoot(stateRoot common.Hash) common.Hash
+ // Link saves the mapping between root hash of the state trie and
+ // root hash of state.AccountExtraData trie to the persistent storage.
+ // Don't write the mapping if extraDataRoot is an emptyRoot
+ Link(stateRoot, extraDataRoot common.Hash) error
}
-type ethDBPrivacyMetadataLinker struct {
+// ethdbAccountExtraDataLinker implements AccountExtraDataLinker using ethdb.Database
+// as the persistence storage
+type ethdbAccountExtraDataLinker struct {
db ethdb.Database
}
-func NewPrivacyMetadataLinker(db ethdb.Database) PrivacyMetadataLinker {
- return ðDBPrivacyMetadataLinker{
+func NewAccountExtraDataLinker(db ethdb.Database) AccountExtraDataLinker {
+ return ðdbAccountExtraDataLinker{
db: db,
}
}
-func (pml *ethDBPrivacyMetadataLinker) PrivacyMetadataRootForPrivateStateRoot(privateStateRoot common.Hash) common.Hash {
- return GetPrivacyMetadataStateRootForPrivateStateRoot(pml.db, privateStateRoot)
+func (pml *ethdbAccountExtraDataLinker) GetAccountExtraDataRoot(stateRoot common.Hash) common.Hash {
+ return GetAccountExtraDataRoot(pml.db, stateRoot)
}
-func (pml *ethDBPrivacyMetadataLinker) LinkPrivacyMetadataRootToPrivateStateRoot(privateStateRoot, privacyMetadataRoot common.Hash) error {
- if privacyMetadataRoot != emptyRoot {
- return WritePrivacyMetadataStateRootForPrivateStateRoot(pml.db, privateStateRoot, privacyMetadataRoot)
+func (pml *ethdbAccountExtraDataLinker) Link(stateRoot, extraDataRoot common.Hash) error {
+ if extraDataRoot != emptyRoot {
+ return WriteRootHashMapping(pml.db, stateRoot, extraDataRoot)
}
return nil
}
diff --git a/core/rawdb/database_quorum_test.go b/core/rawdb/database_quorum_test.go
index bc6598f1fe..b44c7f56f8 100644
--- a/core/rawdb/database_quorum_test.go
+++ b/core/rawdb/database_quorum_test.go
@@ -46,39 +46,39 @@ func TestIsQuorumEIP155Active(t *testing.T) {
}
}
-func TestPrivacyMedatadaLinkEmptyRoot(t *testing.T) {
+func TestAccountExtraDataLinker_whenLinkingEmptyRoot(t *testing.T) {
db := NewMemoryDatabase()
psr := common.Hash{1}
- pml := NewPrivacyMetadataLinker(db)
+ linker := NewAccountExtraDataLinker(db)
- err := pml.LinkPrivacyMetadataRootToPrivateStateRoot(psr, emptyRoot)
+ err := linker.Link(psr, emptyRoot)
if err != nil {
t.Fatal("unable to store the link")
}
- value, _ := db.Get(append(privateRootToPrivacyMetadataRootPrefix, psr[:]...))
+ value, _ := db.Get(append(stateRootToExtraDataRootPrefix, psr[:]...))
if value != nil {
t.Fatal("the mapping should not have been stored")
}
}
-func TestPrivacyMedatadaLinkRoot(t *testing.T) {
+func TestAccountExtraDataLinker_whenLinkingRoots(t *testing.T) {
db := NewMemoryDatabase()
psr := common.Hash{1}
pmr := common.Hash{2}
- pml := NewPrivacyMetadataLinker(db)
+ linker := NewAccountExtraDataLinker(db)
- err := pml.LinkPrivacyMetadataRootToPrivateStateRoot(psr, pmr)
+ err := linker.Link(psr, pmr)
if err != nil {
t.Fatal("unable to store the link")
}
- value, _ := db.Get(append(privateRootToPrivacyMetadataRootPrefix, psr[:]...))
+ value, _ := db.Get(append(stateRootToExtraDataRootPrefix, psr[:]...))
if value == nil {
t.Fatal("the mapping should have been stored")
@@ -101,14 +101,14 @@ func (t *ReadOnlyDB) Put(key []byte, value []byte) error {
return errReadOnly
}
-func TestPrivacyMedatadaLinkRootErrorWrapping(t *testing.T) {
+func TestAccountExtraDataLinker_whenError(t *testing.T) {
db := NewDatabase(&ReadOnlyDB{})
psr := common.Hash{1}
pmr := common.Hash{2}
- pml := NewPrivacyMetadataLinker(db)
+ linker := NewAccountExtraDataLinker(db)
- err := pml.LinkPrivacyMetadataRootToPrivateStateRoot(psr, pmr)
+ err := linker.Link(psr, pmr)
if err == nil {
t.Fatal("expecting a read only error to be returned")
@@ -119,35 +119,35 @@ func TestPrivacyMedatadaLinkRootErrorWrapping(t *testing.T) {
}
}
-func TestPrivacyMedatadaRetrievePrivacyMetadataRoot(t *testing.T) {
+func TestAccountExtraDataLinker_whenFinding(t *testing.T) {
db := NewMemoryDatabase()
psr := common.Hash{1}
pmr := common.Hash{2}
- err := db.Put(append(privateRootToPrivacyMetadataRootPrefix, psr[:]...), pmr[:])
+ err := db.Put(append(stateRootToExtraDataRootPrefix, psr[:]...), pmr[:])
if err != nil {
t.Fatal("unable to write to db")
}
- pml := NewPrivacyMetadataLinker(db)
+ pml := NewAccountExtraDataLinker(db)
- pmrRetrieved := pml.PrivacyMetadataRootForPrivateStateRoot(psr)
+ pmrRetrieved := pml.GetAccountExtraDataRoot(psr)
if pmrRetrieved != pmr {
t.Fatal("the mapping should have been retrieved")
}
}
-func TestPrivacyMedatadaRetrieveEmptyPrivacyMetadataRoot(t *testing.T) {
+func TestAccountExtraDataLinker_whenNotFound(t *testing.T) {
db := NewMemoryDatabase()
psr := common.Hash{1}
- pml := NewPrivacyMetadataLinker(db)
+ pml := NewAccountExtraDataLinker(db)
- pmrRetrieved := pml.PrivacyMetadataRootForPrivateStateRoot(psr)
+ pmrRetrieved := pml.GetAccountExtraDataRoot(psr)
if !common.EmptyHash(pmrRetrieved) {
- t.Fatal("the retrieved privacy metadata root should be thg empty hash")
+ t.Fatal("the retrieved privacy metadata root should be the empty hash")
}
}
diff --git a/core/state/account_extra_data.go b/core/state/account_extra_data.go
new file mode 100644
index 0000000000..8e05e8c974
--- /dev/null
+++ b/core/state/account_extra_data.go
@@ -0,0 +1,160 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "fmt"
+ "io"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/private/engine"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// Quorum
+// AccountExtraData is to contain extra data that supplements existing Account data.
+// It is also maintained in a trie to support rollback.
+// Note:
+// - update copy() method
+// - update DecodeRLP and EncodeRLP when adding new field
+type AccountExtraData struct {
+ // for privacy enhancements
+ PrivacyMetadata *PrivacyMetadata
+ // list of public keys managed by the corresponding Tessera.
+ // This is for multitenancy
+ ManagedParties []string
+}
+
+func (qmd *AccountExtraData) DecodeRLP(stream *rlp.Stream) error {
+ var dataRLP struct {
+ // from state.PrivacyMetadata, this is required to support
+ // backward compatibility with RLP-encoded state.PrivacyMetadata.
+ // Refer to rlp/doc.go for decoding rules.
+ CreationTxHash *common.EncryptedPayloadHash `rlp:"nil"`
+ // from state.PrivacyMetadata, this is required to support
+ // backward compatibility with RLP-encoded state.PrivacyMetadata.
+ // Refer to rlp/doc.go for decoding rules.
+ PrivacyFlag *engine.PrivacyFlagType `rlp:"nil"`
+
+ Rest []rlp.RawValue `rlp:"tail"` // to maintain forward compatibility
+ }
+ if err := stream.Decode(&dataRLP); err != nil {
+ return err
+ }
+ if dataRLP.CreationTxHash != nil && dataRLP.PrivacyFlag != nil {
+ qmd.PrivacyMetadata = &PrivacyMetadata{
+ CreationTxHash: *dataRLP.CreationTxHash,
+ PrivacyFlag: *dataRLP.PrivacyFlag,
+ }
+ }
+ if len(dataRLP.Rest) > 0 {
+ var managedParties []string
+ if err := rlp.DecodeBytes(dataRLP.Rest[0], &managedParties); err != nil {
+ return fmt.Errorf("fail to decode managedParties with error %v", err)
+ }
+ // As RLP encodes empty slice or nil slice as an empty string (192)
+ // we won't be able to determine when decoding. So we use pragmatic approach
+ // to default to nil value. Downstream usage would deal with it easier.
+ if len(managedParties) == 0 {
+ qmd.ManagedParties = nil
+ } else {
+ qmd.ManagedParties = managedParties
+ }
+ }
+ return nil
+}
+
+func (qmd *AccountExtraData) EncodeRLP(writer io.Writer) error {
+ var (
+ hash *common.EncryptedPayloadHash
+ flag *engine.PrivacyFlagType
+ )
+ if qmd.PrivacyMetadata != nil {
+ hash = &qmd.PrivacyMetadata.CreationTxHash
+ flag = &qmd.PrivacyMetadata.PrivacyFlag
+ }
+ return rlp.Encode(writer, struct {
+ CreationTxHash *common.EncryptedPayloadHash `rlp:"nil"`
+ PrivacyFlag *engine.PrivacyFlagType `rlp:"nil"`
+ ManagedParties []string
+ }{
+ CreationTxHash: hash,
+ PrivacyFlag: flag,
+ ManagedParties: qmd.ManagedParties,
+ })
+}
+
+func (qmd *AccountExtraData) copy() *AccountExtraData {
+ if qmd == nil {
+ return nil
+ }
+ var copyPM *PrivacyMetadata
+ if qmd.PrivacyMetadata != nil {
+ copyPM = &PrivacyMetadata{
+ CreationTxHash: qmd.PrivacyMetadata.CreationTxHash,
+ PrivacyFlag: qmd.PrivacyMetadata.PrivacyFlag,
+ }
+ }
+ copyManagedParties := make([]string, len(qmd.ManagedParties))
+ copy(copyManagedParties, qmd.ManagedParties)
+ return &AccountExtraData{
+ PrivacyMetadata: copyPM,
+ ManagedParties: copyManagedParties,
+ }
+}
+
+// attached to every private contract account
+type PrivacyMetadata struct {
+ CreationTxHash common.EncryptedPayloadHash `json:"creationTxHash"`
+ PrivacyFlag engine.PrivacyFlagType `json:"privacyFlag"`
+}
+
+// Quorum
+// privacyMetadataRLP struct is to make sure
+// field order is preserved regardless changes in the PrivacyMetadata and its internal
+//
+// Edit this struct with care to make sure forward and backward compatibility
+type privacyMetadataRLP struct {
+ CreationTxHash common.EncryptedPayloadHash
+ PrivacyFlag engine.PrivacyFlagType
+
+ Rest []rlp.RawValue `rlp:"tail"` // to maintain forward compatibility
+}
+
+func (p *PrivacyMetadata) DecodeRLP(stream *rlp.Stream) error {
+ var dataRLP privacyMetadataRLP
+ if err := stream.Decode(&dataRLP); err != nil {
+ return err
+ }
+ p.CreationTxHash = dataRLP.CreationTxHash
+ p.PrivacyFlag = dataRLP.PrivacyFlag
+ return nil
+}
+
+func (p *PrivacyMetadata) EncodeRLP(writer io.Writer) error {
+ return rlp.Encode(writer, privacyMetadataRLP{
+ CreationTxHash: p.CreationTxHash,
+ PrivacyFlag: p.PrivacyFlag,
+ })
+}
+
+func NewStatePrivacyMetadata(creationTxHash common.EncryptedPayloadHash, privacyFlag engine.PrivacyFlagType) *PrivacyMetadata {
+ return &PrivacyMetadata{
+ CreationTxHash: creationTxHash,
+ PrivacyFlag: privacyFlag,
+ }
+}
diff --git a/core/state/account_extra_data_test.go b/core/state/account_extra_data_test.go
new file mode 100644
index 0000000000..f158a1c805
--- /dev/null
+++ b/core/state/account_extra_data_test.go
@@ -0,0 +1,193 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/private/engine"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/stretchr/testify/assert"
+)
+
+type privacyMetadataOld struct {
+ CreationTxHash common.EncryptedPayloadHash
+ PrivacyFlag engine.PrivacyFlagType
+}
+
+// privacyMetadataToBytes is the utility function under test from previous implementation
+func privacyMetadataToBytes(pm *privacyMetadataOld) ([]byte, error) {
+ return rlp.EncodeToBytes(pm)
+}
+
+// bytesToPrivacyMetadata is the utility function under test from previous implementation
+func bytesToPrivacyMetadata(b []byte) (*privacyMetadataOld, error) {
+ var data *privacyMetadataOld
+ if err := rlp.DecodeBytes(b, &data); err != nil {
+ return nil, fmt.Errorf("unable to decode privacy metadata. Cause: %v", err)
+ }
+ return data, nil
+}
+
+func TestRLP_PrivacyMetadata_DecodeBackwardCompatibility(t *testing.T) {
+ existingPM := &privacyMetadataOld{
+ CreationTxHash: common.BytesToEncryptedPayloadHash([]byte("arbitrary-hash")),
+ PrivacyFlag: engine.PrivacyFlagStateValidation,
+ }
+ existing, err := privacyMetadataToBytes(existingPM)
+ assert.NoError(t, err)
+
+ var actual PrivacyMetadata
+ err = rlp.DecodeBytes(existing, &actual)
+
+ assert.NoError(t, err, "Must decode PrivacyMetadata successfully")
+ assert.Equal(t, existingPM.CreationTxHash, actual.CreationTxHash)
+ assert.Equal(t, existingPM.PrivacyFlag, actual.PrivacyFlag)
+}
+
+func TestRLP_PrivacyMetadata_DecodeForwardCompatibility(t *testing.T) {
+ pm := &PrivacyMetadata{
+ CreationTxHash: common.BytesToEncryptedPayloadHash([]byte("arbitrary-hash")),
+ PrivacyFlag: engine.PrivacyFlagStateValidation,
+ }
+ existing, err := rlp.EncodeToBytes(pm)
+ assert.NoError(t, err)
+
+ var actual *privacyMetadataOld
+ actual, err = bytesToPrivacyMetadata(existing)
+
+ assert.NoError(t, err, "Must encode PrivacyMetadata successfully")
+ assert.Equal(t, pm.CreationTxHash, actual.CreationTxHash)
+ assert.Equal(t, pm.PrivacyFlag, actual.PrivacyFlag)
+}
+
+// From initial privacy enhancements, the privacy metadata is RLP encoded
+// we now wrap PrivacyMetadata in a more generic struct. This test is to make sure
+// we support backward compatibility.
+func TestRLP_AccountExtraData_BackwardCompatibility(t *testing.T) {
+ // prepare existing RLP bytes
+ arbitraryExistingMetadata := &PrivacyMetadata{
+ CreationTxHash: common.BytesToEncryptedPayloadHash([]byte("arbitrary-existing-privacy-metadata-creation-hash")),
+ PrivacyFlag: engine.PrivacyFlagPartyProtection,
+ }
+ existing, err := rlp.EncodeToBytes(arbitraryExistingMetadata)
+ assert.NoError(t, err)
+
+ // now try to decode with the new struct
+ var actual AccountExtraData
+ err = rlp.DecodeBytes(existing, &actual)
+
+ assert.NoError(t, err, "Must decode successfully")
+ assert.Equal(t, arbitraryExistingMetadata.CreationTxHash, actual.PrivacyMetadata.CreationTxHash)
+ assert.Equal(t, arbitraryExistingMetadata.PrivacyFlag, actual.PrivacyMetadata.PrivacyFlag)
+}
+
+func TestRLP_AccountExtraData_withField_ManagedParties(t *testing.T) {
+ // prepare existing RLP bytes
+ arbitraryExtraData := &AccountExtraData{
+ PrivacyMetadata: &PrivacyMetadata{
+ CreationTxHash: common.BytesToEncryptedPayloadHash([]byte("arbitrary-existing-privacy-metadata-creation-hash")),
+ PrivacyFlag: engine.PrivacyFlagPartyProtection,
+ },
+ ManagedParties: []string{"Arbitrary PK1", "Arbitrary PK2"},
+ }
+ existing, err := rlp.EncodeToBytes(arbitraryExtraData)
+ assert.NoError(t, err)
+
+ // now try to decode with the new struct
+ var actual AccountExtraData
+ err = rlp.DecodeBytes(existing, &actual)
+
+ assert.NoError(t, err, "Must decode successfully")
+ assert.Equal(t, arbitraryExtraData.PrivacyMetadata.CreationTxHash, actual.PrivacyMetadata.CreationTxHash)
+ assert.Equal(t, arbitraryExtraData.PrivacyMetadata.PrivacyFlag, actual.PrivacyMetadata.PrivacyFlag)
+ assert.Equal(t, arbitraryExtraData.ManagedParties, actual.ManagedParties)
+}
+
+func TestRLP_AccountExtraData_whenTypical(t *testing.T) {
+ expected := AccountExtraData{
+ PrivacyMetadata: &PrivacyMetadata{
+ CreationTxHash: common.BytesToEncryptedPayloadHash([]byte("arbitrary-payload-hash")),
+ PrivacyFlag: engine.PrivacyFlagPartyProtection,
+ },
+ ManagedParties: []string{"XYZ"},
+ }
+
+ data, err := rlp.EncodeToBytes(&expected)
+ assert.NoError(t, err)
+
+ var actual AccountExtraData
+ assert.NoError(t, rlp.DecodeBytes(data, &actual))
+ assert.Equal(t, expected.PrivacyMetadata.CreationTxHash, actual.PrivacyMetadata.CreationTxHash)
+ assert.Equal(t, expected.PrivacyMetadata.PrivacyFlag, actual.PrivacyMetadata.PrivacyFlag)
+ assert.Equal(t, expected.ManagedParties, actual.ManagedParties)
+}
+
+func TestRLP_AccountExtraData_whenHavingPrivacyMetadataOnly(t *testing.T) {
+ expected := AccountExtraData{
+ PrivacyMetadata: &PrivacyMetadata{
+ CreationTxHash: common.BytesToEncryptedPayloadHash([]byte("arbitrary-payload-hash")),
+ PrivacyFlag: engine.PrivacyFlagPartyProtection,
+ },
+ }
+
+ data, err := rlp.EncodeToBytes(&expected)
+ assert.NoError(t, err)
+
+ var actual AccountExtraData
+ assert.NoError(t, rlp.DecodeBytes(data, &actual))
+ assert.Equal(t, expected.PrivacyMetadata.CreationTxHash, actual.PrivacyMetadata.CreationTxHash)
+ assert.Equal(t, expected.PrivacyMetadata.PrivacyFlag, actual.PrivacyMetadata.PrivacyFlag)
+}
+
+func TestRLP_AccountExtraData_whenHavingNilManagedParties(t *testing.T) {
+ expected := AccountExtraData{
+ PrivacyMetadata: nil,
+ ManagedParties: nil,
+ }
+
+ data, err := rlp.EncodeToBytes(&expected)
+ assert.NoError(t, err)
+
+ var actual AccountExtraData
+ assert.NoError(t, rlp.DecodeBytes(data, &actual))
+ assert.Nil(t, actual.ManagedParties)
+ assert.Nil(t, actual.PrivacyMetadata)
+}
+
+func TestRLP_AccountExtraData_whenHavingEmptyManagedParties(t *testing.T) {
+ expected := AccountExtraData{
+ PrivacyMetadata: nil,
+ ManagedParties: []string{},
+ }
+
+ data, err := rlp.EncodeToBytes(&expected)
+ assert.NoError(t, err)
+
+ var actual AccountExtraData
+ assert.NoError(t, rlp.DecodeBytes(data, &actual))
+ assert.Nil(t, actual.ManagedParties)
+ assert.Nil(t, actual.PrivacyMetadata)
+}
+
+func TestCopy_whenNil(t *testing.T) {
+ var testObj *AccountExtraData = nil
+
+ assert.Nil(t, testObj.copy())
+}
diff --git a/core/state/database.go b/core/state/database.go
index fafe90451a..32302a47e2 100644
--- a/core/state/database.go
+++ b/core/state/database.go
@@ -51,8 +51,11 @@ type Database interface {
// TrieDB retrieves the low level trie database used for data storage.
TrieDB() *trie.Database
- // Privacy metadata linker
- PrivacyMetadataLinker() rawdb.PrivacyMetadataLinker
+ // Quorum
+ //
+ // accountExtraDataLinker maintains mapping between root hash of the state trie
+ // and root hash of state.AccountExtraData trie.
+ AccountExtraDataLinker() rawdb.AccountExtraDataLinker
}
// Trie is a Ethereum Merkle Patricia trie.
@@ -113,23 +116,24 @@ func NewDatabase(db ethdb.Database) Database {
func NewDatabaseWithCache(db ethdb.Database, cache int) Database {
csc, _ := lru.New(codeSizeCacheSize)
return &cachingDB{
- db: trie.NewDatabaseWithCache(db, cache),
- // Quorum - Privacy Enhancements
- privacyMetadataLinker: rawdb.NewPrivacyMetadataLinker(db),
- codeSizeCache: csc,
+ db: trie.NewDatabaseWithCache(db, cache),
+ accountExtraDataLinker: rawdb.NewAccountExtraDataLinker(db),
+ codeSizeCache: csc,
}
}
type cachingDB struct {
db *trie.Database
- // Quorum: Privacy enhacements introducing privacyMetadataLinker which maintains mapping between private state root and privacy metadata root. As this struct is the backing store for state, this gives the reference to the linker when needed.
- privacyMetadataLinker rawdb.PrivacyMetadataLinker
- codeSizeCache *lru.Cache
+ // Quorum
+ //
+ // accountExtraDataLinker maintains mapping between state root and state.AccountExtraData root.
+ // As this struct is the backing store for state, this gives the reference to the linker when needed.
+ accountExtraDataLinker rawdb.AccountExtraDataLinker
+ codeSizeCache *lru.Cache
}
-// Quorum - Privacy Enhancements
-func (db *cachingDB) PrivacyMetadataLinker() rawdb.PrivacyMetadataLinker {
- return db.privacyMetadataLinker
+func (db *cachingDB) AccountExtraDataLinker() rawdb.AccountExtraDataLinker {
+ return db.accountExtraDataLinker
}
// OpenTrie opens the main account trie at a specific root hash.
diff --git a/core/state/journal.go b/core/state/journal.go
index 116c741155..5831cbbc64 100644
--- a/core/state/journal.go
+++ b/core/state/journal.go
@@ -115,10 +115,10 @@ type (
account *common.Address
prevcode, prevhash []byte
}
- // Quorum - Privacy Enhancements - changes to privacy metadata
- privacyMetadataChange struct {
+ // Quorum - changes to AccountExtraData
+ accountExtraDataChange struct {
account *common.Address
- prev *PrivacyMetadata
+ prev *AccountExtraData
}
// Changes to other state values.
refundChange struct {
@@ -199,12 +199,12 @@ func (ch codeChange) dirtied() *common.Address {
return ch.account
}
-// Quorum - Privacy Enhancements
-func (ch privacyMetadataChange) revert(s *StateDB) {
- s.getStateObject(*ch.account).setStatePrivacyMetadata(ch.prev)
+// Quorum
+func (ch accountExtraDataChange) revert(s *StateDB) {
+ s.getStateObject(*ch.account).setAccountExtraData(ch.prev)
}
-func (ch privacyMetadataChange) dirtied() *common.Address {
+func (ch accountExtraDataChange) dirtied() *common.Address {
return ch.account
}
diff --git a/core/state/state_object.go b/core/state/state_object.go
index 4dc4907dd2..c7732d2ed1 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -18,15 +18,16 @@ package state
import (
"bytes"
+ "errors"
"fmt"
"io"
"math/big"
+ "sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics"
- "github.com/ethereum/go-ethereum/private/engine"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -80,8 +81,12 @@ type stateObject struct {
trie Trie // storage trie, which becomes non-nil on first access
code Code // contract bytecode, which gets set when code is loaded
- // Quorum - Privacy Enhancements
- privacyMetadata *PrivacyMetadata
+ // Quorum
+ // contains extra data that is linked to the account
+ accountExtraData *AccountExtraData
+ // as there are many fields in accountExtraData which might be concurrently changed
+ // this is to make sure we can keep track of changes individually.
+ accountExtraDataMutex sync.Mutex
originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction
pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block
@@ -95,8 +100,9 @@ type stateObject struct {
suicided bool
touched bool
deleted bool
- // Quroum - Privacy Enhancements
- dirtyPrivacyMetadata bool
+ // Quorum
+ // flag to track changes in AccountExtraData
+ dirtyAccountExtraData bool
}
// empty returns whether the account is considered empty.
@@ -113,12 +119,6 @@ type Account struct {
CodeHash []byte
}
-//attached to every private contract account
-type PrivacyMetadata struct {
- CreationTxHash common.EncryptedPayloadHash `json:"creationTxHash"`
- PrivacyFlag engine.PrivacyFlagType `json:"privacyFlag"`
-}
-
// newObject creates a state object.
func newObject(db *StateDB, address common.Address, data Account) *stateObject {
if data.Balance == nil {
@@ -141,13 +141,6 @@ func newObject(db *StateDB, address common.Address, data Account) *stateObject {
}
}
-func NewStatePrivacyMetadata(creationTxHash common.EncryptedPayloadHash, privacyFlag engine.PrivacyFlagType) *PrivacyMetadata {
- return &PrivacyMetadata{
- CreationTxHash: creationTxHash,
- PrivacyFlag: privacyFlag,
- }
-}
-
// EncodeRLP implements rlp.Encoder.
func (s *stateObject) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, s.data)
@@ -408,9 +401,9 @@ func (s *stateObject) deepCopy(db *StateDB) *stateObject {
stateObject.suicided = s.suicided
stateObject.dirtyCode = s.dirtyCode
stateObject.deleted = s.deleted
- // Quorum - Privacy Enhancements - copy privacy metadata fields
- stateObject.privacyMetadata = s.privacyMetadata
- stateObject.dirtyPrivacyMetadata = s.dirtyPrivacyMetadata
+ // Quorum - copy AccountExtraData
+ stateObject.accountExtraData = s.accountExtraData
+ stateObject.dirtyAccountExtraData = s.dirtyAccountExtraData
return stateObject
}
@@ -468,22 +461,59 @@ func (s *stateObject) setNonce(nonce uint64) {
s.data.Nonce = nonce
}
-// Quorum - Privacy Enhancements
-func (s *stateObject) SetStatePrivacyMetadata(metadata *PrivacyMetadata) {
- prevPM, _ := s.PrivacyMetadata()
- s.db.journal.append(privacyMetadataChange{
+// Quorum
+// SetAccountExtraData modifies the AccountExtraData reference and journals it
+func (s *stateObject) SetAccountExtraData(extraData *AccountExtraData) {
+ current, _ := s.AccountExtraData()
+ s.db.journal.append(accountExtraDataChange{
account: &s.address,
- prev: prevPM,
+ prev: current,
})
- s.setStatePrivacyMetadata(metadata)
+ s.setAccountExtraData(extraData)
}
-func (s *stateObject) setStatePrivacyMetadata(metadata *PrivacyMetadata) {
- s.privacyMetadata = metadata
- s.dirtyPrivacyMetadata = true
+// A new AccountExtraData will be created if not exists.
+// This must be called after successfully acquiring accountExtraDataMutex lock
+func (s *stateObject) journalAccountExtraData() *AccountExtraData {
+ current, _ := s.AccountExtraData()
+ s.db.journal.append(accountExtraDataChange{
+ account: &s.address,
+ prev: current.copy(),
+ })
+ if current == nil {
+ current = &AccountExtraData{}
+ }
+ return current
}
-// End Quorum - Privacy Enhancements
+// Quorum
+// SetStatePrivacyMetadata updates the PrivacyMetadata in AccountExtraData and journals it.
+func (s *stateObject) SetStatePrivacyMetadata(pm *PrivacyMetadata) {
+ s.accountExtraDataMutex.Lock()
+ defer s.accountExtraDataMutex.Unlock()
+
+ newExtraData := s.journalAccountExtraData()
+ newExtraData.PrivacyMetadata = pm
+ s.setAccountExtraData(newExtraData)
+}
+
+// Quorum
+// SetStatePrivacyMetadata updates the PrivacyMetadata in AccountExtraData and journals it.
+func (s *stateObject) SetManagedParties(managedParties []string) {
+ s.accountExtraDataMutex.Lock()
+ defer s.accountExtraDataMutex.Unlock()
+
+ newExtraData := s.journalAccountExtraData()
+ newExtraData.ManagedParties = managedParties
+ s.setAccountExtraData(newExtraData)
+}
+
+// Quorum
+// setAccountExtraData modifies the AccountExtraData reference in this state object
+func (s *stateObject) setAccountExtraData(extraData *AccountExtraData) {
+ s.accountExtraData = extraData
+ s.dirtyAccountExtraData = true
+}
func (s *stateObject) CodeHash() []byte {
return s.data.CodeHash
@@ -497,49 +527,86 @@ func (s *stateObject) Nonce() uint64 {
return s.data.Nonce
}
+// Quorum
+// AccountExtraData returns the extra data in this state object.
+// It will also update the reference by searching the accountExtraDataTrie.
+//
+// This method enforces on returning error and never returns (nil, nil).
+func (s *stateObject) AccountExtraData() (*AccountExtraData, error) {
+ if s.accountExtraData != nil {
+ return s.accountExtraData, nil
+ }
+ val, err := s.getCommittedAccountExtraData()
+ if err != nil {
+ return nil, err
+ }
+ s.accountExtraData = val
+ return val, nil
+}
+
+// Quorum
+// getCommittedAccountExtraData looks for an entry in accountExtraDataTrie.
+//
+// This method enforces on returning error and never returns (nil, nil).
+func (s *stateObject) getCommittedAccountExtraData() (*AccountExtraData, error) {
+ val, err := s.db.accountExtraDataTrie.TryGet(s.address.Bytes())
+ if err != nil {
+ return nil, fmt.Errorf("unable to retrieve data from the accountExtraDataTrie. Cause: %v", err)
+ }
+ if len(val) == 0 {
+ return nil, fmt.Errorf("%s: %w", s.address.Hex(), common.ErrNoAccountExtraData)
+ }
+ var extraData AccountExtraData
+ if err := rlp.DecodeBytes(val, &extraData); err != nil {
+ return nil, fmt.Errorf("unable to decode to AccountExtraData. Cause: %v", err)
+ }
+ return &extraData, nil
+}
+
// Quorum - Privacy Enhancements
+// PrivacyMetadata returns the reference to PrivacyMetadata.
+// It will returrn an error if no PrivacyMetadata is in the AccountExtraData.
func (s *stateObject) PrivacyMetadata() (*PrivacyMetadata, error) {
- if s.privacyMetadata != nil {
- return s.privacyMetadata, nil
+ extraData, err := s.AccountExtraData()
+ if err != nil {
+ return nil, err
}
- val, err := s.GetCommittedPrivacyMetadata()
- if val != nil {
- s.privacyMetadata = val
+ // extraData can't be nil. Refer to s.AccountExtraData()
+ if extraData.PrivacyMetadata == nil {
+ return nil, fmt.Errorf("no privacy metadata data for contract %s", s.address.Hex())
}
- return val, err
+ return extraData.PrivacyMetadata, nil
}
func (s *stateObject) GetCommittedPrivacyMetadata() (*PrivacyMetadata, error) {
- val, err := s.db.privacyMetaDataTrie.TryGet(s.address.Bytes())
+ extraData, err := s.getCommittedAccountExtraData()
if err != nil {
- return nil, fmt.Errorf("unable to retrieve metadata from the privacyMetadataTrie. Cause: %v", err)
+ return nil, err
}
- if len(val) == 0 {
+ if extraData == nil || extraData.PrivacyMetadata == nil {
return nil, fmt.Errorf("The provided contract does not have privacy metadata: %x", s.address)
}
- return bytesToPrivacyMetadata(val)
+ return extraData.PrivacyMetadata, nil
}
// End Quorum - Privacy Enhancements
+// ManagedParties will return empty if no account extra data found
+func (s *stateObject) ManagedParties() ([]string, error) {
+ extraData, err := s.AccountExtraData()
+ if errors.Is(err, common.ErrNoAccountExtraData) {
+ return []string{}, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ // extraData can't be nil. Refer to s.AccountExtraData()
+ return extraData.ManagedParties, nil
+}
+
// Never called, but must be present to allow stateObject to be used
// as a vm.Account interface that also satisfies the vm.ContractRef
// interface. Interfaces are awesome.
func (s *stateObject) Value() *big.Int {
panic("Value on stateObject should never be called")
}
-
-// Quorum - Privacy Enhancements
-func privacyMetadataToBytes(pm *PrivacyMetadata) ([]byte, error) {
- return rlp.EncodeToBytes(pm)
-}
-
-func bytesToPrivacyMetadata(b []byte) (*PrivacyMetadata, error) {
- var data *PrivacyMetadata
- if err := rlp.DecodeBytes(b, &data); err != nil {
- return nil, fmt.Errorf("unable to decode privacy metadata. Cause: %v", err)
- }
- return data, nil
-}
-
-// End Quorum - Privacy Enhancements
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 65f69dde81..b8fb985f2a 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -65,8 +65,8 @@ func (n *proofList) Delete(key []byte) error {
type StateDB struct {
db Database
trie Trie
- // Quorum - Privacy Enhancements - new trie to hold extra account information that cannot be stored in the accounts trie
- privacyMetaDataTrie Trie
+ // Quorum - a trie to hold extra account information that cannot be stored in the accounts trie
+ accountExtraDataTrie Trie
// This map holds 'live' objects, which will get modified while processing a state transition.
stateObjects map[common.Address]*stateObject
stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie
@@ -114,9 +114,9 @@ func New(root common.Hash, db Database) (*StateDB, error) {
}
// Quorum - Privacy Enhancements - retrieve the privacy metadata root corresponding to the account state root
- privacyMetadataRoot := db.PrivacyMetadataLinker().PrivacyMetadataRootForPrivateStateRoot(root)
- log.Debug("Privacy metadata root", "hash", privacyMetadataRoot)
- privacyMetaDataTrie, err := db.OpenTrie(privacyMetadataRoot)
+ extraDataRoot := db.AccountExtraDataLinker().GetAccountExtraDataRoot(root)
+ log.Debug("Account Extra Data root", "hash", extraDataRoot)
+ accountExtraDataTrie, err := db.OpenTrie(extraDataRoot)
if err != nil {
return nil, fmt.Errorf("Unable to open privacy metadata trie: %v", err)
}
@@ -126,13 +126,13 @@ func New(root common.Hash, db Database) (*StateDB, error) {
db: db,
trie: tr,
// Quorum - Privacy Enhancements
- privacyMetaDataTrie: privacyMetaDataTrie,
- stateObjects: make(map[common.Address]*stateObject),
- stateObjectsPending: make(map[common.Address]struct{}),
- stateObjectsDirty: make(map[common.Address]struct{}),
- logs: make(map[common.Hash][]*types.Log),
- preimages: make(map[common.Hash][]byte),
- journal: newJournal(),
+ accountExtraDataTrie: accountExtraDataTrie,
+ stateObjects: make(map[common.Address]*stateObject),
+ stateObjectsPending: make(map[common.Address]struct{}),
+ stateObjectsDirty: make(map[common.Address]struct{}),
+ logs: make(map[common.Hash][]*types.Log),
+ preimages: make(map[common.Hash][]byte),
+ journal: newJournal(),
}, nil
}
@@ -253,7 +253,7 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 {
return 0
}
-func (self *StateDB) GetStatePrivacyMetadata(addr common.Address) (*PrivacyMetadata, error) {
+func (self *StateDB) GetPrivacyMetadata(addr common.Address) (*PrivacyMetadata, error) {
stateObject := self.getStateObject(addr)
if stateObject != nil {
return stateObject.PrivacyMetadata()
@@ -269,6 +269,14 @@ func (self *StateDB) GetCommittedStatePrivacyMetadata(addr common.Address) (*Pri
return nil, nil
}
+func (self *StateDB) GetManagedParties(addr common.Address) ([]string, error) {
+ stateObject := self.getStateObject(addr)
+ if stateObject != nil {
+ return stateObject.ManagedParties()
+ }
+ return nil, nil
+}
+
func (self *StateDB) GetRLPEncodedStateObject(addr common.Address) ([]byte, error) {
stateObject := self.getStateObject(addr)
if stateObject == nil {
@@ -429,13 +437,20 @@ func (self *StateDB) SetNonce(addr common.Address, nonce uint64) {
}
}
-func (self *StateDB) SetStatePrivacyMetadata(addr common.Address, metadata *PrivacyMetadata) {
+func (self *StateDB) SetPrivacyMetadata(addr common.Address, metadata *PrivacyMetadata) {
stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetStatePrivacyMetadata(metadata)
}
}
+func (self *StateDB) SetManagedParties(addr common.Address, managedParties []string) {
+ stateObject := self.GetOrNewStateObject(addr)
+ if stateObject != nil && len(managedParties) > 0 {
+ stateObject.SetManagedParties(managedParties)
+ }
+}
+
func (self *StateDB) SetCode(addr common.Address, code []byte) {
stateObject := self.GetOrNewStateObject(addr)
if stateObject != nil {
@@ -485,6 +500,8 @@ func (self *StateDB) Suicide(addr common.Address) bool {
//
// updateStateObject writes the given object to the trie.
+// Quorum:
+// - update AccountExtraData trie
func (s *StateDB) updateStateObject(obj *stateObject) {
// Track the amount of time wasted on updating the account from the trie
if metrics.EnabledExpensive {
@@ -504,21 +521,22 @@ func (s *StateDB) updateStateObject(obj *stateObject) {
return
}
- if obj.dirtyPrivacyMetadata && obj.privacyMetadata != nil {
- privacyMetadataBytes, err := privacyMetadataToBytes(obj.privacyMetadata)
+ if obj.dirtyAccountExtraData && obj.accountExtraData != nil {
+ extraDataBytes, err := rlp.EncodeToBytes(obj.accountExtraData)
if err != nil {
panic(fmt.Errorf("can't encode privacy metadata at %x: %v", addr[:], err))
}
- err = s.privacyMetaDataTrie.TryUpdate(addr[:], privacyMetadataBytes)
+ err = s.accountExtraDataTrie.TryUpdate(addr[:], extraDataBytes)
if err != nil {
s.setError(err)
return
}
}
- // End Quorum - Privacy Enhancements
}
// deleteStateObject removes the given object from the state trie.
+// Quorum:
+// - delete the data from the extra data trie corresponding to the account address
func (s *StateDB) deleteStateObject(obj *stateObject) {
// Track the amount of time wasted on deleting the account from the trie
if metrics.EnabledExpensive {
@@ -527,13 +545,11 @@ func (s *StateDB) deleteStateObject(obj *stateObject) {
// Delete the account from the trie
addr := obj.Address()
err := s.trie.TryDelete(addr[:])
- // Quorum - Privacy Enhancements - delete the data from the privacy metadata trie corresponding to the account address
if err != nil {
s.setError(err)
return
}
- s.setError(s.privacyMetaDataTrie.TryDelete(addr[:]))
- // End Quorum - Privacy Enhancements
+ s.setError(s.accountExtraDataTrie.TryDelete(addr[:]))
}
// getStateObject retrieves a state object given by the address, returning nil if
@@ -659,15 +675,15 @@ func (self *StateDB) Copy() *StateDB {
db: self.db,
trie: self.db.CopyTrie(self.trie),
// Quorum - Privacy Enhancements
- privacyMetaDataTrie: self.db.CopyTrie(self.privacyMetaDataTrie),
- stateObjects: make(map[common.Address]*stateObject, len(self.journal.dirties)),
- stateObjectsPending: make(map[common.Address]struct{}, len(self.stateObjectsPending)),
- stateObjectsDirty: make(map[common.Address]struct{}, len(self.journal.dirties)),
- refund: self.refund,
- logs: make(map[common.Hash][]*types.Log, len(self.logs)),
- logSize: self.logSize,
- preimages: make(map[common.Hash][]byte, len(self.preimages)),
- journal: newJournal(),
+ accountExtraDataTrie: self.db.CopyTrie(self.accountExtraDataTrie),
+ stateObjects: make(map[common.Address]*stateObject, len(self.journal.dirties)),
+ stateObjectsPending: make(map[common.Address]struct{}, len(self.stateObjectsPending)),
+ stateObjectsDirty: make(map[common.Address]struct{}, len(self.journal.dirties)),
+ refund: self.refund,
+ logs: make(map[common.Hash][]*types.Log, len(self.logs)),
+ logSize: self.logSize,
+ preimages: make(map[common.Hash][]byte, len(self.preimages)),
+ journal: newJournal(),
}
// Copy the dirty states, logs, and preimages
for addr := range self.journal.dirties {
@@ -813,6 +829,8 @@ func (s *StateDB) clearJournalAndRefund() {
}
// Commit writes the state to the underlying in-memory trie database.
+// Quorum:
+// - linking state root and the AccountExtraData root
func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) {
// Finalize any pending changes and merge everything into the tries
s.IntermediateRoot(deleteEmptyObjects)
@@ -854,23 +872,23 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) {
return nil
})
- // Quorum - Privacy Enhancements
+ // Quorum
+ // linking the state root and the AccountExtraData root
if err == nil {
- // commit the privacy metadata trie
- privacyMetadataTrieRoot, err := s.privacyMetaDataTrie.Commit(nil)
+ // commit the AccountExtraData trie
+ extraDataRoot, err := s.accountExtraDataTrie.Commit(nil)
if err != nil {
- return common.Hash{}, fmt.Errorf("Unable to commit the privacy metadata trie: %v", err)
+ return common.Hash{}, fmt.Errorf("unable to commit the AccountExtraData trie: %v", err)
}
- log.Debug("Privacy metadata root after metadata trie commit", "root", privacyMetadataTrieRoot)
- // link the new state root to the privacy metadata root
- err = s.db.PrivacyMetadataLinker().LinkPrivacyMetadataRootToPrivateStateRoot(root, privacyMetadataTrieRoot)
+ log.Debug("AccountExtraData root after trie commit", "root", extraDataRoot)
+ // link the new state root to the AccountExtraData root
+ err = s.db.AccountExtraDataLinker().Link(root, extraDataRoot)
if err != nil {
return common.Hash{}, fmt.Errorf("Unable to link the state root to the privacy metadata root: %v", err)
}
- // add a reference from the privacy metadata root to the state root so that when the state root is written
- // to the DB the the privacy metadata root is also written
- s.db.TrieDB().Reference(privacyMetadataTrieRoot, root)
+ // add a reference from the AccountExtraData root to the state root so that when the state root is written
+ // to the DB the the AccountExtraData root is also written
+ s.db.TrieDB().Reference(extraDataRoot, root)
}
- // End Quorum - Privacy Enhancements
return root, err
}
diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go
index 7f968e3b9e..cac1d90610 100644
--- a/core/state/statedb_test.go
+++ b/core/state/statedb_test.go
@@ -326,7 +326,7 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
args: make([]int64, 2),
},
{
- name: "SetStatePrivacyMetadata",
+ name: "SetPrivacyMetadata",
fn: func(a testAction, s *StateDB) {
privFlag := engine.PrivacyFlagType((uint64(a.args[0])%2)*2 + 1) // the only possible values should be 1 and 3
@@ -334,7 +334,7 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
binary.BigEndian.PutUint64(b, uint64(a.args[1]))
hash := common.BytesToEncryptedPayloadHash(b)
- s.SetStatePrivacyMetadata(addr, &PrivacyMetadata{
+ s.SetPrivacyMetadata(addr, &PrivacyMetadata{
CreationTxHash: hash,
PrivacyFlag: privFlag,
})
@@ -482,9 +482,9 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
- statePM, _ := state.GetStatePrivacyMetadata(addr)
- checkStatePM, _ := checkstate.GetStatePrivacyMetadata(addr)
- checkeq("GetStatePrivacyMetadata", statePM, checkStatePM)
+ statePM, _ := state.GetPrivacyMetadata(addr)
+ checkStatePM, _ := checkstate.GetPrivacyMetadata(addr)
+ checkeq("GetPrivacyMetadata", statePM, checkStatePM)
// Check storage.
if obj := state.getStateObject(addr); obj != nil {
state.ForEachStorage(addr, func(key, value common.Hash) bool {
@@ -744,7 +744,7 @@ func TestPrivacyMetadataIsSavedOnStateDbCommit(t *testing.T) {
state.CreateAccount(addr)
state.SetNonce(addr, uint64(1))
- state.SetStatePrivacyMetadata(addr, &PrivacyMetadata{
+ state.SetPrivacyMetadata(addr, &PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagPartyProtection,
CreationTxHash: common.EncryptedPayloadHash{1},
})
@@ -771,7 +771,7 @@ func TestPrivacyMetadataIsUpdatedOnAccountReCreateWithDifferentPrivacyMetadata(t
state.CreateAccount(addr)
state.SetNonce(addr, uint64(1))
- state.SetStatePrivacyMetadata(addr, &PrivacyMetadata{
+ state.SetPrivacyMetadata(addr, &PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagPartyProtection,
CreationTxHash: common.EncryptedPayloadHash{1},
})
@@ -784,7 +784,7 @@ func TestPrivacyMetadataIsUpdatedOnAccountReCreateWithDifferentPrivacyMetadata(t
state.CreateAccount(addr)
state.SetNonce(addr, uint64(1))
- state.SetStatePrivacyMetadata(addr, &PrivacyMetadata{
+ state.SetPrivacyMetadata(addr, &PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagStateValidation,
CreationTxHash: common.EncryptedPayloadHash{1},
})
@@ -808,7 +808,7 @@ func TestPrivacyMetadataIsRemovedOnAccountSuicide(t *testing.T) {
state.CreateAccount(addr)
state.SetNonce(addr, uint64(1))
- state.SetStatePrivacyMetadata(addr, &PrivacyMetadata{
+ state.SetPrivacyMetadata(addr, &PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagPartyProtection,
CreationTxHash: common.EncryptedPayloadHash{1},
})
@@ -837,7 +837,7 @@ func TestPrivacyMetadataChangesAreRolledBackOnRevert(t *testing.T) {
state.CreateAccount(addr)
state.SetNonce(addr, uint64(1))
- state.SetStatePrivacyMetadata(addr, &PrivacyMetadata{
+ state.SetPrivacyMetadata(addr, &PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagPartyProtection,
CreationTxHash: common.BytesToEncryptedPayloadHash([]byte("one")),
})
@@ -849,7 +849,7 @@ func TestPrivacyMetadataChangesAreRolledBackOnRevert(t *testing.T) {
}
// update privacy metadata
- state.SetStatePrivacyMetadata(addr, &PrivacyMetadata{
+ state.SetPrivacyMetadata(addr, &PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagStateValidation,
CreationTxHash: common.BytesToEncryptedPayloadHash([]byte("two")),
})
@@ -857,18 +857,18 @@ func TestPrivacyMetadataChangesAreRolledBackOnRevert(t *testing.T) {
// record the snapshot
snapshot := state.Snapshot()
- privMetaData, _ = state.GetStatePrivacyMetadata(addr)
+ privMetaData, _ = state.GetPrivacyMetadata(addr)
if privMetaData.CreationTxHash != common.BytesToEncryptedPayloadHash([]byte("two")) {
t.Errorf("current privacy metadata creation tx hash does not match the expected value")
}
// update the metadata
- state.SetStatePrivacyMetadata(addr, &PrivacyMetadata{
+ state.SetPrivacyMetadata(addr, &PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagStateValidation,
CreationTxHash: common.BytesToEncryptedPayloadHash([]byte("three")),
})
- privMetaData, _ = state.GetStatePrivacyMetadata(addr)
+ privMetaData, _ = state.GetPrivacyMetadata(addr)
if privMetaData.CreationTxHash != common.BytesToEncryptedPayloadHash([]byte("three")) {
t.Errorf("current privacy metadata creation tx hash does not match the expected value")
}
@@ -876,7 +876,7 @@ func TestPrivacyMetadataChangesAreRolledBackOnRevert(t *testing.T) {
// revert to snapshot
state.RevertToSnapshot(snapshot)
- privMetaData, _ = state.GetStatePrivacyMetadata(addr)
+ privMetaData, _ = state.GetPrivacyMetadata(addr)
if privMetaData.CreationTxHash != common.BytesToEncryptedPayloadHash([]byte("two")) {
t.Errorf("current privacy metadata creation tx hash does not match the expected value")
}
diff --git a/core/state_transition.go b/core/state_transition.go
index 662c3e2ba7..5248ad6e6d 100644
--- a/core/state_transition.go
+++ b/core/state_transition.go
@@ -20,12 +20,14 @@ import (
"errors"
"math"
"math/big"
+ "strings"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/private"
)
@@ -200,6 +202,9 @@ func (st *StateTransition) preCheck() error {
// and NOT the actual private payload
// 2. For private transactions, we only deduct intrinsic gas from the gas pool
// regardless the current node is party to the transaction or not
+// 3. With multitenancy support, we enforce the party set in the contract index must contain all
+// parties from the transaction. This is to detect unauthorized access from a legit proxy contract
+// to an unauthorized contract.
func (st *StateTransition) TransitionDb() (ret []byte, usedGas uint64, failed bool, err error) {
if err = st.preCheck(); err != nil {
return
@@ -210,16 +215,18 @@ func (st *StateTransition) TransitionDb() (ret []byte, usedGas uint64, failed bo
istanbul := st.evm.ChainConfig().IsIstanbul(st.evm.BlockNumber)
contractCreation := msg.To() == nil
isQuorum := st.evm.ChainConfig().IsQuorum
+ snapshot := st.evm.StateDB.Snapshot()
var data []byte
+ var managedPartiesInTx []string
isPrivate := false
publicState := st.state
pmh := newPMH(st)
if msg, ok := msg.(PrivateMessage); ok && isQuorum && msg.IsPrivate() {
isPrivate = true
- pmh.snapshot = st.evm.StateDB.Snapshot()
+ pmh.snapshot = snapshot
pmh.eph = common.BytesToEncryptedPayloadHash(st.data)
- data, pmh.receivedPrivacyMetadata, err = private.P.Receive(pmh.eph)
+ _, managedPartiesInTx, data, pmh.receivedPrivacyMetadata, err = private.P.Receive(pmh.eph)
// Increment the public account nonce if:
// 1. Tx is private and *not* a participant of the group and either call or create
// 2. Tx is private we are part of the group and is a call
@@ -291,6 +298,9 @@ func (st *StateTransition) TransitionDb() (ret []byte, usedGas uint64, failed bo
if vmerr == vm.ErrInsufficientBalance {
return nil, 0, false, vmerr
}
+ if errors.Is(vmerr, multitenancy.ErrNotAuthorized) {
+ return nil, 0, false, vmerr
+ }
}
// Quorum - Privacy Enhancements
@@ -304,6 +314,28 @@ func (st *StateTransition) TransitionDb() (ret []byte, usedGas uint64, failed bo
}
// End Quorum - Privacy Enhancements
+ // do the affected contract managed party checks
+ if msg, ok := msg.(PrivateMessage); ok && isQuorum && st.evm.SupportsMultitenancy && msg.IsPrivate() {
+ if len(managedPartiesInTx) > 0 {
+ for _, address := range evm.AffectedContracts() {
+ managedPartiesInContract, err := st.evm.StateDB.GetManagedParties(address)
+ if err != nil {
+ return nil, 0, true, err
+ }
+ // managed parties for public transactions is empty so nothing to check there
+ if len(managedPartiesInContract) > 0 {
+ if common.NotContainsAll(managedPartiesInContract, managedPartiesInTx) {
+ log.Debug("Managed parties check has failed for contract", "addr", address, "EPH",
+ pmh.eph.TerminalString(), "contractMP", managedPartiesInContract, "txMP", managedPartiesInTx)
+ st.evm.RevertToSnapshot(snapshot)
+ // TODO - see whether we can find a way to store this error and make it available via customizations to getTransactionReceipt
+ return nil, 0, true, nil
+ }
+ }
+ }
+ }
+ }
+
// Pay gas used during contract creation or execution (st.gas tracks remaining gas)
// However, if private contract then we don't want to do this else we can get
// a mismatch between a (non-participant) minter and (participant) validator,
@@ -315,6 +347,19 @@ func (st *StateTransition) TransitionDb() (ret []byte, usedGas uint64, failed bo
st.refundGas()
st.state.AddBalance(st.evm.Coinbase, new(big.Int).Mul(new(big.Int).SetUint64(st.gasUsed()), st.gasPrice))
+ // for all contracts being created as the result of the transaction execution
+ // we build the index for them if multitenancy is enabled
+ if st.evm.SupportsMultitenancy {
+ addresses := evm.CreatedContracts()
+ for _, address := range addresses {
+ log.Debug("Save to extra data",
+ "address", strings.ToLower(address.Hex()),
+ "isPrivate", isPrivate,
+ "parties", managedPartiesInTx)
+ st.evm.StateDB.SetManagedParties(address, managedPartiesInTx)
+ }
+ }
+
if isPrivate {
return ret, 0, vmerr != nil, err
}
@@ -354,7 +399,7 @@ func (st *StateTransition) RevertToSnapshot(snapshot int) {
st.evm.StateDB.RevertToSnapshot(snapshot)
}
func (st *StateTransition) GetStatePrivacyMetadata(addr common.Address) (*state.PrivacyMetadata, error) {
- return st.evm.StateDB.GetStatePrivacyMetadata(addr)
+ return st.evm.StateDB.GetPrivacyMetadata(addr)
}
func (st *StateTransition) CalculateMerkleRoot() (common.Hash, error) {
return st.evm.CalculateMerkleRoot()
diff --git a/core/state_transition_pmh.go b/core/state_transition_pmh.go
index ec16dba007..f97d4d7a3e 100644
--- a/core/state_transition_pmh.go
+++ b/core/state_transition_pmh.go
@@ -86,7 +86,7 @@ func (pmh *privateMessageHandler) verify(vmerr error) (bool, error) {
log.Trace("Verify hashes of affected contracts", "expectedHashes", pmh.receivedPrivacyMetadata.ACHashes, "numberOfAffectedAddresses", len(actualACAddresses))
privacyFlag := pmh.receivedPrivacyMetadata.PrivacyFlag
for _, addr := range actualACAddresses {
- // GetStatePrivacyMetadata is invoked on the privateState (as the tx is private) and it returns:
+ // GetPrivacyMetadata is invoked on the privateState (as the tx is private) and it returns:
// 1. public contacts: privacyMetadata = nil, err = nil
// 2. private contracts of type:
// 2.1. StandardPrivate: privacyMetadata = nil, err = "The provided contract does not have privacy metadata"
diff --git a/core/state_transition_test.go b/core/state_transition_test.go
index a5ebcc5561..053f4dad60 100644
--- a/core/state_transition_test.go
+++ b/core/state_transition_test.go
@@ -8,25 +8,21 @@ import (
"testing"
"time"
- "github.com/ethereum/go-ethereum/core/types"
- "github.com/ethereum/go-ethereum/log"
- "github.com/ethereum/go-ethereum/rlp"
- "github.com/ethereum/go-ethereum/trie"
-
- "github.com/ethereum/go-ethereum/private"
- "github.com/ethereum/go-ethereum/private/engine"
- "github.com/ethereum/go-ethereum/private/engine/notinuse"
-
"github.com/ethereum/go-ethereum/accounts/abi"
-
- "github.com/ethereum/go-ethereum/common/math"
-
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
+ "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/private"
+ "github.com/ethereum/go-ethereum/private/engine"
+ "github.com/ethereum/go-ethereum/private/engine/notinuse"
+ "github.com/ethereum/go-ethereum/rlp"
+ "github.com/ethereum/go-ethereum/trie"
testifyassert "github.com/stretchr/testify/assert"
)
@@ -1097,7 +1093,7 @@ func (mpm *mockPrivateTransactionManager) HasFeature(f engine.PrivateTransaction
return true
}
-func (mpm *mockPrivateTransactionManager) Receive(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
+func (mpm *mockPrivateTransactionManager) Receive(data common.EncryptedPayloadHash) (string, []string, []byte, *engine.ExtraMetadata, error) {
mpm.count["Receive"]++
values := mpm.returns["Receive"]
var (
@@ -1114,7 +1110,7 @@ func (mpm *mockPrivateTransactionManager) Receive(data common.EncryptedPayloadHa
if values[2] != nil {
r3 = values[2].(error)
}
- return r1, r2, r3
+ return "", nil, r1, r2, r3
}
func (mpm *mockPrivateTransactionManager) When(name string) *mockPrivateTransactionManager {
@@ -1344,21 +1340,22 @@ type StubPrivateTransactionManager struct {
responses map[string][]interface{}
}
-func (spm *StubPrivateTransactionManager) Receive(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
+func (spm *StubPrivateTransactionManager) Receive(data common.EncryptedPayloadHash) (string, []string, []byte, *engine.ExtraMetadata, error) {
res := spm.responses["Receive"]
if err, ok := res[1].(error); ok {
- return nil, nil, err
+ return "", nil, nil, nil, err
}
if ret, ok := res[0].([]byte); ok {
- return ret, &engine.ExtraMetadata{
+ return "", nil, ret, &engine.ExtraMetadata{
PrivacyFlag: engine.PrivacyFlagStandardPrivate,
}, nil
}
- return nil, nil, nil
+ return "", nil, nil, nil, nil
}
-func (spm *StubPrivateTransactionManager) ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
- return spm.Receive(data)
+func (spm *StubPrivateTransactionManager) ReceiveRaw(hash common.EncryptedPayloadHash) ([]byte, string, *engine.ExtraMetadata, error) {
+ _, sender, data, metadata, err := spm.Receive(hash)
+ return data, sender[0], metadata, err
}
func (spm *StubPrivateTransactionManager) HasFeature(f engine.PrivateTransactionManagerFeature) bool {
diff --git a/core/vm/evm.go b/core/vm/evm.go
index 3632d3793d..4c225e6bc8 100644
--- a/core/vm/evm.go
+++ b/core/vm/evm.go
@@ -17,6 +17,7 @@
package vm
import (
+ "fmt"
"math/big"
"sync/atomic"
"time"
@@ -26,6 +27,7 @@ import (
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/trie"
)
@@ -61,9 +63,23 @@ func run(evm *EVM, contract *Contract, input []byte, readOnly bool) ([]byte, err
// Using CodeAddr is favour over contract.Address()
// During DelegateCall() CodeAddr is the address of the delegated account
address := *contract.CodeAddr
- if _, ok := evm.affectedContracts[address]; !ok {
- evm.affectedContracts[address] = MessageCall
+ // during simulation/eth_call, when contract code is empty, there's no execution hence the
+ // multitenancy check will not happen in captureOperationMode().
+ // This additional check to ensure we capture this case
+ if evm.SupportsMultitenancy && evm.AuthorizeMessageCallFunc != nil && len(contract.Code) == 0 {
+ return nil, multitenancy.ErrNotAuthorized
}
+ if err := evm.captureAffectedContract(address, ModeUnknown); err != nil {
+ return nil, err
+ }
+ // When delegatecall, need to capture the operation mode in the context of contract.Address()
+ // the affected contract is required read only mode
+ if address != contract.Address() {
+ evm.pushAddress(contract.Address())
+ } else {
+ evm.pushAddress(address)
+ }
+ defer evm.popAddress()
precompiles := PrecompiledContractsHomestead
if evm.chainRules.IsByzantium {
precompiles = PrecompiledContractsByzantium
@@ -112,6 +128,16 @@ type Context struct {
BlockNumber *big.Int // Provides information for NUMBER
Time *big.Int // Provides information for TIME
Difficulty *big.Int // Provides information for DIFFICULTY
+
+ // Quorum
+ // EVM should consider multitenancy
+ SupportsMultitenancy bool
+ // AuthorizeCreateFunc performs tenancy authorization check for contract creation.
+ // It's only injected during simulation
+ AuthorizeCreateFunc multitenancy.AuthorizeCreateFunc
+ // AuthorizeMessageCallFunc performs tenancy authorization check for message call to a contract.
+ // It's only injected during simulation/eth_call
+ AuthorizeMessageCallFunc multitenancy.AuthorizeMessageCallFunc
}
type PublicState StateDB
@@ -164,19 +190,87 @@ type EVM struct {
quorumReadOnly bool
readOnlyDepth uint
- // these are for privacy enhancements
- affectedContracts map[common.Address]AffectedType // affected contract account address -> type
- currentTx *types.Transaction // transaction currently being applied on this EVM
+ // these are for privacy enhancements and multitenancy
+ affectedContracts map[common.Address]*AffectedType // affected contract account address -> type
+ currentTx *types.Transaction // transaction currently being applied on this EVM
+ addressStack []common.Address // store contract addresses being executed
+ // store last error during EVM execution lifecycle.
+ // we use this to bubble up the error instead of "evm: execution revert" error.
+ // use it with care as it's meant for runtime multitenancy check during simulation.
+ lastError error
}
-type AffectedType byte
+// AffectedType defines attributes indicating how a contract is affected
+// as the result of the contract code execution in an EVM
+type AffectedType struct {
+ // reason captures how the contract is affected.
+ // Default to MessageCall and set to Creation if the contract under execution is newly created.
+ reason AffectedReason
+ // mode captures how the state is operated as the result of contract code execution.
+ // The value is cached as an expectation by performing trial of ModeRead and ModeWrite against access token for a contract before execution.
+ // Runtime multitenancy check uses this value to verify if an opcode execution violates the expectation.
+ // At the end of EVM lifecycle, this reflects the actual mode.
+ mode AffectedMode
+}
+
+func (t AffectedType) String() string {
+ return fmt.Sprintf("reason=%d,mode=%d", t.reason, t.mode)
+}
+
+// AffectedReason defines a type of operation that was applied to a contract.
+type AffectedReason byte
const (
- _ = iota
- Creation AffectedType = iota
+ _ AffectedReason = iota
+ Creation AffectedReason = iota
MessageCall
)
+// AffectedMode defines a mode in which the state is operated as the result of contract code execution.
+type AffectedMode byte
+
+const (
+ // ModeUnknown indicates an auxiliary mode used during initialization of an affected contract
+ ModeUnknown AffectedMode = iota
+ // ModeRead indicates that state has not been modified as the result of contract code execution
+ ModeRead AffectedMode = iota
+ // ModeWrite indicates that state has been modified as the result of contract code execution
+ ModeWrite = ModeRead << 1
+ // ModeUpdated indicates that the affected mode has been setup for multitenancy check.
+ // This is mainly used during simulation and eth_call
+ ModeUpdated = ModeRead << 7
+)
+
+func ModeOf(isWrite bool) AffectedMode {
+ if isWrite {
+ return ModeWrite
+ }
+ return ModeRead
+}
+
+func (mode AffectedMode) IsNotAuthorized(actualMode AffectedMode) bool {
+ return mode.Has(ModeUpdated) && !mode.Has(actualMode)
+}
+
+func (mode AffectedMode) Update(authorizedRead bool, authorizedWrite bool) AffectedMode {
+ newMode := mode | ModeUpdated
+ if authorizedRead {
+ newMode = newMode | ModeRead
+ }
+ if authorizedWrite {
+ newMode = newMode | ModeWrite
+ }
+ return newMode
+}
+
+func (mode AffectedMode) Has(modes ...AffectedMode) bool {
+ expectedMode := ModeUnknown
+ for _, m := range modes {
+ expectedMode = expectedMode | m
+ }
+ return mode&expectedMode == expectedMode
+}
+
// NewEVM returns a new EVM. The returned EVM is not thread safe and should
// only ever be used *once*.
func NewEVM(ctx Context, statedb, privateState StateDB, chainConfig *params.ChainConfig, vmConfig Config) *EVM {
@@ -191,7 +285,8 @@ func NewEVM(ctx Context, statedb, privateState StateDB, chainConfig *params.Chai
publicState: statedb,
privateState: privateState,
- affectedContracts: make(map[common.Address]AffectedType),
+ affectedContracts: make(map[common.Address]*AffectedType),
+ addressStack: make([]common.Address, 0),
}
if chainConfig.IsEWASM(ctx.BlockNumber) {
@@ -492,8 +587,13 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64,
}
// Create a new account on the state
snapshot := evm.StateDB.Snapshot()
+ if evm.SupportsMultitenancy && evm.AuthorizeCreateFunc != nil {
+ if authorized := evm.AuthorizeCreateFunc(); !authorized {
+ return nil, common.Address{}, gas, multitenancy.ErrNotAuthorized
+ }
+ }
evm.StateDB.CreateAccount(address)
- evm.affectedContracts[address] = Creation
+ evm.affectedContracts[address] = newAffectedType(Creation, ModeWrite|ModeRead)
if evm.chainRules.IsEIP158 {
evm.StateDB.SetNonce(address, 1)
}
@@ -501,7 +601,7 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64,
// for calls (reading contract state) or finding the affected contracts there is no transaction
if evm.currentTx.PrivacyMetadata().PrivacyFlag.IsNotStandardPrivate() {
pm := state.NewStatePrivacyMetadata(common.BytesToEncryptedPayloadHash(evm.currentTx.Data()), evm.currentTx.PrivacyMetadata().PrivacyFlag)
- evm.StateDB.SetStatePrivacyMetadata(address, pm)
+ evm.StateDB.SetPrivacyMetadata(address, pm)
log.Trace("Set Privacy Metadata", "key", address, "privacyMetadata", pm)
}
}
@@ -607,78 +707,216 @@ func (evm *EVM) Create2(caller ContractRef, code []byte, gas uint64, endowment *
func (evm *EVM) ChainConfig() *params.ChainConfig { return evm.chainConfig }
// Quorum functions for dual state
-func getDualState(env *EVM, addr common.Address) StateDB {
+func getDualState(evm *EVM, addr common.Address) StateDB {
// priv: (a) -> (b) (private)
// pub: a -> [b] (private -> public)
// priv: (a) -> b (public)
- state := env.StateDB
+ state := evm.StateDB
- if env.PrivateState().Exist(addr) {
- state = env.PrivateState()
- } else if env.PublicState().Exist(addr) {
- state = env.PublicState()
+ if evm.PrivateState().Exist(addr) {
+ state = evm.PrivateState()
+ evm.captureAffectedContract(addr, ModeUnknown)
+ } else if evm.PublicState().Exist(addr) {
+ state = evm.PublicState()
}
return state
}
-func (env *EVM) PublicState() PublicState { return env.publicState }
-func (env *EVM) PrivateState() PrivateState { return env.privateState }
-func (env *EVM) SetCurrentTX(tx *types.Transaction) { env.currentTx = tx }
-func (env *EVM) SetTxPrivacyMetadata(pm *types.PrivacyMetadata) {
- env.currentTx.SetTxPrivacyMetadata(pm)
+func (evm *EVM) PublicState() PublicState { return evm.publicState }
+func (evm *EVM) PrivateState() PrivateState { return evm.privateState }
+func (evm *EVM) SetCurrentTX(tx *types.Transaction) { evm.currentTx = tx }
+func (evm *EVM) SetTxPrivacyMetadata(pm *types.PrivacyMetadata) {
+ evm.currentTx.SetTxPrivacyMetadata(pm)
}
-func (env *EVM) Push(statedb StateDB) {
+func (evm *EVM) Push(statedb StateDB) {
// Quorum : the read only depth to be set up only once for the entire
// op code execution. This will be set first time transition from
// private state to public state happens
// statedb will be the state of the contract being called.
// if a private contract is calling a public contract make it readonly.
- if !env.quorumReadOnly && env.privateState != statedb {
- env.quorumReadOnly = true
- env.readOnlyDepth = env.currentStateDepth
+ if !evm.quorumReadOnly && evm.privateState != statedb {
+ evm.quorumReadOnly = true
+ evm.readOnlyDepth = evm.currentStateDepth
}
if castedStateDb, ok := statedb.(*state.StateDB); ok {
- env.states[env.currentStateDepth] = castedStateDb
- env.currentStateDepth++
+ evm.states[evm.currentStateDepth] = castedStateDb
+ evm.currentStateDepth++
}
- env.StateDB = statedb
+ evm.StateDB = statedb
}
-func (env *EVM) Pop() {
- env.currentStateDepth--
- if env.quorumReadOnly && env.currentStateDepth == env.readOnlyDepth {
- env.quorumReadOnly = false
+func (evm *EVM) Pop() {
+ evm.currentStateDepth--
+ if evm.quorumReadOnly && evm.currentStateDepth == evm.readOnlyDepth {
+ evm.quorumReadOnly = false
}
- env.StateDB = env.states[env.currentStateDepth-1]
+ evm.StateDB = evm.states[evm.currentStateDepth-1]
}
-func (env *EVM) Depth() int { return env.depth }
+func (evm *EVM) Depth() int { return evm.depth }
// We only need to revert the current state because when we call from private
// public state it's read only, there wouldn't be anything to reset.
// (A)->(B)->C->(B): A failure in (B) wouldn't need to reset C, as C was flagged
// read only.
-func (self *EVM) RevertToSnapshot(snapshot int) {
- self.StateDB.RevertToSnapshot(snapshot)
+func (evm *EVM) RevertToSnapshot(snapshot int) {
+ evm.StateDB.RevertToSnapshot(snapshot)
}
-// Returns all affected contracts that are NOT due to creation transaction
-func (evm *EVM) AffectedContracts() []common.Address {
+// Quorum
+//
+// Returns addresses of contracts which are message-called
+func (evm *EVM) CalledContracts() []common.Address {
addr := make([]common.Address, 0, len(evm.affectedContracts))
for a, t := range evm.affectedContracts {
- if t == MessageCall {
+ if t.reason == MessageCall {
addr = append(addr, a)
}
}
return addr[:]
}
+// Quorum
+//
+// Returns addresses of contracts which are newly created
func (evm *EVM) CreatedContracts() []common.Address {
addr := make([]common.Address, 0, len(evm.affectedContracts))
for a, t := range evm.affectedContracts {
- if t == Creation {
+ if t.reason == Creation {
+ addr = append(addr, a)
+ }
+ }
+ return addr[:]
+}
+
+// Quorum
+//
+// pushAddress stores the contract address being affected during EVM execution
+func (evm *EVM) pushAddress(address common.Address) {
+ evm.addressStack = append(evm.addressStack, address)
+}
+
+// Quorum
+//
+// popAddress retrieves the affected contract address from the stack
+func (evm *EVM) popAddress() {
+ l := len(evm.addressStack)
+ if l == 0 {
+ return
+ }
+ evm.addressStack = evm.addressStack[:l-1]
+}
+
+// Quorum
+//
+// peekAddress retrieves the affected contract address from the top of the stack
+func (evm *EVM) peekAddress() common.Address {
+ l := len(evm.addressStack)
+ if l == 0 {
+ return common.Address{}
+ }
+ return evm.addressStack[l-1]
+}
+
+// Quorum
+//
+// captureOperationMode stores the type of operation being applied on the current
+// affected contract whose address is on top of the stack.
+// For multitenancy, it checks if the mode is allowed. Also it bubbles up the last error
+// captured. This helps to avoid "evm: execution revert" generic error
+func (evm *EVM) captureOperationMode(isWriteOperation bool) error {
+ currentAddress := evm.peekAddress()
+ if (currentAddress == common.Address{}) {
+ return evm.lastError
+ }
+ actualMode := ModeOf(isWriteOperation)
+ if t, ok := evm.affectedContracts[currentAddress]; ok {
+ // perform multitenancy check
+ if evm.enforceMultitenancyCheck() {
+ if t.mode.IsNotAuthorized(actualMode) {
+ log.Trace("Multitenancy check for captureOperationMode()", "address", currentAddress.Hex(), "actual", actualMode, "expect", t.mode)
+ evm.lastError = multitenancy.ErrNotAuthorized
+ }
+ // bubble up the last error
+ if evm.lastError != nil {
+ return evm.lastError
+ }
+ }
+ t.mode = t.mode | actualMode
+ }
+ return nil
+}
+
+// Quorum
+//
+// captureAffectedContract stores the contract address to the affectedContract list if not yet there.
+// The affected mode is also updated if required.
+// Default affected reason is MessageCall.
+// In simulation/eth_call for multitenancy, it sets the expectation of AffectedMode
+// to be verified later when an opcode is executed.
+func (evm *EVM) captureAffectedContract(address common.Address, mode AffectedMode) error {
+ affectedType, found := evm.affectedContracts[address]
+ if !found {
+ affectedType = newAffectedType(MessageCall, mode)
+ evm.affectedContracts[address] = affectedType
+ }
+ if affectedType.mode != ModeUnknown {
+ return nil
+ }
+ if evm.SupportsMultitenancy && evm.AuthorizeMessageCallFunc != nil {
+ authorizedRead, authorizedWrite, err := evm.AuthorizeMessageCallFunc(address)
+ if err != nil {
+ return err
+ }
+ // if we don't authorize either read/write, it's unauthorized access
+ // and we need to inform EVM
+ if !authorizedRead && !authorizedWrite {
+ evm.lastError = multitenancy.ErrNotAuthorized
+ log.Debug("Affected contract not authorized", "address", address.Hex(), "read", authorizedRead, "write", authorizedWrite)
+ return multitenancy.ErrNotAuthorized
+ }
+ oldMode := affectedType.mode
+ affectedType.mode = affectedType.mode.Update(authorizedRead, authorizedWrite)
+ log.Debug("AffectedMode changed", "address", address.Hex(), "old", oldMode, "new", affectedType.mode)
+ }
+ return nil
+}
+
+// enforceMultitenancyCheck returns true if EVM is enforced to do multitenancy check
+// during simulation/eth_call, false otherwise
+func (evm *EVM) enforceMultitenancyCheck() bool {
+ return evm.AuthorizeCreateFunc != nil || evm.AuthorizeMessageCallFunc != nil
+}
+
+// Quorum
+//
+// AffecteMode returns the type of operation (read/write) which was applied on the given
+// contract address. It returns ModeUnknown if the contract is not affected during
+// the lifecycle of this EVM instance
+func (evm *EVM) AffectedMode(a common.Address) (AffectedMode, error) {
+ if t, ok := evm.affectedContracts[a]; ok {
+ return t.mode, nil
+ }
+ return ModeUnknown, fmt.Errorf("address not found")
+}
+
+func newAffectedType(r AffectedReason, m AffectedMode) *AffectedType {
+ return &AffectedType{
+ reason: r,
+ mode: m,
+ }
+}
+
+// Quorum
+//
+// AffectedContracts returns all affected contracts that are the results of
+// MessageCall transaction
+func (evm *EVM) AffectedContracts() []common.Address {
+ addr := make([]common.Address, 0, len(evm.affectedContracts))
+ for a, t := range evm.affectedContracts {
+ if t.reason == MessageCall {
addr = append(addr, a)
}
}
diff --git a/core/vm/evm_test.go b/core/vm/evm_test.go
new file mode 100644
index 0000000000..e280e25e89
--- /dev/null
+++ b/core/vm/evm_test.go
@@ -0,0 +1,23 @@
+package vm
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestAffectedMode_Update_whenTypical(t *testing.T) {
+ testObject := ModeUnknown
+ authorizedReads := []bool{true, false}
+ authorizedWrites := []bool{true, false}
+ for _, authorizedRead := range authorizedReads {
+ for _, authorizedWrite := range authorizedWrites {
+ actual := testObject.Update(authorizedRead, authorizedWrite)
+
+ assert.True(t, actual.Has(ModeUpdated))
+ assert.Equal(t, authorizedRead, actual.Has(ModeRead))
+ assert.Equal(t, authorizedWrite, actual.Has(ModeWrite))
+ assert.False(t, testObject.Has(ModeUpdated))
+ }
+ }
+}
diff --git a/core/vm/instructions.go b/core/vm/instructions.go
index 74127d1596..d953f4b9c5 100644
--- a/core/vm/instructions.go
+++ b/core/vm/instructions.go
@@ -551,10 +551,11 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, contract *Contract,
func opExtCodeHash(pc *uint64, interpreter *EVMInterpreter, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) {
slot := stack.peek()
address := common.BigToAddress(slot)
- if interpreter.evm.StateDB.Empty(address) {
+ stateDB := getDualState(interpreter.evm, address)
+ if stateDB.Empty(address) {
slot.SetUint64(0)
} else {
- slot.SetBytes(interpreter.evm.StateDB.GetCodeHash(address).Bytes())
+ slot.SetBytes(stateDB.GetCodeHash(address).Bytes())
}
return nil, nil
}
diff --git a/core/vm/interface.go b/core/vm/interface.go
index c95adce8b0..af0c373fbe 100644
--- a/core/vm/interface.go
+++ b/core/vm/interface.go
@@ -14,6 +14,8 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
+//go:generate mockgen -source interface.go -destination mock_interface.go -package vm
+
package vm
import (
@@ -24,9 +26,22 @@ import (
"github.com/ethereum/go-ethereum/core/types"
)
+type AccountExtraDataStateGetter interface {
+ // Return nil for public contract
+ GetPrivacyMetadata(addr common.Address) (*state.PrivacyMetadata, error)
+ GetManagedParties(addr common.Address) ([]string, error)
+}
+
+type AccountExtraDataStateSetter interface {
+ SetPrivacyMetadata(addr common.Address, pm *state.PrivacyMetadata)
+ SetManagedParties(addr common.Address, managedParties []string)
+}
+
// Quorum uses a cut-down StateDB, MinimalApiState. We leave the methods in StateDB commented out so they'll produce a
// conflict when upstream changes.
type MinimalApiState interface {
+ AccountExtraDataStateGetter
+
GetBalance(addr common.Address) *big.Int
SetBalance(addr common.Address, balance *big.Int)
GetCode(addr common.Address) []byte
@@ -34,8 +49,6 @@ type MinimalApiState interface {
GetNonce(addr common.Address) uint64
SetNonce(addr common.Address, nonce uint64)
SetCode(common.Address, []byte)
- // Return nil for public contract
- GetStatePrivacyMetadata(addr common.Address) (*state.PrivacyMetadata, error)
// RLP-encoded of the state object in a given address
// Throw error if no state object is found
@@ -52,6 +65,8 @@ type MinimalApiState interface {
// StateDB is an EVM database for full state querying.
type StateDB interface {
MinimalApiState
+ AccountExtraDataStateSetter
+
CreateAccount(common.Address)
SubBalance(common.Address, *big.Int)
@@ -61,7 +76,6 @@ type StateDB interface {
//GetNonce(common.Address) uint64
//SetNonce(common.Address, uint64)
- SetStatePrivacyMetadata(common.Address, *state.PrivacyMetadata)
//GetCodeHash(common.Address) common.Hash
//GetCode(common.Address) []byte
//SetCode(common.Address, []byte)
diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go
index 33a7ec7ee3..99680cf065 100644
--- a/core/vm/interpreter.go
+++ b/core/vm/interpreter.go
@@ -217,7 +217,9 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) (
if in.evm.quorumReadOnly && operation.writes {
return nil, fmt.Errorf("VM in read-only mode. Mutating opcode prohibited")
}
-
+ if err := in.evm.captureOperationMode(operation.writes); err != nil {
+ return nil, err
+ }
// If the operation is valid, enforce and write restrictions
if in.readOnly && in.evm.chainRules.IsByzantium {
// If the interpreter is operating in readonly mode, make sure no
diff --git a/core/vm/mock_interface.go b/core/vm/mock_interface.go
new file mode 100644
index 0000000000..5060315167
--- /dev/null
+++ b/core/vm/mock_interface.go
@@ -0,0 +1,957 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: interface.go
+
+// Package vm is a generated GoMock package.
+package vm
+
+import (
+ big "math/big"
+ reflect "reflect"
+
+ common "github.com/ethereum/go-ethereum/common"
+ state "github.com/ethereum/go-ethereum/core/state"
+ types "github.com/ethereum/go-ethereum/core/types"
+ gomock "github.com/golang/mock/gomock"
+)
+
+// MockAccountExtraDataStateGetter is a mock of AccountExtraDataStateGetter interface.
+type MockAccountExtraDataStateGetter struct {
+ ctrl *gomock.Controller
+ recorder *MockAccountExtraDataStateGetterMockRecorder
+}
+
+// MockAccountExtraDataStateGetterMockRecorder is the mock recorder for MockAccountExtraDataStateGetter.
+type MockAccountExtraDataStateGetterMockRecorder struct {
+ mock *MockAccountExtraDataStateGetter
+}
+
+// NewMockAccountExtraDataStateGetter creates a new mock instance.
+func NewMockAccountExtraDataStateGetter(ctrl *gomock.Controller) *MockAccountExtraDataStateGetter {
+ mock := &MockAccountExtraDataStateGetter{ctrl: ctrl}
+ mock.recorder = &MockAccountExtraDataStateGetterMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockAccountExtraDataStateGetter) EXPECT() *MockAccountExtraDataStateGetterMockRecorder {
+ return m.recorder
+}
+
+// GetPrivacyMetadata mocks base method.
+func (m *MockAccountExtraDataStateGetter) GetPrivacyMetadata(addr common.Address) (*state.PrivacyMetadata, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetPrivacyMetadata", addr)
+ ret0, _ := ret[0].(*state.PrivacyMetadata)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetPrivacyMetadata indicates an expected call of GetPrivacyMetadata.
+func (mr *MockAccountExtraDataStateGetterMockRecorder) GetPrivacyMetadata(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivacyMetadata", reflect.TypeOf((*MockAccountExtraDataStateGetter)(nil).GetPrivacyMetadata), addr)
+}
+
+// GetManagedParties mocks base method.
+func (m *MockAccountExtraDataStateGetter) GetManagedParties(addr common.Address) ([]string, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetManagedParties", addr)
+ ret0, _ := ret[0].([]string)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetManagedParties indicates an expected call of GetManagedParties.
+func (mr *MockAccountExtraDataStateGetterMockRecorder) GetManagedParties(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedParties", reflect.TypeOf((*MockAccountExtraDataStateGetter)(nil).GetManagedParties), addr)
+}
+
+// MockAccountExtraDataStateSetter is a mock of AccountExtraDataStateSetter interface.
+type MockAccountExtraDataStateSetter struct {
+ ctrl *gomock.Controller
+ recorder *MockAccountExtraDataStateSetterMockRecorder
+}
+
+// MockAccountExtraDataStateSetterMockRecorder is the mock recorder for MockAccountExtraDataStateSetter.
+type MockAccountExtraDataStateSetterMockRecorder struct {
+ mock *MockAccountExtraDataStateSetter
+}
+
+// NewMockAccountExtraDataStateSetter creates a new mock instance.
+func NewMockAccountExtraDataStateSetter(ctrl *gomock.Controller) *MockAccountExtraDataStateSetter {
+ mock := &MockAccountExtraDataStateSetter{ctrl: ctrl}
+ mock.recorder = &MockAccountExtraDataStateSetterMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockAccountExtraDataStateSetter) EXPECT() *MockAccountExtraDataStateSetterMockRecorder {
+ return m.recorder
+}
+
+// SetPrivacyMetadata mocks base method.
+func (m *MockAccountExtraDataStateSetter) SetPrivacyMetadata(addr common.Address, pm *state.PrivacyMetadata) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetPrivacyMetadata", addr, pm)
+}
+
+// SetPrivacyMetadata indicates an expected call of SetPrivacyMetadata.
+func (mr *MockAccountExtraDataStateSetterMockRecorder) SetPrivacyMetadata(addr, pm interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPrivacyMetadata", reflect.TypeOf((*MockAccountExtraDataStateSetter)(nil).SetPrivacyMetadata), addr, pm)
+}
+
+// SetManagedParties mocks base method.
+func (m *MockAccountExtraDataStateSetter) SetManagedParties(addr common.Address, managedParties []string) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetManagedParties", addr, managedParties)
+}
+
+// SetManagedParties indicates an expected call of SetManagedParties.
+func (mr *MockAccountExtraDataStateSetterMockRecorder) SetManagedParties(addr, managedParties interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetManagedParties", reflect.TypeOf((*MockAccountExtraDataStateSetter)(nil).SetManagedParties), addr, managedParties)
+}
+
+// MockMinimalApiState is a mock of MinimalApiState interface.
+type MockMinimalApiState struct {
+ ctrl *gomock.Controller
+ recorder *MockMinimalApiStateMockRecorder
+}
+
+// MockMinimalApiStateMockRecorder is the mock recorder for MockMinimalApiState.
+type MockMinimalApiStateMockRecorder struct {
+ mock *MockMinimalApiState
+}
+
+// NewMockMinimalApiState creates a new mock instance.
+func NewMockMinimalApiState(ctrl *gomock.Controller) *MockMinimalApiState {
+ mock := &MockMinimalApiState{ctrl: ctrl}
+ mock.recorder = &MockMinimalApiStateMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockMinimalApiState) EXPECT() *MockMinimalApiStateMockRecorder {
+ return m.recorder
+}
+
+// GetPrivacyMetadata mocks base method.
+func (m *MockMinimalApiState) GetPrivacyMetadata(addr common.Address) (*state.PrivacyMetadata, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetPrivacyMetadata", addr)
+ ret0, _ := ret[0].(*state.PrivacyMetadata)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetPrivacyMetadata indicates an expected call of GetPrivacyMetadata.
+func (mr *MockMinimalApiStateMockRecorder) GetPrivacyMetadata(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivacyMetadata", reflect.TypeOf((*MockMinimalApiState)(nil).GetPrivacyMetadata), addr)
+}
+
+// GetManagedParties mocks base method.
+func (m *MockMinimalApiState) GetManagedParties(addr common.Address) ([]string, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetManagedParties", addr)
+ ret0, _ := ret[0].([]string)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetManagedParties indicates an expected call of GetManagedParties.
+func (mr *MockMinimalApiStateMockRecorder) GetManagedParties(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedParties", reflect.TypeOf((*MockMinimalApiState)(nil).GetManagedParties), addr)
+}
+
+// GetBalance mocks base method.
+func (m *MockMinimalApiState) GetBalance(addr common.Address) *big.Int {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetBalance", addr)
+ ret0, _ := ret[0].(*big.Int)
+ return ret0
+}
+
+// GetBalance indicates an expected call of GetBalance.
+func (mr *MockMinimalApiStateMockRecorder) GetBalance(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalance", reflect.TypeOf((*MockMinimalApiState)(nil).GetBalance), addr)
+}
+
+// SetBalance mocks base method.
+func (m *MockMinimalApiState) SetBalance(addr common.Address, balance *big.Int) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetBalance", addr, balance)
+}
+
+// SetBalance indicates an expected call of SetBalance.
+func (mr *MockMinimalApiStateMockRecorder) SetBalance(addr, balance interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBalance", reflect.TypeOf((*MockMinimalApiState)(nil).SetBalance), addr, balance)
+}
+
+// GetCode mocks base method.
+func (m *MockMinimalApiState) GetCode(addr common.Address) []byte {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetCode", addr)
+ ret0, _ := ret[0].([]byte)
+ return ret0
+}
+
+// GetCode indicates an expected call of GetCode.
+func (mr *MockMinimalApiStateMockRecorder) GetCode(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCode", reflect.TypeOf((*MockMinimalApiState)(nil).GetCode), addr)
+}
+
+// GetState mocks base method.
+func (m *MockMinimalApiState) GetState(a common.Address, b common.Hash) common.Hash {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetState", a, b)
+ ret0, _ := ret[0].(common.Hash)
+ return ret0
+}
+
+// GetState indicates an expected call of GetState.
+func (mr *MockMinimalApiStateMockRecorder) GetState(a, b interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MockMinimalApiState)(nil).GetState), a, b)
+}
+
+// GetNonce mocks base method.
+func (m *MockMinimalApiState) GetNonce(addr common.Address) uint64 {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetNonce", addr)
+ ret0, _ := ret[0].(uint64)
+ return ret0
+}
+
+// GetNonce indicates an expected call of GetNonce.
+func (mr *MockMinimalApiStateMockRecorder) GetNonce(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNonce", reflect.TypeOf((*MockMinimalApiState)(nil).GetNonce), addr)
+}
+
+// SetNonce mocks base method.
+func (m *MockMinimalApiState) SetNonce(addr common.Address, nonce uint64) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetNonce", addr, nonce)
+}
+
+// SetNonce indicates an expected call of SetNonce.
+func (mr *MockMinimalApiStateMockRecorder) SetNonce(addr, nonce interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNonce", reflect.TypeOf((*MockMinimalApiState)(nil).SetNonce), addr, nonce)
+}
+
+// SetCode mocks base method.
+func (m *MockMinimalApiState) SetCode(arg0 common.Address, arg1 []byte) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetCode", arg0, arg1)
+}
+
+// SetCode indicates an expected call of SetCode.
+func (mr *MockMinimalApiStateMockRecorder) SetCode(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCode", reflect.TypeOf((*MockMinimalApiState)(nil).SetCode), arg0, arg1)
+}
+
+// GetRLPEncodedStateObject mocks base method.
+func (m *MockMinimalApiState) GetRLPEncodedStateObject(addr common.Address) ([]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetRLPEncodedStateObject", addr)
+ ret0, _ := ret[0].([]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetRLPEncodedStateObject indicates an expected call of GetRLPEncodedStateObject.
+func (mr *MockMinimalApiStateMockRecorder) GetRLPEncodedStateObject(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRLPEncodedStateObject", reflect.TypeOf((*MockMinimalApiState)(nil).GetRLPEncodedStateObject), addr)
+}
+
+// GetProof mocks base method.
+func (m *MockMinimalApiState) GetProof(arg0 common.Address) ([][]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetProof", arg0)
+ ret0, _ := ret[0].([][]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetProof indicates an expected call of GetProof.
+func (mr *MockMinimalApiStateMockRecorder) GetProof(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProof", reflect.TypeOf((*MockMinimalApiState)(nil).GetProof), arg0)
+}
+
+// GetStorageProof mocks base method.
+func (m *MockMinimalApiState) GetStorageProof(arg0 common.Address, arg1 common.Hash) ([][]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetStorageProof", arg0, arg1)
+ ret0, _ := ret[0].([][]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetStorageProof indicates an expected call of GetStorageProof.
+func (mr *MockMinimalApiStateMockRecorder) GetStorageProof(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStorageProof", reflect.TypeOf((*MockMinimalApiState)(nil).GetStorageProof), arg0, arg1)
+}
+
+// StorageTrie mocks base method.
+func (m *MockMinimalApiState) StorageTrie(addr common.Address) state.Trie {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "StorageTrie", addr)
+ ret0, _ := ret[0].(state.Trie)
+ return ret0
+}
+
+// StorageTrie indicates an expected call of StorageTrie.
+func (mr *MockMinimalApiStateMockRecorder) StorageTrie(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StorageTrie", reflect.TypeOf((*MockMinimalApiState)(nil).StorageTrie), addr)
+}
+
+// Error mocks base method.
+func (m *MockMinimalApiState) Error() error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Error")
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Error indicates an expected call of Error.
+func (mr *MockMinimalApiStateMockRecorder) Error() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockMinimalApiState)(nil).Error))
+}
+
+// GetCodeHash mocks base method.
+func (m *MockMinimalApiState) GetCodeHash(arg0 common.Address) common.Hash {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetCodeHash", arg0)
+ ret0, _ := ret[0].(common.Hash)
+ return ret0
+}
+
+// GetCodeHash indicates an expected call of GetCodeHash.
+func (mr *MockMinimalApiStateMockRecorder) GetCodeHash(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCodeHash", reflect.TypeOf((*MockMinimalApiState)(nil).GetCodeHash), arg0)
+}
+
+// SetState mocks base method.
+func (m *MockMinimalApiState) SetState(arg0 common.Address, arg1, arg2 common.Hash) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetState", arg0, arg1, arg2)
+}
+
+// SetState indicates an expected call of SetState.
+func (mr *MockMinimalApiStateMockRecorder) SetState(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetState", reflect.TypeOf((*MockMinimalApiState)(nil).SetState), arg0, arg1, arg2)
+}
+
+// SetStorage mocks base method.
+func (m *MockMinimalApiState) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetStorage", addr, storage)
+}
+
+// SetStorage indicates an expected call of SetStorage.
+func (mr *MockMinimalApiStateMockRecorder) SetStorage(addr, storage interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStorage", reflect.TypeOf((*MockMinimalApiState)(nil).SetStorage), addr, storage)
+}
+
+// MockStateDB is a mock of StateDB interface.
+type MockStateDB struct {
+ ctrl *gomock.Controller
+ recorder *MockStateDBMockRecorder
+}
+
+// MockStateDBMockRecorder is the mock recorder for MockStateDB.
+type MockStateDBMockRecorder struct {
+ mock *MockStateDB
+}
+
+// NewMockStateDB creates a new mock instance.
+func NewMockStateDB(ctrl *gomock.Controller) *MockStateDB {
+ mock := &MockStateDB{ctrl: ctrl}
+ mock.recorder = &MockStateDBMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockStateDB) EXPECT() *MockStateDBMockRecorder {
+ return m.recorder
+}
+
+// GetPrivacyMetadata mocks base method.
+func (m *MockStateDB) GetPrivacyMetadata(addr common.Address) (*state.PrivacyMetadata, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetPrivacyMetadata", addr)
+ ret0, _ := ret[0].(*state.PrivacyMetadata)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetPrivacyMetadata indicates an expected call of GetPrivacyMetadata.
+func (mr *MockStateDBMockRecorder) GetPrivacyMetadata(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivacyMetadata", reflect.TypeOf((*MockStateDB)(nil).GetPrivacyMetadata), addr)
+}
+
+// GetManagedParties mocks base method.
+func (m *MockStateDB) GetManagedParties(addr common.Address) ([]string, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetManagedParties", addr)
+ ret0, _ := ret[0].([]string)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetManagedParties indicates an expected call of GetManagedParties.
+func (mr *MockStateDBMockRecorder) GetManagedParties(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedParties", reflect.TypeOf((*MockStateDB)(nil).GetManagedParties), addr)
+}
+
+// GetBalance mocks base method.
+func (m *MockStateDB) GetBalance(addr common.Address) *big.Int {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetBalance", addr)
+ ret0, _ := ret[0].(*big.Int)
+ return ret0
+}
+
+// GetBalance indicates an expected call of GetBalance.
+func (mr *MockStateDBMockRecorder) GetBalance(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalance", reflect.TypeOf((*MockStateDB)(nil).GetBalance), addr)
+}
+
+// SetBalance mocks base method.
+func (m *MockStateDB) SetBalance(addr common.Address, balance *big.Int) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetBalance", addr, balance)
+}
+
+// SetBalance indicates an expected call of SetBalance.
+func (mr *MockStateDBMockRecorder) SetBalance(addr, balance interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBalance", reflect.TypeOf((*MockStateDB)(nil).SetBalance), addr, balance)
+}
+
+// GetCode mocks base method.
+func (m *MockStateDB) GetCode(addr common.Address) []byte {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetCode", addr)
+ ret0, _ := ret[0].([]byte)
+ return ret0
+}
+
+// GetCode indicates an expected call of GetCode.
+func (mr *MockStateDBMockRecorder) GetCode(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCode", reflect.TypeOf((*MockStateDB)(nil).GetCode), addr)
+}
+
+// GetState mocks base method.
+func (m *MockStateDB) GetState(a common.Address, b common.Hash) common.Hash {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetState", a, b)
+ ret0, _ := ret[0].(common.Hash)
+ return ret0
+}
+
+// GetState indicates an expected call of GetState.
+func (mr *MockStateDBMockRecorder) GetState(a, b interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MockStateDB)(nil).GetState), a, b)
+}
+
+// GetNonce mocks base method.
+func (m *MockStateDB) GetNonce(addr common.Address) uint64 {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetNonce", addr)
+ ret0, _ := ret[0].(uint64)
+ return ret0
+}
+
+// GetNonce indicates an expected call of GetNonce.
+func (mr *MockStateDBMockRecorder) GetNonce(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNonce", reflect.TypeOf((*MockStateDB)(nil).GetNonce), addr)
+}
+
+// SetNonce mocks base method.
+func (m *MockStateDB) SetNonce(addr common.Address, nonce uint64) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetNonce", addr, nonce)
+}
+
+// SetNonce indicates an expected call of SetNonce.
+func (mr *MockStateDBMockRecorder) SetNonce(addr, nonce interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNonce", reflect.TypeOf((*MockStateDB)(nil).SetNonce), addr, nonce)
+}
+
+// SetCode mocks base method.
+func (m *MockStateDB) SetCode(arg0 common.Address, arg1 []byte) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetCode", arg0, arg1)
+}
+
+// SetCode indicates an expected call of SetCode.
+func (mr *MockStateDBMockRecorder) SetCode(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCode", reflect.TypeOf((*MockStateDB)(nil).SetCode), arg0, arg1)
+}
+
+// GetRLPEncodedStateObject mocks base method.
+func (m *MockStateDB) GetRLPEncodedStateObject(addr common.Address) ([]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetRLPEncodedStateObject", addr)
+ ret0, _ := ret[0].([]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetRLPEncodedStateObject indicates an expected call of GetRLPEncodedStateObject.
+func (mr *MockStateDBMockRecorder) GetRLPEncodedStateObject(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRLPEncodedStateObject", reflect.TypeOf((*MockStateDB)(nil).GetRLPEncodedStateObject), addr)
+}
+
+// GetProof mocks base method.
+func (m *MockStateDB) GetProof(arg0 common.Address) ([][]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetProof", arg0)
+ ret0, _ := ret[0].([][]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetProof indicates an expected call of GetProof.
+func (mr *MockStateDBMockRecorder) GetProof(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProof", reflect.TypeOf((*MockStateDB)(nil).GetProof), arg0)
+}
+
+// GetStorageProof mocks base method.
+func (m *MockStateDB) GetStorageProof(arg0 common.Address, arg1 common.Hash) ([][]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetStorageProof", arg0, arg1)
+ ret0, _ := ret[0].([][]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetStorageProof indicates an expected call of GetStorageProof.
+func (mr *MockStateDBMockRecorder) GetStorageProof(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStorageProof", reflect.TypeOf((*MockStateDB)(nil).GetStorageProof), arg0, arg1)
+}
+
+// StorageTrie mocks base method.
+func (m *MockStateDB) StorageTrie(addr common.Address) state.Trie {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "StorageTrie", addr)
+ ret0, _ := ret[0].(state.Trie)
+ return ret0
+}
+
+// StorageTrie indicates an expected call of StorageTrie.
+func (mr *MockStateDBMockRecorder) StorageTrie(addr interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StorageTrie", reflect.TypeOf((*MockStateDB)(nil).StorageTrie), addr)
+}
+
+// Error mocks base method.
+func (m *MockStateDB) Error() error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Error")
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Error indicates an expected call of Error.
+func (mr *MockStateDBMockRecorder) Error() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockStateDB)(nil).Error))
+}
+
+// GetCodeHash mocks base method.
+func (m *MockStateDB) GetCodeHash(arg0 common.Address) common.Hash {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetCodeHash", arg0)
+ ret0, _ := ret[0].(common.Hash)
+ return ret0
+}
+
+// GetCodeHash indicates an expected call of GetCodeHash.
+func (mr *MockStateDBMockRecorder) GetCodeHash(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCodeHash", reflect.TypeOf((*MockStateDB)(nil).GetCodeHash), arg0)
+}
+
+// SetState mocks base method.
+func (m *MockStateDB) SetState(arg0 common.Address, arg1, arg2 common.Hash) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetState", arg0, arg1, arg2)
+}
+
+// SetState indicates an expected call of SetState.
+func (mr *MockStateDBMockRecorder) SetState(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetState", reflect.TypeOf((*MockStateDB)(nil).SetState), arg0, arg1, arg2)
+}
+
+// SetStorage mocks base method.
+func (m *MockStateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetStorage", addr, storage)
+}
+
+// SetStorage indicates an expected call of SetStorage.
+func (mr *MockStateDBMockRecorder) SetStorage(addr, storage interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStorage", reflect.TypeOf((*MockStateDB)(nil).SetStorage), addr, storage)
+}
+
+// SetPrivacyMetadata mocks base method.
+func (m *MockStateDB) SetPrivacyMetadata(addr common.Address, pm *state.PrivacyMetadata) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetPrivacyMetadata", addr, pm)
+}
+
+// SetPrivacyMetadata indicates an expected call of SetPrivacyMetadata.
+func (mr *MockStateDBMockRecorder) SetPrivacyMetadata(addr, pm interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPrivacyMetadata", reflect.TypeOf((*MockStateDB)(nil).SetPrivacyMetadata), addr, pm)
+}
+
+// SetManagedParties mocks base method.
+func (m *MockStateDB) SetManagedParties(addr common.Address, managedParties []string) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SetManagedParties", addr, managedParties)
+}
+
+// SetManagedParties indicates an expected call of SetManagedParties.
+func (mr *MockStateDBMockRecorder) SetManagedParties(addr, managedParties interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetManagedParties", reflect.TypeOf((*MockStateDB)(nil).SetManagedParties), addr, managedParties)
+}
+
+// CreateAccount mocks base method.
+func (m *MockStateDB) CreateAccount(arg0 common.Address) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "CreateAccount", arg0)
+}
+
+// CreateAccount indicates an expected call of CreateAccount.
+func (mr *MockStateDBMockRecorder) CreateAccount(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAccount", reflect.TypeOf((*MockStateDB)(nil).CreateAccount), arg0)
+}
+
+// SubBalance mocks base method.
+func (m *MockStateDB) SubBalance(arg0 common.Address, arg1 *big.Int) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SubBalance", arg0, arg1)
+}
+
+// SubBalance indicates an expected call of SubBalance.
+func (mr *MockStateDBMockRecorder) SubBalance(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubBalance", reflect.TypeOf((*MockStateDB)(nil).SubBalance), arg0, arg1)
+}
+
+// AddBalance mocks base method.
+func (m *MockStateDB) AddBalance(arg0 common.Address, arg1 *big.Int) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "AddBalance", arg0, arg1)
+}
+
+// AddBalance indicates an expected call of AddBalance.
+func (mr *MockStateDBMockRecorder) AddBalance(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddBalance", reflect.TypeOf((*MockStateDB)(nil).AddBalance), arg0, arg1)
+}
+
+// GetCodeSize mocks base method.
+func (m *MockStateDB) GetCodeSize(arg0 common.Address) int {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetCodeSize", arg0)
+ ret0, _ := ret[0].(int)
+ return ret0
+}
+
+// GetCodeSize indicates an expected call of GetCodeSize.
+func (mr *MockStateDBMockRecorder) GetCodeSize(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCodeSize", reflect.TypeOf((*MockStateDB)(nil).GetCodeSize), arg0)
+}
+
+// AddRefund mocks base method.
+func (m *MockStateDB) AddRefund(arg0 uint64) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "AddRefund", arg0)
+}
+
+// AddRefund indicates an expected call of AddRefund.
+func (mr *MockStateDBMockRecorder) AddRefund(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRefund", reflect.TypeOf((*MockStateDB)(nil).AddRefund), arg0)
+}
+
+// SubRefund mocks base method.
+func (m *MockStateDB) SubRefund(arg0 uint64) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "SubRefund", arg0)
+}
+
+// SubRefund indicates an expected call of SubRefund.
+func (mr *MockStateDBMockRecorder) SubRefund(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubRefund", reflect.TypeOf((*MockStateDB)(nil).SubRefund), arg0)
+}
+
+// GetRefund mocks base method.
+func (m *MockStateDB) GetRefund() uint64 {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetRefund")
+ ret0, _ := ret[0].(uint64)
+ return ret0
+}
+
+// GetRefund indicates an expected call of GetRefund.
+func (mr *MockStateDBMockRecorder) GetRefund() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRefund", reflect.TypeOf((*MockStateDB)(nil).GetRefund))
+}
+
+// GetCommittedState mocks base method.
+func (m *MockStateDB) GetCommittedState(arg0 common.Address, arg1 common.Hash) common.Hash {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetCommittedState", arg0, arg1)
+ ret0, _ := ret[0].(common.Hash)
+ return ret0
+}
+
+// GetCommittedState indicates an expected call of GetCommittedState.
+func (mr *MockStateDBMockRecorder) GetCommittedState(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCommittedState", reflect.TypeOf((*MockStateDB)(nil).GetCommittedState), arg0, arg1)
+}
+
+// Suicide mocks base method.
+func (m *MockStateDB) Suicide(arg0 common.Address) bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Suicide", arg0)
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// Suicide indicates an expected call of Suicide.
+func (mr *MockStateDBMockRecorder) Suicide(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Suicide", reflect.TypeOf((*MockStateDB)(nil).Suicide), arg0)
+}
+
+// HasSuicided mocks base method.
+func (m *MockStateDB) HasSuicided(arg0 common.Address) bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "HasSuicided", arg0)
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// HasSuicided indicates an expected call of HasSuicided.
+func (mr *MockStateDBMockRecorder) HasSuicided(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasSuicided", reflect.TypeOf((*MockStateDB)(nil).HasSuicided), arg0)
+}
+
+// Exist mocks base method.
+func (m *MockStateDB) Exist(arg0 common.Address) bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Exist", arg0)
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// Exist indicates an expected call of Exist.
+func (mr *MockStateDBMockRecorder) Exist(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exist", reflect.TypeOf((*MockStateDB)(nil).Exist), arg0)
+}
+
+// Empty mocks base method.
+func (m *MockStateDB) Empty(arg0 common.Address) bool {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Empty", arg0)
+ ret0, _ := ret[0].(bool)
+ return ret0
+}
+
+// Empty indicates an expected call of Empty.
+func (mr *MockStateDBMockRecorder) Empty(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Empty", reflect.TypeOf((*MockStateDB)(nil).Empty), arg0)
+}
+
+// RevertToSnapshot mocks base method.
+func (m *MockStateDB) RevertToSnapshot(arg0 int) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "RevertToSnapshot", arg0)
+}
+
+// RevertToSnapshot indicates an expected call of RevertToSnapshot.
+func (mr *MockStateDBMockRecorder) RevertToSnapshot(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevertToSnapshot", reflect.TypeOf((*MockStateDB)(nil).RevertToSnapshot), arg0)
+}
+
+// Snapshot mocks base method.
+func (m *MockStateDB) Snapshot() int {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Snapshot")
+ ret0, _ := ret[0].(int)
+ return ret0
+}
+
+// Snapshot indicates an expected call of Snapshot.
+func (mr *MockStateDBMockRecorder) Snapshot() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Snapshot", reflect.TypeOf((*MockStateDB)(nil).Snapshot))
+}
+
+// AddLog mocks base method.
+func (m *MockStateDB) AddLog(arg0 *types.Log) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "AddLog", arg0)
+}
+
+// AddLog indicates an expected call of AddLog.
+func (mr *MockStateDBMockRecorder) AddLog(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddLog", reflect.TypeOf((*MockStateDB)(nil).AddLog), arg0)
+}
+
+// AddPreimage mocks base method.
+func (m *MockStateDB) AddPreimage(arg0 common.Hash, arg1 []byte) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "AddPreimage", arg0, arg1)
+}
+
+// AddPreimage indicates an expected call of AddPreimage.
+func (mr *MockStateDBMockRecorder) AddPreimage(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPreimage", reflect.TypeOf((*MockStateDB)(nil).AddPreimage), arg0, arg1)
+}
+
+// ForEachStorage mocks base method.
+func (m *MockStateDB) ForEachStorage(arg0 common.Address, arg1 func(common.Hash, common.Hash) bool) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "ForEachStorage", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// ForEachStorage indicates an expected call of ForEachStorage.
+func (mr *MockStateDBMockRecorder) ForEachStorage(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForEachStorage", reflect.TypeOf((*MockStateDB)(nil).ForEachStorage), arg0, arg1)
+}
+
+// MockCallContext is a mock of CallContext interface.
+type MockCallContext struct {
+ ctrl *gomock.Controller
+ recorder *MockCallContextMockRecorder
+}
+
+// MockCallContextMockRecorder is the mock recorder for MockCallContext.
+type MockCallContextMockRecorder struct {
+ mock *MockCallContext
+}
+
+// NewMockCallContext creates a new mock instance.
+func NewMockCallContext(ctrl *gomock.Controller) *MockCallContext {
+ mock := &MockCallContext{ctrl: ctrl}
+ mock.recorder = &MockCallContextMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockCallContext) EXPECT() *MockCallContextMockRecorder {
+ return m.recorder
+}
+
+// Call mocks base method.
+func (m *MockCallContext) Call(env *EVM, me ContractRef, addr common.Address, data []byte, gas, value *big.Int) ([]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Call", env, me, addr, data, gas, value)
+ ret0, _ := ret[0].([]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// Call indicates an expected call of Call.
+func (mr *MockCallContextMockRecorder) Call(env, me, addr, data, gas, value interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Call", reflect.TypeOf((*MockCallContext)(nil).Call), env, me, addr, data, gas, value)
+}
+
+// CallCode mocks base method.
+func (m *MockCallContext) CallCode(env *EVM, me ContractRef, addr common.Address, data []byte, gas, value *big.Int) ([]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "CallCode", env, me, addr, data, gas, value)
+ ret0, _ := ret[0].([]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// CallCode indicates an expected call of CallCode.
+func (mr *MockCallContextMockRecorder) CallCode(env, me, addr, data, gas, value interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallCode", reflect.TypeOf((*MockCallContext)(nil).CallCode), env, me, addr, data, gas, value)
+}
+
+// DelegateCall mocks base method.
+func (m *MockCallContext) DelegateCall(env *EVM, me ContractRef, addr common.Address, data []byte, gas *big.Int) ([]byte, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DelegateCall", env, me, addr, data, gas)
+ ret0, _ := ret[0].([]byte)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// DelegateCall indicates an expected call of DelegateCall.
+func (mr *MockCallContextMockRecorder) DelegateCall(env, me, addr, data, gas interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DelegateCall", reflect.TypeOf((*MockCallContext)(nil).DelegateCall), env, me, addr, data, gas)
+}
+
+// Create mocks base method.
+func (m *MockCallContext) Create(env *EVM, me ContractRef, data []byte, gas, value *big.Int) ([]byte, common.Address, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Create", env, me, data, gas, value)
+ ret0, _ := ret[0].([]byte)
+ ret1, _ := ret[1].(common.Address)
+ ret2, _ := ret[2].(error)
+ return ret0, ret1, ret2
+}
+
+// Create indicates an expected call of Create.
+func (mr *MockCallContextMockRecorder) Create(env, me, data, gas, value interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockCallContext)(nil).Create), env, me, data, gas, value)
+}
diff --git a/core/vm/runtime/evm_privacy_test.go b/core/vm/runtime/evm_privacy_test.go
index 9a5c181001..2f54e2b1ef 100644
--- a/core/vm/runtime/evm_privacy_test.go
+++ b/core/vm/runtime/evm_privacy_test.go
@@ -7,23 +7,16 @@ import (
"strings"
"testing"
+ "github.com/ethereum/go-ethereum/accounts/abi"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/common/hexutil"
+ "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
-
- "github.com/ethereum/go-ethereum/private/engine"
-
+ "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
-
- "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/vm"
-
"github.com/ethereum/go-ethereum/log"
-
- "github.com/ethereum/go-ethereum/common/hexutil"
-
- "github.com/ethereum/go-ethereum/core/state"
-
- "github.com/ethereum/go-ethereum/accounts/abi"
- "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/private/engine"
testifyassert "github.com/stretchr/testify/assert"
)
@@ -110,7 +103,7 @@ func TestPrivacyEnhancements_CreateC1(t *testing.T) {
var getPrivacyMetadataFunc func(common.Address) (*state.PrivacyMetadata, error)
cfg.onAfterEVM = func(evm *vm.EVM) {
affectedContracts = evm.AffectedContracts()
- getPrivacyMetadataFunc = evm.StateDB.GetStatePrivacyMetadata
+ getPrivacyMetadataFunc = evm.StateDB.GetPrivacyMetadata
}
stubPrivateTx = newTypicalPrivateTx(cfg)
stubPrivateTx.SetTxPrivacyMetadata(&types.PrivacyMetadata{
@@ -180,7 +173,7 @@ func TestPrivacyEnhancements_CreateC1_StandardPrivate(t *testing.T) {
var getPrivacyMetadataFunc func(common.Address) (*state.PrivacyMetadata, error)
cfg.onAfterEVM = func(evm *vm.EVM) {
affectedContracts = evm.AffectedContracts()
- getPrivacyMetadataFunc = evm.StateDB.GetStatePrivacyMetadata
+ getPrivacyMetadataFunc = evm.StateDB.GetPrivacyMetadata
}
stubPrivateTx = newTypicalPrivateTx(cfg)
stubPrivateTx.SetTxPrivacyMetadata(&types.PrivacyMetadata{
diff --git a/eth/api_backend.go b/eth/api_backend.go
index ca10c69171..1e36a70ef4 100644
--- a/eth/api_backend.go
+++ b/eth/api_backend.go
@@ -36,9 +36,12 @@ import (
"github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/params"
pcore "github.com/ethereum/go-ethereum/permission/core"
"github.com/ethereum/go-ethereum/rpc"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
)
// EthAPIBackend implements ethapi.Backend for full nodes
@@ -223,7 +226,10 @@ func (b *EthAPIBackend) GetEVM(ctx context.Context, msg core.Message, state vm.M
from.SetBalance(math.MaxBig256)
vmError := func() error { return nil }
- context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil)
+ evmCtx := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil)
+ if _, ok := b.SupportsMultitenancy(ctx); ok {
+ evmCtx = core.NewMultitenancyAwareEVMContext(ctx, evmCtx)
+ }
// Set the private state to public state if contract address is not present in the private state
to := common.Address{}
@@ -236,7 +242,7 @@ func (b *EthAPIBackend) GetEVM(ctx context.Context, msg core.Message, state vm.M
privateState = statedb.state
}
- return vm.NewEVM(context, statedb.state, privateState, b.eth.blockchain.Config(), *b.eth.blockchain.GetVMConfig()), vmError, nil
+ return vm.NewEVM(evmCtx, statedb.state, privateState, b.eth.blockchain.Config(), *b.eth.blockchain.GetVMConfig()), vmError, nil
}
func (b *EthAPIBackend) SubscribeRemovedLogsEvent(ch chan<- core.RemovedLogsEvent) event.Subscription {
@@ -356,6 +362,30 @@ func (b *EthAPIBackend) ServiceFilter(ctx context.Context, session *bloombits.Ma
}
}
+// The validation of pre-requisite for multitenancy is done when EthService
+// is being created. So it's safe to use the config value here.
+func (b *EthAPIBackend) SupportsMultitenancy(rpcCtx context.Context) (*proto.PreAuthenticatedAuthenticationToken, bool) {
+ authToken, isPreauthenticated := rpcCtx.Value(rpc.CtxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
+ if isPreauthenticated && b.eth.config.EnableMultitenancy {
+ return authToken, true
+ }
+ return nil, false
+}
+
+func (b *EthAPIBackend) AccountExtraDataStateGetterByNumber(ctx context.Context, number rpc.BlockNumber) (vm.AccountExtraDataStateGetter, error) {
+ s, _, err := b.StateAndHeaderByNumber(ctx, number)
+ return s, err
+}
+
+func (b *EthAPIBackend) IsAuthorized(ctx context.Context, authToken *proto.PreAuthenticatedAuthenticationToken, attributes ...*multitenancy.ContractSecurityAttribute) (bool, error) {
+ auth, err := b.eth.contractAuthzProvider.IsAuthorized(ctx, authToken, attributes...)
+ if err != nil {
+ log.Error("failed to perform authorization check", "err", err, "granted", string(authToken.RawToken), "ask", attributes)
+ return false, err
+ }
+ return auth, nil
+}
+
// used by Quorum
type EthAPIState struct {
state, privateState *state.StateDB
@@ -429,11 +459,18 @@ func (s EthAPIState) GetNonce(addr common.Address) uint64 {
return s.state.GetNonce(addr)
}
-func (s EthAPIState) GetStatePrivacyMetadata(addr common.Address) (*state.PrivacyMetadata, error) {
+func (s EthAPIState) GetPrivacyMetadata(addr common.Address) (*state.PrivacyMetadata, error) {
+ if s.privateState.Exist(addr) {
+ return s.privateState.GetPrivacyMetadata(addr)
+ }
+ return nil, fmt.Errorf("%x: %w", addr, common.ErrNotPrivateContract)
+}
+
+func (s EthAPIState) GetManagedParties(addr common.Address) ([]string, error) {
if s.privateState.Exist(addr) {
- return s.privateState.GetStatePrivacyMetadata(addr)
+ return s.privateState.GetManagedParties(addr)
}
- return nil, fmt.Errorf("The provided address is not a private contract: %x", addr)
+ return nil, fmt.Errorf("%x: %w", addr, common.ErrNotPrivateContract)
}
func (s EthAPIState) GetRLPEncodedStateObject(addr common.Address) ([]byte, error) {
diff --git a/eth/backend.go b/eth/backend.go
index 01aa833b1c..ad468a3aa3 100644
--- a/eth/backend.go
+++ b/eth/backend.go
@@ -48,11 +48,11 @@ import (
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/miner"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/params"
- "github.com/ethereum/go-ethereum/plugin"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc"
)
@@ -91,8 +91,6 @@ type Ethereum struct {
APIBackend *EthAPIBackend
- securityPlugin *plugin.SecurityPluginTemplate
-
miner *miner.Miner
gasPrice *big.Int
etherbase common.Address
@@ -101,6 +99,10 @@ type Ethereum struct {
netRPCService *ethapi.PublicNetAPI
lock sync.RWMutex // Protects the variadic fields (e.g. gas price and etherbase)
+
+ // Quorum - Multitenancy
+ // contractAuthzProvider is set after node starts instead in New()
+ contractAuthzProvider multitenancy.ContractAuthorizationProvider
}
func (s *Ethereum) AddLesServer(ls LesServer) {
@@ -116,6 +118,13 @@ func (s *Ethereum) SetContractBackend(backend bind.ContractBackend) {
}
}
+// Quorum
+//
+// Set the decision manager for multitenancy support
+func (s *Ethereum) SetContractAuthorizationProvider(dm multitenancy.ContractAuthorizationProvider) {
+ s.contractAuthzProvider = dm
+}
+
// New creates a new Ethereum object (including the
// initialisation of the common Ethereum object)
func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
@@ -214,7 +223,11 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
TrieTimeLimit: config.TrieTimeout,
}
)
- eth.blockchain, err = core.NewBlockChain(chainDb, cacheConfig, chainConfig, eth.engine, vmConfig, eth.shouldPreserve)
+ newBlockChainFunc := core.NewBlockChain
+ if config.EnableMultitenancy {
+ newBlockChainFunc = core.NewMultitenantBlockChain
+ }
+ eth.blockchain, err = newBlockChainFunc(chainDb, cacheConfig, chainConfig, eth.engine, vmConfig, eth.shouldPreserve)
if err != nil {
return nil, err
}
@@ -251,15 +264,6 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
}
eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams)
- // Set Security plugin in eth
- var pluginManager *plugin.PluginManager
- if err := ctx.Service(&pluginManager); err == nil {
- sp := new(plugin.SecurityPluginTemplate)
- if err := pluginManager.GetPluginTemplate(plugin.SecurityPluginInterfaceName, sp); err == nil {
- eth.securityPlugin = sp
- }
- }
-
return eth, nil
}
diff --git a/eth/config.go b/eth/config.go
index 68bf57ec95..ecdcae9058 100644
--- a/eth/config.go
+++ b/eth/config.go
@@ -168,4 +168,6 @@ type Config struct {
// timeout value for call
EVMCallTimeOut time.Duration
+
+ EnableMultitenancy bool
}
diff --git a/eth/filters/api.go b/eth/filters/api.go
index 5ed80a8875..75b8174da3 100644
--- a/eth/filters/api.go
+++ b/eth/filters/api.go
@@ -31,6 +31,7 @@ import (
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/rpc"
)
@@ -346,7 +347,11 @@ func (api *PublicFilterAPI) GetLogs(ctx context.Context, crit FilterCriteria) ([
if err != nil {
return nil, err
}
- return returnLogs(logs), err
+ authLogs, err := api.filterUnAuthorized(ctx, logs)
+ if err != nil {
+ return nil, err
+ }
+ return returnLogs(authLogs), err
}
// UninstallFilter removes the filter with the given filter id.
@@ -401,7 +406,11 @@ func (api *PublicFilterAPI) GetFilterLogs(ctx context.Context, id rpc.ID) ([]*ty
if err != nil {
return nil, err
}
- return returnLogs(logs), nil
+ authLogs, err := api.filterUnAuthorized(ctx, logs)
+ if err != nil {
+ return nil, err
+ }
+ return returnLogs(authLogs), nil
}
// GetFilterChanges returns the logs for the filter with the given id since
@@ -411,7 +420,7 @@ func (api *PublicFilterAPI) GetFilterLogs(ctx context.Context, id rpc.ID) ([]*ty
// (pending)Log filters return []Log.
//
// https://github.com/ethereum/wiki/wiki/JSON-RPC#eth_getfilterchanges
-func (api *PublicFilterAPI) GetFilterChanges(id rpc.ID) (interface{}, error) {
+func (api *PublicFilterAPI) GetFilterChanges(ctx context.Context, id rpc.ID) (interface{}, error) {
api.filtersMu.Lock()
defer api.filtersMu.Unlock()
@@ -431,7 +440,11 @@ func (api *PublicFilterAPI) GetFilterChanges(id rpc.ID) (interface{}, error) {
case LogsSubscription:
logs := f.logs
f.logs = nil
- return returnLogs(logs), nil
+ authLogs, err := api.filterUnAuthorized(ctx, logs)
+ if err != nil {
+ return nil, err
+ }
+ return returnLogs(authLogs), nil
}
}
@@ -574,3 +587,32 @@ func decodeTopic(s string) (common.Hash, error) {
}
return common.BytesToHash(b), err
}
+
+// Quorum
+// Perform authorization check for each logs based on the contract addresses
+func (api *PublicFilterAPI) filterUnAuthorized(ctx context.Context, logs []*types.Log) ([]*types.Log, error) {
+ if len(logs) == 0 {
+ return logs, nil
+ }
+ if authToken, ok := api.backend.SupportsMultitenancy(ctx); ok {
+ filteredLogs := make([]*types.Log, 0)
+ for _, l := range logs {
+ extraDataReader, err := api.backend.AccountExtraDataStateGetterByNumber(ctx, rpc.BlockNumber(l.BlockNumber))
+ if err != nil {
+ return nil, fmt.Errorf("no account extra data reader at block %v: %w", l.BlockNumber, err)
+ }
+ attrBuilder := multitenancy.NewContractSecurityAttributeBuilder().Read().Private()
+ managedParties, err := extraDataReader.GetManagedParties(l.Address)
+ if errors.Is(err, common.ErrNotPrivateContract) {
+ attrBuilder.Public()
+ } else if err != nil {
+ return nil, fmt.Errorf("contract %s not found in the index due to %s", l.Address.Hex(), err.Error())
+ }
+ if ok, _ := api.backend.IsAuthorized(ctx, authToken, attrBuilder.Parties(managedParties).Build()); ok {
+ filteredLogs = append(filteredLogs, l)
+ }
+ }
+ return filteredLogs, nil
+ }
+ return logs, nil
+}
diff --git a/eth/filters/filter.go b/eth/filters/filter.go
index f77a57f564..8fdcf754ed 100644
--- a/eth/filters/filter.go
+++ b/eth/filters/filter.go
@@ -26,12 +26,16 @@ import (
"github.com/ethereum/go-ethereum/core/bloombits"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/rpc"
)
type Backend interface {
+ multitenancy.AuthorizationProvider
+
ChainDb() ethdb.Database
EventMux() *event.TypeMux
HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Header, error)
@@ -46,6 +50,9 @@ type Backend interface {
BloomStatus() (uint64, uint64)
ServiceFilter(ctx context.Context, session *bloombits.MatcherSession)
+
+ // AccountExtraDataStateGetterByNumber returns state getter at a given block height
+ AccountExtraDataStateGetterByNumber(ctx context.Context, number rpc.BlockNumber) (vm.AccountExtraDataStateGetter, error)
}
// Filter can be used to retrieve and filter logs.
diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go
index 93cb43123f..44bc75ed9e 100644
--- a/eth/filters/filter_system_test.go
+++ b/eth/filters/filter_system_test.go
@@ -25,17 +25,20 @@ import (
"testing"
"time"
- ethereum "github.com/ethereum/go-ethereum"
+ "github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/bloombits"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
)
type testBackend struct {
@@ -151,6 +154,18 @@ func (b *testBackend) ServiceFilter(ctx context.Context, session *bloombits.Matc
}()
}
+func (b *testBackend) SupportsMultitenancy(rpcCtx context.Context) (*proto.PreAuthenticatedAuthenticationToken, bool) {
+ return nil, false
+}
+
+func (b *testBackend) AccountExtraDataStateGetterByNumber(context.Context, rpc.BlockNumber) (vm.AccountExtraDataStateGetter, error) {
+ return nil, nil
+}
+
+func (b *testBackend) IsAuthorized(ctx context.Context, authToken *proto.PreAuthenticatedAuthenticationToken, attributes ...*multitenancy.ContractSecurityAttribute) (bool, error) {
+ return true, nil
+}
+
// TestBlockSubscription tests if a block subscription returns block hashes for posted chain events.
// It creates multiple subscriptions:
// - one at the start and should receive all posted chain events and a second (blockHashes)
@@ -244,7 +259,7 @@ func TestPendingTxFilter(t *testing.T) {
timeout := time.Now().Add(1 * time.Second)
for {
- results, err := api.GetFilterChanges(fid0)
+ results, err := api.GetFilterChanges(context.Background(), fid0)
if err != nil {
t.Fatalf("Unable to retrieve logs: %v", err)
}
@@ -464,7 +479,7 @@ func TestLogFilter(t *testing.T) {
var fetched []*types.Log
timeout := time.Now().Add(1 * time.Second)
for { // fetch all expected logs
- results, err := api.GetFilterChanges(tt.id)
+ results, err := api.GetFilterChanges(context.Background(), tt.id)
if err != nil {
t.Fatalf("Unable to fetch logs: %v", err)
}
diff --git a/extension/api.go b/extension/api.go
index ee7849f182..e197f85639 100644
--- a/extension/api.go
+++ b/extension/api.go
@@ -1,6 +1,7 @@
package extension
import (
+ "context"
"encoding/base64"
"errors"
"fmt"
@@ -8,7 +9,9 @@ import (
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/internal/ethapi"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/permission/core"
+ "github.com/ethereum/go-ethereum/rpc"
)
var (
@@ -102,9 +105,43 @@ func (api *PrivateExtensionAPI) checkIfPrivateStateExists(toExtend common.Addres
return false
}
+func (api *PrivateExtensionAPI) doMultiTenantChecks(ctx context.Context, address common.Address, txa ethapi.SendTxArgs) error {
+ apiHelper := api.privacyService.apiBackendHelper
+ if authToken, ok := apiHelper.SupportsMultitenancy(ctx); ok {
+ if len(txa.PrivateFrom) == 0 {
+ return errors.New("You must specify 'privateFrom' when running in a multitenant node")
+ }
+ // check whether the user has access to txa.PrivateFrom and the txa.From eth account
+ attributes := multitenancy.FullAccessContractSecurityAttributes(txa.From, txa.PrivateFrom)
+ chainAccessor := api.privacyService.stateFetcher.chainAccessor
+ currentBlock := chainAccessor.CurrentBlock().Number().Int64()
+ extraDataReader, err := apiHelper.AccountExtraDataStateGetterByNumber(ctx, rpc.BlockNumber(currentBlock))
+ if err != nil {
+ return fmt.Errorf("no account extra data reader at block %v: %w", currentBlock, err)
+ }
+
+ managedParties, err := extraDataReader.GetManagedParties(address)
+ if err != nil {
+ return err
+ }
+ attributes = append(attributes,
+ multitenancy.NewContractSecurityAttributeBuilder().FromEOA(txa.From).Private().Write().Parties(managedParties).Build(),
+ multitenancy.NewContractSecurityAttributeBuilder().FromEOA(txa.From).Private().Read().Parties(managedParties).Build())
+
+ if authorized, _ := apiHelper.IsAuthorized(ctx, authToken, attributes...); !authorized {
+ return multitenancy.ErrNotAuthorized
+ }
+ }
+ return nil
+}
+
// ApproveContractExtension submits the vote to the specified extension management contract. The vote indicates whether to extend
// a given contract to a new participant or not
-func (api *PrivateExtensionAPI) ApproveExtension(addressToVoteOn common.Address, vote bool, txa ethapi.SendTxArgs) (string, error) {
+func (api *PrivateExtensionAPI) ApproveExtension(ctx context.Context, addressToVoteOn common.Address, vote bool, txa ethapi.SendTxArgs) (string, error) {
+ err := api.doMultiTenantChecks(ctx, addressToVoteOn, txa)
+ if err != nil {
+ return "", err
+ }
// check if the extension has been completed. if yes
// no acceptance required
status, err := api.checkIfExtensionComplete(addressToVoteOn, txa.From)
@@ -175,8 +212,7 @@ func (api *PrivateExtensionAPI) ApproveExtension(addressToVoteOn common.Address,
// - the contract address we want to extend
// - the new PTM public key
// - the Ethereum addresses of who can vote to extend the contract
-func (api *PrivateExtensionAPI) ExtendContract(toExtend common.Address, newRecipientPtmPublicKey string, recipientAddr common.Address, txa ethapi.SendTxArgs) (string, error) {
-
+func (api *PrivateExtensionAPI) ExtendContract(ctx context.Context, toExtend common.Address, newRecipientPtmPublicKey string, recipientAddr common.Address, txa ethapi.SendTxArgs) (string, error) {
// check if the contract to be extended is already under extension
// if yes throw an error
if api.checkIfContractUnderExtension(toExtend) {
@@ -193,6 +229,11 @@ func (api *PrivateExtensionAPI) ExtendContract(toExtend common.Address, newRecip
return "", errors.New("extending a non-existent private contract!!! not allowed")
}
+ err := api.doMultiTenantChecks(ctx, toExtend, txa)
+ if err != nil {
+ return "", err
+ }
+
// check if recipient address is 0x0
if recipientAddr == (common.Address{0}) {
return "", errors.New("invalid recipient address")
@@ -258,7 +299,12 @@ func (api *PrivateExtensionAPI) ExtendContract(toExtend common.Address, newRecip
// CancelExtension allows the creator to cancel the given extension contract, ensuring
// that no more calls for votes or accepting can be made
-func (api *PrivateExtensionAPI) CancelExtension(extensionContract common.Address, txa ethapi.SendTxArgs) (string, error) {
+func (api *PrivateExtensionAPI) CancelExtension(ctx context.Context, extensionContract common.Address, txa ethapi.SendTxArgs) (string, error) {
+ err := api.doMultiTenantChecks(ctx, extensionContract, txa)
+ if err != nil {
+ return "", err
+ }
+
status, err := api.checkIfExtensionComplete(extensionContract, txa.From)
if err != nil {
return "", err
@@ -309,8 +355,23 @@ func (api *PrivateExtensionAPI) CancelExtension(extensionContract common.Address
}
// Returns the extension status from management contract
-func (api *PrivateExtensionAPI) GetExtensionStatus(extensionContract common.Address) (string, error) {
-
+func (api *PrivateExtensionAPI) GetExtensionStatus(ctx context.Context, extensionContract common.Address) (string, error) {
+ apiHelper := api.privacyService.apiBackendHelper
+ if authToken, ok := apiHelper.SupportsMultitenancy(ctx); ok {
+ currentBlock := apiHelper.CurrentBlock().Number().Int64()
+ extraDataReader, err := apiHelper.AccountExtraDataStateGetterByNumber(ctx, rpc.BlockNumber(currentBlock))
+ if err != nil {
+ return "", fmt.Errorf("no account extra data reader at block %v: %w", currentBlock, err)
+ }
+ managedParties, err := extraDataReader.GetManagedParties(extensionContract)
+ if err != nil {
+ return "", err
+ }
+ if authorized, _ := apiHelper.IsAuthorized(ctx, authToken,
+ multitenancy.NewContractSecurityAttributeBuilder().Private().Read().Parties(managedParties).Build()); !authorized {
+ return "", multitenancy.ErrNotAuthorized
+ }
+ }
status, err := api.checkIfExtensionComplete(extensionContract, common.Address{})
if err != nil {
return "", err
diff --git a/extension/backend.go b/extension/backend.go
index f160c32364..c10ac96b1e 100644
--- a/extension/backend.go
+++ b/extension/backend.go
@@ -1,6 +1,7 @@
package extension
import (
+ "context"
"encoding/hex"
"errors"
"fmt"
@@ -31,6 +32,7 @@ type PrivacyService struct {
managementContractFacade ManagementContractFacade
extClient Client
stopFeed event.Feed
+ apiBackendHelper APIBackendHelper
mu sync.Mutex
currentContracts map[common.Address]*ExtensionContract
@@ -56,13 +58,14 @@ func (service *PrivacyService) subscribeStopEvent() (chan stopEvent, event.Subsc
return c, s
}
-func New(ptm private.PrivateTransactionManager, manager *accounts.Manager, handler DataHandler, fetcher *StateFetcher) (*PrivacyService, error) {
+func New(ptm private.PrivateTransactionManager, manager *accounts.Manager, handler DataHandler, fetcher *StateFetcher, apiBackendHelper APIBackendHelper) (*PrivacyService, error) {
service := &PrivacyService{
currentContracts: make(map[common.Address]*ExtensionContract),
ptm: ptm,
dataHandler: handler,
stateFetcher: fetcher,
accountManager: manager,
+ apiBackendHelper: apiBackendHelper,
}
var err error
@@ -157,14 +160,21 @@ func (service *PrivacyService) watchForNewContracts() error {
log.Error("Extension: unable to fetch all parties for extension management contract", "error", err)
continue
}
+
+ privateFrom, _, _, _, err := service.ptm.Receive(data)
+ if err != nil || len(privateFrom) == 0 {
+ log.Error("Extension: unable to fetch privateFrom(sender) for extension management contract", "error", err)
+ continue
+ }
+
//Find the extension contract in order to interact with it
caller, _ := service.managementContractFacade.Caller(newContractExtension.ManagementContractAddress)
contractCreator, _ := caller.Creator(nil)
- txArgs := ethapi.SendTxArgs{From: contractCreator, PrivateTxArgs: ethapi.PrivateTxArgs{PrivateFor: fetchedParties}}
+ txArgs := ethapi.SendTxArgs{From: contractCreator, PrivateTxArgs: ethapi.PrivateTxArgs{PrivateFor: fetchedParties, PrivateFrom: privateFrom}}
extensionAPI := NewPrivateExtensionAPI(service)
- _, err = extensionAPI.ApproveExtension(newContractExtension.ManagementContractAddress, true, txArgs)
+ _, err = extensionAPI.ApproveExtension(context.Background(), newContractExtension.ManagementContractAddress, true, txArgs)
if err != nil {
log.Error("Extension: initiator vote on management contract failed", "error", err)
@@ -266,7 +276,14 @@ func (service *PrivacyService) watchForCompletionEvents() error {
}
log.Debug("Extension: able to fetch all parties", "parties", fetchedParties)
- txArgs, err := service.GenerateTransactOptions(ethapi.SendTxArgs{From: contractCreator, PrivateTxArgs: ethapi.PrivateTxArgs{PrivateFor: fetchedParties}})
+ privateFrom, _, _, _, err := service.ptm.Receive(payload)
+ if err != nil || len(privateFrom) == 0 {
+ log.Error("Extension: unable to fetch privateFrom(sender) for extension management contract", "error", err)
+ return
+ }
+ log.Debug("Extension: able to fetch privateFrom(sender)", "privateFrom", privateFrom)
+
+ txArgs, err := service.GenerateTransactOptions(ethapi.SendTxArgs{From: contractCreator, PrivateTxArgs: ethapi.PrivateTxArgs{PrivateFor: fetchedParties, PrivateFrom: privateFrom}})
if err != nil {
log.Error("service.accountManager.GenerateTransactOptions", "error", err, "contractCreator", contractCreator.Hex(), "privateFor", fetchedParties)
return
@@ -303,7 +320,7 @@ func (service *PrivacyService) watchForCompletionEvents() error {
extraMetaData.ACMerkleRoot = storageRoot
}
}
- hashOfStateData, err := service.ptm.Send(entireStateData, "", fetchedParties, &extraMetaData)
+ _, _, hashOfStateData, err := service.ptm.Send(entireStateData, privateFrom, fetchedParties, &extraMetaData)
if err != nil {
log.Error("[ptm] service.ptm.Send", "stateDataInHex", hex.EncodeToString(entireStateData[:]), "recipients", fetchedParties, "error", err)
diff --git a/extension/extension_utilities.go b/extension/extension_utilities.go
index e48cf9d74a..ff0a89a37c 100644
--- a/extension/extension_utilities.go
+++ b/extension/extension_utilities.go
@@ -21,7 +21,7 @@ func generateUuid(contractAddress common.Address, privateFrom string, privateFor
return "", err
}
- hash, err := ptm.Send(payloadHash, privateFrom, privateFor, &engine.ExtraMetadata{})
+ _, _, hash, err := ptm.Send(payloadHash, privateFrom, privateFor, &engine.ExtraMetadata{})
if err != nil {
return "", err
}
diff --git a/extension/privacyExtension/state_set_utilities.go b/extension/privacyExtension/state_set_utilities.go
index 3f53fc3f30..1d6143735b 100644
--- a/extension/privacyExtension/state_set_utilities.go
+++ b/extension/privacyExtension/state_set_utilities.go
@@ -8,9 +8,11 @@ import (
"github.com/ethereum/go-ethereum/core/types"
extension "github.com/ethereum/go-ethereum/extension/extensionContracts"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/private"
+ "github.com/ethereum/go-ethereum/private/engine"
)
-func setState(privateState *state.StateDB, accounts map[string]extension.AccountWithMetadata, privacyMetaData *state.PrivacyMetadata) bool {
+func setState(privateState *state.StateDB, accounts map[string]extension.AccountWithMetadata, privacyMetaData *state.PrivacyMetadata, managedParties []string) bool {
log.Debug("Extension: set private state explicitly from state dump")
for key, value := range accounts {
stateDump := value.State
@@ -28,14 +30,19 @@ func setState(privateState *state.StateDB, accounts map[string]extension.Account
for keyStore, valueStore := range stateDump.Storage {
privateState.SetState(contractAddress, keyStore, common.HexToHash(valueStore))
}
- privateState.SetStatePrivacyMetadata(contractAddress, privacyMetaData)
+ if privacyMetaData.PrivacyFlag != engine.PrivacyFlagStandardPrivate {
+ privateState.SetPrivacyMetadata(contractAddress, privacyMetaData)
+ }
+ if managedParties != nil {
+ privateState.SetManagedParties(contractAddress, managedParties)
+ }
}
return true
}
// updates the privacy metadata
func setPrivacyMetadata(privateState *state.StateDB, address common.Address, hash string) {
- privacyMetaData, err := privateState.GetStatePrivacyMetadata(address)
+ privacyMetaData, err := privateState.GetPrivacyMetadata(address)
if err != nil || privacyMetaData.PrivacyFlag.IsStandardPrivate() {
return
}
@@ -46,7 +53,24 @@ func setPrivacyMetadata(privateState *state.StateDB, address common.Address, has
return
}
pm := state.NewStatePrivacyMetadata(ptmHash, privacyMetaData.PrivacyFlag)
- privateState.SetStatePrivacyMetadata(address, pm)
+ privateState.SetPrivacyMetadata(address, pm)
+}
+
+func setManagedParties(ptm private.PrivateTransactionManager, privateState *state.StateDB, address common.Address, hash string) {
+ existingManagedParties, err := privateState.GetManagedParties(address)
+ if err != nil {
+ return
+ }
+
+ ptmHash, err := common.Base64ToEncryptedPayloadHash(hash)
+ if err != nil {
+ log.Error("setting privacy metadata failed", "err", err)
+ return
+ }
+
+ _, managedParties, _, _, err := ptm.Receive(ptmHash)
+ newManagedParties := common.AppendSkipDuplicates(existingManagedParties, managedParties...)
+ privateState.SetManagedParties(address, newManagedParties)
}
func logContainsExtensionTopic(receivedLog *types.Log) bool {
diff --git a/extension/privacyExtension/state_set_utilities_test.go b/extension/privacyExtension/state_set_utilities_test.go
index eea7965f9f..94af7852d9 100644
--- a/extension/privacyExtension/state_set_utilities_test.go
+++ b/extension/privacyExtension/state_set_utilities_test.go
@@ -60,7 +60,7 @@ func createStateDb(t *testing.T) *state.StateDB {
t.Errorf("error when unmarshalling static data: %s", err.Error())
}
- success := setState(statedb, accounts, &state.PrivacyMetadata{})
+ success := setState(statedb, accounts, &state.PrivacyMetadata{}, nil)
if !success {
t.Errorf("unexpected error when setting state")
}
@@ -109,7 +109,7 @@ func TestStateSetWithListedAccountsFailsOnInvalidBalance(t *testing.T) {
t.Errorf("error when unmarshalling static data: %s", err.Error())
}
- success := setState(statedb, accounts, &state.PrivacyMetadata{})
+ success := setState(statedb, accounts, &state.PrivacyMetadata{}, nil)
if success {
t.Errorf("error expected when setting state")
}
@@ -124,16 +124,14 @@ func Test_setPrivacyMetadata(t *testing.T) {
hash := common.BytesToEncryptedPayloadHash(arbitraryBytes1)
setPrivacyMetadata(statedb, address, base64.StdEncoding.EncodeToString(arbitraryBytes1))
- privacyMetaData, err := statedb.GetStatePrivacyMetadata(address)
- if err != nil {
- t.Errorf("expected error to be nil, got err %s", err)
- }
+ // we don't save PrivacyMetadata if it's standardprivate
+ privacyMetaData, err := statedb.GetPrivacyMetadata(address)
+ assert.Error(t, err, common.ErrNoAccountExtraData)
- assert.NotEqual(t, privacyMetaData.CreationTxHash, hash)
- privacyMetaData = &state.PrivacyMetadata{hash, engine.PrivacyFlagPartyProtection}
- statedb.SetStatePrivacyMetadata(address, privacyMetaData)
+ privacyMetaData = &state.PrivacyMetadata{CreationTxHash: hash, PrivacyFlag: engine.PrivacyFlagPartyProtection}
+ statedb.SetPrivacyMetadata(address, privacyMetaData)
- privacyMetaData, err = statedb.GetStatePrivacyMetadata(address)
+ privacyMetaData, err = statedb.GetPrivacyMetadata(address)
if err != nil {
t.Errorf("expected error to be nil, got err %s", err)
}
@@ -144,7 +142,7 @@ func Test_setPrivacyMetadata(t *testing.T) {
newHash := common.BytesToEncryptedPayloadHash(arbitraryBytes2)
setPrivacyMetadata(statedb, address, base64.StdEncoding.EncodeToString(arbitraryBytes2))
- privacyMetaData, err = statedb.GetStatePrivacyMetadata(address)
+ privacyMetaData, err = statedb.GetPrivacyMetadata(address)
if err != nil {
t.Errorf("expected error to be nil, got err %s", err)
}
diff --git a/extension/privacyExtension/state_setter.go b/extension/privacyExtension/state_setter.go
index 4c146a58ae..58e5e12286 100644
--- a/extension/privacyExtension/state_setter.go
+++ b/extension/privacyExtension/state_setter.go
@@ -15,15 +15,20 @@ import (
var DefaultExtensionHandler = NewExtensionHandler(private.P)
type ExtensionHandler struct {
- ptm private.PrivateTransactionManager
+ ptm private.PrivateTransactionManager
+ isMultitenant bool
}
func NewExtensionHandler(transactionManager private.PrivateTransactionManager) *ExtensionHandler {
return &ExtensionHandler{ptm: transactionManager}
}
+func (handler *ExtensionHandler) SupportMultitenancy(b bool) {
+ handler.isMultitenant = b
+}
+
func (handler *ExtensionHandler) CheckExtensionAndSetPrivateState(txLogs []*types.Log, privateState *state.StateDB) {
- privacyMetaDataUpdated := false
+ extraMetaDataUpdated := false
for _, txLog := range txLogs {
if logContainsExtensionTopic(txLog) {
//this is a direct state share
@@ -35,26 +40,32 @@ func (handler *ExtensionHandler) CheckExtensionAndSetPrivateState(txLogs []*type
// check if state exists for the extension address. If yes then skip
// processing
if privateState.GetCode(address) != nil {
- if privacyMetaDataUpdated {
+ if extraMetaDataUpdated {
continue
}
// check the privacy flag of the contract. if its other than
// 0 then need to update the privacy metadata for the contract
//TODO: validate the old and new parties to ensure that all old parties are there
setPrivacyMetadata(privateState, address, hash)
- privacyMetaDataUpdated = true
+ if handler.isMultitenant {
+ setManagedParties(handler.ptm, privateState, address, hash)
+ }
+ extraMetaDataUpdated = true
} else {
- accounts, privacyMetaData, found := handler.FetchStateData(txLog.Address, hash, uuid)
+ managedParties, accounts, privacyMetaData, found := handler.FetchStateData(txLog.Address, hash, uuid)
if !found {
continue
}
+ if !handler.isMultitenant {
+ managedParties = nil
+ }
if !validateAccountsExist([]common.Address{address}, accounts) {
log.Error("Account mismatch", "expected", address, "found", accounts)
continue
}
snapshotId := privateState.Snapshot()
- if success := setState(privateState, accounts, privacyMetaData); !success {
+ if success := setState(privateState, accounts, privacyMetaData, managedParties); !success {
privateState.RevertToSnapshot(snapshotId)
}
}
@@ -63,44 +74,44 @@ func (handler *ExtensionHandler) CheckExtensionAndSetPrivateState(txLogs []*type
}
}
-func (handler *ExtensionHandler) FetchStateData(address common.Address, hash string, uuid string) (map[string]extension.AccountWithMetadata, *state.PrivacyMetadata, bool) {
+func (handler *ExtensionHandler) FetchStateData(address common.Address, hash string, uuid string) ([]string, map[string]extension.AccountWithMetadata, *state.PrivacyMetadata, bool) {
if uuidIsSentByUs := handler.UuidIsOwn(address, uuid); !uuidIsSentByUs {
- return nil, nil, false
+ return nil, nil, nil, false
}
- stateData, privacyMetaData, ok := handler.FetchDataFromPTM(hash)
+ managedParties, stateData, privacyMetaData, ok := handler.FetchDataFromPTM(hash)
if !ok {
//there is nothing to do here, the state wasn't shared with us
log.Error("Extension: No state shared with us")
- return nil, nil, false
+ return nil, nil, nil, false
}
var accounts map[string]extension.AccountWithMetadata
if err := json.Unmarshal(stateData, &accounts); err != nil {
log.Error("Extension: Could not unmarshal data")
- return nil, nil, false
+ return nil, nil, nil, false
}
- return accounts, privacyMetaData, true
+ return managedParties, accounts, privacyMetaData, true
}
// Checks
-func (handler *ExtensionHandler) FetchDataFromPTM(hash string) ([]byte, *state.PrivacyMetadata, bool) {
+func (handler *ExtensionHandler) FetchDataFromPTM(hash string) ([]string, []byte, *state.PrivacyMetadata, bool) {
ptmHash, _ := common.Base64ToEncryptedPayloadHash(hash)
- stateData, extraMetaData, err := handler.ptm.Receive(ptmHash)
+ _, managedParties, stateData, extraMetaData, err := handler.ptm.Receive(ptmHash)
if stateData == nil {
log.Error("No state data found in PTM", "ptm hash", hash)
- return nil, nil, false
+ return nil, nil, nil, false
}
if err != nil {
log.Error("Error receiving state data from PTM", "ptm hash", hash, "err", err.Error())
- return nil, nil, false
+ return nil, nil, nil, false
}
privacyMetaData := state.NewStatePrivacyMetadata(ptmHash, extraMetaData.PrivacyFlag)
- return stateData, privacyMetaData, true
+ return managedParties, stateData, privacyMetaData, true
}
func (handler *ExtensionHandler) UuidIsOwn(address common.Address, uuid string) bool {
@@ -117,7 +128,7 @@ func (handler *ExtensionHandler) UuidIsOwn(address common.Address, uuid string)
return false
}
- encryptedPayload, _, err := handler.ptm.Receive(encryptedTxHash)
+ _, _, encryptedPayload, _, err := handler.ptm.Receive(encryptedTxHash)
if err != nil {
log.Debug("Extension: payload not found", "err", err)
return false
diff --git a/extension/services_factory.go b/extension/services_factory.go
index d0fc1e0a95..8514da25af 100644
--- a/extension/services_factory.go
+++ b/extension/services_factory.go
@@ -1,6 +1,8 @@
package extension
import (
+ "context"
+
"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/extension/privacyExtension"
@@ -30,12 +32,15 @@ func NewServicesFactory(node *node.Node, ptm private.PrivateTransactionManager,
factory.dataHandler = NewJsonFileDataHandler(node.InstanceDir())
factory.stateFetcher = NewStateFetcher(ethService.BlockChain())
- backendService, err := New(ptm, factory.AccountManager(), factory.DataHandler(), factory.StateFetcher())
+ backendService, err := New(ptm, factory.AccountManager(), factory.DataHandler(), factory.StateFetcher(), ethService.APIBackend)
if err != nil {
return nil, err
}
factory.backendService = backendService
+ _, isMultitenant := ethService.BlockChain().SupportsMultitenancy(context.Background())
+ privacyExtension.DefaultExtensionHandler.SupportMultitenancy(isMultitenant)
+
ethService.BlockChain().PopulateSetPrivateState(privacyExtension.DefaultExtensionHandler.CheckExtensionAndSetPrivateState)
go backendService.initialise(node)
diff --git a/extension/state_fetcher.go b/extension/state_fetcher.go
index b1188a933f..37aea7ed3c 100644
--- a/extension/state_fetcher.go
+++ b/extension/state_fetcher.go
@@ -1,17 +1,22 @@
package extension
import (
+ "context"
"encoding/json"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/extension/extensionContracts"
+ "github.com/ethereum/go-ethereum/multitenancy"
+ "github.com/ethereum/go-ethereum/rpc"
)
// ChainAccessor provides methods to fetch state and blocks from the local blockchain
type ChainAccessor interface {
+ multitenancy.ContextAware
// GetBlockByHash retrieves a block from the local chain.
GetBlockByHash(common.Hash) *types.Block
StateAt(root common.Hash) (*state.StateDB, *state.StateDB, error)
@@ -19,6 +24,13 @@ type ChainAccessor interface {
CurrentBlock() *types.Block
}
+// Only extract required methods from ethService.APIBackend
+type APIBackendHelper interface {
+ multitenancy.AuthorizationProvider
+ AccountExtraDataStateGetterByNumber(ctx context.Context, number rpc.BlockNumber) (vm.AccountExtraDataStateGetter, error)
+ CurrentBlock() *types.Block
+}
+
// StateFetcher manages retrieving state from the database and returning it in
// a usable form by the extension API.
type StateFetcher struct {
@@ -84,7 +96,7 @@ func (fetcher *StateFetcher) GetPrivacyMetaData(blockHash common.Hash, address c
return nil, err
}
- privacyMetaData, err := privateState.GetStatePrivacyMetadata(address)
+ privacyMetaData, err := privateState.GetPrivacyMetadata(address)
if err != nil {
return nil, err
}
diff --git a/graphql/graphql.go b/graphql/graphql.go
index be29559fe1..b88ca91076 100644
--- a/graphql/graphql.go
+++ b/graphql/graphql.go
@@ -357,7 +357,7 @@ func (t *Transaction) PrivateInputData(ctx context.Context) (*hexutil.Bytes, err
return &hexutil.Bytes{}, err
}
if tx.IsPrivate() {
- privateInputData, _, err := private.P.Receive(common.BytesToEncryptedPayloadHash(tx.Data()))
+ _, _, privateInputData, _, err := private.P.Receive(common.BytesToEncryptedPayloadHash(tx.Data()))
if err != nil || tx == nil {
return &hexutil.Bytes{}, err
}
@@ -1012,7 +1012,7 @@ func (r *Resolver) SendRawTransaction(ctx context.Context, args struct{ Data hex
if err := rlp.DecodeBytes(args.Data, tx); err != nil {
return common.Hash{}, err
}
- hash, err := ethapi.SubmitTransaction(ctx, r.backend, tx)
+ hash, err := ethapi.SubmitTransaction(ctx, r.backend, tx, "", nil, true)
return hash, err
}
diff --git a/graphql/graphql_test.go b/graphql/graphql_test.go
index 02b19dff28..32ddf3a1ef 100644
--- a/graphql/graphql_test.go
+++ b/graphql/graphql_test.go
@@ -97,19 +97,20 @@ func (spm *StubPrivateTransactionManager) HasFeature(f engine.PrivateTransaction
return true
}
-func (spm *StubPrivateTransactionManager) Receive(txHash common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
+func (spm *StubPrivateTransactionManager) Receive(txHash common.EncryptedPayloadHash) (string, []string, []byte, *engine.ExtraMetadata, error) {
res := spm.responses[txHash]
if err, ok := res[1].(error); ok {
- return nil, nil, err
+ return "", nil, nil, nil, err
}
if ret, ok := res[0].([]byte); ok {
- return ret, &engine.ExtraMetadata{
+ return "", nil, ret, &engine.ExtraMetadata{
PrivacyFlag: engine.PrivacyFlagStandardPrivate,
}, nil
}
- return nil, nil, nil
+ return "", nil, nil, nil, nil
}
-func (spm *StubPrivateTransactionManager) ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
- return spm.Receive(data)
+func (spm *StubPrivateTransactionManager) ReceiveRaw(hash common.EncryptedPayloadHash) ([]byte, string, *engine.ExtraMetadata, error) {
+ _, sender, data, metadata, err := spm.Receive(hash)
+ return data, sender[0], metadata, err
}
diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go
index a68b9c621d..19675fa2a3 100644
--- a/internal/ethapi/api.go
+++ b/internal/ethapi/api.go
@@ -46,12 +46,14 @@ import (
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/private"
"github.com/ethereum/go-ethereum/private/engine"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
"github.com/tyler-smith/go-bip39"
)
@@ -470,7 +472,7 @@ func (s *PrivateAccountAPI) SendTransaction(ctx context.Context, args SendTxArgs
log.Warn("Failed transaction send attempt", "from", args.From, "to", args.To, "value", args.Value.ToInt(), "err", err)
return common.Hash{}, err
}
- return SubmitTransaction(ctx, s.b, signed)
+ return SubmitTransaction(ctx, s.b, signed, args.PrivateFrom, args.PrivateFor, false)
}
// SignTransaction will create a transaction from the given arguments and
@@ -854,6 +856,9 @@ type account struct {
StateDiff *map[common.Hash]common.Hash `json:"stateDiff"`
}
+// Quorum - Multitenancy
+// Before returning the result, we need to inspect the EVM and
+// perform verification check
func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides map[common.Address]account, vmCfg vm.Config, timeout time.Duration, globalGasCap *big.Int) ([]byte, uint64, bool, error) {
defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now())
@@ -943,8 +948,37 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.Blo
// this makes sure resources are cleaned up.
defer cancel()
+ enrichedCtx := ctx
+ // create callbacks to support runtime multitenancy checks during the run
+ if authToken, ok := b.SupportsMultitenancy(ctx); ok {
+ var authorizeMessageCallFunc multitenancy.AuthorizeMessageCallFunc = func(contractAddress common.Address) (bool, bool, error) {
+ var readSecAttr *multitenancy.ContractSecurityAttribute
+ if len(data) == 0 { // public READ
+ readSecAttr = multitenancy.NewContractSecurityAttributeBuilder().FromEOA(addr).ToEOA(*msg.To()).Public().Read().Build()
+ } else {
+ currentBlock := b.CurrentBlock().Number().Int64()
+ extraDataReader, err := b.AccountExtraDataStateGetterByNumber(ctx, rpc.BlockNumber(currentBlock))
+ if err != nil {
+ return false, false, fmt.Errorf("no account extra data reader at block %v: %w", currentBlock, err)
+ }
+ managedParties, err := extraDataReader.GetManagedParties(contractAddress)
+ isPrivate := true
+ if errors.Is(err, common.ErrNotPrivateContract) {
+ isPrivate = false
+ } else if err != nil {
+ return false, false, fmt.Errorf("%s not found in the index, error: %s", contractAddress.Hex(), err.Error())
+ }
+ readSecAttr = multitenancy.NewContractSecurityAttributeBuilder().FromEOA(addr).PrivateIf(isPrivate).PartiesOnlyIf(isPrivate, managedParties).Read().Build()
+ }
+ authorizedRead, _ := b.IsAuthorized(ctx, authToken, readSecAttr)
+ log.Trace("Authorized Message Call", "read", authorizedRead, "address", contractAddress.Hex(), "securityAttribute", readSecAttr)
+ return authorizedRead, false, nil
+ }
+ enrichedCtx = context.WithValue(enrichedCtx, multitenancy.CtxKeyAuthorizeMessageCallFunc, authorizeMessageCallFunc)
+ }
+
// Get a new instance of the EVM.
- evm, vmError, err := b.GetEVM(ctx, msg, state, header)
+ evm, vmError, err := b.GetEVM(enrichedCtx, msg, state, header)
if err != nil {
return nil, 0, false, err
}
@@ -958,7 +992,7 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.Blo
// Setup the gas pool (also for unmetered requests)
// and apply the message.
gp := new(core.GasPool).AddGas(math.MaxUint64)
- res, gas, failed, err := core.ApplyMessage(evm, msg, gp)
+ res, gas, failed, applyErr := core.ApplyMessage(evm, msg, gp)
if err := vmError(); err != nil {
return nil, 0, false, err
}
@@ -966,7 +1000,7 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.Blo
if evm.Cancelled() {
return nil, 0, false, fmt.Errorf("execution aborted (timeout = %v)", timeout)
}
- return res, gas, failed, err
+ return res, gas, failed, applyErr
}
// Call executes the given transaction on the state for the given block number.
@@ -975,14 +1009,17 @@ func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.Blo
//
// Note, this function doesn't make and changes in the state/blockchain and is
// useful to execute and retrieve values.
+// Quorum
+// - replaced the default 5s time out with the value passed in vm.calltimeout
+// - multi tenancy verification
func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *map[common.Address]account) (hexutil.Bytes, error) {
var accounts map[common.Address]account
if overrides != nil {
accounts = *overrides
}
- // Quorum - replaced the default 5s time out with the value passed in vm.calltimeout
result, _, _, err := DoCall(ctx, s.b, args, blockNrOrHash, accounts, vm.Config{}, s.b.CallTimeOut(), s.b.RPCGasCap())
+
return (hexutil.Bytes)(result), err
}
@@ -1386,7 +1423,7 @@ func (s *PublicTransactionPoolAPI) GetContractPrivacyMetadata(ctx context.Contex
if state == nil || err != nil {
return nil, err
}
- return state.GetStatePrivacyMetadata(address)
+ return state.GetPrivacyMetadata(address)
}
// /Quorum
@@ -1475,9 +1512,43 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, ha
if receipt.ContractAddress != (common.Address{}) {
fields["contractAddress"] = receipt.ContractAddress
}
+ if authToken, ok := s.b.SupportsMultitenancy(ctx); ok {
+ extraDataReader, err := s.b.AccountExtraDataStateGetterByNumber(ctx, rpc.BlockNumber(blockNumber))
+ if err != nil {
+ return nil, fmt.Errorf("no account extra data reader at block %v: %w", blockNumber, err)
+ }
+
+ filteredLogs := make([]*types.Log, 0)
+ for _, l := range receipt.Logs {
+ ok, err := s.isContractAuthorized(ctx, authToken, extraDataReader, l.Address)
+ if err != nil {
+ return nil, err
+ }
+ if ok {
+ filteredLogs = append(filteredLogs, l)
+ }
+ }
+ fields["logs"] = filteredLogs
+ receiptClone := &types.Receipt{PostState: receipt.PostState, Status: receipt.Status, Logs: filteredLogs}
+ fields["logsBloom"] = types.CreateBloom(types.Receipts{receiptClone})
+ }
return fields, nil
}
+func (s *PublicTransactionPoolAPI) isContractAuthorized(ctx context.Context, authToken *proto.PreAuthenticatedAuthenticationToken, extraDataReader vm.AccountExtraDataStateGetter, addr common.Address) (bool, error) {
+ attrBuilder := multitenancy.NewContractSecurityAttributeBuilder().Read().Private()
+ managedParties, err := extraDataReader.GetManagedParties(addr)
+ if errors.Is(err, common.ErrNotPrivateContract) {
+ attrBuilder.Public()
+ } else if err != nil {
+ return false, fmt.Errorf("contract %s not found in the index due to %s", addr.Hex(), err.Error())
+ }
+
+ ok, _ := s.b.IsAuthorized(ctx, authToken, attrBuilder.Parties(managedParties).Build())
+
+ return ok, nil
+}
+
// Quorum: if signing a private TX, set with tx.SetPrivate() before calling this method.
// sign is a helper function that signs a transaction with the private key of the given address.
func (s *PublicTransactionPoolAPI) sign(addr common.Address, tx *types.Transaction) (*types.Transaction, error) {
@@ -1625,21 +1696,53 @@ func (args *SendTxArgs) toTransaction() *types.Transaction {
// TODO: this submits a signed transaction, if it is a signed private transaction that should already be recorded in the tx.
// SubmitTransaction is a helper function that submits tx to txPool and logs a message.
-func SubmitTransaction(ctx context.Context, b Backend, tx *types.Transaction) (common.Hash, error) {
- if err := b.SendTx(ctx, tx); err != nil {
+func SubmitTransaction(ctx context.Context, b Backend, tx *types.Transaction, privateFrom string, privateFor []string, isRaw bool) (common.Hash, error) {
+ var signer types.Signer
+ if tx.IsPrivate() {
+ signer = types.QuorumPrivateTxSigner{}
+ } else {
+ signer = types.MakeSigner(b.ChainConfig(), b.CurrentBlock().Number())
+ }
+ from, err := types.Sender(signer, tx)
+ if err != nil {
return common.Hash{}, err
}
- if tx.To() == nil {
- var signer types.Signer
+ if authToken, ok := b.SupportsMultitenancy(ctx); ok {
+ originalTx := tx
+ // for private transaction, private payload will be retrieved from Tessera to build the original transaction
+ // in order to supply to the simulation engine
if tx.IsPrivate() {
- signer = types.QuorumPrivateTxSigner{}
- } else {
- signer = types.MakeSigner(b.ChainConfig(), b.CurrentBlock().Number())
+ if isRaw {
+ // for raw private transaction, the privateFrom will be derived when retrieving the private payload from Tessera
+ originalTx, privateFrom, err = buildPrivateTransactionFromRaw(tx)
+ if err != nil {
+ return common.Hash{}, err
+ }
+ } else {
+ originalTx, err = buildPrivateTransaction(tx)
+ if err != nil {
+ return common.Hash{}, err
+ }
+ }
+ originalTx.SetPrivate()
+ // enforcing privateFrom present
+ if privateFrom == "" {
+ return common.Hash{}, fmt.Errorf("missing privateFrom")
+ }
}
- from, err := types.Sender(signer, tx)
+ err := performMultitenancyChecks(ctx, authToken, b, from, originalTx, &PrivateTxArgs{
+ PrivateFrom: privateFrom,
+ PrivateFor: privateFor,
+ })
if err != nil {
return common.Hash{}, err
}
+ }
+
+ if err := b.SendTx(ctx, tx); err != nil {
+ return common.Hash{}, err
+ }
+ if tx.To() == nil {
addr := crypto.CreateAddress(from, tx.Nonce())
log.Info("Submitted contract creation", "fullhash", tx.Hash().Hex(), "to", addr.Hex())
log.EmitCheckpoint(log.TxCreated, "tx", tx.Hash().Hex(), "to", addr.Hex())
@@ -1648,6 +1751,159 @@ func SubmitTransaction(ctx context.Context, b Backend, tx *types.Transaction) (c
log.EmitCheckpoint(log.TxCreated, "tx", tx.Hash().Hex(), "to", tx.To().Hex())
}
return tx.Hash(), nil
+
+}
+
+// Quorum
+//
+// performMultitenancyChecks is to use the given transaction and construct
+// expected security attributes being checked against entitled ones. The
+// transaction is fed into the simulation engine in order to determine the impact.
+func performMultitenancyChecks(ctx context.Context, authToken *proto.PreAuthenticatedAuthenticationToken,
+ b Backend, fromEOA common.Address, tx *types.Transaction, privateArgs *PrivateTxArgs) error {
+
+ if tx.IsPrivate() {
+ // before running simulation, we verify the ownership of privateFrom
+ // user must be entitled for all actions
+ // READ and WRITE actions are taking Parties into consideration so
+ // we need to populate it with privateFrom
+ if authorized, _ := b.IsAuthorized(ctx, authToken, multitenancy.FullAccessContractSecurityAttributes(fromEOA, privateArgs.PrivateFrom)...); !authorized {
+ return multitenancy.ErrNotAuthorized
+ }
+ }
+ currentBlock := b.CurrentBlock().Number().Int64()
+ extraDataReader, err := b.AccountExtraDataStateGetterByNumber(ctx, rpc.BlockNumber(currentBlock))
+ if err != nil {
+ return fmt.Errorf("no account extra data reader at block %v: %w", currentBlock, err)
+ }
+ // create callbacks to support runtime multitenancy checks during simulation
+ createContractSA := multitenancy.NewContractSecurityAttributeBuilder().
+ FromEOA(fromEOA).PrivateIf(tx.IsPrivate()).Create().PrivateFromOnlyIf(tx.IsPrivate(), privateArgs.PrivateFrom).Build()
+ authorizedCreate, _ := b.IsAuthorized(ctx, authToken, createContractSA)
+ log.Debug("Authorized Contract Creation", "create", authorizedCreate, "securityAttribute", createContractSA)
+ var authorizeCreateFunc multitenancy.AuthorizeCreateFunc = func() bool {
+ return authorizedCreate
+ }
+ var authorizeMessageCallFunc multitenancy.AuthorizeMessageCallFunc = func(contractAddress common.Address) (bool, bool, error) {
+ managedParties, err := extraDataReader.GetManagedParties(contractAddress)
+ isPrivate := true
+ if errors.Is(err, common.ErrNotPrivateContract) {
+ isPrivate = false
+ } else if err != nil {
+ return false, false, fmt.Errorf("%s not found in the index, error: %s", contractAddress.Hex(), err.Error())
+ }
+ readSecAttr := multitenancy.NewContractSecurityAttributeBuilder().FromEOA(fromEOA).PrivateIf(isPrivate).PartiesOnlyIf(isPrivate, managedParties).Read().Build()
+ authorizedRead, _ := b.IsAuthorized(ctx, authToken, readSecAttr)
+ log.Trace("Authorized Message Call", "read", authorizedRead, "address", contractAddress.Hex(), "securityAttribute", readSecAttr)
+ writeSecAttr := multitenancy.NewContractSecurityAttributeBuilder().FromEOA(fromEOA).PrivateIf(isPrivate).PartiesOnlyIf(isPrivate, managedParties).Write().Build()
+ authorizedWrite, _ := b.IsAuthorized(ctx, authToken, writeSecAttr)
+ log.Trace("Authorized Message Call", "write", authorizedWrite, "address", contractAddress.Hex(), "securityAttribute", writeSecAttr)
+ return authorizedRead, authorizedWrite, nil
+ }
+ enrichedCtx := ctx
+ enrichedCtx = context.WithValue(enrichedCtx, multitenancy.CtxKeyAuthorizeCreateFunc, authorizeCreateFunc)
+ enrichedCtx = context.WithValue(enrichedCtx, multitenancy.CtxKeyAuthorizeMessageCallFunc, authorizeMessageCallFunc)
+ if _, err := runSimulation(enrichedCtx, b, fromEOA, tx); err != nil {
+ log.Error("Simulated execution for multitenancy", "error", err)
+ return err
+ }
+ return nil
+}
+
+// Quorum
+//
+// Retrieve private payload and construct the original transaction
+func buildPrivateTransaction(tx *types.Transaction) (*types.Transaction, error) {
+ _, _, privatePayload, _, revErr := private.P.Receive(common.BytesToEncryptedPayloadHash(tx.Data()))
+ if revErr != nil {
+ return nil, revErr
+ }
+ var privateTx *types.Transaction
+ if tx.To() == nil {
+ privateTx = types.NewContractCreation(tx.Nonce(), tx.Value(), tx.Gas(), tx.GasPrice(), privatePayload)
+ } else {
+ privateTx = types.NewTransaction(tx.Nonce(), *tx.To(), tx.Value(), tx.Gas(), tx.GasPrice(), privatePayload)
+ }
+ return privateTx, nil
+}
+
+// Quorum
+//
+// Retrieve private payload and construct the original transaction along with privateFrom information
+func buildPrivateTransactionFromRaw(tx *types.Transaction) (*types.Transaction, string, error) {
+ privatePayload, privateFrom, _, revErr := private.P.ReceiveRaw(common.BytesToEncryptedPayloadHash(tx.Data()))
+ if revErr != nil {
+ return nil, "", revErr
+ }
+ var privateTx *types.Transaction
+ if tx.To() == nil {
+ privateTx = types.NewContractCreation(tx.Nonce(), tx.Value(), tx.Gas(), tx.GasPrice(), privatePayload)
+ } else {
+ privateTx = types.NewTransaction(tx.Nonce(), *tx.To(), tx.Value(), tx.Gas(), tx.GasPrice(), privatePayload)
+ }
+ return privateTx, privateFrom, nil
+}
+
+// runSimulation runs a simulation of the given transaction.
+// It returns the EVM instance upon completion
+func runSimulation(ctx context.Context, b Backend, from common.Address, tx *types.Transaction) (*vm.EVM, error) {
+ defer func(start time.Time) {
+ log.Debug("Simulated Execution EVM call finished", "runtime", time.Since(start))
+ }(time.Now())
+
+ // Set sender address or use a default if none specified
+ addr := from
+ if addr == (common.Address{}) {
+ if wallets := b.AccountManager().Wallets(); len(wallets) > 0 {
+ if accountList := wallets[0].Accounts(); len(accountList) > 0 {
+ addr = accountList[0].Address
+ }
+ }
+ }
+
+ // Create new call message
+ msg := types.NewMessage(addr, tx.To(), tx.Nonce(), tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data(), false)
+
+ // Setup context with timeout as gas un-metered
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, time.Second*5)
+ // Make sure the context is cancelled when the call has completed
+ // this makes sure resources are cleaned up.
+ defer func() { cancel() }()
+
+ // Get a new instance of the EVM.
+ blockNumber := b.CurrentBlock().Number().Uint64()
+ stateAtBlock, header, err := b.StateAndHeaderByNumber(ctx, rpc.BlockNumber(blockNumber))
+ if stateAtBlock == nil || err != nil {
+ return nil, err
+ }
+ evm, _, err := b.GetEVM(ctx, msg, stateAtBlock, header)
+ if err != nil {
+ return nil, err
+ }
+
+ // Wait for the context to be done and cancel the evm. Even if the
+ // EVM has finished, cancelling may be done (repeatedly)
+ go func() {
+ <-ctx.Done()
+ evm.Cancel()
+ }()
+
+ var contractAddr common.Address
+ // even the creation of a contract (init code) can invoke other contracts
+ if tx.To() != nil {
+ // removed contract availability checks as they are performed in checkAndHandlePrivateTransaction
+ _, _, err = evm.Call(vm.AccountRef(addr), *tx.To(), tx.Data(), tx.Gas(), tx.Value())
+ } else {
+ _, contractAddr, _, err = evm.Create(vm.AccountRef(addr), tx.Data(), tx.Gas(), tx.Value())
+ //make sure that nonce is same in simulation as in actual block processing
+ //simulation blockNumber will be behind block processing blockNumber by at least 1
+ //only guaranteed to work for default config where EIP158=1
+ if evm.ChainConfig().IsEIP158(big.NewInt(evm.BlockNumber.Int64() + 1)) {
+ evm.StateDB.SetNonce(contractAddr, 1)
+ }
+ }
+ return evm, err
}
// SendTransaction creates a transaction for the given argument, sign it and submit it to the
@@ -1704,7 +1960,7 @@ func (s *PublicTransactionPoolAPI) SendTransaction(ctx context.Context, args Sen
if err != nil {
return common.Hash{}, err
}
- return SubmitTransaction(ctx, s.b, signed)
+ return SubmitTransaction(ctx, s.b, signed, args.PrivateFrom, args.PrivateFor, false)
}
// FillTransaction fills the defaults (nonce, gas, gasPrice) on a given unsigned transaction,
@@ -1748,7 +2004,7 @@ func (s *PublicTransactionPoolAPI) SendRawTransaction(ctx context.Context, encod
if err := rlp.DecodeBytes(encodedTx, tx); err != nil {
return common.Hash{}, err
}
- return SubmitTransaction(ctx, s.b, tx)
+ return SubmitTransaction(ctx, s.b, tx, "", nil, false)
}
// SendRawPrivateTransaction will add the signed transaction to the transaction pool.
@@ -1770,7 +2026,7 @@ func (s *PublicTransactionPoolAPI) SendRawPrivateTransaction(ctx context.Context
return common.Hash{}, fmt.Errorf("transaction is not private")
}
// /Quorum
- return SubmitTransaction(ctx, s.b, tx)
+ return SubmitTransaction(ctx, s.b, tx, "", args.PrivateFor, true)
}
// Sign calculates an ECDSA signature for:
@@ -2188,7 +2444,7 @@ func (s *PublicBlockChainAPI) GetQuorumPayload(digestHex string) (string, error)
if len(b) != common.EncryptedPayloadHashLength {
return "", fmt.Errorf("Expected a Quorum digest of length 64, but got %d", len(b))
}
- data, _, err := private.P.Receive(common.BytesToEncryptedPayloadHash(b))
+ _, _, data, _, err := private.P.Receive(common.BytesToEncryptedPayloadHash(b))
if err != nil {
return "", err
}
@@ -2256,7 +2512,7 @@ func handlePrivateTransaction(ctx context.Context, b Backend, tx *types.Transact
return
case RawTransaction:
hash = common.BytesToEncryptedPayloadHash(data)
- privatePayload, _, revErr := private.P.ReceiveRaw(hash)
+ privatePayload, _, _, revErr := private.P.ReceiveRaw(hash)
if revErr != nil {
return common.EncryptedPayloadHash{}, revErr
}
@@ -2267,13 +2523,13 @@ func handlePrivateTransaction(ctx context.Context, b Backend, tx *types.Transact
} else {
privateTx = types.NewTransaction(tx.Nonce(), *tx.To(), tx.Value(), tx.Gas(), tx.GasPrice(), privatePayload)
}
- affectedCATxHashes, merkleRoot, err = simulateExecution(ctx, b, from, privateTx, privateTxArgs)
+ affectedCATxHashes, merkleRoot, err = simulateExecutionForPE(ctx, b, from, privateTx, privateTxArgs)
log.Trace("after simulation", "affectedCATxHashes", affectedCATxHashes, "merkleRoot", merkleRoot, "privacyFlag", privateTxArgs.PrivacyFlag, "error", err)
if err != nil {
return
}
- data, err = private.P.SendSignedTx(hash, privateTxArgs.PrivateFor, &engine.ExtraMetadata{
+ _, _, data, err = private.P.SendSignedTx(hash, privateTxArgs.PrivateFor, &engine.ExtraMetadata{
ACHashes: affectedCATxHashes,
ACMerkleRoot: merkleRoot,
PrivacyFlag: privateTxArgs.PrivacyFlag,
@@ -2283,13 +2539,13 @@ func handlePrivateTransaction(ctx context.Context, b Backend, tx *types.Transact
}
case NormalTransaction:
- affectedCATxHashes, merkleRoot, err = simulateExecution(ctx, b, from, tx, privateTxArgs)
+ affectedCATxHashes, merkleRoot, err = simulateExecutionForPE(ctx, b, from, tx, privateTxArgs)
log.Trace("after simulation", "affectedCATxHashes", affectedCATxHashes, "merkleRoot", merkleRoot, "privacyFlag", privateTxArgs.PrivacyFlag, "error", err)
if err != nil {
return
}
- hash, err = private.P.Send(data, privateTxArgs.PrivateFrom, privateTxArgs.PrivateFor, &engine.ExtraMetadata{
+ _, _, hash, err = private.P.Send(data, privateTxArgs.PrivateFrom, privateTxArgs.PrivateFor, &engine.ExtraMetadata{
ACHashes: affectedCATxHashes,
ACMerkleRoot: merkleRoot,
PrivacyFlag: privateTxArgs.PrivacyFlag,
@@ -2311,73 +2567,21 @@ func handlePrivateTransaction(ctx context.Context, b Backend, tx *types.Transact
return
}
-// Simulate execution of a private transaction
+// simulateExecutionForPE simulates execution of a private transaction for enhanced privacy
+//
// Returns hashes of encrypted payload of creation transactions for all affected contract accounts
// and the merkle root combining all affected contract accounts after the simulation
-//
-func simulateExecution(ctx context.Context, b Backend, from common.Address, privateTx *types.Transaction, privateTxArgs *PrivateTxArgs) (common.EncryptedPayloadHashes, common.Hash, error) {
- defer func(start time.Time) {
- log.Debug("Simulated Execution EVM call finished", "runtime", time.Since(start))
- }(time.Now())
-
+func simulateExecutionForPE(ctx context.Context, b Backend, from common.Address, privateTx *types.Transaction, privateTxArgs *PrivateTxArgs) (common.EncryptedPayloadHashes, common.Hash, error) {
// skip simulation if privacy enhancements are disabled
if !b.ChainConfig().IsPrivacyEnhancementsEnabled(b.CurrentBlock().Number()) {
return nil, common.Hash{}, nil
}
- // Set sender address or use a default if none specified
- addr := from
- if addr == (common.Address{}) {
- if wallets := b.AccountManager().Wallets(); len(wallets) > 0 {
- if accounts := wallets[0].Accounts(); len(accounts) > 0 {
- addr = accounts[0].Address
- }
- }
- }
-
- // Create new call message
- msg := types.NewMessage(addr, privateTx.To(), privateTx.Nonce(), privateTx.Value(), privateTx.Gas(), privateTx.GasPrice(), privateTx.Data(), false)
-
- // Setup context with timeout as gas un-metered
- var cancel context.CancelFunc
- ctx, cancel = context.WithTimeout(ctx, time.Second*5)
- // Make sure the context is cancelled when the call has completed
- // this makes sure resources are cleaned up.
- defer func() { cancel() }()
-
- // Get a new instance of the EVM.
- blockNumber := b.CurrentBlock().Number().Uint64()
- state, header, err := b.StateAndHeaderByNumber(ctx, rpc.BlockNumber(blockNumber))
- if state == nil || err != nil {
- return nil, common.Hash{}, err
- }
- evm, _, err := b.GetEVM(ctx, msg, state, header)
- if err != nil {
+ evm, err := runSimulation(ctx, b, from, privateTx)
+ if evm == nil {
+ log.Debug("TX Simulation setup failed", "error", err)
return nil, common.Hash{}, err
}
-
- // Wait for the context to be done and cancel the evm. Even if the
- // EVM has finished, cancelling may be done (repeatedly)
- go func() {
- <-ctx.Done()
- evm.Cancel()
- }()
-
- var contractAddr common.Address
- // even the creation of a contract (init code) can invoke other contracts
- if privateTx.To() != nil {
- // removed contract availability checks as they are performed in checkAndHandlePrivateTransaction
- _, _, err = evm.Call(vm.AccountRef(addr), *privateTx.To(), privateTx.Data(), privateTx.Gas(), privateTx.Value())
- } else {
- _, contractAddr, _, err = evm.Create(vm.AccountRef(addr), privateTx.Data(), privateTx.Gas(), privateTx.Value())
- //make sure that nonce is same in simulation as in actual block processing
- //simulation blockNumber will be behind block processing blockNumber by at least 1
- //only guaranteed to work for default config where EIP158=1
- if evm.ChainConfig().IsEIP158(big.NewInt(evm.BlockNumber.Int64() + 1)) {
- evm.StateDB.SetNonce(contractAddr, 1)
- }
- }
-
if err != nil {
if privateTxArgs.PrivacyFlag.IsStandardPrivate() {
log.Debug("An error occurred during StandardPrivate transaction simulation. "+
@@ -2393,12 +2597,12 @@ func simulateExecution(ctx context.Context, b Backend, from common.Address, priv
privacyFlag := privateTxArgs.PrivacyFlag
log.Trace("after simulation run", "numberOfAffectedContracts", len(addresses), "privacyFlag", privacyFlag)
for _, addr := range addresses {
- // GetStatePrivacyMetadata is invoked directly on the privateState (as the tx is private) and it returns:
+ // GetPrivacyMetadata is invoked directly on the privateState (as the tx is private) and it returns:
// 1. public contacts: privacyMetadata = nil, err = nil
// 2. private contracts of type:
// 2.1. StandardPrivate: privacyMetadata = nil, err = "The provided contract does not have privacy metadata"
// 2.2. PartyProtection/PSV: privacyMetadata = , err = nil
- privacyMetadata, err := evm.StateDB.GetStatePrivacyMetadata(addr)
+ privacyMetadata, err := evm.StateDB.GetPrivacyMetadata(addr)
log.Debug("Found affected contract", "address", addr.Hex(), "privacyMetadata", privacyMetadata)
//privacyMetadata not found=non-party, or another db error
if err != nil && privacyFlag.IsNotStandardPrivate() {
diff --git a/internal/ethapi/api_test.go b/internal/ethapi/api_test.go
index 06e1bda561..9dcc84f8c0 100644
--- a/internal/ethapi/api_test.go
+++ b/internal/ethapi/api_test.go
@@ -21,18 +21,22 @@ import (
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/private"
"github.com/ethereum/go-ethereum/private/engine"
"github.com/ethereum/go-ethereum/private/engine/notinuse"
"github.com/ethereum/go-ethereum/rpc"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
"github.com/stretchr/testify/assert"
)
var (
- arbitraryCtx = context.Background()
- privateTxArgs = &PrivateTxArgs{
- PrivateFor: []string{"arbitrary party 1", "arbitrary party 2"},
+ arbitraryCtx = context.Background()
+ arbitraryPrivateFrom = "arbitrary private from"
+ privateTxArgs = &PrivateTxArgs{
+ PrivateFrom: arbitraryPrivateFrom,
+ PrivateFor: []string{"arbitrary party 1", "arbitrary party 2"},
}
arbitraryFrom = common.BytesToAddress([]byte("arbitrary address"))
@@ -131,7 +135,7 @@ func TestSimulateExecution_whenStandardPrivateCreation(t *testing.T) {
assert := assert.New(t)
privateTxArgs.PrivacyFlag = engine.PrivacyFlagStandardPrivate
- affectedCACreationTxHashes, merkleRoot, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractCreationTx, privateTxArgs)
+ affectedCACreationTxHashes, merkleRoot, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractCreationTx, privateTxArgs)
assert.NoError(err, "simulate execution")
assert.Empty(affectedCACreationTxHashes, "creation tx should not have any affected contract creation tx hashes")
@@ -142,7 +146,7 @@ func TestSimulateExecution_whenPartyProtectionCreation(t *testing.T) {
assert := assert.New(t)
privateTxArgs.PrivacyFlag = engine.PrivacyFlagPartyProtection
- affectedCACreationTxHashes, merkleRoot, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractCreationTx, privateTxArgs)
+ affectedCACreationTxHashes, merkleRoot, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractCreationTx, privateTxArgs)
assert.NoError(err, "simulation execution")
assert.Empty(affectedCACreationTxHashes, "creation tx should not have any affected contract creation tx hashes")
@@ -153,7 +157,7 @@ func TestSimulateExecution_whenCreationWithStateValidation(t *testing.T) {
assert := assert.New(t)
privateTxArgs.PrivacyFlag = engine.PrivacyFlagStateValidation
- affectedCACreationTxHashes, merkleRoot, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractCreationTx, privateTxArgs)
+ affectedCACreationTxHashes, merkleRoot, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractCreationTx, privateTxArgs)
assert.NoError(err, "simulate execution")
assert.Empty(affectedCACreationTxHashes, "creation tx should not have any affected contract creation tx hashes")
@@ -168,7 +172,7 @@ func TestSimulateExecution_whenStandardPrivateMessageCall(t *testing.T) {
privateStateDB.SetState(arbitraryStandardPrivateSimpleStorageContractAddress, common.Hash{0}, common.Hash{100})
privateStateDB.Commit(true)
- affectedCACreationTxHashes, merkleRoot, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, standardPrivateSimpleStorageContractMessageCallTx, privateTxArgs)
+ affectedCACreationTxHashes, merkleRoot, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, standardPrivateSimpleStorageContractMessageCallTx, privateTxArgs)
log.Debug("state", "state", privateStateDB.GetState(arbitraryStandardPrivateSimpleStorageContractAddress, common.Hash{0}))
@@ -181,7 +185,7 @@ func TestSimulateExecution_StandardPrivateMessageCallSucceedsWheContractNotAvail
assert := assert.New(t)
privateTxArgs.PrivacyFlag = engine.PrivacyFlagStandardPrivate
- affectedCACreationTxHashes, merkleRoot, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, standardPrivateSimpleStorageContractMessageCallTx, privateTxArgs)
+ affectedCACreationTxHashes, merkleRoot, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, standardPrivateSimpleStorageContractMessageCallTx, privateTxArgs)
log.Debug("state", "state", privateStateDB.GetState(arbitraryStandardPrivateSimpleStorageContractAddress, common.Hash{0}))
@@ -195,7 +199,7 @@ func TestSimulateExecution_whenPartyProtectionMessageCall(t *testing.T) {
privateTxArgs.PrivacyFlag = engine.PrivacyFlagPartyProtection
privateStateDB.SetCode(arbitrarySimpleStorageContractAddress, hexutil.MustDecode("0x608060405234801561001057600080fd5b506040516020806101618339810180604052602081101561003057600080fd5b81019080805190602001909291905050508060008190555050610109806100586000396000f3fe6080604052600436106049576000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff16806360fe47b114604e5780636d4ce63c146099575b600080fd5b348015605957600080fd5b50608360048036036020811015606e57600080fd5b810190808035906020019092919050505060c1565b6040518082815260200191505060405180910390f35b34801560a457600080fd5b5060ab60d4565b6040518082815260200191505060405180910390f35b6000816000819055506000549050919050565b6000805490509056fea165627a7a723058203624ca2e3479d3fa5a12d97cf3dae0d9a6de3a3b8a53c8605b9cd398d9766b9f00290000000000000000000000000000000000000000000000000000000000000001"))
- privateStateDB.SetStatePrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
+ privateStateDB.SetPrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
PrivacyFlag: privateTxArgs.PrivacyFlag,
CreationTxHash: arbitrarySimpleStorageContractEncryptedPayloadHash,
})
@@ -203,7 +207,7 @@ func TestSimulateExecution_whenPartyProtectionMessageCall(t *testing.T) {
privateStateDB.SetState(arbitrarySimpleStorageContractAddress, common.Hash{0}, common.Hash{100})
privateStateDB.Commit(true)
- affectedCACreationTxHashes, merkleRoot, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
+ affectedCACreationTxHashes, merkleRoot, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
expectedCACreationTxHashes := []common.EncryptedPayloadHash{arbitrarySimpleStorageContractEncryptedPayloadHash}
@@ -223,7 +227,7 @@ func TestSimulateExecution_whenPartyProtectionMessageCallAndPrivacyEnhancementsD
defer func() { params.QuorumTestChainConfig.PrivacyEnhancementsBlock = big.NewInt(0) }()
stbBackend := &StubBackend{}
- affectedCACreationTxHashes, merkleRoot, err := simulateExecution(arbitraryCtx, stbBackend, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
+ affectedCACreationTxHashes, merkleRoot, err := simulateExecutionForPE(arbitraryCtx, stbBackend, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
// the simulation returns early without executing the transaction
assert.False(stbBackend.getEVMCalled, "simulation is ended early - before getEVM is called")
@@ -237,7 +241,7 @@ func TestSimulateExecution_whenStateValidationMessageCall(t *testing.T) {
privateTxArgs.PrivacyFlag = engine.PrivacyFlagStateValidation
privateStateDB.SetCode(arbitrarySimpleStorageContractAddress, hexutil.MustDecode("0x608060405234801561001057600080fd5b506040516020806101618339810180604052602081101561003057600080fd5b81019080805190602001909291905050508060008190555050610109806100586000396000f3fe6080604052600436106049576000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff16806360fe47b114604e5780636d4ce63c146099575b600080fd5b348015605957600080fd5b50608360048036036020811015606e57600080fd5b810190808035906020019092919050505060c1565b6040518082815260200191505060405180910390f35b34801560a457600080fd5b5060ab60d4565b6040518082815260200191505060405180910390f35b6000816000819055506000549050919050565b6000805490509056fea165627a7a723058203624ca2e3479d3fa5a12d97cf3dae0d9a6de3a3b8a53c8605b9cd398d9766b9f00290000000000000000000000000000000000000000000000000000000000000001"))
- privateStateDB.SetStatePrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
+ privateStateDB.SetPrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
PrivacyFlag: privateTxArgs.PrivacyFlag,
CreationTxHash: arbitrarySimpleStorageContractEncryptedPayloadHash,
})
@@ -245,7 +249,7 @@ func TestSimulateExecution_whenStateValidationMessageCall(t *testing.T) {
privateStateDB.SetState(arbitrarySimpleStorageContractAddress, common.Hash{0}, common.Hash{100})
privateStateDB.Commit(true)
- affectedCACreationTxHashes, merkleRoot, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
+ affectedCACreationTxHashes, merkleRoot, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
expectedCACreationTxHashes := []common.EncryptedPayloadHash{arbitrarySimpleStorageContractEncryptedPayloadHash}
@@ -266,7 +270,7 @@ func TestSimulateExecution_PrivacyFlagPartyProtectionCallingStandardPrivateContr
privateStateDB.SetState(arbitraryStandardPrivateSimpleStorageContractAddress, common.Hash{0}, common.Hash{100})
privateStateDB.Commit(true)
- _, _, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, standardPrivateSimpleStorageContractMessageCallTx, privateTxArgs)
+ _, _, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, standardPrivateSimpleStorageContractMessageCallTx, privateTxArgs)
log.Debug("state", "state", privateStateDB.GetState(arbitraryStandardPrivateSimpleStorageContractAddress, common.Hash{0}))
@@ -278,7 +282,7 @@ func TestSimulateExecution_StandardPrivateFlagCallingPartyProtectionContract_Err
privateTxArgs.PrivacyFlag = engine.PrivacyFlagStandardPrivate
privateStateDB.SetCode(arbitrarySimpleStorageContractAddress, hexutil.MustDecode("0x608060405234801561001057600080fd5b506040516020806101618339810180604052602081101561003057600080fd5b81019080805190602001909291905050508060008190555050610109806100586000396000f3fe6080604052600436106049576000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff16806360fe47b114604e5780636d4ce63c146099575b600080fd5b348015605957600080fd5b50608360048036036020811015606e57600080fd5b810190808035906020019092919050505060c1565b6040518082815260200191505060405180910390f35b34801560a457600080fd5b5060ab60d4565b6040518082815260200191505060405180910390f35b6000816000819055506000549050919050565b6000805490509056fea165627a7a723058203624ca2e3479d3fa5a12d97cf3dae0d9a6de3a3b8a53c8605b9cd398d9766b9f00290000000000000000000000000000000000000000000000000000000000000001"))
- privateStateDB.SetStatePrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
+ privateStateDB.SetPrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagPartyProtection,
CreationTxHash: arbitrarySimpleStorageContractEncryptedPayloadHash,
})
@@ -286,7 +290,7 @@ func TestSimulateExecution_StandardPrivateFlagCallingPartyProtectionContract_Err
privateStateDB.SetState(arbitrarySimpleStorageContractAddress, common.Hash{0}, common.Hash{100})
privateStateDB.Commit(true)
- _, _, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
+ _, _, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
assert.Error(err, "simulate execution")
}
@@ -296,7 +300,7 @@ func TestSimulateExecution_StandardPrivateFlagCallingStateValidationContract_Err
privateTxArgs.PrivacyFlag = engine.PrivacyFlagStandardPrivate
privateStateDB.SetCode(arbitrarySimpleStorageContractAddress, hexutil.MustDecode("0x608060405234801561001057600080fd5b506040516020806101618339810180604052602081101561003057600080fd5b81019080805190602001909291905050508060008190555050610109806100586000396000f3fe6080604052600436106049576000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff16806360fe47b114604e5780636d4ce63c146099575b600080fd5b348015605957600080fd5b50608360048036036020811015606e57600080fd5b810190808035906020019092919050505060c1565b6040518082815260200191505060405180910390f35b34801560a457600080fd5b5060ab60d4565b6040518082815260200191505060405180910390f35b6000816000819055506000549050919050565b6000805490509056fea165627a7a723058203624ca2e3479d3fa5a12d97cf3dae0d9a6de3a3b8a53c8605b9cd398d9766b9f00290000000000000000000000000000000000000000000000000000000000000001"))
- privateStateDB.SetStatePrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
+ privateStateDB.SetPrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagStateValidation,
CreationTxHash: arbitrarySimpleStorageContractEncryptedPayloadHash,
})
@@ -304,7 +308,7 @@ func TestSimulateExecution_StandardPrivateFlagCallingStateValidationContract_Err
privateStateDB.SetState(arbitrarySimpleStorageContractAddress, common.Hash{0}, common.Hash{100})
privateStateDB.Commit(true)
- _, _, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
+ _, _, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
log.Debug("state", "state", privateStateDB.GetState(arbitrarySimpleStorageContractAddress, common.Hash{0}))
@@ -316,7 +320,7 @@ func TestSimulateExecution_PartyProtectionFlagCallingStateValidationContract_Err
privateTxArgs.PrivacyFlag = engine.PrivacyFlagPartyProtection
privateStateDB.SetCode(arbitrarySimpleStorageContractAddress, hexutil.MustDecode("0x608060405234801561001057600080fd5b506040516020806101618339810180604052602081101561003057600080fd5b81019080805190602001909291905050508060008190555050610109806100586000396000f3fe6080604052600436106049576000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff16806360fe47b114604e5780636d4ce63c146099575b600080fd5b348015605957600080fd5b50608360048036036020811015606e57600080fd5b810190808035906020019092919050505060c1565b6040518082815260200191505060405180910390f35b34801560a457600080fd5b5060ab60d4565b6040518082815260200191505060405180910390f35b6000816000819055506000549050919050565b6000805490509056fea165627a7a723058203624ca2e3479d3fa5a12d97cf3dae0d9a6de3a3b8a53c8605b9cd398d9766b9f00290000000000000000000000000000000000000000000000000000000000000001"))
- privateStateDB.SetStatePrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
+ privateStateDB.SetPrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagStateValidation,
CreationTxHash: arbitrarySimpleStorageContractEncryptedPayloadHash,
})
@@ -324,7 +328,7 @@ func TestSimulateExecution_PartyProtectionFlagCallingStateValidationContract_Err
privateStateDB.SetState(arbitrarySimpleStorageContractAddress, common.Hash{0}, common.Hash{100})
privateStateDB.Commit(true)
- _, _, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
+ _, _, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
log.Debug("state", "state", privateStateDB.GetState(arbitrarySimpleStorageContractAddress, common.Hash{0}))
@@ -336,7 +340,7 @@ func TestSimulateExecution_StateValidationFlagCallingPartyProtectionContract_Err
privateTxArgs.PrivacyFlag = engine.PrivacyFlagStateValidation
privateStateDB.SetCode(arbitrarySimpleStorageContractAddress, hexutil.MustDecode("0x608060405234801561001057600080fd5b506040516020806101618339810180604052602081101561003057600080fd5b81019080805190602001909291905050508060008190555050610109806100586000396000f3fe6080604052600436106049576000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff16806360fe47b114604e5780636d4ce63c146099575b600080fd5b348015605957600080fd5b50608360048036036020811015606e57600080fd5b810190808035906020019092919050505060c1565b6040518082815260200191505060405180910390f35b34801560a457600080fd5b5060ab60d4565b6040518082815260200191505060405180910390f35b6000816000819055506000549050919050565b6000805490509056fea165627a7a723058203624ca2e3479d3fa5a12d97cf3dae0d9a6de3a3b8a53c8605b9cd398d9766b9f00290000000000000000000000000000000000000000000000000000000000000001"))
- privateStateDB.SetStatePrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
+ privateStateDB.SetPrivacyMetadata(arbitrarySimpleStorageContractAddress, &state.PrivacyMetadata{
PrivacyFlag: engine.PrivacyFlagPartyProtection,
CreationTxHash: arbitrarySimpleStorageContractEncryptedPayloadHash,
})
@@ -344,7 +348,7 @@ func TestSimulateExecution_StateValidationFlagCallingPartyProtectionContract_Err
privateStateDB.SetState(arbitrarySimpleStorageContractAddress, common.Hash{0}, common.Hash{100})
privateStateDB.Commit(true)
- _, _, err := simulateExecution(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
+ _, _, err := simulateExecutionForPE(arbitraryCtx, &StubBackend{}, arbitraryFrom, simpleStorageContractMessageCallTx, privateTxArgs)
//expectedCACreationTxHashes := []common.EncryptedPayloadHash{arbitrarySimpleStorageContractEncryptedPayloadHash}
@@ -441,8 +445,42 @@ func TestHandlePrivateTransaction_whenRawStandardPrivateMessageCall(t *testing.T
}
+// Copy and set private
+func copyTransaction(tx *types.Transaction) *types.Transaction {
+ var privateTx *types.Transaction
+ if tx.To() == nil {
+ privateTx = types.NewContractCreation(tx.Nonce(),
+ tx.Value(),
+ tx.Gas(),
+ tx.GasPrice(),
+ tx.Data())
+ } else {
+ privateTx = types.NewTransaction(tx.Nonce(),
+ *tx.To(),
+ tx.Value(),
+ tx.Gas(),
+ tx.GasPrice(),
+ tx.Data())
+ }
+ privateTx.SetPrivate()
+ return privateTx
+}
+
type StubBackend struct {
- getEVMCalled bool
+ getEVMCalled bool
+ mockAccountExtraDataStateGetter *vm.MockAccountExtraDataStateGetter
+}
+
+func (sb *StubBackend) SupportsMultitenancy(rpcCtx context.Context) (*proto.PreAuthenticatedAuthenticationToken, bool) {
+ panic("implement me")
+}
+
+func (sb *StubBackend) AccountExtraDataStateGetterByNumber(context.Context, rpc.BlockNumber) (vm.AccountExtraDataStateGetter, error) {
+ return sb.mockAccountExtraDataStateGetter, nil
+}
+
+func (sb *StubBackend) IsAuthorized(ctx context.Context, authToken *proto.PreAuthenticatedAuthenticationToken, attributes ...*multitenancy.ContractSecurityAttribute) (bool, error) {
+ panic("implement me")
}
func (sb *StubBackend) GetEVM(ctx context.Context, msg core.Message, state vm.MinimalApiState, header *types.Header) (*vm.EVM, func() error, error) {
@@ -642,7 +680,11 @@ func (StubMinimalApiState) SetCode(common.Address, []byte) {
panic("implement me")
}
-func (StubMinimalApiState) GetStatePrivacyMetadata(addr common.Address) (*state.PrivacyMetadata, error) {
+func (StubMinimalApiState) GetPrivacyMetadata(addr common.Address) (*state.PrivacyMetadata, error) {
+ panic("implement me")
+}
+
+func (StubMinimalApiState) GetManagedParties(addr common.Address) ([]string, error) {
panic("implement me")
}
@@ -683,8 +725,8 @@ type StubPrivateTransactionManager struct {
creation bool
}
-func (sptm *StubPrivateTransactionManager) Send(data []byte, from string, to []string, extra *engine.ExtraMetadata) (common.EncryptedPayloadHash, error) {
- return arbitrarySimpleStorageContractEncryptedPayloadHash, nil
+func (sptm *StubPrivateTransactionManager) Send(data []byte, from string, to []string, extra *engine.ExtraMetadata) (string, []string, common.EncryptedPayloadHash, error) {
+ return "", nil, arbitrarySimpleStorageContractEncryptedPayloadHash, nil
}
func (sptm *StubPrivateTransactionManager) EncryptPayload(data []byte, from string, to []string, extra *engine.ExtraMetadata) ([]byte, error) {
@@ -699,15 +741,15 @@ func (sptm *StubPrivateTransactionManager) StoreRaw(data []byte, from string) (c
return arbitrarySimpleStorageContractEncryptedPayloadHash, nil
}
-func (sptm *StubPrivateTransactionManager) SendSignedTx(data common.EncryptedPayloadHash, to []string, extra *engine.ExtraMetadata) ([]byte, error) {
- return arbitrarySimpleStorageContractEncryptedPayloadHash.Bytes(), nil
+func (sptm *StubPrivateTransactionManager) SendSignedTx(data common.EncryptedPayloadHash, to []string, extra *engine.ExtraMetadata) (string, []string, []byte, error) {
+ return "", nil, arbitrarySimpleStorageContractEncryptedPayloadHash.Bytes(), nil
}
-func (sptm *StubPrivateTransactionManager) ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
+func (sptm *StubPrivateTransactionManager) ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, string, *engine.ExtraMetadata, error) {
if sptm.creation {
- return hexutil.MustDecode("0x6060604052341561000f57600080fd5b604051602080610149833981016040528080519060200190919050505b806000819055505b505b610104806100456000396000f30060606040526000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff1680632a1afcd914605157806360fe47b11460775780636d4ce63c146097575b600080fd5b3415605b57600080fd5b606160bd565b6040518082815260200191505060405180910390f35b3415608157600080fd5b6095600480803590602001909190505060c3565b005b341560a157600080fd5b60a760ce565b6040518082815260200191505060405180910390f35b60005481565b806000819055505b50565b6000805490505b905600a165627a7a72305820d5851baab720bba574474de3d09dbeaabc674a15f4dd93b974908476542c23f00029"), nil, nil
+ return hexutil.MustDecode("0x6060604052341561000f57600080fd5b604051602080610149833981016040528080519060200190919050505b806000819055505b505b610104806100456000396000f30060606040526000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff1680632a1afcd914605157806360fe47b11460775780636d4ce63c146097575b600080fd5b3415605b57600080fd5b606160bd565b6040518082815260200191505060405180910390f35b3415608157600080fd5b6095600480803590602001909190505060c3565b005b341560a157600080fd5b60a760ce565b6040518082815260200191505060405180910390f35b60005481565b806000819055505b50565b6000805490505b905600a165627a7a72305820d5851baab720bba574474de3d09dbeaabc674a15f4dd93b974908476542c23f00029"), "", nil, nil
} else {
- return hexutil.MustDecode("0x60fe47b1000000000000000000000000000000000000000000000000000000000000000e"), nil, nil
+ return hexutil.MustDecode("0x60fe47b1000000000000000000000000000000000000000000000000000000000000000e"), "", nil, nil
}
}
diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go
index 797dd214c5..97d32ee278 100644
--- a/internal/ethapi/backend.go
+++ b/internal/ethapi/backend.go
@@ -31,6 +31,7 @@ import (
"github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc"
)
@@ -38,6 +39,7 @@ import (
// Backend interface provides the common API services (that are provided by
// both full and light clients) with access to necessary functions.
type Backend interface {
+ multitenancy.AuthorizationProvider
// General Ethereum API
Downloader() *downloader.Downloader
ProtocolVersion() int
@@ -85,6 +87,9 @@ type Backend interface {
ChainConfig() *params.ChainConfig
CurrentBlock() *types.Block
+
+ // AccountExtraDataStateGetterByNumber returns state getter at a given block height
+ AccountExtraDataStateGetterByNumber(ctx context.Context, number rpc.BlockNumber) (vm.AccountExtraDataStateGetter, error)
}
func GetAPIs(apiBackend Backend) []rpc.API {
diff --git a/les/api_backend.go b/les/api_backend.go
index dae96927ef..36204751e5 100644
--- a/les/api_backend.go
+++ b/les/api_backend.go
@@ -36,8 +36,10 @@ import (
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/light"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
)
type LesApiBackend struct {
@@ -280,3 +282,24 @@ func (b *LesApiBackend) ServiceFilter(ctx context.Context, session *bloombits.Ma
go session.Multiplex(bloomRetrievalBatch, bloomRetrievalWait, b.eth.bloomRequests)
}
}
+
+func (b *LesApiBackend) SupportsMultitenancy(rpcCtx context.Context) (*proto.PreAuthenticatedAuthenticationToken, bool) {
+ authToken, isPreauthenticated := rpcCtx.Value(rpc.CtxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
+ if isPreauthenticated && b.eth.config.EnableMultitenancy {
+ return authToken, true
+ }
+ return nil, false
+}
+
+func (b *LesApiBackend) AccountExtraDataStateGetterByNumber(ctx context.Context, number rpc.BlockNumber) (vm.AccountExtraDataStateGetter, error) {
+ s, _, err := b.StateAndHeaderByNumber(ctx, number)
+ return s, err
+}
+
+func (b *LesApiBackend) IsAuthorized(ctx context.Context, authToken *proto.PreAuthenticatedAuthenticationToken, attributes ...*multitenancy.ContractSecurityAttribute) (bool, error) {
+ auth, err := b.eth.contractAuthzProvider.IsAuthorized(ctx, authToken, attributes...)
+ if err != nil {
+ return false, err
+ }
+ return auth, nil
+}
diff --git a/les/client.go b/les/client.go
index 34b1c5daf4..ebaea96c54 100644
--- a/les/client.go
+++ b/les/client.go
@@ -38,11 +38,11 @@ import (
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/log"
+ "github.com/ethereum/go-ethereum/multitenancy"
"github.com/ethereum/go-ethereum/node"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params"
- "github.com/ethereum/go-ethereum/plugin"
"github.com/ethereum/go-ethereum/rpc"
)
@@ -67,7 +67,16 @@ type LightEthereum struct {
accountManager *accounts.Manager
netRPCService *ethapi.PublicNetAPI
- securityPlugin *plugin.SecurityPluginTemplate // Quorum: to dispose security plugin being used
+ // Quorum - Multitenancy
+ // contractAuthzProvider is set after node starts instead in New()
+ contractAuthzProvider multitenancy.ContractAuthorizationProvider
+}
+
+// Quorum
+//
+// Set the decision manager for multitenancy support
+func (s *LightEthereum) SetContractAuthorizationManager(dm multitenancy.ContractAuthorizationProvider) {
+ s.contractAuthzProvider = dm
}
func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
@@ -112,9 +121,13 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
if checkpoint == nil {
checkpoint = params.TrustedCheckpoints[genesisHash]
}
+ newChainFunc := light.NewLightChain
+ if config.EnableMultitenancy {
+ newChainFunc = light.NewMultitenantLightChain
+ }
// Note: NewLightChain adds the trusted checkpoint so it needs an ODR with
// indexers already set but not started yet
- if leth.blockchain, err = light.NewLightChain(leth.odr, leth.chainConfig, leth.engine, checkpoint); err != nil {
+ if leth.blockchain, err = newChainFunc(leth.odr, leth.chainConfig, leth.engine, checkpoint); err != nil {
return nil, err
}
leth.chainReader = leth.blockchain
@@ -151,15 +164,6 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
}
leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams)
- // Set Security plugin in eth
- var pluginManager *plugin.PluginManager
- if err := ctx.Service(&pluginManager); err == nil {
- sp := new(plugin.SecurityPluginTemplate)
- if err := pluginManager.GetPluginTemplate(plugin.SecurityPluginInterfaceName, sp); err == nil {
- leth.securityPlugin = sp
- }
- }
-
return leth, nil
}
diff --git a/light/lightchain.go b/light/lightchain.go
index e06247fc33..4412fd2249 100644
--- a/light/lightchain.go
+++ b/light/lightchain.go
@@ -26,6 +26,8 @@ import (
"sync/atomic"
"time"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
+
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/consensus"
"github.com/ethereum/go-ethereum/core"
@@ -72,6 +74,9 @@ type LightChain struct {
running int32 // whether LightChain is running or stopped
procInterrupt int32 // interrupts chain insert
disableCheckFreq int32 // disables header verification
+
+ // Quorum
+ isMultitenant bool
}
// NewLightChain returns a fully initialised light chain using information
@@ -118,6 +123,15 @@ func NewLightChain(odr OdrBackend, config *params.ChainConfig, engine consensus.
return bc, nil
}
+func NewMultitenantLightChain(odr OdrBackend, config *params.ChainConfig, engine consensus.Engine, checkpoint *params.TrustedCheckpoint) (*LightChain, error) {
+ bc, err := NewLightChain(odr, config, engine, checkpoint)
+ if err != nil {
+ return nil, err
+ }
+ bc.isMultitenant = true
+ return bc, nil
+}
+
// AddTrustedCheckpoint adds a trusted checkpoint to the blockchain
func (lc *LightChain) AddTrustedCheckpoint(cp *params.TrustedCheckpoint) {
if lc.odr.ChtIndexer() != nil {
@@ -542,3 +556,7 @@ func (lc *LightChain) DisableCheckFreq() {
func (lc *LightChain) EnableCheckFreq() {
atomic.StoreInt32(&lc.disableCheckFreq, 0)
}
+
+func (lc *LightChain) SupportsMultitenancy(context.Context) (*proto.PreAuthenticatedAuthenticationToken, bool) {
+ return nil, lc.isMultitenant
+}
diff --git a/light/trie.go b/light/trie.go
index 04081b73eb..3a5e6e1b0a 100644
--- a/light/trie.go
+++ b/light/trie.go
@@ -91,28 +91,25 @@ func (db *odrDatabase) TrieDB() *trie.Database {
return nil
}
-// Quorum - Privacy Enhancements
-type stubPrivacyMetadataLinker struct {
+type stubAccountExtraDataLinker struct {
}
-func newPrivacyMetadataLinkerStub() rawdb.PrivacyMetadataLinker {
- return &stubPrivacyMetadataLinker{}
+func newAccountExtraDataLinkerStub() rawdb.AccountExtraDataLinker {
+ return &stubAccountExtraDataLinker{}
}
-func (pml *stubPrivacyMetadataLinker) PrivacyMetadataRootForPrivateStateRoot(privateStateRoot common.Hash) common.Hash {
+func (pml *stubAccountExtraDataLinker) GetAccountExtraDataRoot(_ common.Hash) common.Hash {
return common.Hash{}
}
-func (pml *stubPrivacyMetadataLinker) LinkPrivacyMetadataRootToPrivateStateRoot(privateStateRoot, privacyMetadataRoot common.Hash) error {
+func (pml *stubAccountExtraDataLinker) Link(_, _ common.Hash) error {
return nil
}
-func (db *odrDatabase) PrivacyMetadataLinker() rawdb.PrivacyMetadataLinker {
- return newPrivacyMetadataLinkerStub()
+func (db *odrDatabase) AccountExtraDataLinker() rawdb.AccountExtraDataLinker {
+ return newAccountExtraDataLinkerStub()
}
-// End Quorum - Privacy Enhancements
-
type odrTrie struct {
db *odrDatabase
id *TrieID
diff --git a/multitenancy/authorization_provider.go b/multitenancy/authorization_provider.go
new file mode 100644
index 0000000000..a3aed78c29
--- /dev/null
+++ b/multitenancy/authorization_provider.go
@@ -0,0 +1,222 @@
+package multitenancy
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/url"
+ "strings"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
+)
+
+var (
+ ErrNotAuthorized = errors.New("not authorized")
+ CtxKeyAuthorizeCreateFunc = "AUTHORIZE_CREATE_FUNC"
+ CtxKeyAuthorizeMessageCallFunc = "AUTHORIZE_MESSAGE_CALL_FUNC"
+)
+
+// AccountAuthorizationProvider performs authorization checks for Ethereum Account
+// based on what is entitled in the proto.PreAuthenticatedAuthenticationToken
+// and what is asked in ContractSecurityAttribute list.
+// Note: place holder for future, this is to protect Value Transfer between accounts.
+type AccountAuthorizationProvider interface {
+ IsAuthorized(ctx context.Context, authToken *proto.PreAuthenticatedAuthenticationToken, attr *AccountStateSecurityAttribute) (bool, error)
+}
+
+type AuthorizeCreateFunc func() bool
+
+// AuthorizeMessageCallFunc returns if a contract is authorized to be read / write
+type AuthorizeMessageCallFunc func(contractAddress common.Address) (authorizedRead bool, authorizedWrite bool, err error)
+
+// ContractAuthorizationProvider performs authorization checks for contract
+// based on what is entitled in the proto.PreAuthenticatedAuthenticationToken
+// and what is asked in ContractSecurityAttribute list.
+type ContractAuthorizationProvider interface {
+ IsAuthorized(ctx context.Context, authToken *proto.PreAuthenticatedAuthenticationToken, attributes ...*ContractSecurityAttribute) (bool, error)
+}
+
+type DefaultContractAuthorizationProvider struct {
+}
+
+// isAuthorized performs authorization check for one security attribute against
+// the granted access inside the pre-authenticated access token.
+func (cm *DefaultContractAuthorizationProvider) isAuthorized(authToken *proto.PreAuthenticatedAuthenticationToken, attr *ContractSecurityAttribute) (bool, error) {
+ query := url.Values{}
+ switch attr.Visibility {
+ case VisibilityPublic:
+ switch attr.Action {
+ case ActionRead, ActionWrite, ActionCreate:
+ if (attr.To == common.Address{}) {
+ query.Set(QueryOwnedEOA, toHexAddress(attr.From))
+ } else {
+ query.Set(QueryOwnedEOA, toHexAddress(attr.To))
+ }
+ }
+ case VisibilityPrivate:
+ switch attr.Action {
+ case ActionRead, ActionWrite:
+ if (attr.To == common.Address{}) {
+ query.Set(QueryOwnedEOA, toHexAddress(attr.From))
+ } else {
+ query.Set(QueryOwnedEOA, toHexAddress(attr.To))
+ }
+ for _, tm := range attr.Parties {
+ query.Add(QueryFromTM, tm)
+ }
+ case ActionCreate:
+ query.Set(QueryFromTM, attr.PrivateFrom)
+ }
+ }
+ // construct request permission identifier
+ request, err := url.Parse(fmt.Sprintf("%s://%s/%s/%s?%s", attr.Visibility, toHexAddress(attr.From), attr.Action, "contracts", query.Encode()))
+ if err != nil {
+ return false, err
+ }
+ // compare the contract security attribute with the consolidate list
+ for _, granted := range authToken.GetAuthorities() {
+ pi, err := url.Parse(granted.GetRaw())
+ if err != nil {
+ continue
+ }
+ granted := pi.String()
+ ask := request.String()
+ isMatched := match(attr, request, pi)
+ log.Debug("Checking contract access", "passed", isMatched, "granted", granted, "ask", ask)
+ if isMatched {
+ return true, nil
+ }
+ }
+ return false, nil
+}
+
+// IsAuthorized performs authorization check for each security attribute against
+// the granted access inside the pre-authenticated access token.
+//
+// All security attributes must pass.
+func (cm *DefaultContractAuthorizationProvider) IsAuthorized(_ context.Context, authToken *proto.PreAuthenticatedAuthenticationToken, attributes ...*ContractSecurityAttribute) (bool, error) {
+ if len(attributes) == 0 {
+ return false, nil
+ }
+ for _, attr := range attributes {
+ isMatched, err := cm.isAuthorized(authToken, attr)
+ if err != nil {
+ return false, err
+ }
+ if !isMatched {
+ return false, nil
+ }
+ }
+ return true, nil
+}
+
+func toHexAddress(a common.Address) string {
+ if (a == common.Address{}) {
+ return AnyEOAAddress
+ }
+ return strings.ToLower(a.Hex())
+}
+
+func match(attr *ContractSecurityAttribute, ask, granted *url.URL) bool {
+ askScheme := strings.ToLower(ask.Scheme)
+ if allowedPublic(askScheme) {
+ return true
+ }
+
+ isPathMatched := matchPath(strings.ToLower(ask.Path), strings.ToLower(granted.Path))
+ return askScheme == strings.ToLower(granted.Scheme) && //Note: "askScheme" here is "private" since we checked VisibilityPublic above.
+ matchHost(attr.Action, strings.ToLower(ask.Host), strings.ToLower(granted.Host)) && //whether i have permission to execute using this ethereum address
+ isPathMatched && //is our permission for the same action (read, write, deploy)
+ matchQuery(attr, ask.Query(), granted.Query())
+}
+
+func allowedPublic(scheme string) bool {
+ return scheme == string(VisibilityPublic)
+}
+
+func matchHost(a ContractAction, ask string, granted string) bool {
+ // for READ action, we use owned.eoa query param instead
+ return granted == AnyEOAAddress || ask == granted || a == ActionRead
+}
+
+func matchPath(ask string, granted string) bool {
+ return strings.HasPrefix(granted, "/_") || ask == granted
+}
+
+func matchQuery(attr *ContractSecurityAttribute, ask, granted url.Values) bool {
+ // if asking nothing, we should bail out
+ if len(ask) == 0 || len(ask[QueryFromTM]) == 0 {
+ return false
+ }
+ // possible scenarios:
+ // 1. read/write -> from.tm -> at least 1 of the same key must appear in both lists
+ // 2. read/write - owned.eoa/to.eoa -> check subset
+ // 3. create -> from.tm/owned.eoa/to.eoa -> check subset
+ for k, askValues := range ask {
+ grantedValues := granted[k]
+ switch attr.Action {
+ case ActionRead, ActionWrite:
+ // Scenario 1
+ if k == QueryFromTM {
+ if isIntersectionEmpty(grantedValues, askValues) {
+ return false
+ }
+ }
+ //Scenario 2
+ if k == QueryOwnedEOA || k == QueryToEOA {
+ if !subset(grantedValues, askValues) {
+ return false
+ }
+ }
+ case ActionCreate:
+ //Scenario 3
+ if !subset(grantedValues, askValues) {
+ return false
+ }
+ default:
+ // we don't know, better reject
+ log.Error("unsupported action", "action", attr.Action)
+ return false
+ }
+ }
+ return true
+}
+
+func subset(grantedValues, askValues []string) bool {
+ for _, askValue := range askValues {
+ found := false
+ sanitizedAskValue := askValue
+ if strings.HasPrefix(askValue, "0x") {
+ sanitizedAskValue = strings.ToLower(askValue)
+ }
+ for _, grantedValue := range grantedValues {
+ sanitizedGrantedValue := grantedValue
+ if strings.HasPrefix(grantedValue, "0x") {
+ sanitizedGrantedValue = strings.ToLower(grantedValue)
+ }
+ if sanitizedGrantedValue == AnyEOAAddress || sanitizedAskValue == sanitizedGrantedValue {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return false
+ }
+ }
+ return true
+}
+
+func isIntersectionEmpty(grantedValues, askValues []string) bool {
+ grantedMap := make(map[string]bool)
+ for _, grantedVal := range grantedValues {
+ grantedMap[grantedVal] = true
+ }
+ for _, askVal := range askValues {
+ if grantedMap[askVal] {
+ return false
+ }
+ }
+ return true
+}
diff --git a/multitenancy/authorization_provider_test.go b/multitenancy/authorization_provider_test.go
new file mode 100644
index 0000000000..43ac7fbb1c
--- /dev/null
+++ b/multitenancy/authorization_provider_test.go
@@ -0,0 +1,861 @@
+package multitenancy
+
+import (
+ "context"
+ "net/url"
+ "os"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
+ "github.com/stretchr/testify/assert"
+)
+
+func init() {
+ log.Root().SetHandler(log.StreamHandler(os.Stdout, log.TerminalFormat(false)))
+}
+
+type testCase struct {
+ msg string
+ granted []string
+ ask []*ContractSecurityAttribute
+ isAuthorized bool
+}
+
+func TestMatch_whenTypical(t *testing.T) {
+ granted, _ := url.Parse("private://0xa1b1c1/create/contracts?from.tm=A/")
+ ask, _ := url.Parse("private://0xa1b1c1/create/contracts?from.tm=A%2F")
+
+ assert.True(t, match(&ContractSecurityAttribute{Action: ActionCreate}, ask, granted))
+}
+
+func TestMatch_whenAskNothing(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/_/contracts?from.tm=A&owned.eoa=0x0")
+ ask, _ := url.Parse("private://0xa1b1c1/write/contracts?owned.eoa=0xe1e1e1")
+
+ assert.False(t, match(&ContractSecurityAttribute{Action: ActionCreate}, ask, granted))
+
+ ask, _ = url.Parse("private://0xa1b1c1/write/contracts")
+
+ assert.False(t, match(&ContractSecurityAttribute{Action: ActionCreate}, ask, granted))
+}
+
+func TestMatch_whenGrantNothing(t *testing.T) {
+ granted, _ := url.Parse("private://0xa1b1c1/write/contracts")
+ ask, _ := url.Parse("private://0xa1b1c1/write/contracts?from.tm=A")
+
+ assert.False(t, match(&ContractSecurityAttribute{Action: ActionCreate}, ask, granted))
+}
+
+func TestMatch_whenAnyAction(t *testing.T) {
+ granted, _ := url.Parse("private://0xa1b1c1/_/contracts?owned.eoa=0x0&from.tm=A1")
+ ask, _ := url.Parse("private://0xa1b1c1/read/contracts?from.tm=A1")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ }, ask, granted))
+
+ ask, _ = url.Parse("private://0xa1b1c1/read/contracts?owned.eoa=0x0&from.tm=A1&from.tm=B1")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ }, ask, granted))
+
+ ask, _ = url.Parse("private://0xa1b1c1/write/contracts?owned.eoa=0x0&from.tm=A1")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ }, ask, granted))
+}
+
+func TestMatch_whenPathNotMatched(t *testing.T) {
+ granted, _ := url.Parse("private://0xa1b1c1/write/contracts?owned.eoa=0x0&from.tm=A1")
+ ask, _ := url.Parse("private://0xa1b1c1/read/contracts?from.tm=A1")
+
+ assert.False(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ }, ask, granted))
+}
+
+func TestMatch_whenSchemeIsNotEqual(t *testing.T) {
+ granted, _ := url.Parse("unknown://0xa1b1c1/create/contracts?from.tm=A")
+ ask, _ := url.Parse("private://0xa1b1c1/create/contracts?from.tm=A")
+
+ assert.False(t, match(&ContractSecurityAttribute{Action: ActionCreate}, ask, granted))
+}
+
+func TestMatch_whenContractWritePermission_GrantedIsTheSuperSet(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/write/contracts?owned.eoa=0x0&from.tm=A&from.tm=B")
+ ask, _ := url.Parse("private://0x0/write/contracts?owned.eoa=0x0&from.tm=A")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ }, ask, granted), "with write permission")
+
+ granted, _ = url.Parse("private://0x0/read/contracts?owned.eoa=0x0&from.tm=A&from.tm=B")
+ ask, _ = url.Parse("private://0x0/read/contracts?owned.eoa=0x0&from.tm=A")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ }, ask, granted), "with read permission")
+}
+
+func TestMatch_whenContractReadPermission_AnyAction(t *testing.T) {
+ granted, _ := url.Parse("private://0x1234/_/contracts?owned.eoa=0x0&from.tm=A")
+ ask, _ := url.Parse("private://0x0/read/contracts?owned.eoa=0x1234&from.tm=A")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ }, ask, granted))
+}
+
+func TestMatch_whenContractReadPermission_AnyEoa(t *testing.T) {
+ granted, _ := url.Parse("private://0x1234/_/contracts?owned.eoa=0x0&from.tm=A")
+ ask, _ := url.Parse("private://0x0/read/contracts?owned.eoa=0x0&from.tm=A")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ }, ask, granted))
+}
+
+func TestMatch_whenContractReadPermission_EoaDifferent(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/read/contracts?owned.eoa=0x095e7baea6a6c7c4c2dfeb977efac326af552d87&from.tm=A")
+ ask, _ := url.Parse("private://0x0/read/contracts?owned.eoa=0x945304eb96065b2a98b57a48a06ae28d285a71b5&from.tm=A")
+
+ assert.False(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ }, ask, granted))
+}
+
+func TestMatch_whenContractReadPermission_EoaSame(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/read/contracts?owned.eoa=0x095e7baea6a6c7c4c2dfeb977efac326af552d87&from.tm=A")
+ ask, _ := url.Parse("private://0x0/read/contracts?owned.eoa=0x095e7baea6a6c7c4c2dfeb977efac326af552d87&from.tm=A")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ }, ask, granted))
+}
+
+func TestMatch_whenContractReadPermission_TmKeysIntersect(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/read/contracts?from.tm=A&from.tm=B")
+ ask, _ := url.Parse("private://0x0/read/contracts?from.tm=B&from.tm=C")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ }, ask, granted))
+}
+
+func TestMatch_whenContractReadPermission_TmKeysDontIntersect(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/read/contracts?from.tm=A&from.tm=B")
+ ask, _ := url.Parse("private://0x0/read/contracts?from.tm=C&from.tm=D")
+
+ assert.False(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ }, ask, granted))
+}
+
+func TestMatch_whenContractWritePermission_Same(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/write/contracts?owned.eoa=0x0&from.tm=A")
+ ask, _ := url.Parse("private://0x0/write/contracts?owned.eoa=0x0&from.tm=A")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ }, ask, granted))
+}
+
+func TestMatch_whenContractWritePermission_Different(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/write/contracts?owned.eoa=0x0&from.tm=A")
+ ask, _ := url.Parse("private://0x0/write/contracts?owned.eoa=0x0&from.tm=B")
+
+ assert.False(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ }, ask, granted))
+}
+
+func TestMatch_whenContractWritePermission_AskIsSuperSet(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/write/contracts?owned.eoa=0x0&from.tm=A")
+ ask, _ := url.Parse("private://0x0/write/contracts?owned.eoa=0x0&from.tm=B&from.tm=C&from.tm=A")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ }, ask, granted))
+}
+
+func TestMatch_whenContractCreatePermission_Same(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/create/contracts?owned.eoa=0x0&from.tm=A")
+ ask, _ := url.Parse("private://0x0/create/contracts?owned.eoa=0x0&from.tm=A")
+
+ assert.True(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionCreate,
+ }, ask, granted))
+}
+
+func TestMatch_whenContractCreatePermission_Different(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/create/contracts?owned.eoa=0x0&from.tm=A")
+ ask, _ := url.Parse("private://0x0/create/contracts?owned.eoa=0x0&from.tm=B")
+
+ assert.False(t, match(&ContractSecurityAttribute{
+ Visibility: VisibilityPrivate,
+ Action: ActionCreate,
+ }, ask, granted))
+}
+
+func TestMatch_whenUsingWildcardAccount(t *testing.T) {
+ granted, _ := url.Parse("private://0x0/create/contracts?from.tm=dLHrFQpbSda0EhJnLonsBwDjks%2Bf724NipfI5zK5RSs%3D")
+ ask, _ := url.Parse("private://0xed9d02e382b34818e88b88a309c7fe71e65f419d/create/contracts?from.tm=dLHrFQpbSda0EhJnLonsBwDjks%2Bf724NipfI5zK5RSs%3D")
+
+ assert.True(t, match(&ContractSecurityAttribute{Action: ActionCreate}, ask, granted))
+
+ granted, _ = url.Parse("private://0x0/read/contract?owned.eoa=0x0&from.tm=A")
+ ask, _ = url.Parse("private://0xa1b1c1/read/contract?owned.eoa=0x1234&from.tm=A")
+
+ assert.True(t, match(&ContractSecurityAttribute{Action: ActionRead}, ask, granted))
+}
+
+func TestMatch_whenNotUsingWildcardAccount(t *testing.T) {
+ granted, _ := url.Parse("private://0xed9d02e382b34818e88b88a309c7fe71e65f419d/create/contracts?from.tm=dLHrFQpbSda0EhJnLonsBwDjks%2Bf724NipfI5zK5RSs%3D")
+ ask, _ := url.Parse("private://0xed9d02e382b34818e88b88a309c7fe71e65f419d/create/contracts?from.tm=dLHrFQpbSda0EhJnLonsBwDjks%2Bf724NipfI5zK5RSs%3D")
+
+ assert.True(t, match(&ContractSecurityAttribute{Action: ActionCreate}, ask, granted))
+
+ granted, _ = url.Parse("private://0x0/read/contract?owned.eoa=0x0&from.tm=A")
+ ask, _ = url.Parse("private://0xa1b1c1/read/contract?owned.eoa=0x1234&from.tm=A")
+
+ assert.True(t, match(&ContractSecurityAttribute{Action: ActionRead}, ask, granted))
+}
+
+func TestMatch_failsWhenAccountsDiffer(t *testing.T) {
+ granted, _ := url.Parse("private://0xed9d02e382b34818e88b88a309c7fe71e65f419d/create/contracts?from.tm=dLHrFQpbSda0EhJnLonsBwDjks%2Bf724NipfI5zK5RSs%3D")
+ ask, _ := url.Parse("private://0xa94f5374fce5edbc8e2a8697c15331677e6ebf0b/create/contracts?from.tm=dLHrFQpbSda0EhJnLonsBwDjks%2Bf724NipfI5zK5RSs%3D")
+
+ assert.False(t, match(&ContractSecurityAttribute{Action: ActionCreate}, ask, granted))
+}
+
+func TestMatch_whenPublic(t *testing.T) {
+ granted, _ := url.Parse("private://0xa1b1c1/create/contract?from.tm=A/")
+ ask, _ := url.Parse("public://0x0/create/contract")
+
+ assert.True(t, match(&ContractSecurityAttribute{Action: ActionCreate}, ask, granted))
+}
+
+func TestMatch_whenNotEscaped(t *testing.T) {
+ // query not escaped probably in the granted authority resource identitifer
+ granted, _ := url.Parse("private://0xed9d02e382b34818e88b88a309c7fe71e65f419d/create/contracts?from.tm=BULeR8JyUWhiuuCMU/HLA0Q5pzkYT+cHII3ZKBey3Bo=")
+ ask, _ := url.Parse("private://0xed9d02e382b34818e88b88a309c7fe71e65f419d/create/contracts?from.tm=BULeR8JyUWhiuuCMU%2FHLA0Q5pzkYT%2BcHII3ZKBey3Bo%3D")
+
+ assert.False(t, match(&ContractSecurityAttribute{Action: ActionCreate}, ask, granted))
+}
+
+func runTestCases(t *testing.T, testCases []*testCase) {
+ testObject := &DefaultContractAuthorizationProvider{}
+ for _, tc := range testCases {
+ log.Debug("--> Running test case: " + tc.msg)
+ authorities := make([]*proto.GrantedAuthority, 0)
+ for _, a := range tc.granted {
+ authorities = append(authorities, &proto.GrantedAuthority{Raw: a})
+ }
+ b, err := testObject.IsAuthorized(
+ context.Background(),
+ &proto.PreAuthenticatedAuthenticationToken{Authorities: authorities},
+ tc.ask...)
+ if !assert.NoError(t, err, tc.msg) {
+ return
+ }
+ if !assert.Equal(t, tc.isAuthorized, b, tc.msg) {
+ return
+ }
+ }
+}
+
+func TestDefaultAccountAccessDecisionManager_IsAuthorized_forPublicContracts(t *testing.T) {
+ runTestCases(t, []*testCase{
+ canCreatePublicContracts,
+ // canNotCreatePublicContracts,
+ canReadOwnedPublicContracts,
+ canReadOtherPublicContracts,
+ // canNotReadOtherPublicContracts,
+ canWriteOwnedPublicContracts,
+ canWriteOtherPublicContracts1,
+ canWriteOtherPublicContracts2,
+ // canNotWriteOtherPublicContracts,
+ canCreatePublicContractsAndWriteToOthers,
+ })
+}
+
+func TestDefaultAccountAccessDecisionManager_IsAuthorized_forPrivateContracts(t *testing.T) {
+ runTestCases(t, []*testCase{
+ canCreatePrivateContracts,
+ canNotCreatePrivateContracts,
+ canReadOwnedPrivateContracts,
+ canReadOtherPrivateContracts,
+ canNotReadOtherPrivateContracts,
+ canNotReadOtherPrivateContractsNoPrivy,
+ canWriteOwnedPrivateContracts,
+ canWriteOtherPrivateContracts,
+ canWriteOtherPrivateContractsWithOverlappedScope,
+ canNotWriteOtherPrivateContracts,
+ canNotWriteOtherPrivateContractsNoPrivy,
+ })
+}
+
+func TestDefaultAccountAccessDecisionManager_IsAuthorized_forPrivateContracts_wildcards_whenCreate(t *testing.T) {
+ fullAccessToX := []string{
+ "private://0x0/_/contracts?owned.eoa=0x0&from.tm=X",
+ }
+ runTestCases(t, []*testCase{
+ {
+ msg: "X has full access to a private contract when create",
+ isAuthorized: true,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ // create
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionCreate,
+ PrivateFrom: "X",
+ Parties: []string{},
+ },
+ },
+ },
+ {
+ msg: "X can't creat private contract with other TM key",
+ isAuthorized: false,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ // create
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionCreate,
+ PrivateFrom: "A",
+ Parties: []string{},
+ },
+ },
+ },
+ })
+}
+
+func TestDefaultAccountAccessDecisionManager_IsAuthorized_forPrivateContracts_wildcards_whenRead(t *testing.T) {
+ fullAccessToX := []string{
+ "private://0x0/_/contracts?owned.eoa=0x0&from.tm=X",
+ }
+ runTestCases(t, []*testCase{
+ {
+ msg: "X has full access to a private contract when read as one of the participants",
+ isAuthorized: true,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{},
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ PrivateFrom: "X",
+ Parties: []string{"X", "Y"},
+ },
+ },
+ },
+ {
+ msg: "X has full access to a private contract when read as a single participant",
+ isAuthorized: true,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{},
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ PrivateFrom: "X",
+ Parties: []string{"X"},
+ },
+ },
+ },
+ {
+ msg: "X can't read other private contracts",
+ isAuthorized: false,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{},
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ PrivateFrom: "X",
+ Parties: []string{"A", "B"},
+ },
+ },
+ },
+ {
+ msg: "X can't read other private contracts by faking the read",
+ isAuthorized: false,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{},
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ PrivateFrom: "A",
+ Parties: []string{"A", "B"},
+ },
+ },
+ },
+ {
+ msg: "X can't read other private contracts when proxy-read",
+ isAuthorized: false,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ // read its own contract
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{},
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ PrivateFrom: "X",
+ Parties: []string{"X"},
+ },
+ // but using it as proxy to read other contract
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{},
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ PrivateFrom: "X",
+ Parties: []string{"A", "B"},
+ },
+ },
+ },
+ })
+}
+
+func TestDefaultAccountAccessDecisionManager_IsAuthorized_forPrivateContracts_wildcards_whenWrite(t *testing.T) {
+ fullAccessToX := []string{
+ "private://0x0/_/contracts?owned.eoa=0x0&from.tm=X",
+ }
+ runTestCases(t, []*testCase{
+ {
+ msg: "X has full access to a private contract when write as a single participant",
+ isAuthorized: true,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ PrivateFrom: "X",
+ Parties: []string{"X"},
+ },
+ },
+ },
+ {
+ msg: "X has full access to a private contract when write as one of the participants",
+ isAuthorized: true,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ PrivateFrom: "X",
+ Parties: []string{"X", "Y"},
+ },
+ },
+ },
+ {
+ msg: "X must not access other private contracts when faking write",
+ isAuthorized: false,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"), // creator EOA address
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ PrivateFrom: "A",
+ Parties: []string{"A", "B"},
+ },
+ },
+ },
+ {
+ msg: "X can not write to a private contract not privy to X",
+ isAuthorized: false,
+ granted: fullAccessToX,
+ ask: []*ContractSecurityAttribute{
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"), // creator EOA address
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ PrivateFrom: "X",
+ Parties: []string{"A", "B"},
+ },
+ },
+ },
+ })
+}
+
+var (
+ canCreatePublicContracts = &testCase{
+ msg: "0x0a1a1a1 can create public contracts",
+ granted: []string{
+ "public://0x0000000000000000000000000000000000a1a1a1/create/contracts",
+ },
+ ask: []*ContractSecurityAttribute{
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPublic,
+ Action: ActionCreate,
+ },
+ },
+ isAuthorized: true,
+ }
+ canCreatePublicContractsAndWriteToOthers = &testCase{
+ msg: "0x0a1a1a1 can create public contracts and write to contracts created by 0xb1b1b1",
+ granted: []string{
+ "public://0x0000000000000000000000000000000000a1a1a1/create/contracts",
+ "public://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000b1b1b1&owned.eoa=0x0000000000000000000000000000000000c1c1c1",
+ },
+ ask: []*ContractSecurityAttribute{
+ {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPublic,
+ Action: ActionCreate,
+ }, {
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"),
+ },
+ Visibility: VisibilityPublic,
+ Action: ActionWrite,
+ },
+ },
+ isAuthorized: true,
+ }
+ //
+ //canNotCreatePublicContracts = &testCase{
+ // msg: "0xb1b1b1 can not create public contracts",
+ // granted: []string{
+ // "public://0x0000000000000000000000000000000000a1a1a1/create/contracts",
+ // "public://0x0000000000000000000000000000000000b1b1b1/read/contracts?owned.eoa=0x0000000000000000000000000000000000a1a1a1",
+ // },
+ // ask: []*ContractSecurityAttribute{{
+ // AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ // From: common.HexToAddress("0xb1b1b1"),
+ // },
+ // Visibility: VisibilityPublic,
+ // Action: ActionCreate,
+ // }},
+ // isAuthorized: false,
+ //}
+ canReadOwnedPublicContracts = &testCase{
+ msg: "0x0a1a1a1 can read public contracts created by self",
+ granted: []string{
+ "public://0x0000000000000000000000000000000000a1a1a1/read/contracts?owned.eoa=0x0000000000000000000000000000000000a1a1a1",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPublic,
+ Action: ActionRead,
+ }},
+ isAuthorized: true,
+ }
+ canReadOtherPublicContracts = &testCase{
+ msg: "0x0a1a1a1 can read public contracts created by 0xb1b1b1",
+ granted: []string{
+ "public://0x0000000000000000000000000000000000a1a1a1/read/contracts?owned.eoa=0x0000000000000000000000000000000000b1b1b1",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"),
+ },
+ Visibility: VisibilityPublic,
+ Action: ActionRead,
+ }},
+ isAuthorized: true,
+ }
+ //canNotReadOtherPublicContracts = &testCase{
+ // msg: "0x0a1a1a1 can only read public contracts created by self",
+ // granted: []string{
+ // "public://0x0000000000000000000000000000000000a1a1a1/read/contracts?owned.eoa=0x0000000000000000000000000000000000a1a1a1",
+ // },
+ // ask: []*ContractSecurityAttribute{{
+ // AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ // From: common.HexToAddress("0xa1a1a1"),
+ // To: common.HexToAddress("0xb1b1b1"),
+ // },
+ // Visibility: VisibilityPublic,
+ // Action: ActionRead,
+ // }},
+ // isAuthorized: false,
+ //}
+ canWriteOwnedPublicContracts = &testCase{
+ msg: "0x0a1a1a1 can send transactions to public contracts created by self",
+ granted: []string{
+ "public://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000a1a1a1",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPublic,
+ Action: ActionWrite,
+ }},
+ isAuthorized: true,
+ }
+ canWriteOtherPublicContracts1 = &testCase{
+ msg: "0xa1a1a1 can send transactions to public contracts created by 0xb1b1b1",
+ granted: []string{
+ "public://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000b1b1b1&owned.eoa=0x0000000000000000000000000000000000c1c1c1",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"),
+ },
+ Visibility: VisibilityPublic,
+ Action: ActionWrite,
+ }},
+ isAuthorized: true,
+ }
+ canWriteOtherPublicContracts2 = &testCase{
+ msg: "0xa1a1a1 can send transactions to public contracts created by 0xb1b1b1",
+ granted: []string{
+ "public://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000b1b1b1&owned.eoa=0x0000000000000000000000000000000000c1c1c1",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xc1c1c1"),
+ },
+ Visibility: VisibilityPublic,
+ Action: ActionWrite,
+ }},
+ isAuthorized: true,
+ }
+ //canNotWriteOtherPublicContracts = &testCase{
+ // msg: "0x0a1a1a1 can only send transactions to public contracts created by self",
+ // granted: []string{
+ // "public://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000a1a1a1",
+ // "public://0x0000000000000000000000000000000000a1a1a1/read/contracts?owned.eoa=0x0000000000000000000000000000000000a1a1a1",
+ // },
+ // ask: []*ContractSecurityAttribute{{
+ // AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ // From: common.HexToAddress("0xa1a1a1"),
+ // To: common.HexToAddress("0xb1b1b1"),
+ // },
+ // Visibility: VisibilityPublic,
+ // Action: ActionWrite,
+ // }},
+ // isAuthorized: false,
+ //}
+ // private contracts
+ canCreatePrivateContracts = &testCase{
+ msg: "0x0a1a1a1 can create private contracts with sender key A",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/create/contracts?from.tm=A",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionCreate,
+ PrivateFrom: "A",
+ Parties: []string{},
+ }},
+ isAuthorized: true,
+ }
+ canNotCreatePrivateContracts = &testCase{
+ msg: "0x0a1a1a1 can NOT create private contracts with sender key A if only own key B",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/create/contracts?from.tm=B",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionCreate,
+ PrivateFrom: "A",
+ Parties: []string{},
+ }},
+ isAuthorized: false,
+ }
+ canReadOwnedPrivateContracts = &testCase{
+ msg: "0x0a1a1a1 can read private contracts created by self and was privy to a key A",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/read/contracts?owned.eoa=0x0000000000000000000000000000000000a1a1a1&from.tm=A&from.tm=B",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ Parties: []string{"A"},
+ }},
+ isAuthorized: true,
+ }
+ canReadOtherPrivateContracts = &testCase{
+ msg: "0x0a1a1a1 can read private contracts created by 0xb1b1b1 and was privy to a key A",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/read/contracts?owned.eoa=0x0000000000000000000000000000000000b1b1b1&from.tm=A",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ Parties: []string{"A"},
+ }},
+ isAuthorized: true,
+ }
+ canNotReadOtherPrivateContracts = &testCase{
+ msg: "0x0a1a1a1 can NOT read private contracts created by 0xb1b1b1 even it was privy to a key A",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/read/contracts?owned.eoa=0x0000000000000000000000000000000000c1c1c1&from.tm=A",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ Parties: []string{"A"},
+ }},
+ isAuthorized: false,
+ }
+ canNotReadOtherPrivateContractsNoPrivy = &testCase{
+ msg: "0x0a1a1a1 can NOT read private contracts created by 0xb1b1b1 as it was privy to a key B",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/read/contracts?owned.eoa=0x0000000000000000000000000000000000b1b1b1&from.tm=B",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionRead,
+ Parties: []string{"A"},
+ }},
+ isAuthorized: false,
+ }
+ canWriteOwnedPrivateContracts = &testCase{
+ msg: "0x0a1a1a1 can write private contracts created by self and was privy to a key A",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000a1a1a1&from.tm=A&from.tm=B",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ PrivateFrom: "A",
+ Parties: []string{"A"},
+ }},
+ isAuthorized: true,
+ }
+ canWriteOtherPrivateContracts = &testCase{
+ msg: "0x0a1a1a1 can write private contracts created by 0xb1b1b1 and was privy to a key A",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000b1b1b1&from.tm=A",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ PrivateFrom: "A",
+ Parties: []string{"A"},
+ }},
+ isAuthorized: true,
+ }
+ canWriteOtherPrivateContractsWithOverlappedScope = &testCase{
+ msg: "0x0a1a1a1 can write private contracts created by 0xb1b1b1 and was privy to a key A",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000b1b1b1&from.tm=A",
+ "private://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000b1b1b1&from.tm=A&from.tm=B",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ PrivateFrom: "A",
+ Parties: []string{"A"},
+ }},
+ isAuthorized: true,
+ }
+ canNotWriteOtherPrivateContracts = &testCase{
+ msg: "0x0a1a1a1 can NOT write private contracts created by 0xb1b1b1 even it was privy to a key A",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000c1c1c1&from.tm=A",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ Parties: []string{"A"},
+ }},
+ isAuthorized: false,
+ }
+ canNotWriteOtherPrivateContractsNoPrivy = &testCase{
+ msg: "0x0a1a1a1 can NOT write private contracts created by 0xb1b1b1 as it was privy to a key B",
+ granted: []string{
+ "private://0x0000000000000000000000000000000000a1a1a1/write/contracts?owned.eoa=0x0000000000000000000000000000000000b1b1b1&from.tm=B",
+ },
+ ask: []*ContractSecurityAttribute{{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{
+ From: common.HexToAddress("0xa1a1a1"),
+ To: common.HexToAddress("0xb1b1b1"),
+ },
+ Visibility: VisibilityPrivate,
+ Action: ActionWrite,
+ Parties: []string{"A"},
+ }},
+ isAuthorized: false,
+ }
+)
diff --git a/multitenancy/types.go b/multitenancy/types.go
new file mode 100644
index 0000000000..4ee3f625a0
--- /dev/null
+++ b/multitenancy/types.go
@@ -0,0 +1,200 @@
+package multitenancy
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
+)
+
+type ContractVisibility string
+type ContractAction string
+
+const (
+ VisibilityPublic ContractVisibility = "public"
+ VisibilityPrivate ContractVisibility = "private"
+ ActionRead ContractAction = "read"
+ ActionWrite ContractAction = "write"
+ ActionCreate ContractAction = "create"
+
+ // QueryOwnedEOA query parameter is to capture the EOA address
+ // For value transfer, it represents the account owner
+ // For message call, it represents the EOA that signed the contract creation transaction
+ // in other words, the EOA that owns the contract
+ QueryOwnedEOA = "owned.eoa"
+ // QueryToEOA query parameter is to capture the EOA address which is the
+ // target account in value transfer scenarios
+ QueryToEOA = "to.eoa"
+ // QueryFromTM query parameter is to capture the Tessera Public Key
+ // which indicates the sender of a private transaction or participant of a private contract
+ QueryFromTM = "from.tm"
+
+ // AnyEOAAddress represents wild card for EOA address
+ AnyEOAAddress = "0x0"
+)
+
+// Multitenancy support
+type ContextAware interface {
+ SupportsMultitenancy(ctx context.Context) (*proto.PreAuthenticatedAuthenticationToken, bool)
+}
+
+// AuthorizationProvider specifies APIs to be implemented to provide multitenancy capability
+type AuthorizationProvider interface {
+ ContextAware
+ ContractAuthorizationProvider
+}
+
+// AccountStateSecurityAttribute contains security configuration ask
+// which are defined for a secure account state
+type AccountStateSecurityAttribute struct {
+ From common.Address // Ethereum Account Address
+ To common.Address
+}
+
+func (assa *AccountStateSecurityAttribute) String() string {
+ return fmt.Sprintf("from=%s to=%s", assa.From.Hex(), assa.To.Hex())
+}
+
+// ContractSecurityAttribute contains security configuration ask
+// which are defined for a secure contract account
+type ContractSecurityAttribute struct {
+ *AccountStateSecurityAttribute
+ Visibility ContractVisibility // public/private
+ Action ContractAction // create/read/write
+ PrivateFrom string // TM Key, only if Visibility is private, for write/create
+ Parties []string // TM Keys, only if Visibility is private, for read
+}
+
+func (csa *ContractSecurityAttribute) String() string {
+ return fmt.Sprintf("%v visibility=%s action=%s privateFrom=%s parties=%v", csa.AccountStateSecurityAttribute, csa.Visibility, csa.Action, csa.PrivateFrom, csa.Parties)
+}
+
+type ContractSecurityAttributeBuilder struct {
+ secAttr ContractSecurityAttribute
+}
+
+func NewContractSecurityAttributeBuilder() *ContractSecurityAttributeBuilder {
+ return &ContractSecurityAttributeBuilder{
+ secAttr: ContractSecurityAttribute{
+ AccountStateSecurityAttribute: &AccountStateSecurityAttribute{},
+ Parties: make([]string, 0),
+ },
+ }
+}
+
+func (csab *ContractSecurityAttributeBuilder) FromEOA(eoa common.Address) *ContractSecurityAttributeBuilder {
+ csab.secAttr.AccountStateSecurityAttribute.From = eoa
+ return csab
+}
+
+// ethereum account destination
+func (csab *ContractSecurityAttributeBuilder) ToEOA(eoa common.Address) *ContractSecurityAttributeBuilder {
+ csab.secAttr.AccountStateSecurityAttribute.To = eoa
+ return csab
+}
+
+func (csab *ContractSecurityAttributeBuilder) PrivateFrom(tmPubKey string) *ContractSecurityAttributeBuilder {
+ csab.secAttr.PrivateFrom = tmPubKey
+ return csab
+}
+
+// set privateFrom only if b is true, ignore otherwise
+func (csab *ContractSecurityAttributeBuilder) PrivateFromOnlyIf(b bool, tmPubKey string) *ContractSecurityAttributeBuilder {
+ if b {
+ csab.secAttr.PrivateFrom = tmPubKey
+ }
+ return csab
+}
+
+func (csab *ContractSecurityAttributeBuilder) Visibility(v ContractVisibility) *ContractSecurityAttributeBuilder {
+ csab.secAttr.Visibility = v
+ return csab
+}
+
+func (csab *ContractSecurityAttributeBuilder) Private() *ContractSecurityAttributeBuilder {
+ return csab.Visibility(VisibilityPrivate)
+}
+
+// set VisibilityPrivate if b is true, VisibilityPublic otherwise
+func (csab *ContractSecurityAttributeBuilder) PrivateIf(b bool) *ContractSecurityAttributeBuilder {
+ if b {
+ return csab.Visibility(VisibilityPrivate)
+ } else {
+ return csab.Visibility(VisibilityPublic)
+ }
+}
+
+func (csab *ContractSecurityAttributeBuilder) Public() *ContractSecurityAttributeBuilder {
+ return csab.Visibility(VisibilityPublic)
+}
+
+func (csab *ContractSecurityAttributeBuilder) Action(a ContractAction) *ContractSecurityAttributeBuilder {
+ csab.secAttr.Action = a
+ return csab
+}
+
+func (csab *ContractSecurityAttributeBuilder) Create() *ContractSecurityAttributeBuilder {
+ return csab.Action(ActionCreate)
+}
+
+func (csab *ContractSecurityAttributeBuilder) Read() *ContractSecurityAttributeBuilder {
+ return csab.Action(ActionRead)
+}
+
+func (csab *ContractSecurityAttributeBuilder) Write() *ContractSecurityAttributeBuilder {
+ return csab.Action(ActionWrite)
+}
+
+// set ActionRead only if b is true, ignore otherwise
+func (csab *ContractSecurityAttributeBuilder) ReadOnlyIf(b bool) *ContractSecurityAttributeBuilder {
+ if b {
+ return csab.Action(ActionRead)
+ } else {
+ return csab
+ }
+}
+
+// set ActionWrite only if b is true, ignore otherwise
+func (csab *ContractSecurityAttributeBuilder) WriteOnlyIf(b bool) *ContractSecurityAttributeBuilder {
+ if b {
+ return csab.Action(ActionWrite)
+ } else {
+ return csab
+ }
+}
+
+// set Parties only if b is true, ignore otherwise
+func (csab *ContractSecurityAttributeBuilder) PartiesOnlyIf(b bool, tmPubKeys []string) *ContractSecurityAttributeBuilder {
+ if b {
+ return csab.Parties(tmPubKeys)
+ }
+ return csab
+}
+
+func (csab *ContractSecurityAttributeBuilder) Parties(tmPubKeys []string) *ContractSecurityAttributeBuilder {
+ parties := make([]string, len(tmPubKeys))
+ copy(parties, tmPubKeys)
+ csab.secAttr.Parties = parties
+ return csab
+}
+
+func (csab *ContractSecurityAttributeBuilder) Party(tmPubKey string) *ContractSecurityAttributeBuilder {
+ csab.secAttr.Parties = append(csab.secAttr.Parties, tmPubKey)
+ return csab
+}
+
+func (csab *ContractSecurityAttributeBuilder) Build() *ContractSecurityAttribute {
+ return &csab.secAttr
+}
+
+// FullAccessContractSecurityAttributes returns a list of contract security attributes.
+// The attributes are used to verify ownership of a TM key which is going to be used
+// to send a private transaction.
+func FullAccessContractSecurityAttributes(fromEOA common.Address, privateFrom string) []*ContractSecurityAttribute {
+ return []*ContractSecurityAttribute{
+ NewContractSecurityAttributeBuilder().FromEOA(fromEOA).Private().Create().PrivateFrom(privateFrom).Build(),
+ NewContractSecurityAttributeBuilder().FromEOA(fromEOA).Private().Write().Party(privateFrom).Build(),
+ NewContractSecurityAttributeBuilder().FromEOA(fromEOA).Private().Read().Party(privateFrom).Build(),
+ }
+}
diff --git a/plugin/security/service.go b/plugin/security/service.go
index d508bf15d6..30cce8b06d 100644
--- a/plugin/security/service.go
+++ b/plugin/security/service.go
@@ -17,10 +17,10 @@ type AuthenticationManager interface {
IsEnabled(ctx context.Context) (bool, error)
}
-type AuthentiationManagerDeferFunc func() (AuthenticationManager, error)
+type AuthenticationManagerDeferFunc func() (AuthenticationManager, error)
type DeferredAuthenticationManager struct {
- deferFunc AuthentiationManagerDeferFunc
+ deferFunc AuthenticationManagerDeferFunc
}
func (d *DeferredAuthenticationManager) Authenticate(ctx context.Context, token string) (*proto.PreAuthenticatedAuthenticationToken, error) {
@@ -39,7 +39,7 @@ func (d *DeferredAuthenticationManager) IsEnabled(ctx context.Context) (bool, er
return am.IsEnabled(ctx)
}
-func NewDeferredAuthenticationManager(deferFunc AuthentiationManagerDeferFunc) *DeferredAuthenticationManager {
+func NewDeferredAuthenticationManager(deferFunc AuthenticationManagerDeferFunc) *DeferredAuthenticationManager {
return &DeferredAuthenticationManager{
deferFunc: deferFunc,
}
diff --git a/private/engine/common.go b/private/engine/common.go
index 62dc4d7155..c08347f4b5 100644
--- a/private/engine/common.go
+++ b/private/engine/common.go
@@ -21,8 +21,13 @@ type ExtraMetadata struct {
ACHashes common.EncryptedPayloadHashes
// Root Hash of a Merkle Trie containing all affected contract account in state objects
ACMerkleRoot common.Hash
- //Privacy flag for contract: standardPrivate, partyProtection, psv
+ // Privacy flag for contract: standardPrivate, partyProtection, psv
PrivacyFlag PrivacyFlagType
+ // Contract participants that are managed by the corresponding Tessera.
+ // Being used in Multi Tenancy
+ ManagedParties []string
+ // the sender of the transaction
+ Sender string
}
type Client struct {
@@ -82,6 +87,7 @@ type PrivateTransactionManagerFeature uint64
const (
None PrivateTransactionManagerFeature = iota // 0
PrivacyEnhancements PrivateTransactionManagerFeature = 1 << PrivateTransactionManagerFeature(iota-1) // 1
+ MultiTenancy PrivateTransactionManagerFeature = 1 << PrivateTransactionManagerFeature(iota-1) // 2
)
type FeatureSet struct {
diff --git a/private/engine/constellation/constellation.go b/private/engine/constellation/constellation.go
index ce34471245..ef6dcd53ba 100644
--- a/private/engine/constellation/constellation.go
+++ b/private/engine/constellation/constellation.go
@@ -30,20 +30,20 @@ func New(client *engine.Client) *constellation {
}
}
-func (g *constellation) Send(data []byte, from string, to []string, extra *engine.ExtraMetadata) (common.EncryptedPayloadHash, error) {
+func (g *constellation) Send(data []byte, from string, to []string, extra *engine.ExtraMetadata) (string, []string, common.EncryptedPayloadHash, error) {
if extra.PrivacyFlag.IsNotStandardPrivate() {
- return common.EncryptedPayloadHash{}, engine.ErrPrivateTxManagerDoesNotSupportPrivacyEnhancements
+ return "", nil, common.EncryptedPayloadHash{}, engine.ErrPrivateTxManagerDoesNotSupportPrivacyEnhancements
}
out, err := g.node.SendPayload(data, from, to, extra.ACHashes, extra.ACMerkleRoot)
if err != nil {
- return common.EncryptedPayloadHash{}, err
+ return "", nil, common.EncryptedPayloadHash{}, err
}
cacheKey := string(out.Bytes())
g.c.Set(cacheKey, cache.PrivateCacheItem{
Payload: data,
Extra: *extra,
}, cache.DefaultExpiration)
- return out, nil
+ return "", nil, out, nil
}
func (g *constellation) EncryptPayload(data []byte, from string, to []string, extra *engine.ExtraMetadata) ([]byte, error) {
@@ -58,12 +58,12 @@ func (g *constellation) StoreRaw(data []byte, from string) (common.EncryptedPayl
return common.EncryptedPayloadHash{}, engine.ErrPrivateTxManagerNotSupported
}
-func (g *constellation) SendSignedTx(data common.EncryptedPayloadHash, to []string, extra *engine.ExtraMetadata) (out []byte, err error) {
- return nil, engine.ErrPrivateTxManagerNotSupported
+func (g *constellation) SendSignedTx(data common.EncryptedPayloadHash, to []string, extra *engine.ExtraMetadata) (string, []string, []byte, error) {
+ return "", nil, nil, engine.ErrPrivateTxManagerNotSupported
}
-func (g *constellation) ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
- return nil, nil, engine.ErrPrivateTxManagerNotSupported
+func (g *constellation) ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, string, *engine.ExtraMetadata, error) {
+ return nil, "", nil, engine.ErrPrivateTxManagerNotSupported
}
func (g *constellation) IsSender(txHash common.EncryptedPayloadHash) (bool, error) {
@@ -74,9 +74,9 @@ func (g *constellation) GetParticipants(txHash common.EncryptedPayloadHash) ([]s
return nil, engine.ErrPrivateTxManagerNotSupported
}
-func (g *constellation) Receive(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
+func (g *constellation) Receive(data common.EncryptedPayloadHash) (string, []string, []byte, *engine.ExtraMetadata, error) {
if common.EmptyEncryptedPayloadHash(data) {
- return nil, nil, nil
+ return "", nil, nil, nil, nil
}
// Ignore this error since not being a recipient of
// a payload isn't an error.
@@ -87,13 +87,13 @@ func (g *constellation) Receive(data common.EncryptedPayloadHash) ([]byte, *engi
if found {
cacheItem, ok := x.(cache.PrivateCacheItem)
if !ok {
- return nil, nil, fmt.Errorf("unknown cache item. expected type PrivateCacheItem")
+ return "", nil, nil, nil, fmt.Errorf("unknown cache item. expected type PrivateCacheItem")
}
- return cacheItem.Payload, &cacheItem.Extra, nil
+ return "", nil, cacheItem.Payload, &cacheItem.Extra, nil
}
privatePayload, acHashes, acMerkleRoot, err := g.node.ReceivePayload(data)
if nil != err {
- return nil, nil, err
+ return "", nil, nil, nil, err
}
extra := engine.ExtraMetadata{
ACHashes: acHashes,
@@ -103,7 +103,7 @@ func (g *constellation) Receive(data common.EncryptedPayloadHash) ([]byte, *engi
Payload: privatePayload,
Extra: extra,
}, cache.DefaultExpiration)
- return privatePayload, &extra, nil
+ return "", nil, privatePayload, &extra, nil
}
func (g *constellation) Name() string {
diff --git a/private/engine/notinuse/notInUsePrivateTxManager.go b/private/engine/notinuse/notInUsePrivateTxManager.go
index 3ed1a9149b..22c7190445 100644
--- a/private/engine/notinuse/notInUsePrivateTxManager.go
+++ b/private/engine/notinuse/notInUsePrivateTxManager.go
@@ -21,8 +21,8 @@ func (ptm *PrivateTransactionManager) GetParticipants(txHash common.EncryptedPay
panic("implement me")
}
-func (ptm *PrivateTransactionManager) Send(data []byte, from string, to []string, extra *engine.ExtraMetadata) (common.EncryptedPayloadHash, error) {
- return common.EncryptedPayloadHash{}, engine.ErrPrivateTxManagerNotinUse
+func (ptm *PrivateTransactionManager) Send(data []byte, from string, to []string, extra *engine.ExtraMetadata) (string, []string, common.EncryptedPayloadHash, error) {
+ return "", nil, common.EncryptedPayloadHash{}, engine.ErrPrivateTxManagerNotinUse
}
func (ptm *PrivateTransactionManager) EncryptPayload(data []byte, from string, to []string, extra *engine.ExtraMetadata) ([]byte, error) {
@@ -37,17 +37,17 @@ func (ptm *PrivateTransactionManager) StoreRaw(data []byte, from string) (common
return common.EncryptedPayloadHash{}, engine.ErrPrivateTxManagerNotinUse
}
-func (ptm *PrivateTransactionManager) SendSignedTx(data common.EncryptedPayloadHash, to []string, extra *engine.ExtraMetadata) ([]byte, error) {
- return nil, engine.ErrPrivateTxManagerNotinUse
+func (ptm *PrivateTransactionManager) SendSignedTx(data common.EncryptedPayloadHash, to []string, extra *engine.ExtraMetadata) (string, []string, []byte, error) {
+ return "", nil, nil, engine.ErrPrivateTxManagerNotinUse
}
-func (ptm *PrivateTransactionManager) Receive(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
+func (ptm *PrivateTransactionManager) Receive(data common.EncryptedPayloadHash) (string, []string, []byte, *engine.ExtraMetadata, error) {
//error not thrown here, acts as though no private data to fetch
- return nil, nil, nil
+ return "", nil, nil, nil, nil
}
-func (ptm *PrivateTransactionManager) ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
- return nil, nil, engine.ErrPrivateTxManagerNotinUse
+func (ptm *PrivateTransactionManager) ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, string, *engine.ExtraMetadata, error) {
+ return nil, "", nil, engine.ErrPrivateTxManagerNotinUse
}
func (ptm *PrivateTransactionManager) Name() string {
diff --git a/private/engine/notinuse/notInUsePrivateTxManager_test.go b/private/engine/notinuse/notInUsePrivateTxManager_test.go
index f5215e8a87..e797177d20 100644
--- a/private/engine/notinuse/notInUsePrivateTxManager_test.go
+++ b/private/engine/notinuse/notInUsePrivateTxManager_test.go
@@ -18,7 +18,7 @@ func TestName(t *testing.T) {
func TestSendReturnsError(t *testing.T) {
ptm := &PrivateTransactionManager{}
- _, err := ptm.Send([]byte{}, "", []string{}, nil)
+ _, _, _, err := ptm.Send([]byte{}, "", []string{}, nil)
assert.Equal(t, err, engine.ErrPrivateTxManagerNotinUse, "got wrong error in 'send'")
}
@@ -34,7 +34,7 @@ func TestStoreRawReturnsError(t *testing.T) {
func TestReceiveReturnsEmpty(t *testing.T) {
ptm := &PrivateTransactionManager{}
- data, metadata, err := ptm.Receive(common.EncryptedPayloadHash{})
+ _, _, data, metadata, err := ptm.Receive(common.EncryptedPayloadHash{})
assert.Nil(t, err, "got unexpected error in 'receive'")
assert.Nil(t, data, "got unexpected data in 'receive'")
@@ -44,7 +44,7 @@ func TestReceiveReturnsEmpty(t *testing.T) {
func TestReceiveRawReturnsError(t *testing.T) {
ptm := &PrivateTransactionManager{}
- _, _, err := ptm.ReceiveRaw(common.EncryptedPayloadHash{})
+ _, _, _, err := ptm.ReceiveRaw(common.EncryptedPayloadHash{})
assert.Equal(t, err, engine.ErrPrivateTxManagerNotinUse, "got wrong error in 'send'")
}
@@ -52,7 +52,7 @@ func TestReceiveRawReturnsError(t *testing.T) {
func TestSendSignedTxReturnsError(t *testing.T) {
ptm := &PrivateTransactionManager{}
- _, err := ptm.SendSignedTx(common.EncryptedPayloadHash{}, []string{}, nil)
+ _, _, _, err := ptm.SendSignedTx(common.EncryptedPayloadHash{}, []string{}, nil)
assert.Equal(t, err, engine.ErrPrivateTxManagerNotinUse, "got wrong error in 'SendSignedTx'")
}
diff --git a/private/engine/tessera/model.go b/private/engine/tessera/model.go
index 031e2b4545..a37b40d4b8 100644
--- a/private/engine/tessera/model.go
+++ b/private/engine/tessera/model.go
@@ -32,6 +32,10 @@ type storerawRequest struct {
type sendResponse struct {
// Base64-encoded
Key string `json:"key"`
+ // Public Keys
+ ManagedParties []string `json:"managedParties"`
+ // Sender tessera public key
+ SenderKey string `json:"senderKey"`
}
type receiveResponse struct {
@@ -44,6 +48,11 @@ type receiveResponse struct {
ExecHash string `json:"execHash"`
PrivacyFlag engine.PrivacyFlagType `json:"privacyFlag"`
+
+ // Public Keys
+ ManagedParties []string `json:"managedParties"`
+ // Sender tessera public key
+ SenderKey string `json:"senderKey"`
}
type sendSignedTxRequest struct {
@@ -60,6 +69,10 @@ type sendSignedTxRequest struct {
type sendSignedTxResponse struct {
// Base64-encoded
Key string `json:"key"`
+ // Public Keys
+ ManagedParties []string `json:"managedParties"`
+ // Sender tessera public key
+ SenderKey string `json:"senderKey"`
}
type encryptPayloadResponse struct {
diff --git a/private/engine/tessera/tessera.go b/private/engine/tessera/tessera.go
index 0dac27e47d..f7e4c8b8d1 100644
--- a/private/engine/tessera/tessera.go
+++ b/private/engine/tessera/tessera.go
@@ -43,7 +43,32 @@ func New(client *engine.Client, version []byte) *tesseraPrivateTxManager {
}
func (t *tesseraPrivateTxManager) submitJSON(method, path string, request interface{}, response interface{}) (int, error) {
- req, err := newOptionalJSONRequest(method, t.client.FullPath(path), request)
+ apiVersion := ""
+ if t.features.HasFeature(engine.MultiTenancy) {
+ apiVersion = "vnd.tessera-2.1+"
+ }
+ req, err := newOptionalJSONRequest(method, t.client.FullPath(path), request, apiVersion)
+ if err != nil {
+ return -1, fmt.Errorf("unable to build json request for (method:%s,path:%s). Cause: %v", method, path, err)
+ }
+ res, err := t.client.HttpClient.Do(req)
+ if err != nil {
+ return -1, fmt.Errorf("unable to submit request (method:%s,path:%s). Cause: %v", method, path, err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusCreated {
+ body, _ := ioutil.ReadAll(res.Body)
+ return res.StatusCode, fmt.Errorf("%d status: %s", res.StatusCode, string(body))
+ }
+ if err := json.NewDecoder(res.Body).Decode(response); err != nil {
+ return res.StatusCode, fmt.Errorf("unable to decode response body for (method:%s,path:%s). Cause: %v", method, path, err)
+ }
+ return res.StatusCode, nil
+}
+
+func (t *tesseraPrivateTxManager) submitJSONOld(method, path string, request interface{}, response interface{}) (int, error) {
+ apiVersion := ""
+ req, err := newOptionalJSONRequest(method, t.client.FullPath(path), request, apiVersion)
if err != nil {
return -1, fmt.Errorf("unable to build json request for (method:%s,path:%s). Cause: %v", method, path, err)
}
@@ -62,9 +87,9 @@ func (t *tesseraPrivateTxManager) submitJSON(method, path string, request interf
return res.StatusCode, nil
}
-func (t *tesseraPrivateTxManager) Send(data []byte, from string, to []string, extra *engine.ExtraMetadata) (common.EncryptedPayloadHash, error) {
+func (t *tesseraPrivateTxManager) Send(data []byte, from string, to []string, extra *engine.ExtraMetadata) (string, []string, common.EncryptedPayloadHash, error) {
if extra.PrivacyFlag.IsNotStandardPrivate() && !t.features.HasFeature(engine.PrivacyEnhancements) {
- return common.EncryptedPayloadHash{}, engine.ErrPrivateTxManagerDoesNotSupportPrivacyEnhancements
+ return "", nil, common.EncryptedPayloadHash{}, engine.ErrPrivateTxManagerDoesNotSupportPrivacyEnhancements
}
response := new(sendResponse)
acMerkleRoot := ""
@@ -79,21 +104,27 @@ func (t *tesseraPrivateTxManager) Send(data []byte, from string, to []string, ex
ExecHash: acMerkleRoot,
PrivacyFlag: extra.PrivacyFlag,
}, response); err != nil {
- return common.EncryptedPayloadHash{}, err
+ return "", nil, common.EncryptedPayloadHash{}, err
}
eph, err := common.Base64ToEncryptedPayloadHash(response.Key)
if err != nil {
- return common.EncryptedPayloadHash{}, fmt.Errorf("unable to decode encrypted payload hash: %s. Cause: %v", response.Key, err)
+ return "", nil, common.EncryptedPayloadHash{}, fmt.Errorf("unable to decode encrypted payload hash: %s. Cause: %v", response.Key, err)
}
cacheKey := eph.Hex()
t.cache.Set(cacheKey, cache.PrivateCacheItem{
Payload: data,
- Extra: *extra,
+ Extra: engine.ExtraMetadata{
+ ACHashes: extra.ACHashes,
+ ACMerkleRoot: extra.ACMerkleRoot,
+ PrivacyFlag: extra.PrivacyFlag,
+ ManagedParties: response.ManagedParties,
+ Sender: response.SenderKey,
+ },
}, gocache.DefaultExpiration)
- return eph, nil
+ return response.SenderKey, response.ManagedParties, eph, nil
}
func (t *tesseraPrivateTxManager) EncryptPayload(data []byte, from string, to []string, extra *engine.ExtraMetadata) ([]byte, error) {
@@ -146,34 +177,39 @@ func (t *tesseraPrivateTxManager) StoreRaw(data []byte, from string) (common.Enc
}
// allow new quorum to send raw transactions when connected to an old tessera
-func (c *tesseraPrivateTxManager) sendSignedPayloadOctetStream(signedPayload []byte, b64To []string) ([]byte, error) {
+func (c *tesseraPrivateTxManager) sendSignedPayloadOctetStream(signedPayload []byte, b64To []string) (string, []string, []byte, error) {
buf := bytes.NewBuffer(signedPayload)
req, err := http.NewRequest("POST", c.client.FullPath("/sendsignedtx"), buf)
if err != nil {
- return nil, err
+ return "", nil, nil, err
}
req.Header.Set("c11n-to", strings.Join(b64To, ","))
req.Header.Set("Content-Type", "application/octet-stream")
res, err := c.client.HttpClient.Do(req)
+ if err != nil {
+ return "", nil, nil, err
+ }
+ defer res.Body.Close()
- if res != nil {
- defer res.Body.Close()
+ if res.StatusCode != 200 {
+ return "", nil, nil, fmt.Errorf("Non-200 status code: %+v", res)
}
+ data, err := ioutil.ReadAll(res.Body)
if err != nil {
- return nil, err
+ return "", nil, nil, err
}
- if res.StatusCode != 200 {
- return nil, fmt.Errorf("Non-200 status code: %+v", res)
+ sender := ""
+ if len(res.Header["Tesserasender"]) > 0 {
+ sender = res.Header["Tesserasender"][0]
}
-
- return ioutil.ReadAll(res.Body)
+ return sender, res.Header["Tesseramanagedparties"], data, nil
}
// also populate cache item with additional extra metadata
-func (t *tesseraPrivateTxManager) SendSignedTx(data common.EncryptedPayloadHash, to []string, extra *engine.ExtraMetadata) ([]byte, error) {
+func (t *tesseraPrivateTxManager) SendSignedTx(data common.EncryptedPayloadHash, to []string, extra *engine.ExtraMetadata) (string, []string, []byte, error) {
if extra.PrivacyFlag.IsNotStandardPrivate() && !t.features.HasFeature(engine.PrivacyEnhancements) {
- return nil, engine.ErrPrivateTxManagerDoesNotSupportPrivacyEnhancements
+ return "", nil, nil, engine.ErrPrivateTxManagerDoesNotSupportPrivacyEnhancements
}
response := new(sendSignedTxResponse)
acMerkleRoot := ""
@@ -190,19 +226,21 @@ func (t *tesseraPrivateTxManager) SendSignedTx(data common.EncryptedPayloadHash,
ExecHash: acMerkleRoot,
PrivacyFlag: extra.PrivacyFlag,
}, response); err != nil {
- return nil, err
+ return "", nil, nil, err
}
} else {
- returnedHash, err := t.sendSignedPayloadOctetStream(data.Bytes(), to)
+ sender, managedParties, returnedHash, err := t.sendSignedPayloadOctetStream(data.Bytes(), to)
if err != nil {
- return nil, err
+ return "", nil, nil, err
}
response.Key = string(returnedHash)
+ response.ManagedParties = managedParties
+ response.SenderKey = sender
}
hashBytes, err := base64.StdEncoding.DecodeString(response.Key)
if err != nil {
- return nil, err
+ return "", nil, nil, err
}
// pull incomplete cache item and inject new cache item with complete information
cacheKey := data.Hex()
@@ -211,28 +249,35 @@ func (t *tesseraPrivateTxManager) SendSignedTx(data common.EncryptedPayloadHash,
if incompleteCacheItem, ok := item.(cache.PrivateCacheItem); ok {
t.cache.Set(cacheKey, cache.PrivateCacheItem{
Payload: incompleteCacheItem.Payload,
- Extra: *extra,
+ Extra: engine.ExtraMetadata{
+ ACHashes: extra.ACHashes,
+ ACMerkleRoot: extra.ACMerkleRoot,
+ PrivacyFlag: extra.PrivacyFlag,
+ ManagedParties: response.ManagedParties,
+ Sender: response.SenderKey,
+ },
}, gocache.DefaultExpiration)
t.cache.Delete(cacheKeyTemp)
}
}
- return hashBytes, err
+ return response.SenderKey, response.ManagedParties, hashBytes, err
}
-func (t *tesseraPrivateTxManager) Receive(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
- return t.receive(data, false)
+func (t *tesseraPrivateTxManager) Receive(hash common.EncryptedPayloadHash) (string, []string, []byte, *engine.ExtraMetadata, error) {
+ return t.receive(hash, false)
}
// retrieve raw will not return information about medata.
// Related to SendSignedTx
-func (t *tesseraPrivateTxManager) ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error) {
- return t.receive(data, true)
+func (t *tesseraPrivateTxManager) ReceiveRaw(hash common.EncryptedPayloadHash) ([]byte, string, *engine.ExtraMetadata, error) {
+ sender, _, data, extra, err := t.receive(hash, true)
+ return data, sender, extra, err
}
// retrieve raw will not return information about medata
-func (t *tesseraPrivateTxManager) receive(data common.EncryptedPayloadHash, isRaw bool) ([]byte, *engine.ExtraMetadata, error) {
+func (t *tesseraPrivateTxManager) receive(data common.EncryptedPayloadHash, isRaw bool) (string, []string, []byte, *engine.ExtraMetadata, error) {
if common.EmptyEncryptedPayloadHash(data) {
- return nil, nil, nil
+ return "", nil, nil, nil, nil
}
cacheKey := data.Hex()
if isRaw {
@@ -242,33 +287,35 @@ func (t *tesseraPrivateTxManager) receive(data common.EncryptedPayloadHash, isRa
if item, found := t.cache.Get(cacheKey); found {
cacheItem, ok := item.(cache.PrivateCacheItem)
if !ok {
- return nil, nil, fmt.Errorf("unknown cache item. expected type PrivateCacheItem")
+ return "", nil, nil, nil, fmt.Errorf("unknown cache item. expected type PrivateCacheItem")
}
- return cacheItem.Payload, &cacheItem.Extra, nil
+ return cacheItem.Extra.Sender, cacheItem.Extra.ManagedParties, cacheItem.Payload, &cacheItem.Extra, nil
}
response := new(receiveResponse)
if statusCode, err := t.submitJSON("GET", fmt.Sprintf("/transaction/%s?isRaw=%v", url.PathEscape(data.ToBase64()), isRaw), nil, response); err != nil {
if statusCode == http.StatusNotFound {
- return nil, nil, nil
+ return "", nil, nil, nil, nil
} else {
- return nil, nil, err
+ return "", nil, nil, nil, err
}
}
var extra engine.ExtraMetadata
if !isRaw {
acHashes, err := common.Base64sToEncryptedPayloadHashes(response.AffectedContractTransactions)
if err != nil {
- return nil, nil, fmt.Errorf("unable to decode ACOTHs %v. Cause: %v", response.AffectedContractTransactions, err)
+ return "", nil, nil, nil, fmt.Errorf("unable to decode ACOTHs %v. Cause: %v", response.AffectedContractTransactions, err)
}
acMerkleRoot, err := common.Base64ToHash(response.ExecHash)
if err != nil {
- return nil, nil, fmt.Errorf("unable to decode execution hash %s. Cause: %v", response.ExecHash, err)
+ return "", nil, nil, nil, fmt.Errorf("unable to decode execution hash %s. Cause: %v", response.ExecHash, err)
}
extra = engine.ExtraMetadata{
- ACHashes: acHashes,
- ACMerkleRoot: acMerkleRoot,
- PrivacyFlag: response.PrivacyFlag,
+ ACHashes: acHashes,
+ ACMerkleRoot: acMerkleRoot,
+ PrivacyFlag: response.PrivacyFlag,
+ ManagedParties: response.ManagedParties,
+ Sender: response.SenderKey,
}
}
@@ -277,7 +324,7 @@ func (t *tesseraPrivateTxManager) receive(data common.EncryptedPayloadHash, isRa
Extra: extra,
}, gocache.DefaultExpiration)
- return response.Payload, &extra, nil
+ return response.SenderKey, response.ManagedParties, response.Payload, &extra, nil
}
// retrieve raw will not return information about medata
@@ -380,7 +427,7 @@ func (t *tesseraPrivateTxManager) HasFeature(f engine.PrivateTransactionManagerF
}
// don't serialize body if nil
-func newOptionalJSONRequest(method string, path string, body interface{}) (*http.Request, error) {
+func newOptionalJSONRequest(method string, path string, body interface{}, apiVersion string) (*http.Request, error) {
buf := new(bytes.Buffer)
if body != nil {
err := json.NewEncoder(buf).Encode(body)
@@ -393,7 +440,7 @@ func newOptionalJSONRequest(method string, path string, body interface{}) (*http
return nil, err
}
request.Header.Set("User-Agent", fmt.Sprintf("quorum-v%s", params.QuorumVersion))
- request.Header.Set("Content-type", "application/json")
- request.Header.Set("Accept", "application/json")
+ request.Header.Set("Content-type", fmt.Sprintf("application/%sjson", apiVersion))
+ request.Header.Set("Accept", fmt.Sprintf("application/%sjson", apiVersion))
return request, nil
}
diff --git a/private/engine/tessera/tessera_test.go b/private/engine/tessera/tessera_test.go
index 38ad154685..4c12cc73ea 100644
--- a/private/engine/tessera/tessera_test.go
+++ b/private/engine/tessera/tessera_test.go
@@ -82,7 +82,8 @@ func MockSendAPIHandlerFunc(response http.ResponseWriter, request *http.Request)
} else {
go func(o *capturedRequest) { sendRequestCaptor <- o }(&capturedRequest{request: actualRequest, header: request.Header})
data, _ := json.Marshal(&sendResponse{
- Key: arbitraryHash.ToBase64(),
+ Key: arbitraryHash.ToBase64(),
+ ManagedParties: []string{"ArbitraryPublicKey"},
})
response.Write(data)
}
@@ -103,7 +104,8 @@ func MockReceiveAPIHandlerFunc(response http.ResponseWriter, request *http.Reque
var data []byte
if actualRequest == arbitraryHashNoPrivateMetadata.ToBase64() {
data, _ = json.Marshal(&receiveResponse{
- Payload: arbitraryPrivatePayload,
+ Payload: arbitraryPrivatePayload,
+ ManagedParties: []string{"ArbitraryPublicKey"},
})
} else {
data, _ = json.Marshal(&receiveResponse{
@@ -111,6 +113,7 @@ func MockReceiveAPIHandlerFunc(response http.ResponseWriter, request *http.Reque
ExecHash: arbitraryExtra.ACMerkleRoot.ToBase64(),
AffectedContractTransactions: arbitraryExtra.ACHashes.ToBase64s(),
PrivacyFlag: arbitraryPrivacyFlag,
+ ManagedParties: []string{"ArbitraryPublicKey"},
})
}
response.Write(data)
@@ -159,10 +162,20 @@ func verifyRequestHeader(h http.Header, t *testing.T) {
}
}
+func verifyRequestHeaderMultiTenancy(h http.Header, t *testing.T) {
+ if h.Get("Content-type") != "application/vnd.tessera-2.1+json" {
+ t.Errorf("expected Content-type header is application/vnd.tessera-2.1+json")
+ }
+
+ if h.Get("Accept") != "application/vnd.tessera-2.1+json" {
+ t.Errorf("expected Accept header is application/vnd.tessera-2.1+json")
+ }
+}
+
func TestSend_whenTypical(t *testing.T) {
assert := testifyassert.New(t)
- actualHash, err := testObject.Send(arbitraryPrivatePayload, arbitraryFrom, arbitraryTo, arbitraryExtra)
+ _, _, actualHash, err := testObject.Send(arbitraryPrivatePayload, arbitraryFrom, arbitraryTo, arbitraryExtra)
if err != nil {
t.Fatalf("%s", err)
}
@@ -185,6 +198,37 @@ func TestSend_whenTypical(t *testing.T) {
assert.Equal(arbitraryHash, actualHash, "returned hash")
}
+func TestSend_whenTypical_MultiTenancy(t *testing.T) {
+ assert := testifyassert.New(t)
+
+ testObjectWithMT := New(&engine.Client{
+ HttpClient: &http.Client{},
+ BaseURL: testServer.URL,
+ }, []byte("2.1"))
+
+ _, _, actualHash, err := testObjectWithMT.Send(arbitraryPrivatePayload, arbitraryFrom, arbitraryTo, arbitraryExtra)
+ if err != nil {
+ t.Fatalf("%s", err)
+ }
+ capturedRequest := <-sendRequestCaptor
+
+ if capturedRequest.err != nil {
+ t.Fatalf("%s", capturedRequest.err)
+ }
+
+ verifyRequestHeaderMultiTenancy(capturedRequest.header, t)
+
+ actualRequest := capturedRequest.request.(*sendRequest)
+
+ assert.Equal(arbitraryPrivatePayload, actualRequest.Payload, "request.payload")
+ assert.Equal(arbitraryFrom, actualRequest.From, "request.from")
+ assert.Equal(arbitraryTo, actualRequest.To, "request.to")
+ assert.Equal(arbitraryPrivacyFlag, actualRequest.PrivacyFlag, "request.privacyFlag")
+ assert.Equal(arbitraryExtra.ACHashes.ToBase64s(), actualRequest.AffectedContractTransactions, "request.affectedContractTransactions")
+ assert.Equal(arbitraryExtra.ACMerkleRoot.ToBase64(), actualRequest.ExecHash, "request.execHash")
+ assert.Equal(arbitraryHash, actualHash, "returned hash")
+}
+
func TestSend_whenTesseraVersionDoesNotSupportPrivacyEnhancements(t *testing.T) {
assert := testifyassert.New(t)
@@ -196,7 +240,7 @@ func TestSend_whenTesseraVersionDoesNotSupportPrivacyEnhancements(t *testing.T)
assert.False(testObjectNoPE.HasFeature(engine.PrivacyEnhancements), "the supplied version does not support privacy enhancements")
// trying to send a party protection transaction
- _, err := testObjectNoPE.Send(arbitraryPrivatePayload, arbitraryFrom, arbitraryTo, arbitraryExtra)
+ _, _, _, err := testObjectNoPE.Send(arbitraryPrivatePayload, arbitraryFrom, arbitraryTo, arbitraryExtra)
if err != engine.ErrPrivateTxManagerDoesNotSupportPrivacyEnhancements {
t.Fatal("Expecting send to raise ErrPrivateTxManagerDoesNotSupportPrivacyEnhancements")
}
@@ -221,7 +265,7 @@ func TestSendRaw_whenTesseraVersionDoesNotSupportPrivacyEnhancements(t *testing.
assert.False(testObjectNoPE.HasFeature(engine.PrivacyEnhancements), "the supplied version does not support privacy enhancements")
// trying to send a party protection transaction
- _, err := testObjectNoPE.SendSignedTx(arbitraryHash, arbitraryTo, arbitraryExtra)
+ _, _, _, err := testObjectNoPE.SendSignedTx(arbitraryHash, arbitraryTo, arbitraryExtra)
if err != engine.ErrPrivateTxManagerDoesNotSupportPrivacyEnhancements {
t.Fatal("Expecting send to raise ErrPrivateTxManagerDoesNotSupportPrivacyEnhancements")
}
@@ -229,14 +273,14 @@ func TestSendRaw_whenTesseraVersionDoesNotSupportPrivacyEnhancements(t *testing.
// send a standard private transaction and check that the old version of the /sendsignedtx is used (using octetstream content type)
// caching incomplete item
- _, _, err = testObjectNoPE.ReceiveRaw(arbitraryHashNoPrivateMetadata)
+ _, _, _, err = testObjectNoPE.ReceiveRaw(arbitraryHashNoPrivateMetadata)
if err != nil {
t.Fatalf("%s", err)
}
<-receiveRequestCaptor
// caching complete item
- _, err = testObjectNoPE.SendSignedTx(arbitraryHashNoPrivateMetadata, arbitraryTo, &engine.ExtraMetadata{
+ _, _, _, err = testObjectNoPE.SendSignedTx(arbitraryHashNoPrivateMetadata, arbitraryTo, &engine.ExtraMetadata{
PrivacyFlag: engine.PrivacyFlagStandardPrivate})
if err != nil {
t.Fatalf("%s", err)
@@ -244,7 +288,7 @@ func TestSendRaw_whenTesseraVersionDoesNotSupportPrivacyEnhancements(t *testing.
req := <-sendSignedTxOctetStreamRequestCaptor
assert.Equal("application/octet-stream", req.header["Content-Type"][0])
- _, actualExtra, err := testObjectNoPE.Receive(arbitraryHashNoPrivateMetadata)
+ _, _, _, actualExtra, err := testObjectNoPE.Receive(arbitraryHashNoPrivateMetadata)
if err != nil {
t.Fatalf("%s", err)
}
@@ -255,7 +299,7 @@ func TestSendRaw_whenTesseraVersionDoesNotSupportPrivacyEnhancements(t *testing.
func TestReceive_whenTypical(t *testing.T) {
assert := testifyassert.New(t)
- _, actualExtra, err := testObject.Receive(arbitraryHash1)
+ _, _, _, actualExtra, err := testObject.Receive(arbitraryHash1)
if err != nil {
t.Fatalf("%s", err)
}
@@ -275,10 +319,38 @@ func TestReceive_whenTypical(t *testing.T) {
assert.Equal(arbitraryExtra.PrivacyFlag, actualExtra.PrivacyFlag, "returned privacy flag")
}
+func TestReceive_whenTypical_Multitenancy(t *testing.T) {
+ assert := testifyassert.New(t)
+
+ testObjectWithMT := New(&engine.Client{
+ HttpClient: &http.Client{},
+ BaseURL: testServer.URL,
+ }, []byte("2.1"))
+
+ _, _, _, actualExtra, err := testObjectWithMT.Receive(arbitraryHash1)
+ if err != nil {
+ t.Fatalf("%s", err)
+ }
+ capturedRequest := <-receiveRequestCaptor
+
+ if capturedRequest.err != nil {
+ t.Fatalf("%s", capturedRequest.err)
+ }
+
+ verifyRequestHeaderMultiTenancy(capturedRequest.header, t)
+
+ actualRequest := capturedRequest.request.(string)
+
+ assert.Equal(arbitraryHash1.ToBase64(), actualRequest, "requested hash")
+ assert.Equal(arbitraryExtra.ACHashes, actualExtra.ACHashes, "returned affected contract transaction hashes")
+ assert.Equal(arbitraryExtra.ACMerkleRoot, actualExtra.ACMerkleRoot, "returned merkle root")
+ assert.Equal(arbitraryExtra.PrivacyFlag, actualExtra.PrivacyFlag, "returned privacy flag")
+}
+
func TestReceive_whenPayloadNotFound(t *testing.T) {
assert := testifyassert.New(t)
- data, _, err := testObject.Receive(arbitraryNotFoundHash)
+ _, _, data, _, err := testObject.Receive(arbitraryNotFoundHash)
if err != nil {
t.Fatalf("%s", err)
}
@@ -299,7 +371,7 @@ func TestReceive_whenPayloadNotFound(t *testing.T) {
func TestReceive_whenEncryptedPayloadHashIsEmpty(t *testing.T) {
assert := testifyassert.New(t)
- data, _, err := testObject.Receive(emptyHash)
+ _, _, data, _, err := testObject.Receive(emptyHash)
if err != nil {
t.Fatalf("%s", err)
}
@@ -310,7 +382,7 @@ func TestReceive_whenEncryptedPayloadHashIsEmpty(t *testing.T) {
func TestReceive_whenHavingPayloadButNoPrivateExtraMetadata(t *testing.T) {
assert := testifyassert.New(t)
- _, actualExtra, err := testObject.Receive(arbitraryHashNoPrivateMetadata)
+ _, _, _, actualExtra, err := testObject.Receive(arbitraryHashNoPrivateMetadata)
if err != nil {
t.Fatalf("%s", err)
}
@@ -332,7 +404,7 @@ func TestReceive_whenHavingPayloadButNoPrivateExtraMetadata(t *testing.T) {
func TestSendSignedTx_whenTypical(t *testing.T) {
assert := testifyassert.New(t)
- _, err := testObject.SendSignedTx(arbitraryHash, arbitraryTo, arbitraryExtra)
+ _, _, _, err := testObject.SendSignedTx(arbitraryHash, arbitraryTo, arbitraryExtra)
if err != nil {
t.Fatalf("%s", err)
}
@@ -355,20 +427,20 @@ func TestReceive_whenCachingRawPayload(t *testing.T) {
assert := testifyassert.New(t)
// caching incomplete item
- _, _, err := testObject.ReceiveRaw(arbitraryHashNoPrivateMetadata)
+ _, _, _, err := testObject.ReceiveRaw(arbitraryHashNoPrivateMetadata)
if err != nil {
t.Fatalf("%s", err)
}
<-receiveRequestCaptor
// caching complete item
- _, err = testObject.SendSignedTx(arbitraryHashNoPrivateMetadata, arbitraryTo, arbitraryExtra)
+ _, _, _, err = testObject.SendSignedTx(arbitraryHashNoPrivateMetadata, arbitraryTo, arbitraryExtra)
if err != nil {
t.Fatalf("%s", err)
}
<-sendSignedTxRequestCaptor
- _, actualExtra, err := testObject.Receive(arbitraryHashNoPrivateMetadata)
+ _, _, _, actualExtra, err := testObject.Receive(arbitraryHashNoPrivateMetadata)
if err != nil {
t.Fatalf("%s", err)
}
diff --git a/private/engine/tessera/tessera_version_checker.go b/private/engine/tessera/tessera_version_checker.go
index 9086e3c260..16b0ec35a0 100644
--- a/private/engine/tessera/tessera_version_checker.go
+++ b/private/engine/tessera/tessera_version_checker.go
@@ -15,9 +15,11 @@ type Version [versionLength]uint64
var (
zero = Version{0, 0, 0}
privacyEnhancementsVersion = Version{2, 0, 0}
+ multitenancyVersion = Version{2, 1, 0}
featureVersions = map[engine.PrivateTransactionManagerFeature]Version{
engine.PrivacyEnhancements: privacyEnhancementsVersion,
+ engine.MultiTenancy: multitenancyVersion,
}
)
diff --git a/private/engine/tessera/tessera_version_checker_test.go b/private/engine/tessera/tessera_version_checker_test.go
index 98ddd22aaf..64e3d55c14 100644
--- a/private/engine/tessera/tessera_version_checker_test.go
+++ b/private/engine/tessera/tessera_version_checker_test.go
@@ -58,15 +58,21 @@ func TestVersionsComparison(t *testing.T) {
func TestTesseraVersionFeatures(t *testing.T) {
res := tesseraVersionFeatures(Version{2, 11, 12})
assert.Contains(t, res, engine.PrivacyEnhancements)
+ assert.Contains(t, res, engine.MultiTenancy)
res = tesseraVersionFeatures(Version{0, 12, 0})
assert.NotContains(t, res, engine.PrivacyEnhancements)
+ assert.NotContains(t, res, engine.MultiTenancy)
res = tesseraVersionFeatures(Version{0, 11, 15})
assert.NotContains(t, res, engine.PrivacyEnhancements)
+ assert.NotContains(t, res, engine.MultiTenancy)
res = tesseraVersionFeatures(Version{2, 0, 0})
assert.Contains(t, res, engine.PrivacyEnhancements)
+ assert.NotContains(t, res, engine.MultiTenancy)
res = tesseraVersionFeatures(Version{2, 1, 1})
assert.Contains(t, res, engine.PrivacyEnhancements)
+ assert.Contains(t, res, engine.MultiTenancy)
res = tesseraVersionFeatures(zero)
assert.NotContains(t, res, engine.PrivacyEnhancements)
+ assert.NotContains(t, res, engine.MultiTenancy)
assert.Empty(t, res)
}
diff --git a/private/private.go b/private/private.go
index 4fecbba685..7c87ac952a 100644
--- a/private/private.go
+++ b/private/private.go
@@ -33,13 +33,13 @@ type Identifiable interface {
type PrivateTransactionManager interface {
Identifiable
- Send(data []byte, from string, to []string, extra *engine.ExtraMetadata) (common.EncryptedPayloadHash, error)
+ Send(data []byte, from string, to []string, extra *engine.ExtraMetadata) (string, []string, common.EncryptedPayloadHash, error)
StoreRaw(data []byte, from string) (common.EncryptedPayloadHash, error)
- SendSignedTx(data common.EncryptedPayloadHash, to []string, extra *engine.ExtraMetadata) ([]byte, error)
+ SendSignedTx(data common.EncryptedPayloadHash, to []string, extra *engine.ExtraMetadata) (string, []string, []byte, error)
// Returns nil payload if not found
- Receive(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error)
+ Receive(data common.EncryptedPayloadHash) (string, []string, []byte, *engine.ExtraMetadata, error)
// Returns nil payload if not found
- ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, *engine.ExtraMetadata, error)
+ ReceiveRaw(data common.EncryptedPayloadHash) ([]byte, string, *engine.ExtraMetadata, error)
IsSender(txHash common.EncryptedPayloadHash) (bool, error)
GetParticipants(txHash common.EncryptedPayloadHash) ([]string, error)
EncryptPayload(data []byte, from string, to []string, extra *engine.ExtraMetadata) ([]byte, error)
diff --git a/rpc/handler.go b/rpc/handler.go
index 61a29d55ac..ae95997fea 100644
--- a/rpc/handler.go
+++ b/rpc/handler.go
@@ -312,12 +312,16 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
// handleCall processes method calls.
// Quorum:
// This is where server handle the call requests hence we enforce authorization check
-// before the actual processing of the call
+// before the actual processing of the call. It also populates context with preauthenticated
+// token so the responsible RPC method can leverage if needed (e.g: in multi tenancy)
func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
if r, ok := h.conn.(securityContextResolver); ok {
if err := secureCall(r, msg); err != nil {
return securityErrorMessage(msg, err)
}
+ secCtx := r.Resolve()
+ h.log.Debug("Enrich call context with token from security context")
+ cp.ctx = context.WithValue(cp.ctx, CtxPreauthenticatedToken, secCtx.Value(CtxPreauthenticatedToken))
}
if msg.isSubscribe() {
return h.handleSubscribe(cp, msg)
diff --git a/rpc/security.go b/rpc/security.go
index 7ef2d0cd3d..81a6da1875 100644
--- a/rpc/security.go
+++ b/rpc/security.go
@@ -21,7 +21,7 @@ const (
CtxCredentialsProvider = securityContextKey("CREDENTIALS_PROVIDER") // key to save reference to rpc.HttpCredentialsProviderFunc
// keys used to save values in request context
ctxAuthenticationError = securityContextKey("AUTHENTICATION_ERROR") // key to save error during authentication before processing the request body
- ctxPreauthenticatedToken = securityContextKey("PREAUTHENTICATED_TOKEN") // key to save the preauthenticated token once authenticated
+ CtxPreauthenticatedToken = securityContextKey("PREAUTHENTICATED_TOKEN") // key to save the preauthenticated token once authenticated
)
type securityContextConfigurer interface {
@@ -89,7 +89,7 @@ func secureCall(resolver securityContextResolver, msg *jsonrpcMessage) error {
if err, hasError := secCtx.Value(ctxAuthenticationError).(error); hasError {
return err
}
- if authToken, isPreauthenticated := secCtx.Value(ctxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken); isPreauthenticated {
+ if authToken, isPreauthenticated := secCtx.Value(CtxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken); isPreauthenticated {
if err := verifyExpiration(authToken); err != nil {
return err
}
diff --git a/rpc/security_test.go b/rpc/security_test.go
index 908d7545b1..5d6ea9e019 100644
--- a/rpc/security_test.go
+++ b/rpc/security_test.go
@@ -138,7 +138,7 @@ func TestSecureCall_whenTokenExpired(t *testing.T) {
assert := testifyassert.New(t)
expiredAt, _ := ptypes.TimestampProto(time.Now().Add(-1 * time.Hour))
stubSecurityContextResolver := newStubSecurityContextResolver([]struct{ k, v interface{} }{
- {ctxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
+ {CtxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
ExpiredAt: expiredAt,
}},
})
@@ -152,7 +152,7 @@ func TestSecureCall_whenTypical(t *testing.T) {
assert := testifyassert.New(t)
expiredAt, _ := ptypes.TimestampProto(time.Now().Add(1 * time.Hour))
stubSecurityContextResolver := newStubSecurityContextResolver([]struct{ k, v interface{} }{
- {ctxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
+ {CtxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
ExpiredAt: expiredAt,
Authorities: []*proto.GrantedAuthority{
{
@@ -172,7 +172,7 @@ func TestSecureCall_whenAccessDenied(t *testing.T) {
assert := testifyassert.New(t)
expiredAt, _ := ptypes.TimestampProto(time.Now().Add(1 * time.Hour))
stubSecurityContextResolver := newStubSecurityContextResolver([]struct{ k, v interface{} }{
- {ctxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
+ {CtxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
ExpiredAt: expiredAt,
Authorities: []*proto.GrantedAuthority{
{
@@ -192,7 +192,7 @@ func TestSecureCall_whenMethodInJSONMessageIsNotSupported(t *testing.T) {
assert := testifyassert.New(t)
expiredAt, _ := ptypes.TimestampProto(time.Now().Add(1 * time.Hour))
stubSecurityContextResolver := newStubSecurityContextResolver([]struct{ k, v interface{} }{
- {ctxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
+ {CtxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
ExpiredAt: expiredAt,
}},
})
diff --git a/rpc/server.go b/rpc/server.go
index a893282faa..e6a8d19fce 100644
--- a/rpc/server.go
+++ b/rpc/server.go
@@ -164,7 +164,7 @@ func (s *Server) authenticateHttpRequest(r *http.Request, cfg securityContextCon
if authToken, err := s.authenticationManager.Authenticate(context.Background(), token); err != nil {
securityContext = context.WithValue(securityContext, ctxAuthenticationError, &securityError{err.Error()})
} else {
- securityContext = context.WithValue(securityContext, ctxPreauthenticatedToken, authToken)
+ securityContext = context.WithValue(securityContext, CtxPreauthenticatedToken, authToken)
}
} else {
securityContext = context.WithValue(securityContext, ctxAuthenticationError, &securityError{"missing access token"})
diff --git a/rpc/server_test.go b/rpc/server_test.go
index eeee4c9e40..8d1729a1fa 100644
--- a/rpc/server_test.go
+++ b/rpc/server_test.go
@@ -168,7 +168,7 @@ func TestAuthenticateHttpRequest_whenAuthenticationManagerFails(t *testing.T) {
actualErr, hasError := captor.context.Value(ctxAuthenticationError).(error)
assert.True(t, hasError, "must have error")
assert.EqualError(t, actualErr, "internal error")
- _, hasAuthToken := captor.context.Value(ctxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
+ _, hasAuthToken := captor.context.Value(CtxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
assert.False(t, hasAuthToken, "must not be preauthenticated")
}
@@ -182,7 +182,7 @@ func TestAuthenticateHttpRequest_whenTypical(t *testing.T) {
_, hasError := captor.context.Value(ctxAuthenticationError).(error)
assert.False(t, hasError, "must not have error")
- _, hasAuthToken := captor.context.Value(ctxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
+ _, hasAuthToken := captor.context.Value(CtxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
assert.True(t, hasAuthToken, "must be preauthenticated")
}
@@ -195,7 +195,7 @@ func TestAuthenticateHttpRequest_whenAuthenticationManagerIsDisabled(t *testing.
_, hasError := captor.context.Value(ctxAuthenticationError).(error)
assert.False(t, hasError, "must not have error")
- _, hasAuthToken := captor.context.Value(ctxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
+ _, hasAuthToken := captor.context.Value(CtxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
assert.False(t, hasAuthToken, "must not be preauthenticated")
}
@@ -209,7 +209,7 @@ func TestAuthenticateHttpRequest_whenMissingAccessToken(t *testing.T) {
actualErr, hasError := captor.context.Value(ctxAuthenticationError).(error)
assert.True(t, hasError, "must have error")
assert.EqualError(t, actualErr, "missing access token")
- _, hasAuthToken := captor.context.Value(ctxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
+ _, hasAuthToken := captor.context.Value(CtxPreauthenticatedToken).(*proto.PreAuthenticatedAuthenticationToken)
assert.False(t, hasAuthToken, "must not be preauthenticated")
}