Skip to content

Commit

Permalink
Prevent mutation while iterating
Browse files Browse the repository at this point in the history
  • Loading branch information
SupunS committed Oct 19, 2023
1 parent a7dc05e commit f8c5886
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 59 deletions.
20 changes: 10 additions & 10 deletions runtime/interpreter/interpreter_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (interpreter *Interpreter) VisitWhileStatement(statement *ast.WhileStatemen

var intOne = NewUnmeteredIntValueFromInt64(1)

func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) StatementResult {
func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) (result StatementResult) {

interpreter.activations.PushNewWithCurrent()
defer interpreter.activations.Pop()
Expand All @@ -339,28 +339,28 @@ func (interpreter *Interpreter) VisitForStatement(statement *ast.ForStatement) S
}

forStmtTypes := interpreter.Program.Elaboration.ForStatementType(statement)
iterator := iterable.Iterator(interpreter, forStmtTypes.ValueVariableType, locationRange)

var index IntValue
if statement.Index != nil {
index = NewIntValueFromInt64(interpreter, 0)
}

for {
value := iterator.Next(interpreter)
if value == nil {
return nil
}

statementResult, done := interpreter.visitForStatementBody(statement, index, value)
executeBody := func(value Value) (done bool) {
var statementResult StatementResult
statementResult, done = interpreter.visitForStatementBody(statement, index, value)
if done {
return statementResult
result = statementResult
}

if statement.Index != nil {
index = index.Plus(interpreter, intOne, locationRange).(IntValue)
}

return
}

iterable.ForEach(interpreter, forStmtTypes.ValueVariableType, executeBody, locationRange)
return
}

func (interpreter *Interpreter) visitForStatementBody(
Expand Down
121 changes: 78 additions & 43 deletions runtime/interpreter/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,13 @@ type ContractValue interface {
// IterableValue is a value which can be iterated over, e.g. with a for-loop
type IterableValue interface {
Value
Iterator(interpreter *Interpreter, resultType sema.Type, locationRange LocationRange) ValueIterator
Iterator(interpreter *Interpreter, locationRange LocationRange) ValueIterator
ForEach(
interpreter *Interpreter,
elementType sema.Type,
procedure func(value Value) bool,
locationRange LocationRange,
)
}

// ValueIterator is an iterator which returns values.
Expand Down Expand Up @@ -1580,12 +1586,31 @@ func (v *StringValue) ConformsToStaticType(
return true
}

func (v *StringValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator {
func (v *StringValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator {
return StringValueIterator{
graphemes: uniseg.NewGraphemes(v.Str),
}
}

func (v *StringValue) ForEach(
interpreter *Interpreter,
_ sema.Type,
procedure func(value Value) bool,
locationRange LocationRange,
) {
iterator := v.Iterator(interpreter, locationRange)
for {
value := iterator.Next(interpreter)
if value == nil {
return
}

if procedure(value) {
return
}
}
}

type StringValueIterator struct {
graphemes *uniseg.Graphemes
}
Expand Down Expand Up @@ -1614,7 +1639,7 @@ type ArrayValueIterator struct {
atreeIterator *atree.ArrayIterator
}

func (v *ArrayValue) Iterator(_ *Interpreter, _ sema.Type, _ LocationRange) ValueIterator {
func (v *ArrayValue) Iterator(_ *Interpreter, _ LocationRange) ValueIterator {
arrayIterator, err := v.array.Iterator()
if err != nil {
panic(errors.NewExternalError(err))
Expand Down Expand Up @@ -3244,6 +3269,15 @@ func (v *ArrayValue) Map(
)
}

func (v *ArrayValue) ForEach(
interpreter *Interpreter,
_ sema.Type,
procedure func(value Value) bool,
_ LocationRange,
) {
v.Iterate(interpreter, procedure)
}

// NumberValue
type NumberValue interface {
ComparableValue
Expand Down Expand Up @@ -20198,34 +20232,51 @@ func (*StorageReferenceValue) DeepRemove(_ *Interpreter) {
func (*StorageReferenceValue) isReference() {}

func (v *StorageReferenceValue) Iterator(
_ *Interpreter,
_ LocationRange,
) ValueIterator {
// Not used for now
panic(errors.NewUnreachableError())
}

func (v *StorageReferenceValue) ForEach(
interpreter *Interpreter,
resultType sema.Type,
elementType sema.Type,
procedure func(value Value) bool,
locationRange LocationRange,
) ValueIterator {
) {
referencedValue := v.mustReferencedValue(interpreter, locationRange)
return referenceValueIterator(interpreter, referencedValue, resultType, locationRange)
forEachReference(interpreter, referencedValue, elementType, procedure, locationRange)
}

func referenceValueIterator(
func forEachReference(
interpreter *Interpreter,
referencedValue Value,
resultType sema.Type,
elementType sema.Type,
procedure func(value Value) bool,
locationRange LocationRange,
) ValueIterator {
) {
referencedIterable, ok := referencedValue.(IterableValue)
if !ok {
panic(errors.NewUnreachableError())
}

referencedValueIterator := referencedIterable.Iterator(interpreter, resultType, locationRange)
referenceType, isResultReference := sema.GetReferenceType(elementType)

_, isResultReference := sema.GetReferenceType(resultType)
updatedProcedure := func(value Value) bool {
if isResultReference {
value = interpreter.getReferenceValue(value, elementType)
}

return ReferenceValueIterator{
iterator: referencedValueIterator,
resultType: resultType,
isResultReference: isResultReference,
return procedure(value)
}

referencedElementType := elementType
if isResultReference {
referencedElementType = referenceType.Type
}

referencedIterable.ForEach(interpreter, referencedElementType, updatedProcedure, locationRange)
}

// EphemeralReferenceValue
Expand Down Expand Up @@ -20575,37 +20626,21 @@ func (*EphemeralReferenceValue) DeepRemove(_ *Interpreter) {
func (*EphemeralReferenceValue) isReference() {}

func (v *EphemeralReferenceValue) Iterator(
interpreter *Interpreter,
resultType sema.Type,
locationRange LocationRange,
_ *Interpreter,
_ LocationRange,
) ValueIterator {
referencedValue := v.MustReferencedValue(interpreter, locationRange)
return referenceValueIterator(interpreter, referencedValue, resultType, locationRange)
}

// ReferenceValueIterator

type ReferenceValueIterator struct {
iterator ValueIterator
resultType sema.Type
isResultReference bool
// Not used for now
panic(errors.NewUnreachableError())
}

var _ ValueIterator = ReferenceValueIterator{}

func (i ReferenceValueIterator) Next(interpreter *Interpreter) Value {
element := i.iterator.Next(interpreter)

if element == nil {
return nil
}

// For non-primitive values, return a reference.
if i.isResultReference {
return interpreter.getReferenceValue(element, i.resultType)
}

return element
func (v *EphemeralReferenceValue) ForEach(
interpreter *Interpreter,
elementType sema.Type,
procedure func(value Value) bool,
locationRange LocationRange,
) {
referencedValue := v.MustReferencedValue(interpreter, locationRange)
forEachReference(interpreter, referencedValue, elementType, procedure, locationRange)
}

// AddressValue
Expand Down
39 changes: 33 additions & 6 deletions runtime/tests/interpreter/for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) {
require.NoError(t, err)
})

t.Run("Moved resource element", func(t *testing.T) {
t.Run("Mutating reference to resource array", func(t *testing.T) {
t.Parallel()

inter := parseCheckAndInterpret(t, `
Expand All @@ -482,6 +482,7 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) {
for element in arrayRef {
// Move the actual element
// This mutation should fail.
let oldElement <- arrayRef.remove(at: 0)
// Use the element reference
Expand All @@ -495,7 +496,33 @@ func TestInterpretEphemeralReferencesInForLoop(t *testing.T) {
`)

_, err := inter.Invoke("main")
require.ErrorAs(t, err, &interpreter.InvalidatedResourceReferenceError{})
require.ErrorAs(t, err, &interpreter.ContainerMutatedDuringIterationError{})
})

t.Run("Mutating reference to struct array", func(t *testing.T) {
t.Parallel()

inter := parseCheckAndInterpret(t, `
struct Foo{
fun sayHello() {}
}
fun main() {
let array = [Foo()]
let arrayRef = &array as auth(Mutate) &[Foo]
for element in arrayRef {
// Move the actual element
let oldElement = arrayRef.remove(at: 0)
// Use the element reference
element.sayHello()
}
}
`)

_, err := inter.Invoke("main")
require.NoError(t, err)
})
}

Expand All @@ -510,7 +537,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) {

inter, _ := testAccount(t, address, true, nil, `
fun test() {
var let = ["Hello", "World", "Foo", "Bar"]
let array = ["Hello", "World", "Foo", "Bar"]
account.storage.save(array, to: /storage/array)
let arrayRef = account.storage.borrow<&[String]>(from: /storage/array)!
Expand All @@ -533,7 +560,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) {
struct Foo{}
fun test() {
var let = [Foo(), Foo()]
let array = [Foo(), Foo()]
account.storage.save(array, to: /storage/array)
let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)!
Expand All @@ -556,7 +583,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) {
resource Foo{}
fun test() {
var let <- [ <- create Foo(), <- create Foo()]
let array <- [ <- create Foo(), <- create Foo()]
account.storage.save(<- array, to: /storage/array)
let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)!
Expand All @@ -579,7 +606,7 @@ func TestInterpretStorageReferencesInForLoop(t *testing.T) {
resource Foo{}
fun test() {
var let <- [ <- create Foo(), <- create Foo()]
let array <- [ <- create Foo(), <- create Foo()]
account.storage.save(<- array, to: /storage/array)
let arrayRef = account.storage.borrow<&[Foo]>(from: /storage/array)!
Expand Down

0 comments on commit f8c5886

Please sign in to comment.