diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index f1f7f9d..5181518 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -20,7 +20,7 @@ func Eval(node ast.Node) object.Object { switch node := node.(type) { // 2. If the node is a *ast.Program, evaluate the statements case *ast.Program: - return evalStatements(node.Statements) + return evalProgram(node) // 3. If the node is a *ast.ExpressionStatement, recursively evaluate the expression case *ast.ExpressionStatement: return Eval(node.Expression) @@ -37,10 +37,12 @@ func Eval(node ast.Node) object.Object { right := Eval(node.Right) return evalInfixExpression(node.Operator, left, right) case *ast.BlockStatement: - return evalStatements(node.Statements) + return evalBlockStatement(node) case *ast.IfExpression: return evalIfExpression(node) - + case *ast.ReturnStatement: + value := Eval(node.ReturnValue) + return &object.ReturnValue{Value: value} } return nil @@ -154,6 +156,35 @@ func evalStatements(statements []ast.Statement) object.Object { var evaluatedResults object.Object for _, statement := range statements { evaluatedResults = Eval(statement) + + // if the statement is a return statement, return the value + if evaluatedResults, ok := evaluatedResults.(*object.ReturnValue); ok { + return evaluatedResults.Value + } } return evaluatedResults +} + + +func evalBlockStatement(block *ast.BlockStatement) object.Object { + var result object.Object + for _, statement := range block.Statements{ + result = Eval(statement) + if result != nil && result.Type() == object.RETURN_VALUE_OBJ { + return result + } + } + return result +} + +func evalProgram(program *ast.Program) object.Object { + var evaluatedResult object.Object + for _, statement := range program.Statements { + evaluatedResult = Eval(statement) + + if returnValue, ok := evaluatedResult.(*object.ReturnValue); ok { + return returnValue.Value + } + } + return evaluatedResult } \ No newline at end of file diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 5ec9b16..e775c31 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -96,14 +96,31 @@ func TestIfElseExpression(t *testing.T) { } } -func testNullObject(t *testing.T, evaluated object.Object) bool { - if evaluated != NULL { - t.Errorf("object is not NULL. got=%T (%+v)", evaluated, evaluated) - return false +func TestReturnStatements(t *testing.T) { + tests := []struct { + input string + expected int64 + }{ + {"return 10;", 10}, + {"return 10; 9;", 10}, + {"return 50 / 5; 9;", 10}, + {"10; return 50; 5;", 50}, + {` + if (10 > 1) { + if (10 > 1) { + return 10; + } + return 1; + } + `, 10, + }, } - return true -} + for _, test := range tests { + evaluated := testEval(test.input) + testIntegerObject(t, evaluated, test.expected) + } +} // evaluating prefix expressions func TestBangOperator(t *testing.T) { tests := []struct { @@ -155,6 +172,13 @@ func testIntegerObject(t *testing.T, obj object.Object ,expected int64) bool { } +func testNullObject(t *testing.T, evaluated object.Object) bool { + if evaluated != NULL { + t.Errorf("object is not NULL. got=%T (%+v)", evaluated, evaluated) + return false + } + return true +} func testEval(input string) object.Object { lexer := lexer.New(input) parser := parser.New(lexer) diff --git a/object/object.go b/object/object.go index 20f29f9..c4eb641 100644 --- a/object/object.go +++ b/object/object.go @@ -8,6 +8,7 @@ const ( INTEGER_OBJ = "INTEGER" BOOLEAN_OBJ = "BOOLEAN" NULL_OBJ = "NULL" + RETURN_VALUE_OBJ = "RETURN_VALUE" ) type Object interface { @@ -37,4 +38,9 @@ type Null struct {} func (n *Null) Type() ObjectType { return NULL_OBJ } func (n *Null) Inspect() string { return "null" } +type ReturnValue struct { + Value Object +} +func (rv *ReturnValue) Type() ObjectType { return RETURN_VALUE_OBJ } +func (rv *ReturnValue) Inspect() string { return rv.Value.Inspect() }