Skip to content

Commit

Permalink
load_data: fix the bug that column list does not work in load data. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Jun 24, 2022
1 parent 6aacaf5 commit d16bcd8
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 0 deletions.
1 change: 1 addition & 0 deletions errno/errcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,7 @@ const (
ErrWarnOptimizerHintParseError = 8064
ErrWarnOptimizerHintInvalidInteger = 8065
ErrUnsupportedSecondArgumentType = 8066
ErrColumnNotMatched = 8067
ErrInvalidPluginID = 8101
ErrInvalidPluginManifest = 8102
ErrInvalidPluginName = 8103
Expand Down
1 change: 1 addition & 0 deletions errno/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,7 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{
ErrInvalidWildCard: mysql.Message("Wildcard fields without any table name appears in wrong place", nil),
ErrMixOfGroupFuncAndFieldsIncompatible: mysql.Message("In aggregated query without GROUP BY, expression #%d of SELECT list contains nonaggregated column '%s'; this is incompatible with sql_mode=only_full_group_by", nil),
ErrUnsupportedSecondArgumentType: mysql.Message("JSON_OBJECTAGG: unsupported second argument type %v", nil),
ErrColumnNotMatched: mysql.Message("Load data: unmatched columns", nil),
ErrLockExpire: mysql.Message("TTL manager has timed out, pessimistic locks may expire, please commit or rollback this transaction", nil),
ErrTableOptionUnionUnsupported: mysql.Message("CREATE/ALTER table with union option is not supported", nil),
ErrTableOptionInsertMethodUnsupported: mysql.Message("CREATE/ALTER table with insert method option is not supported", nil),
Expand Down
5 changes: 5 additions & 0 deletions errors.toml
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,11 @@ error = '''
TiDB admin check table failed.
'''

["executor:8067"]
error = '''
Load data: unmatched columns
'''

["executor:8114"]
error = '''
Unknown plan
Expand Down
1 change: 1 addition & 0 deletions executor/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
ErrUnsupportedPs = dbterror.ClassExecutor.NewStd(mysql.ErrUnsupportedPs)
ErrSubqueryMoreThan1Row = dbterror.ClassExecutor.NewStd(mysql.ErrSubqueryNo1Row)
ErrIllegalGrantForTable = dbterror.ClassExecutor.NewStd(mysql.ErrIllegalGrantForTable)
ErrColumnsNotMatched = dbterror.ClassExecutor.NewStd(mysql.ErrColumnNotMatched)

ErrCantCreateUserWithGrant = dbterror.ClassExecutor.NewStd(mysql.ErrCantCreateUserWithGrant)
ErrPasswordNoMatch = dbterror.ClassExecutor.NewStd(mysql.ErrPasswordNoMatch)
Expand Down
37 changes: 37 additions & 0 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,36 @@ type FieldMapping struct {
UserVar *ast.VariableExpr
}

// reorderColumns reorder the e.insertColumns according to the order of columnNames
// Note: We must ensure there must be one-to-one mapping between e.insertColumns and columnNames in terms of column name.
func (e *LoadDataInfo) reorderColumns(columnNames []string) error {
cols := e.insertColumns

if len(cols) != len(columnNames) {
return ErrColumnsNotMatched
}

reorderedColumns := make([]*table.Column, len(cols))

if columnNames == nil {
return nil
}

mapping := make(map[string]int)
for idx, colName := range columnNames {
mapping[strings.ToLower(colName)] = idx
}

for _, col := range cols {
idx := mapping[col.Name.L]
reorderedColumns[idx] = col
}

e.insertColumns = reorderedColumns

return nil
}

// initLoadColumns sets columns which the input fields loaded to.
func (e *LoadDataInfo) initLoadColumns(columnNames []string) error {
var cols []*table.Column
Expand Down Expand Up @@ -163,6 +193,13 @@ func (e *LoadDataInfo) initLoadColumns(columnNames []string) error {
break
}
}

// e.insertColumns is appended according to the original tables' column sequence.
// We have to reorder it to follow the use-specified column order which is shown in the columnNames.
if err = e.reorderColumns(columnNames); err != nil {
return err
}

e.rowLen = len(e.insertColumns)
// Check column whether is specified only once.
err = table.CheckOnce(cols)
Expand Down
138 changes: 138 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,144 @@ func (cli *testServerClient) checkRows(t *testing.T, rows *sql.Rows, expectedRow
require.Equal(t, strings.Join(expectedRows, "\n"), strings.Join(result, "\n"))
}

func (cli *testServerClient) runTestLoadDataWithColumnList(t *testing.T, _ *Server) {
fp, err := os.CreateTemp("", "load_data_test.csv")
require.NoError(t, err)
path := fp.Name()
require.NotNil(t, fp)
defer func() {
err = fp.Close()
require.NoError(t, err)
err = os.Remove(path)
require.NoError(t, err)
}()

_, err = fp.WriteString("dsadasdas\n" +
"\"1\",\"1\",,\"2022-04-19\",\"a\",\"2022-04-19 00:00:01\"\n" +
"\"1\",\"2\",\"a\",\"2022-04-19\",\"a\",\"2022-04-19 00:00:01\"\n" +
"\"1\",\"3\",\"a\",\"2022-04-19\",\"a\",\"2022-04-19 00:00:01\"\n" +
"\"1\",\"4\",\"a\",\"2022-04-19\",\"a\",\"2022-04-19 00:00:01\"")

cli.runTestsOnNewDB(t, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Params["sql_mode"] = "''"
}, "LoadData", func(db *testkit.DBTestKit) {
db.MustExec("use test")
db.MustExec("drop table if exists t66")
db.MustExec("create table t66 (id int primary key,k int,c varchar(10),dt date,vv char(1),ts datetime)")
db.MustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE t66 FIELDS TERMINATED BY ',' ENCLOSED BY '\\\"' IGNORE 1 LINES (k,id,c,dt,vv,ts)", path))
rows := db.MustQuery("select * from t66")
var (
id sql.NullString
k sql.NullString
c sql.NullString
dt sql.NullString
vv sql.NullString
ts sql.NullString
)
columns := []*sql.NullString{&k, &id, &c, &dt, &vv, &ts}
require.Truef(t, rows.Next(), "unexpected data")
err := rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,1,,2022-04-19,a,2022-04-19 00:00:01", ","))
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,2,a,2022-04-19,a,2022-04-19 00:00:01", ","))
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,3,a,2022-04-19,a,2022-04-19 00:00:01", ","))
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,4,a,2022-04-19,a,2022-04-19 00:00:01", ","))
})

// Also test cases where column list only specifies partial columns
cli.runTestsOnNewDB(t, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Params["sql_mode"] = "''"
}, "LoadData", func(db *testkit.DBTestKit) {
db.MustExec("use test")
db.MustExec("drop table if exists t66")
db.MustExec("create table t66 (id int primary key,k int,c varchar(10),dt date,vv char(1),ts datetime)")
db.MustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE t66 FIELDS TERMINATED BY ',' ENCLOSED BY '\\\"' IGNORE 1 LINES (k,id,c)", path))
rows := db.MustQuery("select * from t66")
var (
id sql.NullString
k sql.NullString
c sql.NullString
dt sql.NullString
vv sql.NullString
ts sql.NullString
)
columns := []*sql.NullString{&k, &id, &c, &dt, &vv, &ts}
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,1,,,,", ","))
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,2,a,,,", ","))
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,3,a,,,", ","))
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,4,a,,,", ","))
})

// Also test for case-insensitivity
cli.runTestsOnNewDB(t, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Params["sql_mode"] = "''"
}, "LoadData", func(db *testkit.DBTestKit) {
db.MustExec("use test")
db.MustExec("drop table if exists t66")
db.MustExec("create table t66 (id int primary key,k int,c varchar(10),dt date,vv char(1),ts datetime)")
// We modify the upper case and lower case in the column list to test the case-insensitivity
db.MustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE t66 FIELDS TERMINATED BY ',' ENCLOSED BY '\\\"' IGNORE 1 LINES (K,Id,c,dT,Vv,Ts)", path))
rows := db.MustQuery("select * from t66")
var (
id sql.NullString
k sql.NullString
c sql.NullString
dt sql.NullString
vv sql.NullString
ts sql.NullString
)
columns := []*sql.NullString{&k, &id, &c, &dt, &vv, &ts}
require.Truef(t, rows.Next(), "unexpected data")
err := rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,1,,2022-04-19,a,2022-04-19 00:00:01", ","))
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,2,a,2022-04-19,a,2022-04-19 00:00:01", ","))
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,3,a,2022-04-19,a,2022-04-19 00:00:01", ","))
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&id, &k, &c, &dt, &vv, &ts)
require.NoError(t, err)
columnsAsExpected(t, columns, strings.Split("1,4,a,2022-04-19,a,2022-04-19 00:00:01", ","))
})
}

func columnsAsExpected(t *testing.T, columns []*sql.NullString, expected []string) {
require.Equal(t, len(columns), len(expected))

for i := 0; i < len(columns); i++ {
require.Equal(t, expected[i], columns[i].String)
}
}

func (cli *testServerClient) runTestLoadData(t *testing.T, server *Server) {
// create a file and write data.
path := "/tmp/load_data_test.csv"
Expand Down
1 change: 1 addition & 0 deletions server/tidb_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func TestLoadData(t *testing.T) {
ts, cleanup := createTidbTestSuite(t)
defer cleanup()

ts.runTestLoadDataWithColumnList(t, ts.server)
ts.runTestLoadData(t, ts.server)
ts.runTestLoadDataWithSelectIntoOutfile(t, ts.server)
ts.runTestLoadDataForSlowLog(t, ts.server)
Expand Down

0 comments on commit d16bcd8

Please sign in to comment.