Skip to content

Commit

Permalink
Merge pull request #1689 from nspcc-dev/overload
Browse files Browse the repository at this point in the history
core: allow to overload contract methods
  • Loading branch information
roman-khimov authored Jan 27, 2021
2 parents 32e8678 + dd1e2ce commit f1792b3
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 43 deletions.
7 changes: 4 additions & 3 deletions cli/wallet/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/native/nativenames"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/wallet"
Expand Down Expand Up @@ -419,9 +420,9 @@ func importDeployed(ctx *cli.Context) error {
if err != nil {
return cli.NewExitError(fmt.Errorf("can't fetch contract info: %w", err), 1)
}
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify)
if md == nil {
return cli.NewExitError("contract has no `verify` method", 1)
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1)
if md == nil || md.ReturnType != smartcontract.BoolType {
return cli.NewExitError("contract has no `verify` method with boolean return", 1)
}
acc.Address = address.Uint160ToString(cs.Hash)
acc.Contract.Script = cs.NEF.Script
Expand Down
7 changes: 4 additions & 3 deletions pkg/core/blockchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
Expand Down Expand Up @@ -1667,11 +1668,11 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160,
if err != nil {
return ErrUnknownVerificationContract
}
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify)
if md == nil {
md := cs.Manifest.ABI.GetMethod(manifest.MethodVerify, -1)
if md == nil || md.ReturnType != smartcontract.BoolType {
return ErrInvalidVerificationContract
}
initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit)
initMD := cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
v.LoadScriptWithHash(cs.NEF.Script, hash, callflag.ReadStates)
v.Context().NEF = &cs.NEF
v.Jump(v.Context(), md.Offset)
Expand Down
8 changes: 4 additions & 4 deletions pkg/core/interop/contract/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func Call(ic *interop.Context) error {
if strings.HasPrefix(method, "_") {
return errors.New("invalid method name (starts with '_')")
}
md := cs.Manifest.ABI.GetMethod(method)
md := cs.Manifest.ABI.GetMethod(method, len(args))
if md == nil {
return errors.New("method not found")
}
Expand All @@ -68,7 +68,7 @@ func Call(ic *interop.Context) error {

func callInternal(ic *interop.Context, cs *state.Contract, name string, f callflag.CallFlag,
hasReturn bool, args []stackitem.Item) error {
md := cs.Manifest.ABI.GetMethod(name)
md := cs.Manifest.ABI.GetMethod(name, len(args))
if md.Safe {
f &^= callflag.WriteStates
} else if ctx := ic.VM.Context(); ctx != nil && ctx.IsDeployed() {
Expand All @@ -85,7 +85,7 @@ func callInternal(ic *interop.Context, cs *state.Contract, name string, f callfl
// callExFromNative calls a contract with flags using provided calling hash.
func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contract,
name string, args []stackitem.Item, f callflag.CallFlag, hasReturn bool) error {
md := cs.Manifest.ABI.GetMethod(name)
md := cs.Manifest.ABI.GetMethod(name, len(args))
if md == nil {
return fmt.Errorf("method '%s' not found", name)
}
Expand Down Expand Up @@ -119,7 +119,7 @@ func callExFromNative(ic *interop.Context, caller util.Uint160, cs *state.Contra
ic.VM.Context().RetCount = 0
}

md = cs.Manifest.ABI.GetMethod(manifest.MethodInit)
md = cs.Manifest.ABI.GetMethod(manifest.MethodInit, 0)
if md != nil {
ic.VM.Call(ic.VM.Context(), md.Offset)
}
Expand Down
51 changes: 41 additions & 10 deletions pkg/core/interop_system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
emit.Opcodes(w.BinWriter, opcode.ABORT)
addOff := w.Len()
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.RET)
addMultiOff := w.Len()
emit.Opcodes(w.BinWriter, opcode.ADD, opcode.ADD, opcode.RET)
ret7Off := w.Len()
emit.Opcodes(w.BinWriter, opcode.PUSH7, opcode.RET)
dropOff := w.Len()
Expand Down Expand Up @@ -533,6 +535,16 @@ func getTestContractState(bc *Blockchain) (*state.Contract, *state.Contract) {
},
ReturnType: smartcontract.IntegerType,
},
{
Name: "add",
Offset: addMultiOff,
Parameters: []manifest.Parameter{
manifest.NewParameter("addend1", smartcontract.IntegerType),
manifest.NewParameter("addend2", smartcontract.IntegerType),
manifest.NewParameter("addend3", smartcontract.IntegerType),
},
ReturnType: smartcontract.IntegerType,
},
{
Name: "ret7",
Offset: ret7Off,
Expand Down Expand Up @@ -731,16 +743,31 @@ func TestContractCall(t *testing.T) {

addArgs := stackitem.NewArray([]stackitem.Item{stackitem.Make(1), stackitem.Make(2)})
t.Run("Good", func(t *testing.T) {
loadScript(ic, currScript, 42)
ic.VM.Estack().PushVal(addArgs)
ic.VM.Estack().PushVal(callflag.All)
ic.VM.Estack().PushVal("add")
ic.VM.Estack().PushVal(h.BytesBE())
require.NoError(t, contract.Call(ic))
require.NoError(t, ic.VM.Run())
require.Equal(t, 2, ic.VM.Estack().Len())
require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value())
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
t.Run("2 arguments", func(t *testing.T) {
loadScript(ic, currScript, 42)
ic.VM.Estack().PushVal(addArgs)
ic.VM.Estack().PushVal(callflag.All)
ic.VM.Estack().PushVal("add")
ic.VM.Estack().PushVal(h.BytesBE())
require.NoError(t, contract.Call(ic))
require.NoError(t, ic.VM.Run())
require.Equal(t, 2, ic.VM.Estack().Len())
require.Equal(t, big.NewInt(3), ic.VM.Estack().Pop().Value())
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
})
t.Run("3 arguments", func(t *testing.T) {
loadScript(ic, currScript, 42)
ic.VM.Estack().PushVal(stackitem.NewArray(
append(addArgs.Value().([]stackitem.Item), stackitem.Make(3))))
ic.VM.Estack().PushVal(callflag.All)
ic.VM.Estack().PushVal("add")
ic.VM.Estack().PushVal(h.BytesBE())
require.NoError(t, contract.Call(ic))
require.NoError(t, ic.VM.Run())
require.Equal(t, 2, ic.VM.Estack().Len())
require.Equal(t, big.NewInt(6), ic.VM.Estack().Pop().Value())
require.Equal(t, big.NewInt(42), ic.VM.Estack().Pop().Value())
})
})

t.Run("CallExInvalidFlag", func(t *testing.T) {
Expand Down Expand Up @@ -778,6 +805,10 @@ func TestContractCall(t *testing.T) {
t.Run("Arguments", runInvalid(1, "add", h.BytesBE()))
t.Run("NotEnoughArguments", runInvalid(
stackitem.NewArray([]stackitem.Item{stackitem.Make(1)}), "add", h.BytesBE()))
t.Run("TooMuchArguments", runInvalid(
stackitem.NewArray([]stackitem.Item{
stackitem.Make(1), stackitem.Make(2), stackitem.Make(3), stackitem.Make(4)}),
"add", h.BytesBE()))
})

t.Run("ReturnValues", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/native/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func (m *Management) setMinimumDeploymentFee(ic *interop.Context, args []stackit
}

func (m *Management) callDeploy(ic *interop.Context, cs *state.Contract, isUpdate bool) {
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy)
md := cs.Manifest.ABI.GetMethod(manifest.MethodDeploy, 1)
if md != nil {
err := contract.CallFromNative(ic, m.Hash, cs, manifest.MethodDeploy,
[]stackitem.Item{stackitem.NewBool(isUpdate)}, false)
Expand Down
3 changes: 1 addition & 2 deletions pkg/core/native_management_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,8 @@ func TestContractDeploy(t *testing.T) {
Offset: 0,
Parameters: []manifest.Parameter{
manifest.NewParameter("isUpdate", smartcontract.BoolType),
manifest.NewParameter("param", smartcontract.IntegerType),
},
ReturnType: smartcontract.VoidType,
ReturnType: smartcontract.ArrayType,
},
}
nefD, err := nef.NewFile(deployScript)
Expand Down
4 changes: 2 additions & 2 deletions pkg/smartcontract/manifest/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ func DefaultManifest(name string) *Manifest {
}

// GetMethod returns methods with the specified name.
func (a *ABI) GetMethod(name string) *Method {
func (a *ABI) GetMethod(name string, paramCount int) *Method {
for i := range a.Methods {
if a.Methods[i].Name == name {
if a.Methods[i].Name == name && (paramCount == -1 || len(a.Methods[i].Parameters) == paramCount) {
return &a.Methods[i]
}
}
Expand Down
7 changes: 2 additions & 5 deletions pkg/smartcontract/manifest/standard/comply.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,12 @@ func Check(m *manifest.Manifest, standards ...string) error {
func Comply(m, st *manifest.Manifest) error {
for _, stm := range st.ABI.Methods {
name := stm.Name
md := m.ABI.GetMethod(name)
md := m.ABI.GetMethod(name, len(stm.Parameters))
if md == nil {
return fmt.Errorf("%w: '%s'", ErrMethodMissing, name)
return fmt.Errorf("%w: '%s' with %d parameters", ErrMethodMissing, name, len(stm.Parameters))
} else if stm.ReturnType != md.ReturnType {
return fmt.Errorf("%w: '%s' (expected %s, got %s)", ErrInvalidReturnType,
name, stm.ReturnType, md.ReturnType)
} else if len(stm.Parameters) != len(md.Parameters) {
return fmt.Errorf("%w: '%s' (expected %d, got %d)", ErrInvalidParameterCount,
name, len(stm.Parameters), len(md.Parameters))
}
for i := range stm.Parameters {
if stm.Parameters[i].Type != md.Parameters[i].Type {
Expand Down
12 changes: 6 additions & 6 deletions pkg/smartcontract/manifest/standard/comply_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,25 @@ func fooMethodBarEvent() *manifest.Manifest {

func TestComplyMissingMethod(t *testing.T) {
m := fooMethodBarEvent()
m.ABI.GetMethod("foo").Name = "notafoo"
m.ABI.GetMethod("foo", -1).Name = "notafoo"
err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrMethodMissing))
}

func TestComplyInvalidReturnType(t *testing.T) {
m := fooMethodBarEvent()
m.ABI.GetMethod("foo").ReturnType = smartcontract.VoidType
m.ABI.GetMethod("foo", -1).ReturnType = smartcontract.VoidType
err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrInvalidReturnType))
}

func TestComplyMethodParameterCount(t *testing.T) {
t.Run("Method", func(t *testing.T) {
m := fooMethodBarEvent()
f := m.ABI.GetMethod("foo")
f := m.ABI.GetMethod("foo", -1)
f.Parameters = append(f.Parameters, manifest.Parameter{Type: smartcontract.BoolType})
err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrInvalidParameterCount))
require.True(t, errors.Is(err, ErrMethodMissing))
})
t.Run("Event", func(t *testing.T) {
m := fooMethodBarEvent()
Expand All @@ -69,7 +69,7 @@ func TestComplyMethodParameterCount(t *testing.T) {
func TestComplyParameterType(t *testing.T) {
t.Run("Method", func(t *testing.T) {
m := fooMethodBarEvent()
m.ABI.GetMethod("foo").Parameters[0].Type = smartcontract.InteropInterfaceType
m.ABI.GetMethod("foo", -1).Parameters[0].Type = smartcontract.InteropInterfaceType
err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrInvalidParameterType))
})
Expand All @@ -90,7 +90,7 @@ func TestMissingEvent(t *testing.T) {

func TestSafeFlag(t *testing.T) {
m := fooMethodBarEvent()
m.ABI.GetMethod("foo").Safe = false
m.ABI.GetMethod("foo", -1).Safe = false
err := Comply(m, fooMethodBarEvent())
require.True(t, errors.Is(err, ErrSafeMethodMismatch))
}
Expand Down
14 changes: 7 additions & 7 deletions pkg/vm/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,25 +389,25 @@ func handleRun(c *ishell.Context) {
runCurrent = c.Args[0] != "_"
)

params, err = parseArgs(c.Args[1:])
if err != nil {
c.Err(err)
return
}
if runCurrent {
md := m.ABI.GetMethod(c.Args[0])
md := m.ABI.GetMethod(c.Args[0], len(params))
if md == nil {
c.Err(fmt.Errorf("%w: method not found", ErrInvalidParameter))
return
}
offset = md.Offset
}
params, err = parseArgs(c.Args[1:])
if err != nil {
c.Err(err)
return
}
for i := len(params) - 1; i >= 0; i-- {
v.Estack().PushVal(params[i])
}
if runCurrent {
v.Jump(v.Context(), offset)
if initMD := m.ABI.GetMethod(manifest.MethodInit); initMD != nil {
if initMD := m.ABI.GetMethod(manifest.MethodInit, 0); initMD != nil {
v.Call(v.Context(), initMD.Offset)
}
}
Expand Down

0 comments on commit f1792b3

Please sign in to comment.