Skip to content

Commit

Permalink
cherry pick pingcap#11343 to release-2.1
Browse files Browse the repository at this point in the history
Signed-off-by: sre-bot <sre-bot@pingcap.com>
  • Loading branch information
AndrewDi authored and sre-bot committed Apr 8, 2020
1 parent 4f03354 commit 2fcd794
Show file tree
Hide file tree
Showing 4 changed files with 425 additions and 0 deletions.
36 changes: 36 additions & 0 deletions cmd/explaintest/r/tpch.result
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ s_name,
p_partkey
limit 100;
id count task operator info
<<<<<<< HEAD
Projection_34 100.00 root tpch.supplier.s_acctbal, tpch.supplier.s_name, tpch.nation.n_name, tpch.part.p_partkey, tpch.part.p_mfgr, tpch.supplier.s_address, tpch.supplier.s_phone, tpch.supplier.s_comment
└─TopN_37 100.00 root tpch.supplier.s_acctbal:desc, tpch.nation.n_name:asc, tpch.supplier.s_name:asc, tpch.part.p_partkey:asc, offset:0, count:100
└─HashRightJoin_39 125109.42 root inner join, inner:HashLeftJoin_44, equal:[eq(tpch.part.p_partkey, tpch.partsupp.ps_partkey) eq(tpch.partsupp.ps_supplycost, min(ps_supplycost))]
Expand Down Expand Up @@ -215,6 +216,41 @@ Projection_34 100.00 root tpch.supplier.s_acctbal, tpch.supplier.s_name, tpch.na
│ └─TableScan_107 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false
└─TableReader_110 40000000.00 root data:TableScan_109
└─TableScan_109 40000000.00 cop table:partsupp, range:[-inf,+inf], keep order:false
=======
Projection_37 100.00 root tpch.supplier.s_acctbal, tpch.supplier.s_name, tpch.nation.n_name, tpch.part.p_partkey, tpch.part.p_mfgr, tpch.supplier.s_address, tpch.supplier.s_phone, tpch.supplier.s_comment
└─TopN_40 100.00 root tpch.supplier.s_acctbal:desc, tpch.nation.n_name:asc, tpch.supplier.s_name:asc, tpch.part.p_partkey:asc, offset:0, count:100
└─HashRightJoin_45 155496.00 root inner join, inner:HashLeftJoin_51, equal:[eq(tpch.part.p_partkey, tpch.partsupp.ps_partkey) eq(tpch.partsupp.ps_supplycost, min(ps_supplycost))]
├─HashLeftJoin_51 155496.00 root inner join, inner:TableReader_74, equal:[eq(tpch.partsupp.ps_partkey, tpch.part.p_partkey)]
│ ├─HashRightJoin_54 8155010.44 root inner join, inner:HashRightJoin_56, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)]
│ │ ├─HashRightJoin_56 100000.00 root inner join, inner:HashRightJoin_62, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)]
│ │ │ ├─HashRightJoin_62 5.00 root inner join, inner:TableReader_67, equal:[eq(tpch.region.r_regionkey, tpch.nation.n_regionkey)]
│ │ │ │ ├─TableReader_67 1.00 root data:Selection_66
│ │ │ │ │ └─Selection_66 1.00 cop eq(tpch.region.r_name, "ASIA")
│ │ │ │ │ └─TableScan_65 5.00 cop table:region, range:[-inf,+inf], keep order:false
│ │ │ │ └─TableReader_64 25.00 root data:TableScan_63
│ │ │ │ └─TableScan_63 25.00 cop table:nation, range:[-inf,+inf], keep order:false
│ │ │ └─TableReader_69 500000.00 root data:TableScan_68
│ │ │ └─TableScan_68 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false
│ │ └─TableReader_71 40000000.00 root data:TableScan_70
│ │ └─TableScan_70 40000000.00 cop table:partsupp, range:[-inf,+inf], keep order:false
│ └─TableReader_74 155496.00 root data:Selection_73
│ └─Selection_73 155496.00 cop eq(tpch.part.p_size, 30), like(tpch.part.p_type, "%STEEL", 92)
│ └─TableScan_72 10000000.00 cop table:part, range:[-inf,+inf], keep order:false
└─Selection_75 6524008.35 root not(isnull(min(ps_supplycost)))
└─HashAgg_78 8155010.44 root group by:tpch.partsupp.ps_partkey, funcs:min(tpch.partsupp.ps_supplycost), firstrow(tpch.partsupp.ps_partkey)
└─HashRightJoin_82 8155010.44 root inner join, inner:HashRightJoin_84, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)]
├─HashRightJoin_84 100000.00 root inner join, inner:HashRightJoin_90, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)]
│ ├─HashRightJoin_90 5.00 root inner join, inner:TableReader_95, equal:[eq(tpch.region.r_regionkey, tpch.nation.n_regionkey)]
│ │ ├─TableReader_95 1.00 root data:Selection_94
│ │ │ └─Selection_94 1.00 cop eq(tpch.region.r_name, "ASIA")
│ │ │ └─TableScan_93 5.00 cop table:region, range:[-inf,+inf], keep order:false
│ │ └─TableReader_92 25.00 root data:TableScan_91
│ │ └─TableScan_91 25.00 cop table:nation, range:[-inf,+inf], keep order:false
│ └─TableReader_97 500000.00 root data:TableScan_96
│ └─TableScan_96 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false
└─TableReader_99 40000000.00 root data:TableScan_98
└─TableScan_98 40000000.00 cop table:partsupp, range:[-inf,+inf], keep order:false
>>>>>>> 829ba98... expression: remove the NotNullFlag for aggregation func MAX/MIN when inferring type (#11343)
/*
Q3 Shipping Priority Query
This query retrieves the 10 unshipped orders with the highest value.
Expand Down
332 changes: 332 additions & 0 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggregation

import (
"bytes"
"math"
"strings"

"github.com/cznic/mathutil"
"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
)

// baseFuncDesc describes an function signature, only used in planner.
type baseFuncDesc struct {
// Name represents the function name.
Name string
// Args represents the arguments of the function.
Args []expression.Expression
// RetTp represents the return type of the function.
RetTp *types.FieldType
}

func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) {
b := baseFuncDesc{Name: strings.ToLower(name), Args: args}
err := b.typeInfer(ctx)
return b, err
}

func (a *baseFuncDesc) equal(ctx sessionctx.Context, other *baseFuncDesc) bool {
if a.Name != other.Name || len(a.Args) != len(other.Args) {
return false
}
for i := range a.Args {
if !a.Args[i].Equal(ctx, other.Args[i]) {
return false
}
}
return true
}

func (a *baseFuncDesc) clone() *baseFuncDesc {
clone := *a
newTp := *a.RetTp
clone.RetTp = &newTp
clone.Args = make([]expression.Expression, len(a.Args))
for i := range a.Args {
clone.Args[i] = a.Args[i].Clone()
}
return &clone
}

// String implements the fmt.Stringer interface.
func (a *baseFuncDesc) String() string {
buffer := bytes.NewBufferString(a.Name)
buffer.WriteString("(")
for i, arg := range a.Args {
buffer.WriteString(arg.String())
if i+1 != len(a.Args) {
buffer.WriteString(", ")
}
}
buffer.WriteString(")")
return buffer.String()
}

// typeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error {
switch a.Name {
case ast.AggFuncCount:
a.typeInfer4Count(ctx)
case ast.AggFuncSum:
a.typeInfer4Sum(ctx)
case ast.AggFuncAvg:
a.typeInfer4Avg(ctx)
case ast.AggFuncGroupConcat:
a.typeInfer4GroupConcat(ctx)
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow,
ast.WindowFuncFirstValue, ast.WindowFuncLastValue, ast.WindowFuncNthValue:
a.typeInfer4MaxMin(ctx)
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
a.typeInfer4BitFuncs(ctx)
case ast.WindowFuncRowNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank:
a.typeInfer4NumberFuncs()
case ast.WindowFuncCumeDist:
a.typeInfer4CumeDist()
case ast.WindowFuncNtile:
a.typeInfer4Ntile()
case ast.WindowFuncPercentRank:
a.typeInfer4PercentRank()
case ast.WindowFuncLead, ast.WindowFuncLag:
a.typeInfer4LeadLag(ctx)
default:
return errors.Errorf("unsupported agg function: %s", a.Name)
}
return nil
}

func (a *baseFuncDesc) typeInfer4Count(ctx sessionctx.Context) {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.Flen = 21
types.SetBinChsClnFlag(a.RetTp)
}

// typeInfer4Sum should returns a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) {
switch a.Args[0].GetType().Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, 0
case mysql.TypeNewDecimal:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, a.Args[0].GetType().Decimal
if a.RetTp.Decimal < 0 || a.RetTp.Decimal > mysql.MaxDecimalScale {
a.RetTp.Decimal = mysql.MaxDecimalScale
}
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal
default:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
}
types.SetBinChsClnFlag(a.RetTp)
}

// typeInfer4Avg should returns a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Avg(ctx sessionctx.Context) {
switch a.Args[0].GetType().Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeNewDecimal:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
if a.Args[0].GetType().Decimal < 0 {
a.RetTp.Decimal = mysql.MaxDecimalScale
} else {
a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale)
}
a.RetTp.Flen = mysql.MaxDecimalWidth
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal
default:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
}
types.SetBinChsClnFlag(a.RetTp)
}

func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) {
a.RetTp = types.NewFieldType(mysql.TypeVarString)
a.RetTp.Charset, a.RetTp.Collate = charset.GetDefaultCharsetAndCollate()

a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxBlobWidth, 0
// TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i])
}

func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) {
_, argIsScalaFunc := a.Args[0].(*expression.ScalarFunction)
if argIsScalaFunc && a.Args[0].GetType().Tp == mysql.TypeFloat {
// For scalar function, the result of "float32" is set to the "float64"
// field in the "Datum". If we do not wrap a cast-as-double function on a.Args[0],
// error would happen when extracting the evaluation of a.Args[0] to a ProjectionExec.
tp := types.NewFieldType(mysql.TypeDouble)
tp.Flen, tp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
types.SetBinChsClnFlag(tp)
a.Args[0] = expression.BuildCastFunction(ctx, a.Args[0], tp)
}
a.RetTp = a.Args[0].GetType()
if (a.Name == ast.AggFuncMax || a.Name == ast.AggFuncMin) && a.RetTp.Tp != mysql.TypeBit {
a.RetTp = a.Args[0].GetType().Clone()
a.RetTp.Flag &^= mysql.NotNullFlag
}
if a.RetTp.Tp == mysql.TypeEnum || a.RetTp.Tp == mysql.TypeSet {
a.RetTp = &types.FieldType{Tp: mysql.TypeString, Flen: mysql.MaxFieldCharLength}
}
}

func (a *baseFuncDesc) typeInfer4BitFuncs(ctx sessionctx.Context) {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.Flen = 21
types.SetBinChsClnFlag(a.RetTp)
a.RetTp.Flag |= mysql.UnsignedFlag | mysql.NotNullFlag
// TODO: a.Args[0] = expression.WrapWithCastAsInt(ctx, a.Args[0])
}

func (a *baseFuncDesc) typeInfer4NumberFuncs() {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.Flen = 21
types.SetBinChsClnFlag(a.RetTp)
}

func (a *baseFuncDesc) typeInfer4CumeDist() {
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec
}

func (a *baseFuncDesc) typeInfer4Ntile() {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.Flen = 21
types.SetBinChsClnFlag(a.RetTp)
a.RetTp.Flag |= mysql.UnsignedFlag
}

func (a *baseFuncDesc) typeInfer4PercentRank() {
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flag, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec
}

func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) {
if len(a.Args) <= 2 {
a.typeInfer4MaxMin(ctx)
} else {
// Merge the type of first and third argument.
a.RetTp = expression.InferType4ControlFuncs(a.Args[0].GetType(), a.Args[2].GetType())
}
}

// GetDefaultValue gets the default value when the function's input is null.
// According to MySQL, default values of the function are listed as follows:
// e.g.
// Table t which is empty:
// +-------+---------+---------+
// | Table | Field | Type |
// +-------+---------+---------+
// | t | a | int(11) |
// +-------+---------+---------+
//
// Query: `select a, avg(a), sum(a), count(a), bit_xor(a), bit_or(a), bit_and(a), max(a), min(a), group_concat(a) from t;`
// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+
// | a | avg(a) | sum(a) | count(a) | bit_xor(a) | bit_or(a) | bit_and(a) | max(a) | min(a) | group_concat(a) |
// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+
// | NULL | NULL | NULL | 0 | 0 | 0 | 18446744073709551615 | NULL | NULL | NULL |
// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+
func (a *baseFuncDesc) GetDefaultValue() (v types.Datum) {
switch a.Name {
case ast.AggFuncCount, ast.AggFuncBitOr, ast.AggFuncBitXor:
v = types.NewIntDatum(0)
case ast.AggFuncFirstRow, ast.AggFuncAvg, ast.AggFuncSum, ast.AggFuncMax,
ast.AggFuncMin, ast.AggFuncGroupConcat:
v = types.Datum{}
case ast.AggFuncBitAnd:
v = types.NewUintDatum(uint64(math.MaxUint64))
}
return
}

// We do not need to wrap cast upon these functions,
// since the EvalXXX method called by the arg is determined by the corresponding arg type.
var noNeedCastAggFuncs = map[string]struct{}{
ast.AggFuncCount: {},
ast.AggFuncMax: {},
ast.AggFuncMin: {},
ast.AggFuncFirstRow: {},
ast.WindowFuncNtile: {},
}

// WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
if len(a.Args) == 0 {
return
}
if _, ok := noNeedCastAggFuncs[a.Name]; ok {
return
}
var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression
switch retTp := a.RetTp; retTp.EvalType() {
case types.ETInt:
castFunc = expression.WrapWithCastAsInt
case types.ETReal:
castFunc = expression.WrapWithCastAsReal
case types.ETString:
castFunc = expression.WrapWithCastAsString
case types.ETDecimal:
castFunc = expression.WrapWithCastAsDecimal
case types.ETDatetime, types.ETTimestamp:
castFunc = func(ctx sessionctx.Context, expr expression.Expression) expression.Expression {
return expression.WrapWithCastAsTime(ctx, expr, retTp)
}
case types.ETDuration:
castFunc = expression.WrapWithCastAsDuration
case types.ETJson:
castFunc = expression.WrapWithCastAsJSON
default:
panic("should never happen in baseFuncDesc.WrapCastForAggArgs")
}
for i := range a.Args {
// Do not cast the second args of these functions, as they are simply non-negative numbers.
if i == 1 && (a.Name == ast.WindowFuncLead || a.Name == ast.WindowFuncLag || a.Name == ast.WindowFuncNthValue) {
continue
}
a.Args[i] = castFunc(ctx, a.Args[i])
if a.Name != ast.AggFuncAvg && a.Name != ast.AggFuncSum {
continue
}
// After wrapping cast on the argument, flen etc. may not the same
// as the type of the aggregation function. The following part set
// the type of the argument exactly as the type of the aggregation
// function.
// Note: If the `Tp` of argument is the same as the `Tp` of the
// aggregation function, it will not wrap cast function on it
// internally. The reason of the special handling for `Column` is
// that the `RetType` of `Column` refers to the `infoschema`, so we
// need to set a new variable for it to avoid modifying the
// definition in `infoschema`.
if col, ok := a.Args[i].(*expression.Column); ok {
col.RetType = types.NewFieldType(col.RetType.Tp)
}
// originTp is used when the the `Tp` of column is TypeFloat32 while
// the type of the aggregation function is TypeFloat64.
originTp := a.Args[i].GetType().Tp
*(a.Args[i].GetType()) = *(a.RetTp)
a.Args[i].GetType().Tp = originTp
}
}
Loading

0 comments on commit 2fcd794

Please sign in to comment.