From b10595e9d77a116151b0752ed0eb43072d29c2b1 Mon Sep 17 00:00:00 2001 From: Laurent Demailly Date: Tue, 3 Sep 2024 13:34:30 -0700 Subject: [PATCH] Reduce greatly the peppering of object.Value() and return the actual type for REFERENCE instead of referenced object's type --- eval/eval.go | 89 +++++++++++++++++++++-------------------- eval/eval_api.go | 5 ++- eval/eval_test.go | 4 +- extensions/extension.go | 8 ++-- object/object.go | 2 +- 5 files changed, 56 insertions(+), 52 deletions(-) diff --git a/eval/eval.go b/eval/eval.go index 783fe177..26d3e74f 100644 --- a/eval/eval.go +++ b/eval/eval.go @@ -57,7 +57,7 @@ func (s *State) evalIndexAssigment(which ast.Node, index, value object.Object) o if index.Type() != object.INTEGER { return s.NewError("index assignment to array with non integer index: " + index.Inspect()) } - idx := object.Value(index).(object.Integer).Value + idx := index.(object.Integer).Value if idx < 0 { idx = int64(object.Len(val)) + idx } @@ -72,7 +72,7 @@ func (s *State) evalIndexAssigment(which ast.Node, index, value object.Object) o } return value case object.MAP: - m := object.Value(val).(object.Map) + m := val.(object.Map) m = m.Set(index, value) oerr := s.env.Set(id.Literal(), m) if oerr.Type() == object.ERROR { @@ -104,7 +104,7 @@ func (s *State) evalPrefixIncrDecr(operator token.Type, node ast.Node) object.Ob log.LogVf("eval prefix %s", ast.DebugString(node)) nv := node.Value() if nv.Type() != token.IDENT { - return s.NewError("can't increment/decrement " + nv.DebugString()) + return s.NewError("can't prefix increment/decrement " + nv.DebugString()) } id := nv.Literal() val, ok := s.env.Get(id) @@ -122,7 +122,7 @@ func (s *State) evalPrefixIncrDecr(operator token.Type, node ast.Node) object.Ob case object.Float: return s.env.Set(id, object.Float{Value: val.Value + float64(toAdd)}) // So PI++ fails not silently. default: - return s.NewError("can't increment/decrement " + val.Type().String()) + return s.NewError("can't prefix increment/decrement " + val.Type().String()) } } @@ -142,15 +142,15 @@ func (s *State) evalPostfixExpression(node *ast.PostfixExpression) object.Object default: return s.NewError("unknown postfix operator: " + node.Type().String()) } - var oerr object.Object val = object.Value(val) + var oerr object.Object switch val := val.(type) { case object.Integer: oerr = s.env.Set(id, object.Integer{Value: val.Value + toAdd}) case object.Float: oerr = s.env.Set(id, object.Float{Value: val.Value + float64(toAdd)}) // So PI++ fails not silently. default: - return s.NewError("can't increment/decrement " + val.Type().String()) + return s.NewError("can't postfix increment/decrement " + val.Type().String()) } if oerr.Type() == object.ERROR { return oerr @@ -196,7 +196,7 @@ func (s *State) evalInternal(node any) object.Object { //nolint:funlen,gocyclo / log.LogVf("eval infix %s", node.DebugString()) // Eval and not evalInternal because we need to unwrap "return". if node.Token.Type() == token.ASSIGN || node.Token.Type() == token.DEFINE { - return s.evalAssignment(s.Eval(node.Right), node) + return s.evalAssignment(s.evalInternal(node.Right), node) } // Humans expect left to right evaluations. left := s.Eval(node.Left) @@ -260,7 +260,7 @@ func (s *State) evalInternal(node any) object.Object { //nolint:funlen,gocyclo / } return fn case *ast.CallExpression: - f := s.evalInternal(node.Function) + f := s.Eval(node.Function) if f.Type() == object.ERROR { return f } @@ -269,7 +269,7 @@ func (s *State) evalInternal(node any) object.Object { //nolint:funlen,gocyclo / return *oerr } if f.Type() == object.EXTENSION { - return s.applyExtension(object.Value(f).(object.Extension), args) + return s.applyExtension(f.(object.Extension), args) } name := node.Function.Value().Literal() return s.applyFunction(name, f, args) @@ -282,7 +282,7 @@ func (s *State) evalInternal(node any) object.Object { //nolint:funlen,gocyclo / case *ast.MapLiteral: return s.evalMapLiteral(node) case *ast.IndexExpression: - return s.evalIndexExpression(s.evalInternal(node.Left), node) + return s.evalIndexExpression(s.Eval(node.Left), node) case *ast.Comment: return object.NULL } @@ -306,7 +306,7 @@ func (s *State) evalIndexExpression(left object.Object, node *ast.IndexExpressio rangeExp := node.Index.(*ast.InfixExpression) return s.evalIndexRangeExpression(left, rangeExp.Left, rangeExp.Right) } - index = s.evalInternal(node.Index) + index = s.Eval(node.Index) if index.Type() == object.ERROR { return index } @@ -345,7 +345,7 @@ func (s *State) evalPrintLogError(node *ast.Builtin) object.Object { return r } if isString := r.Type() == object.STRING; isString { - buf.WriteString(object.Value(r).(object.String).Value) + buf.WriteString(r.(object.String).Value) } else { buf.WriteString(r.Inspect()) } @@ -415,20 +415,20 @@ func (s *State) evalBuiltin(node *ast.Builtin) object.Object { } func (s *State) evalIndexRangeExpression(left object.Object, leftIdx, rightIdx ast.Node) object.Object { - leftIndex := s.evalInternal(leftIdx) + leftIndex := s.Eval(leftIdx) nilRight := (rightIdx == nil) var rightIndex object.Object if nilRight { log.Debugf("eval index %s[%s:]", left.Inspect(), leftIndex.Inspect()) } else { - rightIndex = s.evalInternal(rightIdx) + rightIndex = s.Eval(rightIdx) log.Debugf("eval index %s[%s:%s]", left.Inspect(), leftIndex.Inspect(), rightIndex.Inspect()) } if leftIndex.Type() != object.INTEGER || (!nilRight && rightIndex.Type() != object.INTEGER) { return s.NewError("range index not integer") } num := object.Len(left) - l := object.Value(leftIndex).(object.Integer).Value + l := leftIndex.(object.Integer).Value if l < 0 { // negative is relative to the end. l = int64(num) + l } @@ -436,7 +436,7 @@ func (s *State) evalIndexRangeExpression(left object.Object, leftIdx, rightIdx a if nilRight { r = int64(num) } else { - r = object.Value(rightIndex).(object.Integer).Value + r = rightIndex.(object.Integer).Value if r < 0 { r = int64(num) + r } @@ -448,7 +448,7 @@ func (s *State) evalIndexRangeExpression(left object.Object, leftIdx, rightIdx a r = min(r, int64(num)) switch { case left.Type() == object.STRING: - str := object.Value(left).(object.String).Value + str := left.(object.String).Value return object.String{Value: str[l:r]} case left.Type() == object.ARRAY: return object.NewArray(object.Elements(left)[l:r]) @@ -468,8 +468,8 @@ func (s *State) evalIndexExpressionIdx(left, index object.Object) object.Object } switch { case left.Type() == object.STRING && idxOrZero.Type() == object.INTEGER: - idx := object.Value(idxOrZero).(object.Integer).Value - str := object.Value(left).(object.String).Value + idx := idxOrZero.(object.Integer).Value + str := left.(object.String).Value num := len(str) if idx < 0 { // negative is relative to the end. idx = int64(num) + idx @@ -490,7 +490,7 @@ func (s *State) evalIndexExpressionIdx(left, index object.Object) object.Object } func evalMapIndexExpression(assoc, key object.Object) object.Object { - m := object.Value(assoc).(object.Map) + m := assoc.(object.Map) v, ok := m.Get(key) if !ok { return object.NULL @@ -499,7 +499,7 @@ func evalMapIndexExpression(assoc, key object.Object) object.Object { } func evalArrayIndexExpression(array, index object.Object) object.Object { - idx := object.Value(index).(object.Integer).Value + idx := index.(object.Integer).Value maxV := int64(object.Len(array) - 1) if idx < 0 { // negative is relative to the end. idx = maxV + 1 + idx // elsewhere we use len() but here maxV is len-1 @@ -538,7 +538,7 @@ func (s *State) applyExtension(fn object.Extension, args []object.Object) object } // Auto promote integer to float if needed. if fn.ArgTypes[i] == object.FLOAT && arg.Type() == object.INTEGER { - args[i] = object.Float{Value: float64(object.Value(arg).(object.Integer).Value)} + args[i] = object.Float{Value: float64(arg.(object.Integer).Value)} continue } if fn.ArgTypes[i] != arg.Type() { @@ -553,7 +553,7 @@ func (s *State) applyExtension(fn object.Extension, args []object.Object) object } func (s *State) applyFunction(name string, fn object.Object, args []object.Object) object.Object { - function, ok := object.Value(fn).(object.Function) + function, ok := fn.(object.Function) if !ok { return s.NewError("not a function: " + fn.Type().String() + ":" + fn.Inspect()) } @@ -637,10 +637,11 @@ func extendFunctionEnv( name, len(args), atLeast, n)} } for paramIdx, param := range params { - oerr := env.CreateOrSet(param.Value().Literal(), args[paramIdx], true) + // By definition function parameters are local copies, deref argument values: + oerr := env.CreateOrSet(param.Value().Literal(), object.Value(args[paramIdx]), true) log.LogVf("set %s to %s - %s", param.Value().Literal(), args[paramIdx].Inspect(), oerr.Inspect()) if oerr.Type() == object.ERROR { - oe, _ := object.Value(oerr).(object.Error) + oe, _ := oerr.(object.Error) return nil, &oe } } @@ -662,7 +663,7 @@ func (s *State) evalExpressions(exps []ast.Node) ([]object.Object, *object.Error for _, e := range exps { evaluated := s.evalInternal(e) if rt := evaluated.Type(); rt == object.ERROR { - oerr := object.Value(evaluated).(object.Error) + oerr := evaluated.(object.Error) return nil, &oerr } result = append(result, evaluated) @@ -761,14 +762,14 @@ func (s *State) evalForSpecialForms(fe *ast.ForExpression) (object.Object, bool) if end.Type() != object.INTEGER { return s.NewError("for var = n:m m not an integer: " + end.Inspect()), true } - startInt := object.Value(start).(object.Integer) - return s.evalForInteger(fe, &startInt, object.Value(end).(object.Integer), name), true + startInt := start.(object.Integer) + return s.evalForInteger(fe, &startInt, end.(object.Integer), name), true } // Evaluate: - v := s.Eval(ie.Right) + v := s.evalInternal(ie.Right) switch v.Type() { case object.INTEGER: - return s.evalForInteger(fe, nil, object.Value(v).(object.Integer), name), true + return s.evalForInteger(fe, nil, v.(object.Integer), name), true case object.ERROR: return v, true case object.ARRAY, object.MAP, object.STRING: @@ -837,7 +838,7 @@ func (s *State) evalForExpression(fe *ast.ForExpression) object.Object { case object.ERROR: return condition case object.INTEGER: - return s.evalForInteger(fe, nil, object.Value(condition).(object.Integer), "") + return s.evalForInteger(fe, nil, condition.(object.Integer), "") default: return s.NewError("for condition is not a boolean nor integer nor assignment: " + condition.Inspect()) } @@ -878,7 +879,7 @@ func (s *State) evalPrefixExpression(operator token.Type, right object.Object) o return s.evalMinusPrefixOperatorExpression(right) case token.BITNOT, token.BITXOR: if right.Type() == object.INTEGER { - return object.Integer{Value: ^object.Value(right).(object.Integer).Value} + return object.Integer{Value: ^right.(object.Integer).Value} } return s.NewError("bitwise not of " + right.Inspect()) case token.PLUS: @@ -905,10 +906,10 @@ func (s *State) evalBangOperatorExpression(right object.Object) object.Object { func (s *State) evalMinusPrefixOperatorExpression(right object.Object) object.Object { switch right.Type() { case object.INTEGER: - value := object.Value(right).(object.Integer).Value + value := right.(object.Integer).Value return object.Integer{Value: -value} case object.FLOAT: - value := object.Value(right).(object.Float).Value + value := right.(object.Float).Value return object.Float{Value: -value} default: return s.NewError("minus of " + right.Inspect()) @@ -950,13 +951,13 @@ func (s *State) evalInfixExpression(operator token.Type, left, right object.Obje } func (s *State) evalStringInfixExpression(operator token.Type, left, right object.Object) object.Object { - leftVal := object.Value(left).(object.String).Value + leftVal := left.(object.String).Value switch { case operator == token.PLUS && right.Type() == object.STRING: - rightVal := object.Value(right).(object.String).Value + rightVal := right.(object.String).Value return object.String{Value: leftVal + rightVal} case operator == token.ASTERISK && right.Type() == object.INTEGER: - rightVal := object.Value(right).(object.Integer).Value + rightVal := right.(object.Integer).Value n := len(leftVal) * int(rightVal) if rightVal < 0 { return s.NewError("right operand of * on strings must be a positive integer") @@ -977,7 +978,7 @@ func (s *State) evalArrayInfixExpression(operator token.Type, left, right object return s.NewError("right operand of * on arrays must be an integer") } // TODO: go1.23 use slices.Repeat - rightVal := object.Value(right).(object.Integer).Value + rightVal := right.(object.Integer).Value if rightVal < 0 { return s.NewError("right operand of * on arrays must be a positive integer") } @@ -1000,8 +1001,8 @@ func (s *State) evalArrayInfixExpression(operator token.Type, left, right object } func evalMapInfixExpression(operator token.Type, left, right object.Object) object.Object { - leftMap := object.Value(left).(object.Map) - rightMap := object.Value(right).(object.Map) + leftMap := left.(object.Map) + rightMap := right.(object.Map) switch operator { case token.PLUS: // concat / append return leftMap.Append(rightMap) @@ -1016,8 +1017,8 @@ func evalMapInfixExpression(operator token.Type, left, right object.Object) obje // https://github.com/golang/go/issues/48522 // would need getters/setters which is not very go idiomatic. func (s *State) evalIntegerInfixExpression(operator token.Type, left, right object.Object) object.Object { - leftVal := object.Value(left).(object.Integer).Value - rightVal := object.Value(right).(object.Integer).Value + leftVal := left.(object.Integer).Value + rightVal := right.(object.Integer).Value switch operator { case token.PLUS: @@ -1058,9 +1059,9 @@ func (s *State) evalIntegerInfixExpression(operator token.Type, left, right obje func (s *State) getFloatValue(o object.Object) (float64, *object.Error) { switch o.Type() { case object.INTEGER: - return float64(object.Value(o).(object.Integer).Value), nil + return float64(o.(object.Integer).Value), nil case object.FLOAT: - return object.Value(o).(object.Float).Value, nil + return o.(object.Float).Value, nil default: e := s.NewError("not converting to float: " + o.Type().String()) return math.NaN(), &e diff --git a/eval/eval_api.go b/eval/eval_api.go index f6d0bd13..59a45341 100644 --- a/eval/eval_api.go +++ b/eval/eval_api.go @@ -139,7 +139,10 @@ func (s *State) Eval(node any) object.Object { if returnValue.ControlType != token.RETURN { return s.Errorf("unexpected control type %v outside of for loops", returnValue.ControlType) } - return returnValue.Value + result = returnValue.Value + } + if refValue, ok := result.(object.Reference); ok { + return object.Value(refValue) } return result } diff --git a/eval/eval_test.go b/eval/eval_test.go index e0fe9d4d..ffeb3123 100644 --- a/eval/eval_test.go +++ b/eval/eval_test.go @@ -106,7 +106,7 @@ func testEval(t *testing.T, input string) object.Object { } func testIntegerObject(t *testing.T, obj object.Object, expected int64) bool { - result, ok := object.Value(obj).(object.Integer) + result, ok := obj.(object.Integer) if !ok { t.Errorf("object is not Integer. got=%T (%+v)", obj, obj) return false @@ -175,7 +175,7 @@ func TestBangOperator(t *testing.T) { } func testBooleanObject(t *testing.T, obj object.Object, expected bool) { - result, ok := object.Value(obj).(object.Boolean) + result, ok := obj.(object.Boolean) if !ok { t.Errorf("object is not Boolean. got=%T (%+v)", obj, obj) return diff --git a/extensions/extension.go b/extensions/extension.go index 6766ff12..f024a9ca 100644 --- a/extensions/extension.go +++ b/extensions/extension.go @@ -294,7 +294,7 @@ func createStrFunctions() { if a.Type() != object.STRING { strs[i] = a.Inspect() } else { - strs[i] = object.Value(a).(object.String).Value + strs[i] = a.(object.String).Value } totalLen += len(strs[i]) + sepLen } @@ -353,14 +353,14 @@ func createMisc() { case object.NIL: return object.Integer{Value: 0} case object.BOOLEAN: - if object.Value(o).(object.Boolean).Value { + if o.(object.Boolean).Value { return object.Integer{Value: 1} } return object.Integer{Value: 0} case object.FLOAT: - return object.Integer{Value: int64(object.Value(o).(object.Float).Value)} + return object.Integer{Value: int64(o.(object.Float).Value)} case object.STRING: - i, serr := strconv.ParseInt(object.Value(o).(object.String).Value, 0, 64) + i, serr := strconv.ParseInt(o.(object.String).Value, 0, 64) if serr != nil { return s.Error(serr) } diff --git a/object/object.go b/object/object.go index 35d1bcc6..b856a0d7 100644 --- a/object/object.go +++ b/object/object.go @@ -1057,7 +1057,7 @@ func (r Reference) Value() Object { } func (r Reference) Unwrap(str bool) any { return r.Value().Unwrap(str) } -func (r Reference) Type() Type { return r.Value().Type() } +func (r Reference) Type() Type { return REFERENCE } func (r Reference) Inspect() string { return r.Value().Inspect() } func (r Reference) JSON(w io.Writer) error { return r.Value().JSON(w) }