Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto memoization, "Share" button on wasm/online #86

Merged
merged 12 commits into from
Aug 2, 2024
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ print, log

macros and more all the time (like canonical reformat using `grol -format` and wasm/online version etc)

automatic memoization

See also [sample.gr](examples/sample.gr) and others in that folder, that you can run with
```
gorepl examples/*.gr
Expand Down
41 changes: 33 additions & 8 deletions eval/eval.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package eval

import (
"bytes"
"fmt"
"io"
"math"
Expand All @@ -14,13 +15,19 @@ import (
)

type State struct {
env *object.Environment
Out io.Writer
NoLog bool // turn log() into print() (for EvalString)
env *object.Environment
Out io.Writer
LogOut io.Writer
NoLog bool // turn log() into print() (for EvalString)
cache Cache
}

func NewState() *State {
return &State{env: object.NewEnvironment(), Out: os.Stdout}
return &State{env: object.NewEnvironment(), Out: os.Stdout, LogOut: os.Stdout, cache: NewCache()}
}

func (s *State) ResetCache() {
s.cache = NewCache()
}

// Forward to env to count the number of bindings. Used mostly to know if there are any macros.
Expand Down Expand Up @@ -150,7 +157,9 @@ func (s *State) evalInternal(node any) object.Object {
case *ast.FunctionLiteral:
params := node.Parameters
body := node.Body
return object.Function{Parameters: params, Env: s.env, Body: body}
fn := object.Function{Parameters: params, Env: s.env, Body: body}
fn.SetCacheKey() // sets cache key
return fn
case *ast.CallExpression:
f := s.evalInternal(node.Function)
name := node.Function.Value().Literal()
Expand Down Expand Up @@ -243,14 +252,17 @@ func (s *State) evalBuiltin(node *ast.Builtin) object.Object {
}
doLog := node.Type() != token.PRINT
if s.NoLog && doLog {
doLog = false
buf.WriteRune('\n') // log() has a implicit newline when using log.Xxx, print() doesn't.
}
if doLog {
if doLog && !s.NoLog {
// Consider passing the arguments to log instead of making a string concatenation.
log.Printf("%s", buf.String())
} else {
_, err := s.Out.Write([]byte(buf.String()))
where := s.Out
if doLog {
where = s.LogOut
}
_, err := where.Write([]byte(buf.String()))
if err != nil {
log.Warnf("print: %v", err)
}
Expand Down Expand Up @@ -323,15 +335,28 @@ func (s *State) applyFunction(name string, fn object.Object, args []object.Objec
if !ok {
return object.Error{Value: "<not a function: " + fn.Type().String() + ":" + fn.Inspect() + ">"}
}
if v, output, ok := s.cache.Get(function.CacheKey, args); ok {
log.Debugf("Cache hit for %s %v", function.CacheKey, args)
_, _ = s.Out.Write(output)
return v
}
nenv, oerr := extendFunctionEnv(name, function, args)
if oerr != nil {
return *oerr
}
curState := s.env
s.env = nenv
oldOut := s.Out
buf := bytes.Buffer{}
s.Out = &buf
res := s.Eval(function.Body) // Need to have the return value unwrapped. Fixes bug #46
// restore the previous env/state.
s.env = curState
s.Out = oldOut
output := buf.Bytes()
_, _ = s.Out.Write(output)
s.cache.Set(function.CacheKey, args, res, output)
log.Debugf("Cache miss for %s %v", function.CacheKey, args)
return res
}

Expand Down
2 changes: 1 addition & 1 deletion eval/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func TestEvalIntegerExpression(t *testing.T) {
input string
expected int64
}{
{`f=func(x) {len(x)}; f([1,2,3])`, 3},
{"(3)\n(4)", 4}, // expression on new line should be... new.
{"5 // is 5", 5},
{"10", 10},
Expand Down Expand Up @@ -56,7 +57,6 @@ func(n) {
}(5)
`, 120},
}

for i, tt := range tests {
evaluated := testEval(t, tt.input)
r := testIntegerObject(t, evaluated, tt.expected)
Expand Down
54 changes: 54 additions & 0 deletions eval/memo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package eval

import (
"grol.io/grol/object"
)

const MaxArgs = 4

type CacheKey struct {
Fn string
Args [MaxArgs]object.Object
}

type CacheValue struct {
Result object.Object
Output []byte
}

type Cache map[CacheKey]CacheValue

func NewCache() Cache {
return make(Cache)
}

func (c Cache) Get(fn string, args []object.Object) (object.Object, []byte, bool) {
if len(args) > MaxArgs {
return nil, nil, false
}
key := CacheKey{Fn: fn}
for i, v := range args {
// Can't hash functions, arrays, maps arguments (yet).
if !object.Hashable(v) {
return nil, nil, false
}
key.Args[i] = v
}
result, ok := c[key]
return result.Result, result.Output, ok
}

func (c Cache) Set(fn string, args []object.Object, result object.Object, output []byte) {
if len(args) > MaxArgs {
return
}
key := CacheKey{Fn: fn}
for i, v := range args {
// Can't hash functions arguments (yet).
if !object.Hashable(v) {
return
}
key.Args[i] = v
}
c[key] = CacheValue{Result: result, Output: output}
}
12 changes: 12 additions & 0 deletions examples/fib.gr
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
fib = func(x) {
if x <= 0 {
return 0
}
if x == 1 {
return 1
}
fib(x - 1) + fib(x - 2)
}
r = fib(35)
log("fib(35) =", r)
r
24 changes: 22 additions & 2 deletions main_test.txtar
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,19 @@ stdout '<err: <identifier not found: foo>>'
# sample_test.gr
grol sample_test.gr
!stderr 'Errors'
cmp stdout sample_test_stdout.gr
cmp stdout sample_test_stdout
stderr 'I] Running sample_test.gr'
stderr 'called fact 5'
stderr 'called fact 1'
stderr 'I] All done'

# fib_50.gr
grol fib_50.gr
!stderr 'Errors'
cmp stdout fib50_stdout
stderr 'I] Running fib_50.gr'
stderr 'I] All done'

# Bug repro, return aborts the whole program
grol -c 'f=func(){return 1;2};log(f());f();3'
stdout '^1\n3$'
Expand Down Expand Up @@ -86,8 +93,21 @@ first(m["key"]) // get the value from key from map, which is an array, and the f

// ^^^ gorepl sample.gr should output 120

-- sample_test_stdout.gr --
-- fib_50.gr --
fib = func(x) {
if (x == 0) {
return 0
}
if (x == 1) {
return 1
}
fib(x - 1) + fib(x - 2)
}
fib(50)
-- sample_test_stdout --
macro test: greater
m is: {73:29,"key":[120,"abc",73]} .
Outputting a smiley: 😀
120
-- fib50_stdout --
12586269025
25 changes: 22 additions & 3 deletions object/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ type Number interface {
}
*/

func Hashable(o Object) bool {
switch o.Type() { //nolint:exhaustive // We have all the types that are hashable + default for the others.
case INTEGER, FLOAT, BOOLEAN, NIL, ERROR, RETURN, QUOTE, STRING:
return true
default:
return false
}
}

func NativeBoolToBooleanObject(input bool) Boolean {
if input {
return TRUE
Expand Down Expand Up @@ -170,6 +179,7 @@ func (rv ReturnValue) Inspect() string { return rv.Value.Inspect() }

type Function struct {
Parameters []ast.Node
CacheKey string
Body *ast.Statements
Env *Environment
}
Expand All @@ -186,17 +196,26 @@ func WriteStrings(out *strings.Builder, list []Object, before, sep, after string
}

func (f Function) Type() Type { return FUNC }
func (f Function) Inspect() string {
out := strings.Builder{}

// Must be called after the function is fully initialized.
func (f *Function) SetCacheKey() string {
out := strings.Builder{}
out.WriteString("func")
out.WriteString("(")
ps := &ast.PrintState{Out: &out, Compact: true}
ps.ComaList(f.Parameters)
out.WriteString("){")
f.Body.PrettyPrint(ps)
out.WriteString("}")
return out.String()
f.CacheKey = out.String()
return f.CacheKey
}

func (f Function) Inspect() string {
if f.CacheKey == "" {
panic("CacheKey not set")
}
return f.CacheKey
}

type Array struct {
Expand Down
1 change: 1 addition & 0 deletions repl/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func EvalString(what string) (res string, errs []string, formatted string) {
macroState := eval.NewState()
out := &strings.Builder{}
s.Out = out
s.LogOut = out
s.NoLog = true
_, errs, formatted = EvalOne(s, macroState, what, out,
Options{All: true, ShowEval: true, NoColor: true, Compact: CompactEvalString})
Expand Down
37 changes: 37 additions & 0 deletions repl/repl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,43 @@ Factorial of 5 is 120` + " \n120\n" // there is an extra space before \n that vs
}
}

func TestEvalMemoPrint(t *testing.T) {
s := `
fact=func(n) {
log("logger fact", n) // should be actual executions of the function only
print("print fact", n, ".\n") // should get recorded
if (n<=1) {
return 1
}
n*self(n-1)
}
fact(3)
print("---\n")
result = fact(5)
print("Factorial of 5 is", result, ".\n") // print to stdout
result`
expected := `logger fact 3
logger fact 2
logger fact 1
print fact 3 .
print fact 2 .
print fact 1 .
---
logger fact 5
logger fact 4
print fact 5 .
print fact 4 .
print fact 3 .
print fact 2 .
print fact 1 .
Factorial of 5 is 120 .
120
`
if got, errs, _ := repl.EvalString(s); got != expected || len(errs) > 0 {
t.Errorf("EvalString() got %v\n---\n%s\n---want---\n%s\n---", errs, got, expected)
}
}

func TestEvalString50(t *testing.T) {
s := `
fact=func(n) { // function
Expand Down
18 changes: 18 additions & 0 deletions wasm/grol_wasm.html
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@
</div>
<div>
Hit enter or click <button onClick="run();" id="runButton" disabled>Run</button> (will also format the code, also try <input type="checkbox" id="compact">compact)
<button id="addParamButton">Share</button>
<script>
document.getElementById('addParamButton').addEventListener('click', () => {
const paramValue = document.getElementById('input').value
const url = new URL(window.location)
url.searchParams.set('c', paramValue)
window.history.pushState({}, '', url)
});
</script>

</div>
<div>
<label for="output">Result:</label>
Expand All @@ -129,3 +139,11 @@
<textarea id="errors" rows="1" cols="80" class="error-textarea"></textarea>
</div>
<div id="version">GROL</div>
<script>
const urlParams = new URLSearchParams(window.location.search)
const paramValue = urlParams.get('c')
console.log('paramValue', paramValue)
if (paramValue) {
document.getElementById('input').value = decodeURIComponent(paramValue)
}
</script>