Skip to content

Commit

Permalink
fix(cursor, execute): bind param parsing for multiline comments, colo…
Browse files Browse the repository at this point in the history
…n character issue
  • Loading branch information
Brooke-white committed Jun 5, 2023
1 parent f9af45b commit df7fc1d
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 59 deletions.
6 changes: 5 additions & 1 deletion redshift_connector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import typing

from redshift_connector import plugin
from redshift_connector.config import DEFAULT_PROTOCOL_VERSION, ClientProtocolVersion
from redshift_connector.config import (
DEFAULT_PROTOCOL_VERSION,
ClientProtocolVersion,
DbApiParamstyle,
)
from redshift_connector.core import BINARY, Connection, Cursor
from redshift_connector.error import (
ArrayContentNotHomogenousError,
Expand Down
15 changes: 14 additions & 1 deletion redshift_connector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from calendar import timegm
from datetime import datetime as Datetime
from datetime import timezone as Timezone
from enum import IntEnum
from enum import Enum, IntEnum

FC_TEXT: int = 0
FC_BINARY: int = 1
Expand All @@ -29,6 +29,19 @@ def get_name(cls, i: int) -> str:

DEFAULT_PROTOCOL_VERSION: int = ClientProtocolVersion.BINARY.value


class DbApiParamstyle(Enum):
QMARK = "qmark"
NUMERIC = "numeric"
NAMED = "named"
FORMAT = "format"
PYFORMAT = "pyformat"

@classmethod
def list(cls) -> typing.List[int]:
return list(map(lambda p: p.value, cls)) # type: ignore


min_int2: int = -(2 ** 15)
max_int2: int = 2 ** 15
min_int4: int = -(2 ** 31)
Expand Down
39 changes: 25 additions & 14 deletions redshift_connector/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from redshift_connector.config import (
DEFAULT_PROTOCOL_VERSION,
ClientProtocolVersion,
DbApiParamstyle,
_client_encoding,
max_int2,
max_int4,
Expand Down Expand Up @@ -145,6 +146,7 @@ def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]:
INSIDE_ES: int = 3 # inside escaped single-quote string, E'...'
INSIDE_PN: int = 4 # inside parameter name eg. :name
INSIDE_CO: int = 5 # inside inline comment eg. --
INSIDE_MC: int = 6 # inside multiline comment eg. /*

in_quote_escape: bool = False
in_param_escape: bool = False
Expand Down Expand Up @@ -173,23 +175,27 @@ def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]:
output_query.append(c)
if prev_c == "-":
state = INSIDE_CO
elif style == "qmark" and c == "?":
elif c == "*":
output_query.append(c)
if prev_c == "/":
state = INSIDE_MC
elif style == DbApiParamstyle.QMARK.value and c == "?":
output_query.append(next(param_idx))
elif style == "numeric" and c == ":" and next_c not in ":=" and prev_c != ":":
elif style == DbApiParamstyle.NUMERIC.value and c == ":" and next_c not in ":=" and prev_c != ":":
# Treat : as beginning of parameter name if and only
# if it's the only : around
# Needed to properly process type conversions
# i.e. sum(x)::float
output_query.append("$")
elif style == "named" and c == ":" and next_c not in ":=" and prev_c != ":":
elif style == DbApiParamstyle.NAMED.value and c == ":" and next_c not in ":=" and prev_c != ":":
# Same logic for : as in numeric parameters
state = INSIDE_PN
placeholders.append("")
elif style == "pyformat" and c == "%" and next_c == "(":
elif style == DbApiParamstyle.PYFORMAT.value and c == "%" and next_c == "(":
state = INSIDE_PN
placeholders.append("")
elif style in ("format", "pyformat") and c == "%":
style = "format"
elif style in (DbApiParamstyle.FORMAT.value, DbApiParamstyle.PYFORMAT.value) and c == "%":
style = DbApiParamstyle.FORMAT.value
if in_param_escape:
in_param_escape = False
output_query.append(c)
Expand Down Expand Up @@ -227,7 +233,7 @@ def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]:
output_query.append(c)

elif state == INSIDE_PN:
if style == "named":
if style == DbApiParamstyle.NAMED.value:
placeholders[-1] += c
if next_c is None or (not next_c.isalnum() and next_c != "_"):
state = OUTSIDE
Expand All @@ -237,7 +243,7 @@ def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]:
del placeholders[-1]
except ValueError:
output_query.append("$" + str(len(placeholders)))
elif style == "pyformat":
elif style == DbApiParamstyle.PYFORMAT.value:
if prev_c == ")" and c == "s":
state = OUTSIDE
try:
Expand All @@ -250,17 +256,22 @@ def convert_paramstyle(style: str, query) -> typing.Tuple[str, typing.Any]:
pass
else:
placeholders[-1] += c
elif style == "format":
elif style == DbApiParamstyle.FORMAT.value:
state = OUTSIDE

elif state == INSIDE_CO:
output_query.append(c)
if c == "\n":
state = OUTSIDE

elif state == INSIDE_MC:
output_query.append(c)
if c == "/" and prev_c == "*":
state = OUTSIDE

prev_c = c

if style in ("numeric", "qmark", "format"):
if style in (DbApiParamstyle.NUMERIC.value, DbApiParamstyle.QMARK.value, DbApiParamstyle.FORMAT.value):

def make_args(vals):
return vals
Expand Down Expand Up @@ -466,7 +477,7 @@ def __init__(
self.notices: deque = deque(maxlen=100)
self.parameter_statuses: deque = deque(maxlen=100)
self.max_prepared_statements: int = int(max_prepared_statements)
self._run_cursor: Cursor = Cursor(self, paramstyle="named")
self._run_cursor: Cursor = Cursor(self, paramstyle=DbApiParamstyle.NAMED.value)
self._client_protocol_version: int = client_protocol_version
self._database = database
self.py_types = deepcopy(PY_TYPES)
Expand Down Expand Up @@ -1600,7 +1611,7 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
args: typing.Tuple[typing.Optional[typing.Tuple[str, typing.Any]], ...] = ()
# transforms user provided bind parameters to server friendly bind parameters
params: typing.Tuple[typing.Optional[typing.Tuple[int, int, typing.Callable]], ...] = ()

has_bind_parameters: bool = False if vals is None else True
# multi dimensional dictionary to store the data
# cache = self._caches[cursor.paramstyle][pid]
# cache = {'statement': {}, 'ps': {}}
Expand All @@ -1622,12 +1633,12 @@ def execute(self: "Connection", cursor: Cursor, operation: str, vals) -> None:
try:
statement, make_args = cache["statement"][operation]
except KeyError:
if vals:
if has_bind_parameters:
statement, make_args = cache["statement"][operation] = convert_paramstyle(cursor.paramstyle, operation)
else:
# use a no-op make_args in lieu of parsing the sql statement
statement, make_args = cache["statement"][operation] = operation, lambda p: ()
if vals:
if has_bind_parameters:
args = make_args(vals)
# change the args to the format that the DB will identify
# take reference from self.py_types
Expand Down
15 changes: 8 additions & 7 deletions redshift_connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import redshift_connector
from redshift_connector.config import (
ClientProtocolVersion,
DbApiParamstyle,
_client_encoding,
table_type_clauses,
)
Expand Down Expand Up @@ -355,7 +356,7 @@ def __has_valid_columns(self: "Cursor", table: str, columns: typing.List[str]) -
else:
param_list = [[split_table_name[0], c] for c in columns]
temp = self.paramstyle
self.paramstyle = "qmark"
self.paramstyle = DbApiParamstyle.QMARK.value
try:
for params in param_list:
self.execute(q, params)
Expand All @@ -376,7 +377,7 @@ def callproc(self, procname, parameters=None):
from redshift_connector.core import convert_paramstyle

try:
statement, make_args = convert_paramstyle("format", operation)
statement, make_args = convert_paramstyle(DbApiParamstyle.FORMAT.value, operation)
vals = make_args(args)
self.execute(statement, vals)

Expand Down Expand Up @@ -534,7 +535,7 @@ def __is_valid_table(self: "Cursor", table: str) -> bool:
q: str = "select 1 from information_schema.tables where table_name = ?"

temp = self.paramstyle
self.paramstyle = "qmark"
self.paramstyle = DbApiParamstyle.QMARK.value

try:
if len(split_table_name) == 2:
Expand Down Expand Up @@ -643,7 +644,7 @@ def get_procedures(
if len(query_args) > 0:
# temporarily use qmark paramstyle
temp = self.paramstyle
self.paramstyle = "qmark"
self.paramstyle = DbApiParamstyle.QMARK.value

try:
self.execute(sql, tuple(query_args))
Expand Down Expand Up @@ -721,7 +722,7 @@ def get_schemas(
if len(query_args) == 1:
# temporarily use qmark paramstyle
temp = self.paramstyle
self.paramstyle = "qmark"
self.paramstyle = DbApiParamstyle.QMARK.value
try:
self.execute(sql, tuple(query_args))
except:
Expand Down Expand Up @@ -774,7 +775,7 @@ def get_primary_keys(
if len(query_args) > 0:
# temporarily use qmark paramstyle
temp = self.paramstyle
self.paramstyle = "qmark"
self.paramstyle = DbApiParamstyle.QMARK.value
try:
self.execute(sql, tuple(query_args))
except:
Expand Down Expand Up @@ -855,7 +856,7 @@ def get_tables(

if len(sql_args) > 0:
temp = self.paramstyle
self.paramstyle = "qmark"
self.paramstyle = DbApiParamstyle.QMARK.value
try:
self.execute(sql, sql_args)
except:
Expand Down
20 changes: 10 additions & 10 deletions test/integration/test_dbapi20.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,15 @@ def _paraminsert(cur):
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (table_prefix))
assert cur.rowcount in (-1, 1)

if driver.paramstyle == "qmark":
if driver.paramstyle == redshift_connector.config.DbApiParamstyle.QMARK.value:
cur.execute("insert into %sbooze values (?)" % table_prefix, ("Cooper's",))
elif driver.paramstyle == "numeric":
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.NUMERIC.value:
cur.execute("insert into %sbooze values (:1)" % table_prefix, ("Cooper's",))
elif driver.paramstyle == "named":
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.NAMED.value:
cur.execute("insert into %sbooze values (:beer)" % table_prefix, {"beer": "Cooper's"})
elif driver.paramstyle == "format":
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.FORMAT.value:
cur.execute("insert into %sbooze values (%%s)" % table_prefix, ("Cooper's",))
elif driver.paramstyle == "pyformat":
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.PYFORMAT.value:
cur.execute("insert into %sbooze values (%%(beer)s)" % table_prefix, {"beer": "Cooper's"})
else:
assert False, "Invalid paramstyle"
Expand All @@ -212,15 +212,15 @@ def test_executemany(cursor):
execute_ddl_1(cursor)
largs: typing.List[typing.Tuple[str]] = [("Cooper's",), ("Boag's",)]
margs: typing.List[typing.Dict[str, str]] = [{"beer": "Cooper's"}, {"beer": "Boag's"}]
if driver.paramstyle == "qmark":
if driver.paramstyle == redshift_connector.config.DbApiParamstyle.QMARK.value:
cursor.executemany("insert into %sbooze values (?)" % table_prefix, largs)
elif driver.paramstyle == "numeric":
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.NUMERIC.value:
cursor.executemany("insert into %sbooze values (:1)" % table_prefix, largs)
elif driver.paramstyle == "named":
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.NAMED.value:
cursor.executemany("insert into %sbooze values (:beer)" % table_prefix, margs)
elif driver.paramstyle == "format":
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.FORMAT.value:
cursor.executemany("insert into %sbooze values (%%s)" % table_prefix, largs)
elif driver.paramstyle == "pyformat":
elif driver.paramstyle == redshift_connector.config.DbApiParamstyle.PYFORMAT.value:
cursor.executemany("insert into %sbooze values (%%(beer)s)" % (table_prefix), margs)
else:
assert False, "Unknown paramstyle"
Expand Down
4 changes: 3 additions & 1 deletion test/unit/test_dbapi20.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ def test_threadsafety():


def test_paramstyle():
from redshift_connector.config import DbApiParamstyle

try:
# Must exist
paramstyle: str = driver.paramstyle
# Must be a valid value
assert paramstyle in ("qmark", "numeric", "named", "format", "pyformat")
assert paramstyle in DbApiParamstyle.list()
except AttributeError:
assert False, "Driver doesn't define paramstyle"

Expand Down
Loading

0 comments on commit df7fc1d

Please sign in to comment.