Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-534004: Added database and schema to the queries related to temporary stage #1274

Merged
merged 6 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 60 additions & 44 deletions src/snowflake/connector/pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ def chunk_helper(lst: T, n: int) -> Iterator[tuple[int, T]]:
yield int(i / n), lst[i : i + n]


def build_location_helper(
database: str | None, schema: str | None, name: str, quote_identifiers: bool
) -> str:
"""Helper to format table/stage/file format's location."""
if quote_identifiers:
location = (
(('"' + database + '".') if database else "")
+ (('"' + schema + '".') if schema else "")
+ ('"' + name + '"')
)
else:
location = (
(database + "." if database else "")
+ (schema + "." if schema else "")
+ name
)
return location


def write_pandas(
conn: SnowflakeConnection,
df: pandas.DataFrame,
Expand Down Expand Up @@ -131,9 +150,7 @@ def write_pandas(
compression_map = {"gzip": "auto", "snappy": "snappy"}
if compression not in compression_map.keys():
raise ProgrammingError(
"Invalid compression '{}', only acceptable values are: {}".format(
compression, compression_map.keys()
)
f"Invalid compression '{compression}', only acceptable values are: {compression_map.keys()}"
)

if create_temp_table:
Expand All @@ -151,31 +168,33 @@ def write_pandas(
"Unsupported table type. Expected table types: temp/temporary, transient"
)

if quote_identifiers:
location = (
(('"' + database + '".') if database else "")
+ (('"' + schema + '".') if schema else "")
+ ('"' + table_name + '"')
)
else:
location = (
(database + "." if database else "")
+ (schema + "." if schema else "")
+ (table_name)
)
table_location = build_location_helper(
database=database,
schema=schema,
name=table_name,
quote_identifiers=quote_identifiers,
)

if chunk_size is None:
chunk_size = len(df)
cursor = conn.cursor()
stage_name = None # Forward declaration
stage_location = None # Forward declaration
while True:
try:
stage_name = "".join(
random.choice(string.ascii_lowercase) for _ in range(5)
)
stage_location = build_location_helper(
database=database,
schema=schema,
name=stage_name,
quote_identifiers=quote_identifiers,
)

create_stage_sql = (
"create temporary stage /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
'"{stage_name}"'
).format(stage_name=stage_name)
f"create temporary stage /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
f"{stage_location}"
)
logger.debug(f"creating stage with '{create_stage_sql}'")
cursor.execute(create_stage_sql, _is_internal=True).fetchall()
break
Expand All @@ -192,10 +211,10 @@ def write_pandas(
# Upload parquet file
upload_sql = (
"PUT /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
"'file://{path}' @\"{stage_name}\" PARALLEL={parallel}"
"'file://{path}' @{stage_location} PARALLEL={parallel}"
).format(
path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"),
stage_name=stage_name,
stage_location=stage_location,
parallel=parallel,
)
logger.debug(f"uploading files with '{upload_sql}'")
Expand All @@ -209,25 +228,29 @@ def write_pandas(

if overwrite:
if auto_create_table:
drop_table_sql = f"DROP TABLE IF EXISTS {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
drop_table_sql = f"DROP TABLE IF EXISTS {table_location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
logger.debug(f"dropping table with '{drop_table_sql}'")
cursor.execute(drop_table_sql, _is_internal=True)
else:
truncate_table_sql = f"TRUNCATE TABLE IF EXISTS {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
truncate_table_sql = f"TRUNCATE TABLE IF EXISTS {table_location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
logger.debug(f"truncating table with '{truncate_table_sql}'")
cursor.execute(truncate_table_sql, _is_internal=True)

if auto_create_table:
file_format_name = None
file_format_location = None
while True:
try:
file_format_name = (
'"'
+ "".join(random.choice(string.ascii_lowercase) for _ in range(5))
+ '"'
file_format_name = "".join(
random.choice(string.ascii_lowercase) for _ in range(5)
)
file_format_location = build_location_helper(
database=database,
schema=schema,
name=file_format_name,
quote_identifiers=quote_identifiers,
)
file_format_sql = (
f"CREATE FILE FORMAT {file_format_name} "
f"CREATE FILE FORMAT {file_format_location} "
f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
f"TYPE=PARQUET COMPRESSION={compression_map[compression]}"
)
Expand All @@ -238,7 +261,7 @@ def write_pandas(
if pe.msg.endswith("already exists."):
continue
raise
infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@\"{stage_name}\"', file_format=>'{file_format_name}'))"
infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@{stage_location}', file_format=>'{file_format_location}'))"
logger.debug(f"inferring schema with '{infer_schema_sql}'")
column_type_mapping = dict(
cursor.execute(infer_schema_sql, _is_internal=True).fetchall()
Expand All @@ -251,13 +274,13 @@ def write_pandas(
[f"{quote}{c}{quote} {column_type_mapping[c]}" for c in df.columns]
)
create_table_sql = (
f"CREATE {table_type.upper()} TABLE IF NOT EXISTS {location} "
f"CREATE {table_type.upper()} TABLE IF NOT EXISTS {table_location} "
f"({create_table_columns})"
f" /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
)
logger.debug(f"auto creating table with '{create_table_sql}'")
cursor.execute(create_table_sql, _is_internal=True)
drop_file_format_sql = f"DROP FILE FORMAT IF EXISTS {file_format_name}"
drop_file_format_sql = f"DROP FILE FORMAT IF EXISTS {file_format_location}"
logger.debug(f"dropping file format with '{drop_file_format_sql}'")
cursor.execute(drop_file_format_sql, _is_internal=True)

Expand All @@ -269,18 +292,11 @@ def write_pandas(
parquet_columns = "$1:" + ",$1:".join(df.columns)

copy_into_sql = (
"COPY INTO {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
"({columns}) "
'FROM (SELECT {parquet_columns} FROM @"{stage_name}") '
"FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}) "
"PURGE=TRUE ON_ERROR={on_error}"
).format(
location=location,
columns=columns,
parquet_columns=parquet_columns,
stage_name=stage_name,
compression=compression_map[compression],
on_error=on_error,
f"COPY INTO {table_location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
f"({columns}) "
f"FROM (SELECT {parquet_columns} FROM @{stage_location}) "
f"FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression_map[compression]}) "
f"PURGE=TRUE ON_ERROR={on_error}"
)
logger.debug(f"copying into with '{copy_into_sql}'")
copy_results = cursor.execute(copy_into_sql, _is_internal=True).fetchall()
Expand Down
119 changes: 87 additions & 32 deletions test/integ/pandas/test_pandas_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,20 +314,33 @@ def test_empty_dataframe_write_pandas(
), f"sucess: {success}, num_chunks: {num_chunks}, num_rows: {num_rows}"


@pytest.mark.parametrize("quote_identifiers", [True, False])
def test_location_building_db_schema(conn_cnx, quote_identifiers: bool):
"""This tests that write_pandas constructs location correctly with database, schema and table name."""
@pytest.mark.parametrize(
"database,schema,quote_identifiers,expected_location",
[
("database", "schema", True, '"database"."schema"."table"'),
("database", "schema", False, "database.schema.table"),
(None, "schema", True, '"schema"."table"'),
(None, "schema", False, "schema.table"),
(None, None, True, '"table"'),
(None, None, False, "table"),
],
)
def test_table_location_building(
conn_cnx,
database: str | None,
schema: str | None,
quote_identifiers: bool,
expected_location: str,
):
"""This tests that write_pandas constructs table location correctly with database, schema, and table name."""
from snowflake.connector.cursor import SnowflakeCursor

with conn_cnx() as cnx:

def mocked_execute(*args, **kwargs):
if len(args) >= 1 and args[0].startswith("COPY INTO"):
location = args[0].split(" ")[2]
if quote_identifiers:
assert location == '"database"."schema"."table"'
else:
assert location == "database.schema.table"
assert location == expected_location
cur = SnowflakeCursor(cnx)
cur._result = iter([])
return cur
Expand All @@ -340,29 +353,42 @@ def mocked_execute(*args, **kwargs):
cnx,
sf_connector_version_df.get(),
"table",
database="database",
schema="schema",
database=database,
schema=schema,
quote_identifiers=quote_identifiers,
)
assert m_execute.called and any(
map(lambda e: "COPY INTO" in str(e[0]), m_execute.call_args_list)
)


@pytest.mark.parametrize("quote_identifiers", [True, False])
def test_location_building_schema(conn_cnx, quote_identifiers: bool):
"""This tests that write_pandas constructs location correctly with schema and table name."""
@pytest.mark.parametrize(
"database,schema,quote_identifiers,expected_db_schema",
[
("database", "schema", True, '"database"."schema"'),
("database", "schema", False, "database.schema"),
(None, "schema", True, '"schema"'),
(None, "schema", False, "schema"),
(None, None, True, ""),
(None, None, False, ""),
],
)
def test_stage_location_building(
conn_cnx,
database: str | None,
schema: str | None,
quote_identifiers: bool,
expected_db_schema: str,
):
"""This tests that write_pandas constructs stage location correctly with database and schema."""
from snowflake.connector.cursor import SnowflakeCursor

with conn_cnx() as cnx:

def mocked_execute(*args, **kwargs):
if len(args) >= 1 and args[0].startswith("COPY INTO"):
location = args[0].split(" ")[2]
if quote_identifiers:
assert location == '"schema"."table"'
else:
assert location == "schema.table"
if len(args) >= 1 and args[0].startswith("create temporary stage"):
db_schema = ".".join(args[0].split(" ")[-1].split(".")[:-1])
assert db_schema == expected_db_schema
cur = SnowflakeCursor(cnx)
cur._result = iter([])
return cur
Expand All @@ -375,30 +401,53 @@ def mocked_execute(*args, **kwargs):
cnx,
sf_connector_version_df.get(),
"table",
schema="schema",
database=database,
schema=schema,
quote_identifiers=quote_identifiers,
)
assert m_execute.called and any(
map(lambda e: "COPY INTO" in str(e[0]), m_execute.call_args_list)
map(
lambda e: "create temporary stage" in str(e[0]),
m_execute.call_args_list,
)
)


@pytest.mark.parametrize("quote_identifiers", [True, False])
def test_location_building(conn_cnx, quote_identifiers: bool):
"""This tests that write_pandas constructs location correctly with schema and table name."""
@pytest.mark.parametrize(
"database,schema,quote_identifiers,expected_db_schema",
[
("database", "schema", True, '"database"."schema"'),
("database", "schema", False, "database.schema"),
(None, "schema", True, '"schema"'),
(None, "schema", False, "schema"),
(None, None, True, ""),
(None, None, False, ""),
],
)
def test_file_format_location_building(
conn_cnx,
database: str | None,
schema: str | None,
quote_identifiers: bool,
expected_db_schema: str,
):
"""This tests that write_pandas constructs file format location correctly with database and schema."""
from snowflake.connector.cursor import SnowflakeCursor

with conn_cnx() as cnx:

def mocked_execute(*args, **kwargs):
if len(args) >= 1 and args[0].startswith("COPY INTO"):
location = args[0].split(" ")[2]
if quote_identifiers:
assert location == '"teble.table"'
else:
assert location == "teble.table"
if len(args) >= 1 and args[0].startswith("CREATE FILE FORMAT"):
db_schema = ".".join(args[0].split(" ")[3].split(".")[:-1])
assert db_schema == expected_db_schema
cur = SnowflakeCursor(cnx)
cur._result = iter([])
if args[0].startswith("SELECT"):
cur._rownumber = 0
cur._result = iter(
[(col, "") for col in sf_connector_version_df.get().columns]
)
else:
cur._result = iter([])
return cur

with mock.patch(
Expand All @@ -408,11 +457,17 @@ def mocked_execute(*args, **kwargs):
success, nchunks, nrows, _ = write_pandas(
cnx,
sf_connector_version_df.get(),
"teble.table",
"table",
database=database,
schema=schema,
quote_identifiers=quote_identifiers,
auto_create_table=True,
)
assert m_execute.called and any(
map(lambda e: "COPY INTO" in str(e[0]), m_execute.call_args_list)
map(
lambda e: "CREATE FILE FORMAT" in str(e[0]),
m_execute.call_args_list,
)
)


Expand Down