From 6f1bdf76cf3e1ea0496ca6217b8afdad745a1b5d Mon Sep 17 00:00:00 2001 From: lance6716 Date: Mon, 4 Sep 2023 16:09:46 +0800 Subject: [PATCH] executor: optimize load data assignment expression (#46563) close pingcap/tidb#46081 --- executor/importer/BUILD.bazel | 1 + executor/importer/import.go | 22 ++++++++++++ executor/load_data.go | 30 +++++++++++----- server/server_test.go | 67 +++++++++++++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 8 deletions(-) diff --git a/executor/importer/BUILD.bazel b/executor/importer/BUILD.bazel index 52444cf42442a..7e937232b0cff 100644 --- a/executor/importer/BUILD.bazel +++ b/executor/importer/BUILD.bazel @@ -36,6 +36,7 @@ go_library( "//parser/terror", "//planner/core", "//sessionctx", + "//sessionctx/stmtctx", "//sessionctx/variable", "//table", "//table/tables", diff --git a/executor/importer/import.go b/executor/importer/import.go index 5a536c3752bfb..5ab9de5f23b67 100644 --- a/executor/importer/import.go +++ b/executor/importer/import.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/br/pkg/lightning/mydump" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/executor/asyncloaddata" + "github.com/pingcap/tidb/expression" tidbkv "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" @@ -39,6 +40,7 @@ import ( "github.com/pingcap/tidb/parser/terror" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util/chunk" @@ -936,6 +938,26 @@ func (e *LoadDataController) toMyDumpFiles() []mydump.FileInfo { return res } +// CreateColAssignExprs creates the column assignment expressions using session context. +// RewriteAstExpr will write ast node in place(due to xxNode.Accept), but it doesn't change node content, +// so we sync it. +func (e *LoadDataController) CreateColAssignExprs(sctx sessionctx.Context) ([]expression.Expression, []stmtctx.SQLWarn, error) { + res := make([]expression.Expression, 0, len(e.ColumnAssignments)) + allWarnings := []stmtctx.SQLWarn{} + for _, assign := range e.ColumnAssignments { + newExpr, err := expression.RewriteAstExpr(sctx, assign.Expr, nil, nil, false) + // col assign expr warnings is static, we should generate it for each row processed. + // so we save it and clear it here. + allWarnings = append(allWarnings, sctx.GetSessionVars().StmtCtx.GetWarnings()...) + sctx.GetSessionVars().StmtCtx.SetWarnings(nil) + if err != nil { + return nil, nil, err + } + res = append(res, newExpr) + } + return res, allWarnings, nil +} + // JobImportParam is the param of the job import. type JobImportParam struct { Job *asyncloaddata.Job diff --git a/executor/load_data.go b/executor/load_data.go index 0c14652b221fc..10d518f226c13 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -399,10 +399,16 @@ func (ji *logicalJobImporter) initEncodeCommitWorkers(e *LoadDataWorker) (err er return err2 } createdSessions = append(createdSessions, commitCore.ctx) + colAssignExprs, exprWarnings, err2 := e.controller.CreateColAssignExprs(encodeCore.ctx) + if err2 != nil { + return err2 + } encode := &encodeWorker{ - InsertValues: encodeCore, - controller: e.controller, - killed: &e.UserSctx.GetSessionVars().Killed, + InsertValues: encodeCore, + controller: e.controller, + colAssignExprs: colAssignExprs, + exprWarnings: exprWarnings, + killed: &e.UserSctx.GetSessionVars().Killed, } encode.resetBatch() encodeWorkers = append(encodeWorkers, encode) @@ -627,9 +633,13 @@ func (ji *logicalJobImporter) Close() error { // encodeWorker is a sub-worker of LoadDataWorker that dedicated to encode data. type encodeWorker struct { *InsertValues - controller *importer.LoadDataController - killed *uint32 - rows [][]types.Datum + controller *importer.LoadDataController + colAssignExprs []expression.Expression + // sessionCtx generate warnings when rewrite AST node into expression. + // we should generate such warnings for each row encoded. + exprWarnings []stmtctx.SQLWarn + killed *uint32 + rows [][]types.Datum } // processStream always trys to build a parser from channel and process it. When @@ -818,9 +828,9 @@ func (w *encodeWorker) parserData2TableData( row = append(row, parserData[i]) } - for i := 0; i < len(w.controller.ColumnAssignments); i++ { + for i := 0; i < len(w.colAssignExprs); i++ { // eval expression of `SET` clause - d, err := expression.EvalAstExpr(w.ctx, w.controller.ColumnAssignments[i].Expr) + d, err := w.colAssignExprs[i].Eval(chunk.Row{}) if err != nil { if w.controller.Restrictive { return nil, err @@ -830,6 +840,10 @@ func (w *encodeWorker) parserData2TableData( row = append(row, d) } + if len(w.exprWarnings) > 0 { + w.ctx.GetSessionVars().StmtCtx.AppendWarnings(w.exprWarnings) + } + // a new row buffer will be allocated in getRow newRow, err := w.getRow(ctx, row) if err != nil { diff --git a/server/server_test.go b/server/server_test.go index 46efb5950037c..0734b8f4b748d 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1573,6 +1573,73 @@ func (cli *testServerClient) runTestLoadData(t *testing.T, server *Server) { require.NoError(t, rows.Close()) dbt.MustExec("drop table if exists pn") }) + + err = fp.Close() + require.NoError(t, err) + err = os.Remove(path) + require.NoError(t, err) + + fp, err = os.Create(path) + require.NoError(t, err) + require.NotNil(t, fp) + + _, err = fp.WriteString( + `1,2` + "\n" + + `1,2,,4` + "\n" + + `1,2,3` + "\n" + + `,,,` + "\n" + + `,,3` + "\n" + + `1,,,4` + "\n") + require.NoError(t, err) + + nullInt32 := func(val int32, valid bool) sql.NullInt32 { + return sql.NullInt32{Int32: val, Valid: valid} + } + expects := []struct { + col1 sql.NullInt32 + col2 sql.NullInt32 + col3 sql.NullInt32 + col4 sql.NullInt32 + }{ + {nullInt32(1, true), nullInt32(2, true), nullInt32(0, false), nullInt32(0, false)}, + {nullInt32(1, true), nullInt32(2, true), nullInt32(0, false), nullInt32(4, true)}, + {nullInt32(1, true), nullInt32(2, true), nullInt32(3, true), nullInt32(0, false)}, + {nullInt32(0, true), nullInt32(0, false), nullInt32(0, false), nullInt32(0, false)}, + {nullInt32(0, true), nullInt32(0, false), nullInt32(3, true), nullInt32(0, false)}, + {nullInt32(1, true), nullInt32(0, false), nullInt32(0, false), nullInt32(4, true)}, + } + + cli.runTestsOnNewDB(t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("drop table if exists pn") + dbt.MustExec("create table pn (c1 int, c2 int, c3 int, c4 int)") + dbt.MustExec("set @@tidb_dml_batch_size = 1") + _, err1 := dbt.GetDB().Exec(fmt.Sprintf(`load data local infile %q into table pn FIELDS TERMINATED BY ',' (c1, @val2, @val3, @val4) + SET c2 = NULLIF(@val2, ''), c3 = NULLIF(@val3, ''), c4 = NULLIF(@val4, '')`, path)) + require.NoError(t, err1) + var ( + a sql.NullInt32 + b sql.NullInt32 + c sql.NullInt32 + d sql.NullInt32 + ) + rows := dbt.MustQuery("select * from pn") + for _, expect := range expects { + require.Truef(t, rows.Next(), "unexpected data") + err = rows.Scan(&a, &b, &c, &d) + require.NoError(t, err) + require.Equal(t, expect.col1, a) + require.Equal(t, expect.col2, b) + require.Equal(t, expect.col3, c) + require.Equal(t, expect.col4, d) + } + + require.Falsef(t, rows.Next(), "unexpected data") + require.NoError(t, rows.Close()) + dbt.MustExec("drop table if exists pn") + }) } func (cli *testServerClient) runTestConcurrentUpdate(t *testing.T) {