Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow injection of functions into composite values, refactor PublicKey based on it #2878

Merged
merged 9 commits into from
Oct 25, 2023
37 changes: 13 additions & 24 deletions runtime/convertValues.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,8 @@

case *interpreter.CompositeValue:
fieldValue = v.GetField(inter, locationRange, fieldName)
if fieldValue == nil && v.ComputedFields != nil {
if computedField, ok := v.ComputedFields[fieldName]; ok {
fieldValue = computedField(inter, locationRange)
}
if fieldValue == nil {
fieldValue = v.GetComputedField(inter, locationRange, fieldName)

Check warning on line 416 in runtime/convertValues.go

View check run for this annotation

Codecov / codecov/patch

runtime/convertValues.go#L416

Added line #L416 was not covered by tests
}
}

Expand Down Expand Up @@ -459,69 +457,62 @@
structure, err := cadence.NewMeteredStruct(
inter,
len(fieldNames),
func() ([]cadence.Value, error) {
return makeFields()
},
makeFields,
)
if err != nil {
return nil, err
}
return structure.WithType(t.(*cadence.StructType)), nil

case common.CompositeKindResource:
resource, err := cadence.NewMeteredResource(
inter,
len(fieldNames),
func() ([]cadence.Value, error) {
return makeFields()
},
makeFields,
)
if err != nil {
return nil, err
}
return resource.WithType(t.(*cadence.ResourceType)), nil

case common.CompositeKindAttachment:
attachment, err := cadence.NewMeteredAttachment(
inter,
len(fieldNames),
func() ([]cadence.Value, error) {
return makeFields()
},
makeFields,
)
if err != nil {
return nil, err
}
return attachment.WithType(t.(*cadence.AttachmentType)), nil

case common.CompositeKindEvent:
event, err := cadence.NewMeteredEvent(
inter,
len(fieldNames),
func() ([]cadence.Value, error) {
return makeFields()
},
makeFields,

Check warning on line 493 in runtime/convertValues.go

View check run for this annotation

Codecov / codecov/patch

runtime/convertValues.go#L493

Added line #L493 was not covered by tests
)
if err != nil {
return nil, err
}
return event.WithType(t.(*cadence.EventType)), nil

case common.CompositeKindContract:
contract, err := cadence.NewMeteredContract(
inter,
len(fieldNames),
func() ([]cadence.Value, error) {
return makeFields()
},
makeFields,

Check warning on line 504 in runtime/convertValues.go

View check run for this annotation

Codecov / codecov/patch

runtime/convertValues.go#L504

Added line #L504 was not covered by tests
)
if err != nil {
return nil, err
}
return contract.WithType(t.(*cadence.ContractType)), nil

case common.CompositeKindEnum:
enum, err := cadence.NewMeteredEnum(
inter,
len(fieldNames),
func() ([]cadence.Value, error) {
return makeFields()
},
makeFields,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1536,8 +1527,6 @@
publicKeyValue,
signAlgoValue,
i.standardLibraryHandler,
i.standardLibraryHandler,
i.standardLibraryHandler,
), nil
}

Expand Down
2 changes: 0 additions & 2 deletions runtime/convertValues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,6 @@ func TestExportValue(t *testing.T) {
PublicKey: []byte{1, 2, 3},
SignAlgo: 2,
},
nil,
nil,
),
hashAlgorithm,
interpreter.NewUnmeteredUFix64ValueWithInteger(10, interpreter.EmptyLocationRange),
Expand Down
30 changes: 30 additions & 0 deletions runtime/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ import (
type Environment interface {
ArgumentDecoder

SetCompositeValueFunctionsHandler(
typeID common.TypeID,
handler stdlib.CompositeValueFunctionsHandler,
)
DeclareValue(
valueDeclaration stdlib.StandardLibraryValue,
location common.Location,
Expand Down Expand Up @@ -121,6 +125,7 @@ type interpreterEnvironment struct {
deployedContractConstructorInvocation *stdlib.DeployedContractConstructorInvocation
stackDepthLimiter *stackDepthLimiter
checkedImports importResolutionResults
compositeValueFunctionsHandlers stdlib.CompositeValueFunctionsHandlers
config Config
}

Expand Down Expand Up @@ -156,6 +161,7 @@ func newInterpreterEnvironment(config Config) *interpreterEnvironment {
}
env.InterpreterConfig = env.newInterpreterConfig()
env.CheckerConfig = env.newCheckerConfig()
env.compositeValueFunctionsHandlers = stdlib.DefaultStandardLibraryCompositeValueFunctionHandlers(env)
return env
}

Expand All @@ -175,6 +181,7 @@ func (e *interpreterEnvironment) newInterpreterConfig() *interpreter.Config {
OnRecordTrace: e.newOnRecordTraceHandler(),
OnResourceOwnerChange: e.newResourceOwnerChangedHandler(),
CompositeTypeHandler: e.newCompositeTypeHandler(),
CompositeValueFunctionsHandler: e.newCompositeValueFunctionsHandler(),
TracingEnabled: e.config.TracingEnabled,
AtreeValueValidationEnabled: e.config.AtreeValidationEnabled,
// NOTE: ignore e.config.AtreeValidationEnabled here,
Expand Down Expand Up @@ -296,6 +303,13 @@ func (e *interpreterEnvironment) interpreterBaseActivationFor(
return baseActivation
}

func (e *interpreterEnvironment) SetCompositeValueFunctionsHandler(
typeID common.TypeID,
handler stdlib.CompositeValueFunctionsHandler,
) {
e.compositeValueFunctionsHandlers[typeID] = handler
}

func (e *interpreterEnvironment) NewAuthAccountValue(address interpreter.AddressValue) interpreter.Value {
return stdlib.NewAuthAccountValue(e, e, address)
}
Expand Down Expand Up @@ -1007,6 +1021,22 @@ func (e *interpreterEnvironment) newCompositeTypeHandler() interpreter.Composite
}
}

func (e *interpreterEnvironment) newCompositeValueFunctionsHandler() interpreter.CompositeValueFunctionsHandlerFunc {
return func(
inter *interpreter.Interpreter,
locationRange interpreter.LocationRange,
compositeValue *interpreter.CompositeValue,
) map[string]interpreter.FunctionValue {

handler := e.compositeValueFunctionsHandlers[compositeValue.TypeID()]
if handler == nil {
return nil
}

return handler(inter, locationRange, compositeValue)
}
}

func (e *interpreterEnvironment) loadContract(
inter *interpreter.Interpreter,
compositeType *sema.CompositeType,
Expand Down
8 changes: 5 additions & 3 deletions runtime/interpreter/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ type Config struct {
// UUIDHandler is used to handle the generation of UUIDs
UUIDHandler UUIDHandlerFunc
// CompositeTypeHandler is used to load composite types
CompositeTypeHandler CompositeTypeHandlerFunc
BaseActivationHandler func(location common.Location) *VariableActivation
Debugger *Debugger
CompositeTypeHandler CompositeTypeHandlerFunc
// CompositeValueFunctionsHandler is used to load composite value functions
CompositeValueFunctionsHandler CompositeValueFunctionsHandlerFunc
BaseActivationHandler func(location common.Location) *VariableActivation
Debugger *Debugger
// OnStatement is triggered when a statement is about to be executed
OnStatement OnStatementFunc
// OnLoopIteration is triggered when a loop iteration is about to be executed
Expand Down
83 changes: 82 additions & 1 deletion runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ type UUIDHandlerFunc func() (uint64, error)
// CompositeTypeHandlerFunc is a function that loads composite types.
type CompositeTypeHandlerFunc func(location common.Location, typeID TypeID) *sema.CompositeType

// CompositeValueFunctionsHandlerFunc is a function that loads composite value functions.
type CompositeValueFunctionsHandlerFunc func(
inter *Interpreter,
locationRange LocationRange,
compositeValue *CompositeValue,
) map[string]FunctionValue

// CompositeTypeCode contains the "prepared" / "callable" "code"
// for the functions and the destructor of a composite
// (contract, struct, resource, event).
Expand Down Expand Up @@ -1280,7 +1287,7 @@ func (interpreter *Interpreter) declareNonEnumCompositeValue(
address,
)

value.InjectedFields = injectedFields
value.injectedFields = injectedFields
value.Functions = functions
value.Destructor = destructorFunction

Expand Down Expand Up @@ -4791,6 +4798,80 @@ func (interpreter *Interpreter) GetContractComposite(contractLocation common.Add
return contractValue, nil
}

func GetNativeCompositeValueComputedFields(v *CompositeValue) map[string]ComputedField {
switch v.QualifiedIdentifier {
case sema.PublicKeyType.Identifier:
return map[string]ComputedField{
sema.PublicKeyTypePublicKeyFieldName: func(interpreter *Interpreter, locationRange LocationRange) Value {
publicKeyValue := v.GetField(interpreter, locationRange, sema.PublicKeyTypePublicKeyFieldName)
return publicKeyValue.Transfer(
interpreter,
locationRange,
atree.Address{},
false,
nil,
nil,
)
},
}
}

return nil
}

func (interpreter *Interpreter) GetCompositeValueComputedFields(v *CompositeValue) map[string]ComputedField {

var computedFields map[string]ComputedField
if v.Location == nil {
computedFields = GetNativeCompositeValueComputedFields(v)
if computedFields != nil {
return computedFields
}
}

// TODO: add handler to config
turbolent marked this conversation as resolved.
Show resolved Hide resolved

return nil
}

func (interpreter *Interpreter) GetCompositeValueInjectedFields(v *CompositeValue) map[string]Value {
config := interpreter.SharedState.Config
injectedCompositeFieldsHandler := config.InjectedCompositeFieldsHandler
if injectedCompositeFieldsHandler == nil {
return nil
}

return injectedCompositeFieldsHandler(
interpreter,
v.Location,
v.QualifiedIdentifier,
v.Kind,
)
}

func (interpreter *Interpreter) GetCompositeValueFunctions(
v *CompositeValue,
locationRange LocationRange,
) map[string]FunctionValue {

var functions map[string]FunctionValue

typeID := v.TypeID()

sharedState := interpreter.SharedState

compositeValueFunctionsHandler := sharedState.Config.CompositeValueFunctionsHandler
if compositeValueFunctionsHandler != nil {
functions = compositeValueFunctionsHandler(interpreter, locationRange, v)
if functions != nil {
return functions
}
}

compositeCodes := sharedState.typeCodes.CompositeCodes
return compositeCodes[typeID].CompositeFunctions
}

func (interpreter *Interpreter) GetCompositeType(
location common.Location,
qualifiedIdentifier string,
Expand Down
Loading
Loading