diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py new file mode 100644 index 0000000000..f65e8ada1a --- /dev/null +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -0,0 +1,29 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from google.cloud.spanner_dbapi.parsed_statement import ( + ParsedStatement, + ClientSideStatementType, +) + + +def execute(connection, parsed_statement: ParsedStatement): + """Executes the client side statements by calling the relevant method. + + It is an internal method that can make backwards-incompatible changes. + + :type parsed_statement: ParsedStatement + :param parsed_statement: parsed_statement based on the sql query + """ + if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT: + return connection.commit() diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py new file mode 100644 index 0000000000..e93b71f3e1 --- /dev/null +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -0,0 +1,42 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from google.cloud.spanner_dbapi.parsed_statement import ( + ParsedStatement, + StatementType, + ClientSideStatementType, +) + +RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE) + + +def parse_stmt(query): + """Parses the sql query to check if it matches with any of the client side + statement regex. + + It is an internal method that can make backwards-incompatible changes. + + :type query: str + :param query: sql query + + :rtype: ParsedStatement + :returns: ParsedStatement object. + """ + if RE_COMMIT.match(query): + return ParsedStatement( + StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT + ) + return None diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 330aeb2c72..95d20f5730 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -32,13 +32,14 @@ from google.cloud.spanner_dbapi.exceptions import OperationalError from google.cloud.spanner_dbapi.exceptions import ProgrammingError -from google.cloud.spanner_dbapi import _helpers +from google.cloud.spanner_dbapi import _helpers, client_side_statement_executor from google.cloud.spanner_dbapi._helpers import ColumnInfo from google.cloud.spanner_dbapi._helpers import CODE_TO_DISPLAY_SIZE from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parse_utils import get_param_types from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner +from google.cloud.spanner_dbapi.parsed_statement import StatementType from google.cloud.spanner_dbapi.utils import PeekIterator from google.cloud.spanner_dbapi.utils import StreamedManyResultSets @@ -210,7 +211,10 @@ def _batch_DDLs(self, sql): for ddl in sqlparse.split(sql): if ddl: ddl = ddl.rstrip(";") - if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL: + if ( + parse_utils.classify_statement(ddl).statement_type + != StatementType.DDL + ): raise ValueError("Only DDL statements may be batched.") statements.append(ddl) @@ -239,8 +243,12 @@ def execute(self, sql, args=None): self._handle_DQL(sql, args or None) return - class_ = parse_utils.classify_stmt(sql) - if class_ == parse_utils.STMT_DDL: + parsed_statement = parse_utils.classify_statement(sql) + if parsed_statement.statement_type == StatementType.CLIENT_SIDE: + return client_side_statement_executor.execute( + self.connection, parsed_statement + ) + if parsed_statement.statement_type == StatementType.DDL: self._batch_DDLs(sql) if self.connection.autocommit: self.connection.run_prior_DDL_statements() @@ -251,7 +259,7 @@ def execute(self, sql, args=None): # self._run_prior_DDL_statements() self.connection.run_prior_DDL_statements() - if class_ == parse_utils.STMT_UPDATING: + if parsed_statement.statement_type == StatementType.UPDATE: sql = parse_utils.ensure_where_clause(sql) sql, args = sql_pyformat_args_to_spanner(sql, args or None) @@ -276,7 +284,7 @@ def execute(self, sql, args=None): self.connection.retry_transaction() return - if class_ == parse_utils.STMT_NON_UPDATING: + if parsed_statement.statement_type == StatementType.QUERY: self._handle_DQL(sql, args or None) else: self.connection.database.run_in_transaction( @@ -309,19 +317,29 @@ def executemany(self, operation, seq_of_params): self._result_set = None self._row_count = _UNSET_COUNT - class_ = parse_utils.classify_stmt(operation) - if class_ == parse_utils.STMT_DDL: + parsed_statement = parse_utils.classify_statement(operation) + if parsed_statement.statement_type == StatementType.DDL: raise ProgrammingError( "Executing DDL statements with executemany() method is not allowed." ) + if parsed_statement.statement_type == StatementType.CLIENT_SIDE: + raise ProgrammingError( + "Executing the following operation: " + + operation + + ", with executemany() method is not allowed." + ) + # For every operation, we've got to ensure that any prior DDL # statements were run. self.connection.run_prior_DDL_statements() many_result_set = StreamedManyResultSets() - if class_ in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING): + if parsed_statement.statement_type in ( + StatementType.INSERT, + StatementType.UPDATE, + ): statements = [] for params in seq_of_params: diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 84cb2dc7a5..97276e54f6 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -21,8 +21,11 @@ import sqlparse from google.cloud import spanner_v1 as spanner from google.cloud.spanner_v1 import JsonObject +from . import client_side_statement_parser +from deprecated import deprecated from .exceptions import Error +from .parsed_statement import ParsedStatement, StatementType from .types import DateStr, TimestampStr from .utils import sanitize_literals_for_upload @@ -174,12 +177,11 @@ RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL) +@deprecated(reason="This method is deprecated. Use _classify_stmt method") def classify_stmt(query): """Determine SQL query type. - :type query: str :param query: A SQL query. - :rtype: str :returns: The query type name. """ @@ -203,6 +205,39 @@ def classify_stmt(query): return STMT_UPDATING +def classify_statement(query): + """Determine SQL query type. + + It is an internal method that can make backwards-incompatible changes. + + :type query: str + :param query: A SQL query. + + :rtype: ParsedStatement + :returns: parsed statement attributes. + """ + # sqlparse will strip Cloud Spanner comments, + # still, special commenting styles, like + # PostgreSQL dollar quoted comments are not + # supported and will not be stripped. + query = sqlparse.format(query, strip_comments=True).strip() + parsed_statement = client_side_statement_parser.parse_stmt(query) + if parsed_statement is not None: + return parsed_statement + if RE_DDL.match(query): + return ParsedStatement(StatementType.DDL, query) + + if RE_IS_INSERT.match(query): + return ParsedStatement(StatementType.INSERT, query) + + if RE_NON_UPDATE.match(query) or RE_WITH.match(query): + # As of 13-March-2020, Cloud Spanner only supports WITH for DQL + # statements and doesn't yet support WITH for DML statements. + return ParsedStatement(StatementType.QUERY, query) + + return ParsedStatement(StatementType.UPDATE, query) + + def sql_pyformat_args_to_spanner(sql, params): """ Transform pyformat set SQL to named arguments for Cloud Spanner. diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py new file mode 100644 index 0000000000..c36bc1d81c --- /dev/null +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -0,0 +1,36 @@ +# Copyright 20203 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum + + +class StatementType(Enum): + CLIENT_SIDE = 1 + DDL = 2 + QUERY = 3 + UPDATE = 4 + INSERT = 5 + + +class ClientSideStatementType(Enum): + COMMIT = 1 + BEGIN = 2 + + +@dataclass +class ParsedStatement: + statement_type: StatementType + query: str + client_side_statement_type: ClientSideStatementType = None diff --git a/setup.py b/setup.py index 1738eed2ea..76aaed4c8c 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ "proto-plus >= 1.22.0, <2.0.0dev", "sqlparse >= 0.4.4", "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", + "deprecated >= 1.2.14", ] extras = { "tracing": [ diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index f3c5da1f46..bd49e478ba 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -20,6 +20,8 @@ from google.cloud import spanner_v1 from google.cloud._helpers import UTC + +from google.cloud.spanner_dbapi import Cursor from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.connection import Connection from google.cloud.spanner_dbapi.exceptions import ProgrammingError @@ -72,37 +74,11 @@ def dbapi_database(raw_database): def test_commit(shared_instance, dbapi_database): """Test committing a transaction with several statements.""" - want_row = ( - 1, - "updated-first-name", - "last-name", - "test.email_updated@domen.ru", - ) # connect to the test database conn = Connection(shared_instance, dbapi_database) cursor = conn.cursor() - # execute several DML statements within one transaction - cursor.execute( - """ -INSERT INTO contacts (contact_id, first_name, last_name, email) -VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') - """ - ) - cursor.execute( - """ -UPDATE contacts -SET first_name = 'updated-first-name' -WHERE first_name = 'first-name' -""" - ) - cursor.execute( - """ -UPDATE contacts -SET email = 'test.email_updated@domen.ru' -WHERE email = 'test.email@domen.ru' -""" - ) + want_row = _execute_common_precommit_statements(cursor) conn.commit() # read the resulting data from the database @@ -116,6 +92,25 @@ def test_commit(shared_instance, dbapi_database): conn.close() +def test_commit_client_side(shared_instance, dbapi_database): + """Test committing a transaction with several statements.""" + # connect to the test database + conn = Connection(shared_instance, dbapi_database) + cursor = conn.cursor() + + want_row = _execute_common_precommit_statements(cursor) + cursor.execute("""COMMIT""") + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + conn.commit() + cursor.close() + conn.close() + + assert got_rows == [want_row] + + def test_rollback(shared_instance, dbapi_database): """Test rollbacking a transaction with several statements.""" want_row = (2, "first-name", "last-name", "test.email@domen.ru") @@ -810,3 +805,33 @@ def test_dml_returning_delete(shared_instance, dbapi_database, autocommit): assert cur.fetchone() == (1, "first-name") assert cur.rowcount == 1 conn.commit() + + +def _execute_common_precommit_statements(cursor: Cursor): + # execute several DML statements within one transaction + cursor.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cursor.execute( + """ + UPDATE contacts + SET first_name = 'updated-first-name' + WHERE first_name = 'first-name' + """ + ) + cursor.execute( + """ + UPDATE contacts + SET email = 'test.email_updated@domen.ru' + WHERE email = 'test.email@domen.ru' + """ + ) + return ( + 1, + "updated-first-name", + "last-name", + "test.email_updated@domen.ru", + ) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 46a093b109..972816f47a 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -14,10 +14,12 @@ """Cursor() class unit tests.""" -import mock +from unittest import mock import sys import unittest +from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, StatementType + class TestCursor(unittest.TestCase): INSTANCE = "test-instance" @@ -182,7 +184,6 @@ def test_execute_autocommit_off(self): self.assertIsInstance(cursor._itr, PeekIterator) def test_execute_insert_statement_autocommit_off(self): - from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.utils import PeekIterator @@ -192,54 +193,54 @@ def test_execute_insert_statement_autocommit_off(self): cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) cursor._checksum = ResultsChecksum() + sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value=parse_utils.STMT_UPDATING, + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.UPDATE, sql), ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", return_value=(mock.MagicMock(), ResultsChecksum()), ): - cursor.execute( - sql="INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" - ) + cursor.execute(sql) self.assertIsInstance(cursor._result_set, mock.MagicMock) self.assertIsInstance(cursor._itr, PeekIterator) def test_execute_statement(self): - from google.cloud.spanner_dbapi import parse_utils - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + sql = "sql" with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - side_effect=[parse_utils.STMT_DDL, parse_utils.STMT_UPDATING], - ) as mock_classify_stmt: - sql = "sql" + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + side_effect=[ + ParsedStatement(StatementType.DDL, sql), + ParsedStatement(StatementType.UPDATE, sql), + ], + ) as mockclassify_statement: with self.assertRaises(ValueError): cursor.execute(sql=sql) - mock_classify_stmt.assert_called_with(sql) - self.assertEqual(mock_classify_stmt.call_count, 2) + mockclassify_statement.assert_called_with(sql) + self.assertEqual(mockclassify_statement.call_count, 2) self.assertEqual(cursor.connection._ddl_statements, []) with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value=parse_utils.STMT_DDL, - ) as mock_classify_stmt: + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.DDL, sql), + ) as mockclassify_statement: sql = "sql" cursor.execute(sql=sql) - mock_classify_stmt.assert_called_with(sql) - self.assertEqual(mock_classify_stmt.call_count, 2) + mockclassify_statement.assert_called_with(sql) + self.assertEqual(mockclassify_statement.call_count, 2) self.assertEqual(cursor.connection._ddl_statements, [sql]) with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value=parse_utils.STMT_NON_UPDATING, + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.QUERY, sql), ): with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor._handle_DQL", - return_value=parse_utils.STMT_NON_UPDATING, + return_value=ParsedStatement(StatementType.QUERY, sql), ) as mock_handle_ddl: connection.autocommit = True sql = "sql" @@ -247,14 +248,15 @@ def test_execute_statement(self): mock_handle_ddl.assert_called_once_with(sql, None) with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value="other_statement", + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.UPDATE, sql), ): cursor.connection._database = mock_db = mock.MagicMock() mock_db.run_in_transaction = mock_run_in = mock.MagicMock() - sql = "sql" - cursor.execute(sql=sql) - mock_run_in.assert_called_once_with(cursor._do_execute_update, sql, None) + cursor.execute(sql="sql") + mock_run_in.assert_called_once_with( + cursor._do_execute_update, "sql WHERE 1=1", None + ) def test_execute_integrity_error(self): from google.api_core import exceptions @@ -264,21 +266,21 @@ def test_execute_integrity_error(self): cursor = self._make_one(connection) with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + "google.cloud.spanner_dbapi.parse_utils.classify_statement", side_effect=exceptions.AlreadyExists("message"), ): with self.assertRaises(IntegrityError): cursor.execute(sql="sql") with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + "google.cloud.spanner_dbapi.parse_utils.classify_statement", side_effect=exceptions.FailedPrecondition("message"), ): with self.assertRaises(IntegrityError): cursor.execute(sql="sql") with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + "google.cloud.spanner_dbapi.parse_utils.classify_statement", side_effect=exceptions.OutOfRange("message"), ): with self.assertRaises(IntegrityError): @@ -292,7 +294,7 @@ def test_execute_invalid_argument(self): cursor = self._make_one(connection) with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + "google.cloud.spanner_dbapi.parse_utils.classify_statement", side_effect=exceptions.InvalidArgument("message"), ): with self.assertRaises(ProgrammingError): @@ -306,7 +308,7 @@ def test_execute_internal_server_error(self): cursor = self._make_one(connection) with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + "google.cloud.spanner_dbapi.parse_utils.classify_statement", side_effect=exceptions.InternalServerError("message"), ): with self.assertRaises(OperationalError): @@ -336,6 +338,20 @@ def test_executemany_DLL(self, mock_client): with self.assertRaises(ProgrammingError): cursor.executemany("""DROP DATABASE database_name""", ()) + def test_executemany_client_statement(self): + from google.cloud.spanner_dbapi import connect, ProgrammingError + + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + + with self.assertRaises(ProgrammingError) as error: + cursor.executemany("""COMMIT TRANSACTION""", ()) + self.assertEqual( + str(error.exception), + "Executing the following operation: COMMIT TRANSACTION, with executemany() method is not allowed.", + ) + @mock.patch("google.cloud.spanner_v1.Client") def test_executemany(self, mock_client): from google.cloud.spanner_dbapi import connect diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 887f984c2c..162535349f 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -15,6 +15,7 @@ import sys import unittest +from google.cloud.spanner_dbapi.parsed_statement import StatementType from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1 import JsonObject @@ -24,45 +25,43 @@ class TestParseUtils(unittest.TestCase): skip_message = "Subtests are not supported in Python 2" def test_classify_stmt(self): - from google.cloud.spanner_dbapi.parse_utils import STMT_DDL - from google.cloud.spanner_dbapi.parse_utils import STMT_INSERT - from google.cloud.spanner_dbapi.parse_utils import STMT_NON_UPDATING - from google.cloud.spanner_dbapi.parse_utils import STMT_UPDATING - from google.cloud.spanner_dbapi.parse_utils import classify_stmt + from google.cloud.spanner_dbapi.parse_utils import classify_statement cases = ( - ("SELECT 1", STMT_NON_UPDATING), - ("SELECT s.SongName FROM Songs AS s", STMT_NON_UPDATING), - ("(SELECT s.SongName FROM Songs AS s)", STMT_NON_UPDATING), + ("SELECT 1", StatementType.QUERY), + ("SELECT s.SongName FROM Songs AS s", StatementType.QUERY), + ("(SELECT s.SongName FROM Songs AS s)", StatementType.QUERY), ( "WITH sq AS (SELECT SchoolID FROM Roster) SELECT * from sq", - STMT_NON_UPDATING, + StatementType.QUERY, ), ( "CREATE TABLE django_content_type (id STRING(64) NOT NULL, name STRING(100) " "NOT NULL, app_label STRING(100) NOT NULL, model STRING(100) NOT NULL) PRIMARY KEY(id)", - STMT_DDL, + StatementType.DDL, ), ( "CREATE INDEX SongsBySingerAlbumSongNameDesc ON " "Songs(SingerId, AlbumId, SongName DESC), INTERLEAVE IN Albums", - STMT_DDL, + StatementType.DDL, ), - ("CREATE INDEX SongsBySongName ON Songs(SongName)", STMT_DDL), + ("CREATE INDEX SongsBySongName ON Songs(SongName)", StatementType.DDL), ( "CREATE INDEX AlbumsByAlbumTitle2 ON Albums(AlbumTitle) STORING (MarketingBudget)", - STMT_DDL, + StatementType.DDL, ), - ("CREATE ROLE parent", STMT_DDL), - ("GRANT SELECT ON TABLE Singers TO ROLE parent", STMT_DDL), - ("REVOKE SELECT ON TABLE Singers TO ROLE parent", STMT_DDL), - ("GRANT ROLE parent TO ROLE child", STMT_DDL), - ("INSERT INTO table (col1) VALUES (1)", STMT_INSERT), - ("UPDATE table SET col1 = 1 WHERE col1 = NULL", STMT_UPDATING), + ("CREATE ROLE parent", StatementType.DDL), + ("commit", StatementType.CLIENT_SIDE), + (" commit TRANSACTION ", StatementType.CLIENT_SIDE), + ("GRANT SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL), + ("REVOKE SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL), + ("GRANT ROLE parent TO ROLE child", StatementType.DDL), + ("INSERT INTO table (col1) VALUES (1)", StatementType.INSERT), + ("UPDATE table SET col1 = 1 WHERE col1 = NULL", StatementType.UPDATE), ) for query, want_class in cases: - self.assertEqual(classify_stmt(query), want_class) + self.assertEqual(classify_statement(query).statement_type, want_class) @unittest.skipIf(skip_condition, skip_message) def test_sql_pyformat_args_to_spanner(self):