Skip to content

Commit

Permalink
2143 adds snowflake hints
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Dec 17, 2024
1 parent fca69c7 commit 19f6cf2
Show file tree
Hide file tree
Showing 16 changed files with 224 additions and 123 deletions.
5 changes: 1 addition & 4 deletions dlt/destinations/impl/bigquery/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,7 @@ def _get_info_schema_columns_query(
return query, folded_table_names

def _get_column_def_sql(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(column["name"])
column_def_sql = (
f"{name} {self.type_mapper.to_destination_type(column, table)} {self._gen_not_null(column.get('nullable', True))}"
)
column_def_sql = super()._get_column_def_sql(column, table)
if column.get(ROUND_HALF_EVEN_HINT, False):
column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')"
if column.get(ROUND_HALF_AWAY_FROM_ZERO_HINT, False):
Expand Down
5 changes: 2 additions & 3 deletions dlt/destinations/impl/clickhouse/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,10 @@ def _get_table_update_sql(

return sql

@staticmethod
def _gen_not_null(v: bool) -> str:
def _gen_not_null(self, v: bool) -> str:
# ClickHouse fields are not nullable by default.
# We use the `Nullable` modifier instead of NULL / NOT NULL modifiers to cater for ALTER statement.
pass
return ""

def _from_db_type(
self, ch_t: str, precision: Optional[int], scale: Optional[int]
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,6 @@ def _from_db_type(
) -> TColumnType:
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
)

def _get_storage_table_query_columns(self) -> List[str]:
fields = super()._get_storage_table_query_columns()
fields[2] = ( # Override because this is the only way to get data type with precision
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/dremio/dremio.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,6 @@ def _from_db_type(
) -> TColumnType:
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
)

def _create_merge_followup_jobs(
self, table_chain: Sequence[PreparedTableSchema]
) -> List[FollowupJobRequest]:
Expand Down
11 changes: 0 additions & 11 deletions dlt/destinations/impl/duckdb/duck.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,6 @@ def create_load_job(
job = DuckDbCopyJob(file_path)
return job

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
hints_str = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
return (
f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
)

def _from_db_type(
self, pq_t: str, precision: Optional[int], scale: Optional[int]
) -> TColumnType:
Expand Down
6 changes: 1 addition & 5 deletions dlt/destinations/impl/mssql/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = Non
else:
db_type = self.type_mapper.to_destination_type(c, table)

hints_str = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
hints_str = self._get_column_hints_sql(c)
column_name = self.sql_client.escape_column_name(c["name"])
return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c.get('nullable', True))}"

Expand Down
12 changes: 0 additions & 12 deletions dlt/destinations/impl/postgres/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,6 @@ def create_load_job(
job = PostgresCsvCopyJob(file_path)
return job

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
hints_ = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
nullability = self._gen_not_null(c.get("nullable", True))
column_type = self.type_mapper.to_destination_type(c, table)

return f"{column_name} {column_type} {hints_} {nullability}"

def _create_replace_followup_jobs(
self, table_chain: Sequence[PreparedTableSchema]
) -> List[FollowupJobRequest]:
Expand Down
12 changes: 1 addition & 11 deletions dlt/destinations/impl/redshift/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
capabilities,
)
super().__init__(schema, config, sql_client)
self.active_hints = HINT_TO_REDSHIFT_ATTR
self.sql_client = sql_client
self.config: RedshiftClientConfiguration = config
self.type_mapper = self.capabilities.get_type_mapper()
Expand All @@ -162,17 +163,6 @@ def _create_merge_followup_jobs(
) -> List[FollowupJobRequest]:
return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)]

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
hints_str = " ".join(
HINT_TO_REDSHIFT_ATTR.get(h, "")
for h in HINT_TO_REDSHIFT_ATTR.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
return (
f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
)

def create_load_job(
self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False
) -> LoadJob:
Expand Down
18 changes: 18 additions & 0 deletions dlt/destinations/impl/snowflake/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,24 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration)
query_tag: Optional[str] = None
"""A tag with placeholders to tag sessions executing jobs"""

create_indexes: bool = False
"""Whether UNIQUE or PRIMARY KEY constrains should be created"""

def __init__(
self,
*,
credentials: SnowflakeCredentials = None,
create_indexes: bool = False,
destination_name: str = None,
environment: str = None,
) -> None:
super().__init__(
credentials=credentials,
destination_name=destination_name,
environment=environment,
)
self.create_indexes = create_indexes

def fingerprint(self) -> str:
"""Returns a fingerprint of host part of a connection string"""
if self.credentials and self.credentials.host:
Expand Down
45 changes: 36 additions & 9 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Sequence, List
from typing import Optional, Sequence, List, Dict, Set
from urllib.parse import urlparse, urlunparse

from dlt.common import logger
from dlt.common.data_writers.configuration import CsvFormatConfiguration
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import (
Expand All @@ -15,20 +16,24 @@
AwsCredentialsWithoutDefaults,
AzureCredentialsWithoutDefaults,
)
from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema
from dlt.common.schema.typing import TColumnType
from dlt.common.schema import TColumnSchema, Schema, TColumnHint
from dlt.common.schema.typing import TColumnType, TTableSchema

from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS
from dlt.common.typing import TLoaderFileFormat
from dlt.common.utils import uniq_id
from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset
from dlt.destinations.exceptions import LoadJobTerminalException

from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration
from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient
from dlt.destinations.job_impl import ReferenceFollowupJobRequest

SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE"}


class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs):
def __init__(
Expand Down Expand Up @@ -238,6 +243,7 @@ def __init__(
self.config: SnowflakeClientConfiguration = config
self.sql_client: SnowflakeSqlClient = sql_client # type: ignore
self.type_mapper = self.capabilities.get_type_mapper()
self.active_hints = SUPPORTED_HINTS if self.config.create_indexes else {}

def create_load_job(
self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False
Expand All @@ -264,6 +270,33 @@ def _make_add_column_sql(
"ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns)
]

def _get_constraints_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
) -> str:
# "primary_key": "PRIMARY KEY"
if self.config.create_indexes:
partial: TTableSchema = {
"name": table_name,
"columns": {c["name"]: c for c in new_columns},
}
# Add PK constraint if pk_columns exist
pk_columns = get_columns_names_with_prop(partial, "primary_key")
if pk_columns:
if generate_alter:
logger.warning(
f"PRIMARY KEY on {table_name} constraint cannot be added in ALTER TABLE and"
" is ignored"
)
else:
pk_constraint_name = list(
self._norm_and_escape_columns(f"PK_{table_name}_{uniq_id(4)}")
)[0]
quoted_pk_cols = ", ".join(
self.sql_client.escape_column_name(col) for col in pk_columns
)
return f",\nCONSTRAINT {pk_constraint_name} PRIMARY KEY ({quoted_pk_cols})"
return ""

def _get_table_update_sql(
self,
table_name: str,
Expand All @@ -287,11 +320,5 @@ def _from_db_type(
) -> TColumnType:
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
)

def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool:
return self.config.truncate_tables_on_staging_destination_before_load
39 changes: 28 additions & 11 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,26 @@
from typing import (
Any,
ClassVar,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Iterable,
Iterator,
Generator,
)
import zlib
import re
from contextlib import contextmanager
from contextlib import suppress

from dlt.common import pendulum, logger
from dlt.common.destination.capabilities import DataTypeMapper
from dlt.common.json import json
from dlt.common.schema.typing import (
C_DLT_LOAD_ID,
COLUMN_HINTS,
TColumnType,
TColumnSchemaBase,
TTableFormat,
)
from dlt.common.schema.utils import (
get_inherited_table_hint,
Expand All @@ -40,11 +38,11 @@
from dlt.common.storages import FileStorage
from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName
from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables
from dlt.common.schema import TColumnHint
from dlt.common.destination.reference import (
PreparedTableSchema,
StateInfo,
StorageSchemaInfo,
SupportsReadableDataset,
WithStateSync,
DestinationClientConfiguration,
DestinationClientDwhConfiguration,
Expand All @@ -55,9 +53,7 @@
JobClientBase,
HasFollowupJobs,
CredentialsConfiguration,
SupportsReadableRelation,
)
from dlt.destinations.dataset import ReadableDBAPIDataset

from dlt.destinations.exceptions import DatabaseUndefinedRelation
from dlt.destinations.job_impl import (
Expand Down Expand Up @@ -154,6 +150,8 @@ def __init__(
self.state_table_columns = ", ".join(
sql_client.escape_column_name(col) for col in state_table_["columns"]
)
self.active_hints: Dict[TColumnHint, str] = {}
self.type_mapper: DataTypeMapper = None
super().__init__(schema, config, sql_client.capabilities)
self.sql_client = sql_client
assert isinstance(config, DestinationClientDwhConfiguration)
Expand Down Expand Up @@ -569,6 +567,7 @@ def _get_table_update_sql(
# build CREATE
sql = self._make_create_table(qualified_name, table) + " (\n"
sql += ",\n".join([self._get_column_def_sql(c, table) for c in new_columns])
sql += self._get_constraints_sql(table_name, new_columns, generate_alter)
sql += ")"
sql_result.append(sql)
else:
Expand All @@ -582,8 +581,16 @@ def _get_table_update_sql(
sql_result.extend(
[sql_base + col_statement for col_statement in add_column_statements]
)
constraints_sql = self._get_constraints_sql(table_name, new_columns, generate_alter)
if constraints_sql:
sql_result.append(constraints_sql)
return sql_result

def _get_constraints_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
) -> str:
return ""

def _check_table_update_hints(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
) -> None:
Expand Down Expand Up @@ -613,12 +620,22 @@ def _check_table_update_hints(
" existing tables."
)

@abstractmethod
def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
pass
hints_ = self._get_column_hints_sql(c)
column_name = self.sql_client.escape_column_name(c["name"])
nullability = self._gen_not_null(c.get("nullable", True))
column_type = self.type_mapper.to_destination_type(c, table)

return f"{column_name} {column_type} {hints_} {nullability}"

def _get_column_hints_sql(self, c: TColumnSchema) -> str:
return " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True # use ColumnPropInfos to get default value
)

@staticmethod
def _gen_not_null(nullable: bool) -> str:
def _gen_not_null(self, nullable: bool) -> str:
return "NOT NULL" if not nullable else ""

def _create_table_update(
Expand Down
9 changes: 9 additions & 0 deletions docs/website/docs/dlt-ecosystem/destinations/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ Note that we ignore missing columns `ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE` and
## Supported column hints
Snowflake supports the following [column hints](../../general-usage/schema#tables-and-columns):
* `cluster` - Creates a cluster column(s). Many columns per table are supported and only when a new table is created.
* `unique` - Creates UNIQUE hint on a Snowflake column, can be added to many columns. ([optional](#additional-destination-options))
* `primary_key` - Creates PRIMARY KEY on selected column(s), may be compound. ([optional](#additional-destination-options))

`unique` and `primary_key` are not enforced and `dlt` does not instruct Snowflake to `RELY` on them when
query planning.


## Table and column identifiers
Snowflake supports both case-sensitive and case-insensitive identifiers. All unquoted and uppercase identifiers resolve case-insensitively in SQL statements. Case-insensitive [naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations) like the default **snake_case** will generate case-insensitive identifiers. Case-sensitive (like **sql_cs_v1**) will generate
Expand Down Expand Up @@ -308,13 +314,16 @@ pipeline = dlt.pipeline(
## Additional destination options

You can define your own stage to PUT files and disable the removal of the staged files after loading.
You can also opt-in to [create indexes](#supported-column-hints).

```toml
[destination.snowflake]
# Use an existing named stage instead of the default. Default uses the implicit table stage per table
stage_name="DLT_STAGE"
# Whether to keep or delete the staged files after COPY INTO succeeds
keep_staged_files=true
# Add UNIQUE and PRIMARY KEY hints to tables
create_indexes=true
```

### Setting up CSV format
Expand Down
Loading

0 comments on commit 19f6cf2

Please sign in to comment.