diff --git a/README.md b/README.md index 6df4fa17..7a04fa79 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,8 @@ print, log macros and more all the time (like canonical reformat using `grol -format` and wasm/online version etc) +automatic memoization + See also [sample.gr](examples/sample.gr) and others in that folder, that you can run with ``` gorepl examples/*.gr diff --git a/eval/eval.go b/eval/eval.go index 22062ac9..3a572bb1 100644 --- a/eval/eval.go +++ b/eval/eval.go @@ -1,6 +1,7 @@ package eval import ( + "bytes" "fmt" "io" "math" @@ -14,13 +15,19 @@ import ( ) type State struct { - env *object.Environment - Out io.Writer - NoLog bool // turn log() into print() (for EvalString) + env *object.Environment + Out io.Writer + LogOut io.Writer + NoLog bool // turn log() into print() (for EvalString) + cache Cache } func NewState() *State { - return &State{env: object.NewEnvironment(), Out: os.Stdout} + return &State{env: object.NewEnvironment(), Out: os.Stdout, LogOut: os.Stdout, cache: NewCache()} +} + +func (s *State) ResetCache() { + s.cache = NewCache() } // Forward to env to count the number of bindings. Used mostly to know if there are any macros. @@ -150,7 +157,9 @@ func (s *State) evalInternal(node any) object.Object { case *ast.FunctionLiteral: params := node.Parameters body := node.Body - return object.Function{Parameters: params, Env: s.env, Body: body} + fn := object.Function{Parameters: params, Env: s.env, Body: body} + fn.SetCacheKey() // sets cache key + return fn case *ast.CallExpression: f := s.evalInternal(node.Function) name := node.Function.Value().Literal() @@ -243,14 +252,17 @@ func (s *State) evalBuiltin(node *ast.Builtin) object.Object { } doLog := node.Type() != token.PRINT if s.NoLog && doLog { - doLog = false buf.WriteRune('\n') // log() has a implicit newline when using log.Xxx, print() doesn't. } - if doLog { + if doLog && !s.NoLog { // Consider passing the arguments to log instead of making a string concatenation. log.Printf("%s", buf.String()) } else { - _, err := s.Out.Write([]byte(buf.String())) + where := s.Out + if doLog { + where = s.LogOut + } + _, err := where.Write([]byte(buf.String())) if err != nil { log.Warnf("print: %v", err) } @@ -323,15 +335,28 @@ func (s *State) applyFunction(name string, fn object.Object, args []object.Objec if !ok { return object.Error{Value: ""} } + if v, output, ok := s.cache.Get(function.CacheKey, args); ok { + log.Debugf("Cache hit for %s %v", function.CacheKey, args) + _, _ = s.Out.Write(output) + return v + } nenv, oerr := extendFunctionEnv(name, function, args) if oerr != nil { return *oerr } curState := s.env s.env = nenv + oldOut := s.Out + buf := bytes.Buffer{} + s.Out = &buf res := s.Eval(function.Body) // Need to have the return value unwrapped. Fixes bug #46 // restore the previous env/state. s.env = curState + s.Out = oldOut + output := buf.Bytes() + _, _ = s.Out.Write(output) + s.cache.Set(function.CacheKey, args, res, output) + log.Debugf("Cache miss for %s %v", function.CacheKey, args) return res } diff --git a/eval/eval_test.go b/eval/eval_test.go index c1ad2cf2..36549766 100644 --- a/eval/eval_test.go +++ b/eval/eval_test.go @@ -15,6 +15,7 @@ func TestEvalIntegerExpression(t *testing.T) { input string expected int64 }{ + {`f=func(x) {len(x)}; f([1,2,3])`, 3}, {"(3)\n(4)", 4}, // expression on new line should be... new. {"5 // is 5", 5}, {"10", 10}, @@ -56,7 +57,6 @@ func(n) { }(5) `, 120}, } - for i, tt := range tests { evaluated := testEval(t, tt.input) r := testIntegerObject(t, evaluated, tt.expected) diff --git a/eval/memo.go b/eval/memo.go new file mode 100644 index 00000000..53c2166b --- /dev/null +++ b/eval/memo.go @@ -0,0 +1,54 @@ +package eval + +import ( + "grol.io/grol/object" +) + +const MaxArgs = 4 + +type CacheKey struct { + Fn string + Args [MaxArgs]object.Object +} + +type CacheValue struct { + Result object.Object + Output []byte +} + +type Cache map[CacheKey]CacheValue + +func NewCache() Cache { + return make(Cache) +} + +func (c Cache) Get(fn string, args []object.Object) (object.Object, []byte, bool) { + if len(args) > MaxArgs { + return nil, nil, false + } + key := CacheKey{Fn: fn} + for i, v := range args { + // Can't hash functions, arrays, maps arguments (yet). + if !object.Hashable(v) { + return nil, nil, false + } + key.Args[i] = v + } + result, ok := c[key] + return result.Result, result.Output, ok +} + +func (c Cache) Set(fn string, args []object.Object, result object.Object, output []byte) { + if len(args) > MaxArgs { + return + } + key := CacheKey{Fn: fn} + for i, v := range args { + // Can't hash functions arguments (yet). + if !object.Hashable(v) { + return + } + key.Args[i] = v + } + c[key] = CacheValue{Result: result, Output: output} +} diff --git a/examples/fib.gr b/examples/fib.gr new file mode 100644 index 00000000..957d1ac6 --- /dev/null +++ b/examples/fib.gr @@ -0,0 +1,12 @@ +fib = func(x) { + if x <= 0 { + return 0 + } + if x == 1 { + return 1 + } + fib(x - 1) + fib(x - 2) +} +r = fib(35) +log("fib(35) =", r) +r diff --git a/main_test.txtar b/main_test.txtar index ee7fbe5f..fe05fa39 100644 --- a/main_test.txtar +++ b/main_test.txtar @@ -31,12 +31,19 @@ stdout '>' # sample_test.gr grol sample_test.gr !stderr 'Errors' -cmp stdout sample_test_stdout.gr +cmp stdout sample_test_stdout stderr 'I] Running sample_test.gr' stderr 'called fact 5' stderr 'called fact 1' stderr 'I] All done' +# fib_50.gr +grol fib_50.gr +!stderr 'Errors' +cmp stdout fib50_stdout +stderr 'I] Running fib_50.gr' +stderr 'I] All done' + # Bug repro, return aborts the whole program grol -c 'f=func(){return 1;2};log(f());f();3' stdout '^1\n3$' @@ -86,8 +93,21 @@ first(m["key"]) // get the value from key from map, which is an array, and the f // ^^^ gorepl sample.gr should output 120 --- sample_test_stdout.gr -- +-- fib_50.gr -- +fib = func(x) { + if (x == 0) { + return 0 + } + if (x == 1) { + return 1 + } + fib(x - 1) + fib(x - 2) +} +fib(50) +-- sample_test_stdout -- macro test: greater m is: {73:29,"key":[120,"abc",73]} . Outputting a smiley: 😀 120 +-- fib50_stdout -- +12586269025 diff --git a/object/object.go b/object/object.go index c6260c92..0e95c054 100644 --- a/object/object.go +++ b/object/object.go @@ -48,6 +48,15 @@ type Number interface { } */ +func Hashable(o Object) bool { + switch o.Type() { //nolint:exhaustive // We have all the types that are hashable + default for the others. + case INTEGER, FLOAT, BOOLEAN, NIL, ERROR, RETURN, QUOTE, STRING: + return true + default: + return false + } +} + func NativeBoolToBooleanObject(input bool) Boolean { if input { return TRUE @@ -170,6 +179,7 @@ func (rv ReturnValue) Inspect() string { return rv.Value.Inspect() } type Function struct { Parameters []ast.Node + CacheKey string Body *ast.Statements Env *Environment } @@ -186,9 +196,10 @@ func WriteStrings(out *strings.Builder, list []Object, before, sep, after string } func (f Function) Type() Type { return FUNC } -func (f Function) Inspect() string { - out := strings.Builder{} +// Must be called after the function is fully initialized. +func (f *Function) SetCacheKey() string { + out := strings.Builder{} out.WriteString("func") out.WriteString("(") ps := &ast.PrintState{Out: &out, Compact: true} @@ -196,7 +207,15 @@ func (f Function) Inspect() string { out.WriteString("){") f.Body.PrettyPrint(ps) out.WriteString("}") - return out.String() + f.CacheKey = out.String() + return f.CacheKey +} + +func (f Function) Inspect() string { + if f.CacheKey == "" { + panic("CacheKey not set") + } + return f.CacheKey } type Array struct { diff --git a/repl/repl.go b/repl/repl.go index 7d7ffbe4..d6d3d9b8 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -68,6 +68,7 @@ func EvalString(what string) (res string, errs []string, formatted string) { macroState := eval.NewState() out := &strings.Builder{} s.Out = out + s.LogOut = out s.NoLog = true _, errs, formatted = EvalOne(s, macroState, what, out, Options{All: true, ShowEval: true, NoColor: true, Compact: CompactEvalString}) diff --git a/repl/repl_test.go b/repl/repl_test.go index 1ef52c10..59ce93ea 100644 --- a/repl/repl_test.go +++ b/repl/repl_test.go @@ -29,6 +29,43 @@ Factorial of 5 is 120` + " \n120\n" // there is an extra space before \n that vs } } +func TestEvalMemoPrint(t *testing.T) { + s := ` +fact=func(n) { + log("logger fact", n) // should be actual executions of the function only + print("print fact", n, ".\n") // should get recorded + if (n<=1) { + return 1 + } + n*self(n-1) +} +fact(3) +print("---\n") +result = fact(5) +print("Factorial of 5 is", result, ".\n") // print to stdout +result` + expected := `logger fact 3 +logger fact 2 +logger fact 1 +print fact 3 . +print fact 2 . +print fact 1 . +--- +logger fact 5 +logger fact 4 +print fact 5 . +print fact 4 . +print fact 3 . +print fact 2 . +print fact 1 . +Factorial of 5 is 120 . +120 +` + if got, errs, _ := repl.EvalString(s); got != expected || len(errs) > 0 { + t.Errorf("EvalString() got %v\n---\n%s\n---want---\n%s\n---", errs, got, expected) + } +} + func TestEvalString50(t *testing.T) { s := ` fact=func(n) { // function diff --git a/wasm/grol_wasm.html b/wasm/grol_wasm.html index 33e3532c..1d84aa3d 100644 --- a/wasm/grol_wasm.html +++ b/wasm/grol_wasm.html @@ -119,6 +119,16 @@
Hit enter or click (will also format the code, also try compact) + + +
@@ -129,3 +139,11 @@
GROL
+