Skip to content

Commit

Permalink
Merge pull request #35 from pingcap/siddontang/cleanup-compare
Browse files Browse the repository at this point in the history
cleanup compare
  • Loading branch information
ngaut committed Sep 7, 2015
2 parents 2396b31 + 213793b commit 6b0dce0
Show file tree
Hide file tree
Showing 13 changed files with 387 additions and 404 deletions.
101 changes: 1 addition & 100 deletions expression/expressions/binop.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,105 +346,6 @@ func (o *BinaryOperation) evalLogicOp(ctx context.Context, args map[interface{}]
}
}

func compareFloatString(a float64, s string) (int, error) {
// MySQL will convert string to a float point value
// MySQL use a very loose conversation, e.g, 123.abc -> 123
// we should do a trade off whether supporting this feature or using a strict mode
// now we use a strict mode
b, err := types.StrToFloat(s)
if err != nil {
return 0, err
}
return types.CompareFloat64(a, b), nil
}

func compareStringFloat(s string, a float64) (int, error) {
n, err := compareFloatString(a, s)
return -n, err
}

// See https://dev.mysql.com/doc/refman/5.7/en/type-conversion.html
func evalCompare(a interface{}, b interface{}) (int, error) {
// TODO: support compare time type with other types
switch x := a.(type) {
case float64:
switch y := b.(type) {
case float64:
return types.CompareFloat64(x, y), nil
case string:
return compareFloatString(x, y)
}
case int64:
switch y := b.(type) {
case int64:
return types.CompareInt64(x, y), nil
case uint64:
return types.CompareInteger(x, y), nil
case string:
return compareFloatString(float64(x), y)
}
case uint64:
switch y := b.(type) {
case uint64:
return types.CompareUint64(x, y), nil
case int64:
return -types.CompareInteger(y, x), nil
case string:
return compareFloatString(float64(x), y)
}
case mysql.Decimal:
switch y := b.(type) {
case mysql.Decimal:
return x.Cmp(y), nil
case string:
f, err := mysql.ConvertToDecimal(y)
if err != nil {
return 0, err
}
return x.Cmp(f), nil
}
case string:
switch y := b.(type) {
case string:
return types.CompareString(x, y), nil
case int64:
return compareStringFloat(x, float64(y))
case uint64:
return compareStringFloat(x, float64(y))
case float64:
return compareStringFloat(x, y)
case mysql.Decimal:
f, err := mysql.ConvertToDecimal(x)
if err != nil {
return 0, err
}
return f.Cmp(y), nil
case mysql.Time:
n, err := y.CompareString(x)
return -n, err
case mysql.Duration:
n, err := y.CompareString(x)
return -n, err
}
case mysql.Time:
switch y := b.(type) {
case mysql.Time:
return x.Compare(y), nil
case string:
return x.CompareString(y)
}
case mysql.Duration:
switch y := b.(type) {
case mysql.Duration:
return x.Compare(y), nil
case string:
return x.CompareString(y)
}
}

return 0, errors.Errorf("invalid compare type %T cmp %T", a, b)
}

// operator: >=, >, <=, <, !=, <>, = <=>, etc.
// see https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html
func (o *BinaryOperation) evalComparisonOp(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
Expand All @@ -459,7 +360,7 @@ func (o *BinaryOperation) evalComparisonOp(ctx context.Context, args map[interfa
return nil, nil
}

n, err := evalCompare(a, b)
n, err := types.Compare(a, b)
if err != nil {
return nil, o.traceErr(err)
}
Expand Down
42 changes: 0 additions & 42 deletions expression/expressions/binop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ package expressions

import (
"errors"
"time"

. "github.com/pingcap/check"
"github.com/pingcap/tidb/expression"
Expand Down Expand Up @@ -106,44 +105,6 @@ func (s *testBinOpSuite) TestComparisonOp(c *C) {
c.Assert(v, IsNil)
}

// test evalCompare function
cmpTbl := []struct {
lhs interface{}
rhs interface{}
ret int // 0, 1, -1
}{
{float64(1), float64(1), 0},
{float64(1), "1", 0},
{int64(1), int64(1), 0},
{int64(-1), uint64(1), -1},
{int64(-1), "-1", 0},
{uint64(1), uint64(1), 0},
{uint64(1), int64(-1), 1},
{uint64(1), "1", 0},
{mysql.NewDecimalFromInt(1, 0), mysql.NewDecimalFromInt(1, 0), 0},
{mysql.NewDecimalFromInt(1, 0), "1", 0},
{"1", "1", 0},
{"1", int64(-1), 1},
{"1", float64(2), -1},
{"1", uint64(1), 0},
{"1", mysql.NewDecimalFromInt(1, 0), 0},
{"2011-01-01 11:11:11", mysql.Time{Time: time.Now(), Type: mysql.TypeDatetime, Fsp: 0}, -1},
{"12:00:00", mysql.ZeroDuration, 1},
{mysql.ZeroDuration, mysql.ZeroDuration, 0},
{mysql.Time{Time: time.Now().Add(time.Second * 10), Type: mysql.TypeDatetime, Fsp: 0},
mysql.Time{Time: time.Now(), Type: mysql.TypeDatetime, Fsp: 0}, 1},
}

for _, t := range cmpTbl {
ret, err := evalCompare(t.lhs, t.rhs)
c.Assert(err, IsNil)
c.Assert(ret, Equals, t.ret)

ret, err = evalCompare(t.rhs, t.lhs)
c.Assert(err, IsNil)
c.Assert(ret, Equals, -t.ret)
}

// test error
mock := mockExpr{
isStatic: false,
Expand Down Expand Up @@ -194,9 +155,6 @@ func (s *testBinOpSuite) TestComparisonOp(c *C) {
expr := BinaryOperation{Op: opcode.Plus, L: Value{1}, R: Value{1}}
_, err := expr.evalComparisonOp(nil, nil)
c.Assert(err, NotNil)

_, err = evalCompare(1, 1)
c.Assert(err, NotNil)
}

func (s *testBinOpSuite) TestIdentRelOp(c *C) {
Expand Down
5 changes: 1 addition & 4 deletions expression/expressions/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ func builtinNullIf(args []interface{}, m map[interface{}]interface{}) (interface
return v1, nil
}

// coerce for later eval compare
x, y := types.Coerce(v1, v2)

if n, err := evalCompare(x, y); err != nil || n == 0 {
if n, err := types.Compare(v1, v2); err != nil || n == 0 {
return nil, err
}

Expand Down
12 changes: 10 additions & 2 deletions expression/expressions/builtin_groupby.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,11 @@ func builtinMax(args []interface{}, ctx map[interface{}]interface{}) (v interfac
if max == nil {
max = y
} else {
if types.Compare(max, y) < 0 {
n, err := types.Compare(max, y)
if err != nil {
return nil, errors.Trace(err)
}
if n < 0 {
max = y
}
}
Expand Down Expand Up @@ -288,7 +292,11 @@ func builtinMin(args []interface{}, ctx map[interface{}]interface{}) (v interfac
if min == nil {
min = y
} else {
if types.Compare(min, y) > 0 {
n, err := types.Compare(min, y)
if err != nil {
return nil, errors.Trace(err)
}
if n > 0 {
min = y
}
}
Expand Down
3 changes: 2 additions & 1 deletion expression/expressions/unary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/pingcap/tidb/expression"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)

var _ = Suite(&testUnaryOperationSuite{})
Expand Down Expand Up @@ -111,7 +112,7 @@ func (s *testUnaryOperationSuite) TestUnaryOp(c *C) {
result, err := exprc.Eval(nil, nil)
c.Assert(err, IsNil)

ret, err := evalCompare(result, t.result)
ret, err := types.Compare(result, t.result)
c.Assert(err, IsNil)
c.Assert(ret, Equals, 0)
}
Expand Down
10 changes: 9 additions & 1 deletion plan/plans/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package plans

import (
"fmt"

"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
Expand Down Expand Up @@ -151,7 +153,13 @@ func indexCompare(a interface{}, b interface{}) int {
return -1
}

return types.Compare(a, b)
n, err := types.Compare(a, b)
if err != nil {
// Old compare panics if err, so here we do the same thing now.
// TODO: return err instead of panic.
panic(fmt.Sprintf("should never happend %v", err))
}
return n
}

func (r *indexPlan) doSpan(ctx context.Context, txn kv.Transaction, span *indexSpan, f plan.RowIterFunc) error {
Expand Down
9 changes: 8 additions & 1 deletion plan/plans/orderby.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strings"

"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/expressions"
Expand Down Expand Up @@ -102,7 +103,13 @@ func (t *orderByTable) Less(i, j int) bool {
v1 := t.Rows[i].Key[index]
v2 := t.Rows[j].Key[index]

ret := types.Compare(v1, v2)
ret, err := types.Compare(v1, v2)
if err != nil {
// we just have to log this error and skip it.
// TODO: record this error and handle it out later.
log.Errorf("compare %v %v err %v", v1, v2, err)
}

if !asc {
ret = -ret
}
Expand Down
7 changes: 6 additions & 1 deletion stmt/stmts/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ func updateRecord(ctx context.Context, h int64, data []interface{}, t table.Tabl
continue
}
od := oldData[i]
if types.Compare(d, od) != 0 {
n, err := types.Compare(d, od)
if err != nil {
return errors.Trace(err)
}

if n != 0 {
rowChanged = true
break
}
Expand Down
Loading

0 comments on commit 6b0dce0

Please sign in to comment.