Skip to content

Commit

Permalink
go/ssa/interp: implement min/max builtins
Browse files Browse the repository at this point in the history
Updates golang/go#59488.

Change-Id: I68c90ddf0f9dea2c6506b9ab43beb522cbdf5fdd
Reviewed-on: https://go-review.googlesource.com/c/tools/+/497516
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
gopls-CI: kokoro <noreply+kokoro@google.com>
Reviewed-by: Tim King <taking@google.com>
  • Loading branch information
mdempsky committed May 24, 2023
1 parent 9c97539 commit a12e1a6
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 0 deletions.
5 changes: 5 additions & 0 deletions go/ssa/interp/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func init() {
"bytes.IndexByte": ext۰bytes۰IndexByte,
"fmt.Sprint": ext۰fmt۰Sprint,
"math.Abs": ext۰math۰Abs,
"math.Copysign": ext۰math۰Copysign,
"math.Exp": ext۰math۰Exp,
"math.Float32bits": ext۰math۰Float32bits,
"math.Float32frombits": ext۰math۰Float32frombits,
Expand Down Expand Up @@ -158,6 +159,10 @@ func ext۰math۰Abs(fr *frame, args []value) value {
return math.Abs(args[0].(float64))
}

func ext۰math۰Copysign(fr *frame, args []value) value {
return math.Copysign(args[0].(float64), args[1].(float64))
}

func ext۰math۰Exp(fr *frame, args []value) value {
return math.Exp(args[0].(float64))
}
Expand Down
12 changes: 12 additions & 0 deletions go/ssa/interp/interp_go121_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build go1.21
// +build go1.21

package interp_test

func init() {
testdataTests = append(testdataTests, "minmax.go")
}
91 changes: 91 additions & 0 deletions go/ssa/interp/ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,11 @@ func callBuiltin(caller *frame, callpos token.Pos, fn *ssa.Builtin, args []value
panic(fmt.Sprintf("cap: illegal operand: %T", x))
}

case "min":
return foldLeft(min, args)
case "max":
return foldLeft(max, args)

case "real":
switch c := args[0].(type) {
case complex64:
Expand Down Expand Up @@ -1426,3 +1431,89 @@ func checkInterface(i *interpreter, itype *types.Interface, x iface) string {
}
return "" // ok
}

func foldLeft(op func(value, value) value, args []value) value {
x := args[0]
for _, arg := range args[1:] {
x = op(x, arg)
}
return x
}

func min(x, y value) value {
switch x := x.(type) {
case float32:
return fmin(x, y.(float32))
case float64:
return fmin(x, y.(float64))
}

// return (y < x) ? y : x
if binop(token.LSS, nil, y, x).(bool) {
return y
}
return x
}

func max(x, y value) value {
switch x := x.(type) {
case float32:
return fmax(x, y.(float32))
case float64:
return fmax(x, y.(float64))
}

// return (y > x) ? y : x
if binop(token.GTR, nil, y, x).(bool) {
return y
}
return x
}

// copied from $GOROOT/src/runtime/minmax.go

type floaty interface{ ~float32 | ~float64 }

func fmin[F floaty](x, y F) F {
if y != y || y < x {
return y
}
if x != x || x < y || x != 0 {
return x
}
// x and y are both ±0
// if either is -0, return -0; else return +0
return forbits(x, y)
}

func fmax[F floaty](x, y F) F {
if y != y || y > x {
return y
}
if x != x || x > y || x != 0 {
return x
}
// x and y are both ±0
// if both are -0, return -0; else return +0
return fandbits(x, y)
}

func forbits[F floaty](x, y F) F {
switch unsafe.Sizeof(x) {
case 4:
*(*uint32)(unsafe.Pointer(&x)) |= *(*uint32)(unsafe.Pointer(&y))
case 8:
*(*uint64)(unsafe.Pointer(&x)) |= *(*uint64)(unsafe.Pointer(&y))
}
return x
}

func fandbits[F floaty](x, y F) F {
switch unsafe.Sizeof(x) {
case 4:
*(*uint32)(unsafe.Pointer(&x)) &= *(*uint32)(unsafe.Pointer(&y))
case 8:
*(*uint64)(unsafe.Pointer(&x)) &= *(*uint64)(unsafe.Pointer(&y))
}
return x
}
118 changes: 118 additions & 0 deletions go/ssa/interp/testdata/minmax.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main

import (
"fmt"
"math"
)

func main() {
TestMinFloat()
TestMaxFloat()
TestMinMaxInt()
TestMinMaxUint8()
TestMinMaxString()
}

func errorf(format string, args ...any) { panic(fmt.Sprintf(format, args...)) }
func fatalf(format string, args ...any) { panic(fmt.Sprintf(format, args...)) }

// derived from $GOROOT/src/runtime/minmax_test.go

var (
zero = math.Copysign(0, +1)
negZero = math.Copysign(0, -1)
inf = math.Inf(+1)
negInf = math.Inf(-1)
nan = math.NaN()
)

var tests = []struct{ min, max float64 }{
{1, 2},
{-2, 1},
{negZero, zero},
{zero, inf},
{negInf, zero},
{negInf, inf},
{1, inf},
{negInf, 1},
}

var all = []float64{1, 2, -1, -2, zero, negZero, inf, negInf, nan}

func eq(x, y float64) bool {
return x == y && math.Signbit(x) == math.Signbit(y)
}

func TestMinFloat() {
for _, tt := range tests {
if z := min(tt.min, tt.max); !eq(z, tt.min) {
errorf("min(%v, %v) = %v, want %v", tt.min, tt.max, z, tt.min)
}
if z := min(tt.max, tt.min); !eq(z, tt.min) {
errorf("min(%v, %v) = %v, want %v", tt.max, tt.min, z, tt.min)
}
}
for _, x := range all {
if z := min(nan, x); !math.IsNaN(z) {
errorf("min(%v, %v) = %v, want %v", nan, x, z, nan)
}
if z := min(x, nan); !math.IsNaN(z) {
errorf("min(%v, %v) = %v, want %v", nan, x, z, nan)
}
}
}

func TestMaxFloat() {
for _, tt := range tests {
if z := max(tt.min, tt.max); !eq(z, tt.max) {
errorf("max(%v, %v) = %v, want %v", tt.min, tt.max, z, tt.max)
}
if z := max(tt.max, tt.min); !eq(z, tt.max) {
errorf("max(%v, %v) = %v, want %v", tt.max, tt.min, z, tt.max)
}
}
for _, x := range all {
if z := max(nan, x); !math.IsNaN(z) {
errorf("min(%v, %v) = %v, want %v", nan, x, z, nan)
}
if z := max(x, nan); !math.IsNaN(z) {
errorf("min(%v, %v) = %v, want %v", nan, x, z, nan)
}
}
}

// testMinMax tests that min/max behave correctly on every pair of
// values in vals.
//
// vals should be a sequence of values in strictly ascending order.
func testMinMax[T int | uint8 | string](vals ...T) {
for i, x := range vals {
for _, y := range vals[i+1:] {
if !(x < y) {
fatalf("values out of order: !(%v < %v)", x, y)
}

if z := min(x, y); z != x {
errorf("min(%v, %v) = %v, want %v", x, y, z, x)
}
if z := min(y, x); z != x {
errorf("min(%v, %v) = %v, want %v", y, x, z, x)
}

if z := max(x, y); z != y {
errorf("max(%v, %v) = %v, want %v", x, y, z, y)
}
if z := max(y, x); z != y {
errorf("max(%v, %v) = %v, want %v", y, x, z, y)
}
}
}
}

func TestMinMaxInt() { testMinMax[int](-7, 0, 9) }
func TestMinMaxUint8() { testMinMax[uint8](0, 1, 2, 4, 7) }
func TestMinMaxString() { testMinMax[string]("a", "b", "c") }
2 changes: 2 additions & 0 deletions go/ssa/interp/testdata/src/math/math.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package math

func Copysign(float64, float64) float64

func NaN() float64

func Inf(int) float64
Expand Down

0 comments on commit a12e1a6

Please sign in to comment.