Skip to content

Commit

Permalink
executor: unify replace into logic for InsertValues and ReplaceExec (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot authored Oct 26, 2023
1 parent f55b7e7 commit b8e6499
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 134 deletions.
6 changes: 6 additions & 0 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,8 @@ func (b *executorBuilder) buildLoadData(v *plannercore.LoadData) Executor {
isLoadData: true,
txnInUse: sync.Mutex{},
}
restrictive := b.ctx.GetSessionVars().SQLMode.HasStrictMode() &&
v.OnDuplicate != ast.OnDuplicateKeyHandlingIgnore
loadDataInfo := &LoadDataInfo{
row: make([]types.Datum, 0, len(insertVal.insertColumns)),
InsertValues: insertVal,
Expand All @@ -937,6 +939,10 @@ func (b *executorBuilder) buildLoadData(v *plannercore.LoadData) Executor {
ColumnAssignments: v.ColumnAssignments,
ColumnsAndUserVars: v.ColumnsAndUserVars,
Ctx: b.ctx,
restrictive: restrictive,
}
if !restrictive {
b.ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true
}
columnNames := loadDataInfo.initFieldMappings()
err := loadDataInfo.initLoadColumns(columnNames)
Expand Down
1 change: 0 additions & 1 deletion executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2100,7 +2100,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.NoZeroDate = vars.SQLMode.HasNoZeroDateMode()
sc.TruncateAsWarning = !vars.StrictSQLMode
case *ast.LoadDataStmt:
sc.DupKeyAsWarning = true
sc.BadNullAsWarning = true
// With IGNORE or LOCAL, data-interpretation errors become warnings and the load operation continues,
// even if the SQL mode is restrictive. For details: https://dev.mysql.com/doc/refman/8.0/en/load-data.html
Expand Down
88 changes: 64 additions & 24 deletions executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -1155,12 +1155,15 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D
if r.ignored {
continue
}
skip := false
if r.handleKey != nil {
_, err := txn.Get(ctx, r.handleKey.newKey)
if err == nil {
if replace {
err2 := e.removeRow(ctx, txn, r)
handle, err := tablecodec.DecodeRowKey(r.handleKey.newKey)
if err != nil {
return err
}
_, err2 := e.removeRow(ctx, txn, handle, r, false)
if err2 != nil {
return err2
}
Expand All @@ -1176,19 +1179,40 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D
return err
}
}
skip := false
for _, uk := range r.uniqueKeys {
_, err := txn.Get(ctx, uk.newKey)
if err == nil {
// If duplicate keys were found in BatchGet, mark row = nil.
e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr)
if txnCtx := e.ctx.GetSessionVars().TxnCtx; txnCtx.IsPessimistic {
// lock duplicated unique key on insert-ignore
txnCtx.AddUnchangedRowKey(uk.newKey)
if replace {
_, handle, err := tables.FetchDuplicatedHandle(
ctx,
uk.newKey,
true,
txn,
e.Table.Meta().ID,
uk.commonHandle,
)
if err != nil {
return err
}
if handle == nil {
continue
}
_, err = e.removeRow(ctx, txn, handle, r, true)
if err != nil {
return err
}
} else {
// If duplicate keys were found in BatchGet, mark row = nil.
e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr)
if txnCtx := e.ctx.GetSessionVars().TxnCtx; txnCtx.IsPessimistic {
// lock duplicated unique key on insert-ignore
txnCtx.AddUnchangedRowKey(uk.newKey)
}
skip = true
break
}
skip = true
break
}
if !kv.IsErrNotFound(err) {
} else if !kv.IsErrNotFound(err) {
return err
}
}
Expand All @@ -1210,12 +1234,16 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D
return nil
}

func (e *InsertValues) removeRow(ctx context.Context, txn kv.Transaction, r toBeCheckedRow) error {
handle, err := tablecodec.DecodeRowKey(r.handleKey.newKey)
if err != nil {
return err
}

// removeRow removes the duplicate row and cleanup its keys in the key-value map.
// But if the to-be-removed row equals to the to-be-added row, no remove or add
// things to do and return (true, nil).
func (e *InsertValues) removeRow(
ctx context.Context,
txn kv.Transaction,
handle kv.Handle,
r toBeCheckedRow,
inReplace bool,
) (bool, error) {
newRow := r.row
oldRow, err := getOldRow(ctx, e.ctx, txn, r.t, handle, e.GenExprs)
if err != nil {
Expand All @@ -1225,28 +1253,40 @@ func (e *InsertValues) removeRow(ctx context.Context, txn kv.Transaction, r toBe
if kv.IsErrNotFound(err) {
err = errors.NotFoundf("can not be duplicated row, due to old row not found. handle %s", handle)
}
return err
return false, err
}

identical, err := e.equalDatumsAsBinary(oldRow, newRow)
if err != nil {
return err
return false, err
}
if identical {
if inReplace {
e.ctx.GetSessionVars().StmtCtx.AddAffectedRows(1)
}
_, err := appendUnchangedRowForLock(e.ctx, r.t, handle, oldRow)
if err != nil {
return err
return false, err
}
return nil
return true, nil
}

err = r.t.RemoveRecord(e.ctx, handle, oldRow)
if err != nil {
return err
return false, err
}
// need https://github.com/pingcap/tidb/pull/40069
//err = onRemoveRowForFK(e.ctx, oldRow, e.fkChecks, e.fkCascades)
//if err != nil {
// return false, err
//}
if inReplace {
e.ctx.GetSessionVars().StmtCtx.AddAffectedRows(1)
} else {
e.ctx.GetSessionVars().StmtCtx.AddDeletedRows(1)
}
e.ctx.GetSessionVars().StmtCtx.AddDeletedRows(1)

return nil
return false, nil
}

// equalDatumsAsBinary compare if a and b contains the same datum values in binary collation.
Expand Down
44 changes: 34 additions & 10 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ type LoadDataInfo struct {
exprWarnings []stmtctx.SQLWarn
ColumnsAndUserVars []*ast.ColumnNameOrUserVar
FieldMappings []*FieldMapping
// Data interpretation is restrictive if the SQL mode is restrictive and neither
// the IGNORE nor the LOCAL modifier is specified. Errors terminate the load
// operation.
// ref https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-column-assignments
restrictive bool

commitTaskQueue chan CommitTask
StopCh chan struct{}
Expand Down Expand Up @@ -622,16 +627,33 @@ func (e *LoadDataInfo) CheckAndInsertOneBatch(ctx context.Context, rows [][]type
}
e.ctx.GetSessionVars().StmtCtx.AddRecordRows(cnt)

replace := false
if e.OnDuplicate == ast.OnDuplicateKeyHandlingReplace {
replace = true
}

err = e.batchCheckAndInsert(ctx, rows[0:cnt], e.addRecordLD, replace)
if err != nil {
return err
switch e.OnDuplicate {
case ast.OnDuplicateKeyHandlingReplace:
return e.batchCheckAndInsert(ctx, rows[0:cnt], e.addRecordLD, true)
case ast.OnDuplicateKeyHandlingIgnore:
return e.batchCheckAndInsert(ctx, rows[0:cnt], e.addRecordLD, false)
case ast.OnDuplicateKeyHandlingError:
for i, row := range rows[0:cnt] {
sizeHintStep := int(e.Ctx.GetSessionVars().ShardAllocateStep)
if sizeHintStep > 0 && i%sizeHintStep == 0 {
sizeHint := sizeHintStep
remain := len(rows[0:cnt]) - i
if sizeHint > remain {
sizeHint = remain
}
err = e.addRecordWithAutoIDHint(ctx, row, sizeHint)
} else {
err = e.addRecord(ctx, row)
}
if err != nil {
return err
}
e.ctx.GetSessionVars().StmtCtx.AddCopiedRows(1)
}
return nil
default:
return errors.Errorf("unknown on duplicate key handling: %v", e.OnDuplicate)
}
return err
}

// SetMessage sets info message(ERR_LOAD_INFO) generated by LOAD statement, it is public because of the special way that
Expand Down Expand Up @@ -716,8 +738,10 @@ func (e *LoadDataInfo) addRecordLD(ctx context.Context, row []types.Datum) error
}
err := e.addRecord(ctx, row)
if err != nil {
if e.restrictive {
return err
}
e.handleWarning(err)
return err
}
return nil
}
Expand Down
62 changes: 2 additions & 60 deletions executor/replace.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,14 @@ import (
"runtime/trace"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/memory"
"go.uber.org/zap"
)

// ReplaceExec represents a replace executor.
Expand Down Expand Up @@ -65,59 +60,6 @@ func (e *ReplaceExec) Open(ctx context.Context) error {
return nil
}

// removeRow removes the duplicate row and cleanup its keys in the key-value map,
// but if the to-be-removed row equals to the to-be-added row, no remove or add things to do.
func (e *ReplaceExec) removeRow(ctx context.Context, txn kv.Transaction, handle kv.Handle, r toBeCheckedRow) (bool, error) {
newRow := r.row
oldRow, err := getOldRow(ctx, e.ctx, txn, r.t, handle, e.GenExprs)
if err != nil {
logutil.BgLogger().Error("get old row failed when replace",
zap.String("handle", handle.String()),
zap.String("toBeInsertedRow", types.DatumsToStrNoErr(r.row)))
if kv.IsErrNotFound(err) {
err = errors.NotFoundf("can not be duplicated row, due to old row not found. handle %s", handle)
}
return false, err
}

rowUnchanged, err := e.EqualDatumsAsBinary(e.ctx.GetSessionVars().StmtCtx, oldRow, newRow)
if err != nil {
return false, err
}
if rowUnchanged {
e.ctx.GetSessionVars().StmtCtx.AddAffectedRows(1)
_, err := appendUnchangedRowForLock(e.ctx, r.t, handle, oldRow)
if err != nil {
return false, err
}
return true, nil
}

err = r.t.RemoveRecord(e.ctx, handle, oldRow)
if err != nil {
return false, err
}
e.ctx.GetSessionVars().StmtCtx.AddAffectedRows(1)
return false, nil
}

// EqualDatumsAsBinary compare if a and b contains the same datum values in binary collation.
func (e *ReplaceExec) EqualDatumsAsBinary(sc *stmtctx.StatementContext, a []types.Datum, b []types.Datum) (bool, error) {
if len(a) != len(b) {
return false, nil
}
for i, ai := range a {
v, err := ai.Compare(sc, &b[i], collate.GetBinaryCollator())
if err != nil {
return false, errors.Trace(err)
}
if v != 0 {
return false, nil
}
}
return true, nil
}

// replaceRow removes all duplicate rows for one row, then inserts it.
func (e *ReplaceExec) replaceRow(ctx context.Context, r toBeCheckedRow) error {
txn, err := e.ctx.Txn(true)
Expand All @@ -132,7 +74,7 @@ func (e *ReplaceExec) replaceRow(ctx context.Context, r toBeCheckedRow) error {
}

if _, err := txn.Get(ctx, r.handleKey.newKey); err == nil {
rowUnchanged, err := e.removeRow(ctx, txn, handle, r)
rowUnchanged, err := e.removeRow(ctx, txn, handle, r, true)
if err != nil {
return err
}
Expand Down Expand Up @@ -184,7 +126,7 @@ func (e *ReplaceExec) removeIndexRow(ctx context.Context, txn kv.Transaction, r
if handle == nil {
continue
}
rowUnchanged, err := e.removeRow(ctx, txn, handle, r)
rowUnchanged, err := e.removeRow(ctx, txn, handle, r, true)
if err != nil {
return false, true, err
}
Expand Down
34 changes: 34 additions & 0 deletions executor/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"testing"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/types"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -93,3 +95,35 @@ func TestBatchRetrieverHelper(t *testing.T) {
require.Equal(t, rangeStarts, []int{0})
require.Equal(t, rangeEnds, []int{10})
}

func TestEqualDatumsAsBinary(t *testing.T) {
tests := []struct {
a []interface{}
b []interface{}
same bool
}{
// Positive cases
{[]interface{}{1}, []interface{}{1}, true},
{[]interface{}{1, "aa"}, []interface{}{1, "aa"}, true},
{[]interface{}{1, "aa", 1}, []interface{}{1, "aa", 1}, true},

// negative cases
{[]interface{}{1}, []interface{}{2}, false},
{[]interface{}{1, "a"}, []interface{}{1, "aaaaaa"}, false},
{[]interface{}{1, "aa", 3}, []interface{}{1, "aa", 2}, false},

// Corner cases
{[]interface{}{}, []interface{}{}, true},
{[]interface{}{nil}, []interface{}{nil}, true},
{[]interface{}{}, []interface{}{1}, false},
{[]interface{}{1}, []interface{}{1, 1}, false},
{[]interface{}{nil}, []interface{}{1}, false},
}

e := &InsertValues{baseExecutor: baseExecutor{ctx: core.MockContext()}}
for _, tt := range tests {
res, err := e.equalDatumsAsBinary(types.MakeDatums(tt.a...), types.MakeDatums(tt.b...))
require.NoError(t, err)
require.Equal(t, tt.same, res)
}
}
Loading

0 comments on commit b8e6499

Please sign in to comment.