diff --git a/connection.go b/connection.go index 67d3dbee8..517158052 100644 --- a/connection.go +++ b/connection.go @@ -167,17 +167,31 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err return nil, driver.ErrBadConn } if len(args) == 0 { // no args, fastpath - mc.affectedRows = 0 - mc.insertId = 0 + totalAffectedRows := int64(0) + lastInsertId := int64(0) + + queries := strings.Split(query, ";") + for _, singleQuery := range queries { + singleQuery := strings.TrimSpace(singleQuery) + + if len(singleQuery) > 0 { + mc.affectedRows = 0 + mc.insertId = 0 + + err := mc.exec(singleQuery) + if err != nil { + return nil, err + } else { + totalAffectedRows += int64(mc.affectedRows) + lastInsertId = int64(mc.insertId) + } + } - err := mc.exec(query) - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, err } - return nil, err + return &mysqlResult{ + affectedRows: totalAffectedRows, + insertId: lastInsertId, + }, nil } // with args, must use prepared stmt diff --git a/driver_test.go b/driver_test.go index a52cc5cd0..707582f23 100644 --- a/driver_test.go +++ b/driver_test.go @@ -128,6 +128,60 @@ func TestEmptyQuery(t *testing.T) { }) } +func TestSimpleMultipleStatement(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value int(11))") + + // Multiple Statement + query := "INSERT INTO test VALUES(1);INSERT INTO test VALUES(2);" + res := dbt.mustExec(query) + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + + id, err := res.LastInsertId() + if err != nil { + dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) + } + if id != 0 { + dbt.Fatalf("Expected InsertID 0, got %d", id) + } + + }) +} +func TestComplexMultipleStatement(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Multiple Statement + query, err := readFromExternalFile("complex_statement.sql") + if err != nil { + dbt.Fatalf("Unable to read external file, returned error : %s", err.Error()) + } + + res := dbt.mustExec(query) + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + + id, err := res.LastInsertId() + if err != nil { + dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) + } + if id != 0 { + dbt.Fatalf("Expected InsertID 0, got %d", id) + } + + }) +} + func TestCRUD(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { // Create Table @@ -1538,3 +1592,8 @@ func TestCustomDial(t *testing.T) { t.Fatalf("Connection failed: %s", err.Error()) } } + +func readFromExternalFile(filename string) (string, error) { + content, err := ioutil.ReadFile("./tests/" + filename) + return string(content), err +} diff --git a/tests/complex_statement.sql b/tests/complex_statement.sql new file mode 100644 index 000000000..bd5473707 --- /dev/null +++ b/tests/complex_statement.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS `test` ( + `id` int(11), + `value` int(11), + PRIMARY KEY (`id`) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +INSERT INTO test VALUES(1,1); +INSERT INTO test VALUES(2,1);