Skip to content

Commit

Permalink
Support rand() as column default expr.
Browse files Browse the repository at this point in the history
  • Loading branch information
CbcWestwolf committed Mar 29, 2022
1 parent fbaaa11 commit 1720c3e
Show file tree
Hide file tree
Showing 4 changed files with 8,083 additions and 7,889 deletions.
16 changes: 16 additions & 0 deletions ddl/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,22 @@ func TestDefaultValueIsString(t *testing.T) {
require.Equal(t, "1", tbl.Meta().Columns[0].DefaultValue)
}

func TestDefaultColumnWithRand(t *testing.T) {
// Related issue: https://github.com/pingcap/tidb/issues/10377
store, clean := testkit.CreateMockStoreWithSchemaLease(t, testLease)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")

tk.MustExec("drop table if exists t, t1, t2")
tk.MustExec("create table t (a int(10) default rand())")
tk.MustExec("create table t1 (c int, c1 double default rand())")
tk.MustExec("create table t2 (c int, c1 double default rand(1))")
tk.MustExec("insert into t1(c) values (1),(2),(3)")

tk.MustExec("alter table t add column b int(10) default rand(2)")
}

func TestChangingDBCharset(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
Expand Down
95 changes: 64 additions & 31 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ const (
tiflashCheckPendingTablesRetry = 7
)

// ColumnDefaultType is the default type for column
type ColumnDefaultType int

const (
// ColumnDefaultLiteral indicates that the default value is a literal like integer, float or string, etc.
ColumnDefaultLiteral ColumnDefaultType = iota
// ColumnDefaultExpression indicates that the default value is a function call like 'current_timestamp(6)'
ColumnDefaultExpression
// ColumnDefaultSequence indicates that the default value is a next value of a sequence
ColumnDefaultSequence
)

func (d *ddl) CreateSchema(ctx sessionctx.Context, schema model.CIStr, charsetInfo *ast.CharsetOpt, placementPolicyRef *model.PolicyRefInfo) (err error) {
dbInfo := &model.DBInfo{Name: schema}
if charsetInfo != nil {
Expand Down Expand Up @@ -1038,65 +1050,86 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o
// getDefaultValue will get the default value for column.
// 1: get the expr restored string for the column which uses sequence next value as default value.
// 2: get specific default value for the other column.
func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOption) (interface{}, bool, error) {
func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOption) (interface{}, ColumnDefaultType, error) {
tp, fsp := col.FieldType.Tp, col.FieldType.Decimal
if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime {
colDefaultType := ColumnDefaultLiteral
switch x := c.Expr.(type) {
case *ast.FuncCallExpr:
if x.FnName.L == ast.CurrentTimestamp {
if x.FnName.L == ast.CurrentTimestamp || x.FnName.L == ast.Now {
colDefaultType = ColumnDefaultExpression
defaultFsp := 0
if len(x.Args) == 1 {
if val := x.Args[0].(*driver.ValueExpr); val != nil {
defaultFsp = int(val.GetInt64())
}
}
if defaultFsp != fsp {
return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
return nil, colDefaultType, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
}
}
vd, err := expression.GetTimeValue(ctx, c.Expr, tp, fsp)
value := vd.GetValue()
if err != nil {
return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
return nil, colDefaultType, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}

// Value is nil means `default null`.
if value == nil {
return nil, false, nil
return nil, colDefaultType, nil
}

// If value is types.Time, convert it to string.
if vv, ok := value.(types.Time); ok {
return vv.String(), false, nil
return vv.String(), colDefaultType, nil
}

return value, false, nil
return value, colDefaultType, nil
}

switch x := c.Expr.(type) {
case *ast.FuncCallExpr:
if x.FnName.L == ast.Rand {
if err := expression.VerifyArgsWrapper(ast.Rand, len(x.Args)); err != nil {
return nil, ColumnDefaultExpression, expression.ErrIncorrectParameterCount.GenWithStackByArgs(ast.Rand)
}
col.DefaultIsExpr = true
var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)
if err := c.Expr.Restore(restoreCtx); err != nil {
return "", ColumnDefaultExpression, err
}
return sb.String(), ColumnDefaultExpression, nil
}
}

// handle default next value of sequence. (keep the expr string)
str, isSeqExpr, err := tryToGetSequenceDefaultValue(c)
str, defaultType, err := tryToGetSequenceDefaultValue(c)
if err != nil {
return nil, false, errors.Trace(err)
return nil, ColumnDefaultLiteral, errors.Trace(err)
}
if isSeqExpr {
return str, true, nil
if defaultType == ColumnDefaultSequence {
return str, ColumnDefaultSequence, nil
}

// evaluate the non-sequence expr to a certain value.
v, err := expression.EvalAstExpr(ctx, c.Expr)
if err != nil {
return nil, false, errors.Trace(err)
return nil, ColumnDefaultLiteral, errors.Trace(err)
}

if v.IsNull() {
return nil, false, nil
return nil, ColumnDefaultLiteral, nil
}

if v.Kind() == types.KindBinaryLiteral || v.Kind() == types.KindMysqlBit {
if types.IsTypeBlob(tp) || tp == mysql.TypeJSON {
// BLOB/TEXT/JSON column cannot have a default value.
// Skip the unnecessary decode procedure.
return v.GetString(), false, err
return v.GetString(), ColumnDefaultLiteral, err
}
if tp == mysql.TypeBit || tp == mysql.TypeString || tp == mysql.TypeVarchar ||
tp == mysql.TypeVarString || tp == mysql.TypeEnum || tp == mysql.TypeSet {
Expand All @@ -1106,50 +1139,50 @@ func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOpt
// Overwrite the decoding error with invalid default value error.
err = dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O)
}
return str, false, err
return str, ColumnDefaultLiteral, err
}
// For other kind of fields (e.g. INT), we supply its integer as string value.
value, err := v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx)
if err != nil {
return nil, false, err
return nil, ColumnDefaultLiteral, err
}
return strconv.FormatUint(value, 10), false, nil
return strconv.FormatUint(value, 10), ColumnDefaultLiteral, nil
}

switch tp {
case mysql.TypeSet:
val, err := getSetDefaultValue(v, col)
return val, false, err
return val, ColumnDefaultLiteral, err
case mysql.TypeEnum:
val, err := getEnumDefaultValue(v, col)
return val, false, err
return val, ColumnDefaultLiteral, err
case mysql.TypeDuration:
if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err != nil {
return "", false, errors.Trace(err)
return "", ColumnDefaultLiteral, errors.Trace(err)
}
case mysql.TypeBit:
if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 {
// For BIT fields, convert int into BinaryLiteral.
return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), false, nil
return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), ColumnDefaultLiteral, nil
}
}

val, err := v.ToString()
return val, false, err
return val, ColumnDefaultLiteral, err
}

func tryToGetSequenceDefaultValue(c *ast.ColumnOption) (expr string, isExpr bool, err error) {
func tryToGetSequenceDefaultValue(c *ast.ColumnOption) (expr string, tp ColumnDefaultType, err error) {
if f, ok := c.Expr.(*ast.FuncCallExpr); ok && f.FnName.L == ast.NextVal {
var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)
if err := c.Expr.Restore(restoreCtx); err != nil {
return "", true, err
return "", ColumnDefaultSequence, err
}
return sb.String(), true, nil
return sb.String(), ColumnDefaultSequence, nil
}
return "", false, nil
return "", ColumnDefaultLiteral, nil
}

// getSetDefaultValue gets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html.
Expand Down Expand Up @@ -3408,11 +3441,11 @@ func checkAndCreateNewColumn(ctx sessionctx.Context, ti ast.Ident, schema *model
// known rows with specific sequence next value under current add column logic.
// More explanation can refer: TestSequenceDefaultLogic's comment in sequence_test.go
if option.Tp == ast.ColumnOptionDefaultValue {
_, isSeqExpr, err := tryToGetSequenceDefaultValue(option)
_, columnDefaultType, err := tryToGetSequenceDefaultValue(option)
if err != nil {
return nil, errors.Trace(err)
}
if isSeqExpr {
if columnDefaultType == ColumnDefaultSequence {
return nil, errors.Trace(dbterror.ErrAddColumnWithSequenceAsDefault.GenWithStackByArgs(specNewColumn.Name.Name.O))
}
}
Expand Down Expand Up @@ -4174,15 +4207,15 @@ func checkModifyTypes(ctx sessionctx.Context, origin *types.FieldType, to *types

func setDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (bool, error) {
hasDefaultValue := false
value, isSeqExpr, err := getDefaultValue(ctx, col, option)
value, columnDefaultType, err := getDefaultValue(ctx, col, option)
if err != nil {
return false, errors.Trace(err)
}
if isSeqExpr {
if columnDefaultType == ColumnDefaultSequence {
if err := checkSequenceDefaultValue(col); err != nil {
return false, errors.Trace(err)
}
col.DefaultIsExpr = isSeqExpr
col.DefaultIsExpr = true
}

if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil {
Expand Down
Loading

0 comments on commit 1720c3e

Please sign in to comment.