diff --git a/dlt/common/configuration/providers/vault.py b/dlt/common/configuration/providers/vault.py index 0ed8842d55..b8181a3e41 100644 --- a/dlt/common/configuration/providers/vault.py +++ b/dlt/common/configuration/providers/vault.py @@ -53,6 +53,9 @@ def get_value( value, _ = super().get_value(key, hint, pipeline_name, *sections) if value is None: # only secrets hints are handled + # TODO: we need to refine how we filer out non-secrets + # at the least we should load known fragments for fields + # that are part of a secret (ie. coming from Credentials) if self.only_secrets and not is_secret_hint(hint) and hint is not AnyType: return None, full_key diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index 97bcfd315e..dd9502baa6 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -196,13 +196,25 @@ def _resolve_config_fields( if key in config.__hint_resolvers__: # Type hint for this field is created dynamically hint = config.__hint_resolvers__[key](config) + # check if hint optional + is_optional = is_optional_type(hint) # get default and explicit values default_value = getattr(config, key, None) explicit_none = False + explicit_value = None + current_value = None traces: List[LookupTrace] = [] + def _set_field() -> None: + # collect unresolved fields + # NOTE: we hide B023 here because the function is called only within a loop + if not is_optional and current_value is None: # noqa + unresolved_fields[key] = traces # noqa + # set resolved value in config + if default_value != current_value: # noqa + setattr(config, key, current_value) # noqa + if explicit_values: - explicit_value = None if key in explicit_values: # allow None to be passed in explicit values # so we are able to reset defaults like in regular function calls @@ -211,14 +223,15 @@ def _resolve_config_fields( # detect dlt.config and dlt.secrets and force injection if isinstance(explicit_value, ConfigValueSentinel): explicit_value = None - else: - if is_hint_not_resolvable(hint): - # for final fields default value is like explicit - explicit_value = default_value - else: - explicit_value = None - current_value = None + if is_hint_not_resolvable(hint): + # do not resolve not resolvable, but allow for explicit values to be passed + if not explicit_none: + current_value = default_value if explicit_value is None else explicit_value + traces = [LookupTrace("ExplicitValues", None, key, current_value)] + _set_field() + continue + # explicit none skips resolution if not explicit_none: # if hint is union of configurations, any of them must be resolved @@ -276,16 +289,7 @@ def _resolve_config_fields( # set the trace for explicit none traces = [LookupTrace("ExplicitValues", None, key, None)] - # check if hint optional - is_optional = is_optional_type(hint) - # collect unresolved fields - if not is_optional and current_value is None: - unresolved_fields[key] = traces - # set resolved value in config - if default_value != current_value: - if not is_hint_not_resolvable(hint) or explicit_value is not None or explicit_none: - # ignore final types - setattr(config, key, current_value) + _set_field() # Check for dynamic hint resolvers which have no corresponding fields unmatched_hint_resolvers: List[str] = [] diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index c0cf9d1962..cdebb16688 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -1,11 +1,12 @@ import dataclasses from typing import ClassVar, Final, Optional, Any, Dict, List +from dlt.common import logger from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec from dlt.common.destination.client import DestinationClientDwhWithStagingConfiguration from dlt.common.configuration.exceptions import ConfigurationValueError - +from dlt.common.utils import digest128 DATABRICKS_APPLICATION_ID = "dltHub_dlt" @@ -13,8 +14,8 @@ @configspec class DatabricksCredentials(CredentialsConfiguration): catalog: str = None - server_hostname: str = None - http_path: str = None + server_hostname: Optional[str] = None + http_path: Optional[str] = None access_token: Optional[TSecretStrValue] = None client_id: Optional[TSecretStrValue] = None client_secret: Optional[TSecretStrValue] = None @@ -37,10 +38,57 @@ class DatabricksCredentials(CredentialsConfiguration): def on_resolved(self) -> None: if not ((self.client_id and self.client_secret) or self.access_token): - raise ConfigurationValueError( - "No valid authentication method detected. Provide either 'client_id' and" - " 'client_secret' for OAuth, or 'access_token' for token-based authentication." - ) + try: + # attempt context authentication + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + self.access_token = w.dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None) # type: ignore[union-attr] + except Exception: + self.access_token = None + + try: + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + self.access_token = w.config.authenticate # type: ignore[assignment] + logger.info(f"Will attempt to use default auth of type {w.config.auth_type}") + except Exception: + pass + + if not self.access_token: + raise ConfigurationValueError( + "Authentication failed: No valid authentication method detected. " + "Provide either 'client_id' and 'client_secret' for OAuth authentication, " + "or 'access_token' for token-based authentication." + ) + + if not self.server_hostname or not self.http_path: + try: + # attempt to fetch warehouse details + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + # warehouse ID may be present in an env variable + if w.config.warehouse_id: + warehouse = w.warehouses.get(w.config.warehouse_id) + else: + # for some reason list of warehouses has different type than a single one 🤯 + warehouse = list(w.warehouses.list())[0] # type: ignore[assignment] + logger.info( + f"Will attempt to use warehouse {warehouse.id} to get sql connection params" + ) + self.server_hostname = self.server_hostname or warehouse.odbc_params.hostname + self.http_path = self.http_path or warehouse.odbc_params.path + except Exception: + pass + + for param in ("catalog", "server_hostname", "http_path"): + if not getattr(self, param): + raise ConfigurationValueError( + f"Configuration error: Missing required parameter '{param}'. " + "Please provide it in the configuration." + ) def to_connector_params(self) -> Dict[str, Any]: conn_params = dict( @@ -60,6 +108,9 @@ def to_connector_params(self) -> Dict[str, Any]: return conn_params + def __str__(self) -> str: + return f"databricks://{self.server_hostname}{self.http_path}/{self.catalog}" + @configspec class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration): @@ -69,10 +120,20 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration "If set, credentials with given name will be used in copy command" is_staging_external_location: bool = False """If true, the temporary credentials are not propagated to the COPY command""" + staging_volume_name: Optional[str] = None + """Name of the Databricks managed volume for temporary storage, e.g., ... Defaults to '_dlt_temp_load_volume' if not set.""" + keep_staged_files: Optional[bool] = True + """Tells if to keep the files in internal (volume) stage""" def __str__(self) -> str: """Return displayable destination location""" - if self.staging_config: - return str(self.staging_config.credentials) + if self.credentials: + return str(self.credentials) else: - return "[no staging set]" + return "" + + def fingerprint(self) -> str: + """Returns a fingerprint of host part of a connection string""" + if self.credentials and self.credentials.server_hostname: + return digest128(self.credentials.server_hostname) + return "" diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index b58f3c6af2..0b1a8b295b 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -1,3 +1,4 @@ +import os from typing import Optional, Sequence, List, cast from urllib.parse import urlparse, urlunparse @@ -28,6 +29,7 @@ from dlt.common.schema import TColumnSchema, Schema from dlt.common.schema.typing import TColumnType from dlt.common.storages import FilesystemConfiguration, fsspec_from_config +from dlt.common.utils import uniq_id from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration @@ -36,7 +38,6 @@ from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.utils import is_compression_disabled - SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS + GCS_PROTOCOLS @@ -54,122 +55,210 @@ def run(self) -> None: self._sql_client = self._job_client.sql_client qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) - staging_credentials = self._staging_config.credentials - # extract and prepare some vars + + # decide if this is a local file or a staged file + is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path) + if is_local_file: + # conn parameter staging_allowed_local_path must be set to use 'PUT/REMOVE volume_path' SQL statement + self._sql_client.native_connection.thrift_backend.staging_allowed_local_path = ( + os.path.dirname(self._file_path) + ) + # local file by uploading to a temporary volume on Databricks + from_clause, file_name, volume_path, volume_file_path = self._handle_local_file_upload( + self._file_path + ) + credentials_clause = "" + orig_bucket_path = None # not used for local file + else: + # staged file + from_clause, credentials_clause, file_name, orig_bucket_path = ( + self._handle_staged_file() + ) + + # decide on source format, file_name will either be a local file or a bucket path + source_format, format_options_clause, skip_load = self._determine_source_format( + file_name, orig_bucket_path + ) + + if skip_load: + # If the file is empty or otherwise un-loadable, exit early + return + + statement = self._build_copy_into_statement( + qualified_table_name, + from_clause, + credentials_clause, + source_format, + format_options_clause, + ) + + self._sql_client.execute_sql(statement) + + if is_local_file and not self._job_client.config.keep_staged_files: + self._handle_staged_file_remove(volume_path, volume_file_path) + + def _handle_staged_file_remove(self, volume_path: str, volume_file_path: str) -> None: + self._sql_client.execute_sql(f"REMOVE '{volume_file_path}'") + self._sql_client.execute_sql(f"REMOVE '{volume_path}'") + + def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str, str, str]: + file_name = FileStorage.get_file_name_from_file_path(local_file_path) + volume_file_name = file_name + if file_name.startswith(("_", ".")): + volume_file_name = ( + "valid" + file_name + ) # databricks loading fails when file_name starts with - or . + + volume_catalog = self._sql_client.database_name + volume_database = self._sql_client.dataset_name + volume_name = "_dlt_staging_load_volume" + + fully_qualified_volume_name = f"{volume_catalog}.{volume_database}.{volume_name}" + if self._job_client.config.staging_volume_name: + fully_qualified_volume_name = self._job_client.config.staging_volume_name + volume_catalog, volume_database, volume_name = fully_qualified_volume_name.split(".") + else: + # create staging volume named _dlt_staging_load_volume + self._sql_client.execute_sql(f""" + CREATE VOLUME IF NOT EXISTS {fully_qualified_volume_name} + """) + + volume_path = f"/Volumes/{volume_catalog}/{volume_database}/{volume_name}/{uniq_id()}" + volume_file_path = f"{volume_path}/{volume_file_name}" + + self._sql_client.execute_sql(f"PUT '{local_file_path}' INTO '{volume_file_path}' OVERWRITE") + + from_clause = f"FROM '{volume_path}'" + + return from_clause, file_name, volume_path, volume_file_path + + def _handle_staged_file(self) -> tuple[str, str, str, str]: bucket_path = orig_bucket_path = ( ReferenceFollowupJobRequest.resolve_reference(self._file_path) if ReferenceFollowupJobRequest.is_reference_job(self._file_path) else "" ) - file_name = ( - FileStorage.get_file_name_from_file_path(bucket_path) - if bucket_path - else self._file_name - ) - from_clause = "" - credentials_clause = "" - format_options_clause = "" - if bucket_path: - bucket_url = urlparse(bucket_path) - bucket_scheme = bucket_url.scheme + if not bucket_path: + raise LoadJobTerminalException( + self._file_path, + "Cannot load from local file. Databricks does not support loading from local files." + " Configure staging with an s3, azure or google storage bucket.", + ) - if bucket_scheme not in SUPPORTED_BLOB_STORAGE_PROTOCOLS: - raise LoadJobTerminalException( - self._file_path, - f"Databricks cannot load data from staging bucket {bucket_path}. Only s3, azure" - " and gcs buckets are supported. Please note that gcs buckets are supported" - " only via named credential", - ) + file_name = FileStorage.get_file_name_from_file_path(bucket_path) - if self._job_client.config.is_staging_external_location: - # just skip the credentials clause for external location - # https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location - pass - elif self._job_client.config.staging_credentials_name: - # add named credentials - credentials_clause = ( - f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )" - ) - else: - # referencing an staged files via a bucket URL requires explicit AWS credentials - if bucket_scheme == "s3": - assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults) - s3_creds = staging_credentials.to_session_credentials() - credentials_clause = f"""WITH(CREDENTIAL( + staging_credentials = self._staging_config.credentials + bucket_url = urlparse(bucket_path) + bucket_scheme = bucket_url.scheme + + if bucket_scheme not in SUPPORTED_BLOB_STORAGE_PROTOCOLS: + raise LoadJobTerminalException( + self._file_path, + f"Databricks cannot load data from staging bucket {bucket_path}. " + "Only s3, azure and gcs buckets are supported. " + "Please note that gcs buckets are supported only via named credential.", + ) + + credentials_clause = "" + + if self._job_client.config.is_staging_external_location: + # just skip the credentials clause for external location + # https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location + pass + elif self._job_client.config.staging_credentials_name: + # add named credentials + credentials_clause = ( + f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )" + ) + else: + # referencing an staged files via a bucket URL requires explicit AWS credentials + if bucket_scheme == "s3": + assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults) + s3_creds = staging_credentials.to_session_credentials() + credentials_clause = f"""WITH(CREDENTIAL( AWS_ACCESS_KEY='{s3_creds["aws_access_key_id"]}', AWS_SECRET_KEY='{s3_creds["aws_secret_access_key"]}', - AWS_SESSION_TOKEN='{s3_creds["aws_session_token"]}' - )) - """ - elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: - assert isinstance( - staging_credentials, AzureCredentialsWithoutDefaults - ), "AzureCredentialsWithoutDefaults required to pass explicit credential" - # Explicit azure credentials are needed to load from bucket without a named stage - credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" - bucket_path = self.ensure_databricks_abfss_url( - bucket_path, - staging_credentials.azure_storage_account_name, - staging_credentials.azure_account_host, - ) - else: - raise LoadJobTerminalException( - self._file_path, - "You need to use Databricks named credential to use google storage." - " Passing explicit Google credentials is not supported by Databricks.", - ) - - if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: + ))""" + elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: assert isinstance( - staging_credentials, - ( - AzureCredentialsWithoutDefaults, - AzureServicePrincipalCredentialsWithoutDefaults, - ), - ) + staging_credentials, AzureCredentialsWithoutDefaults + ), "AzureCredentialsWithoutDefaults required to pass explicit credential" + # Explicit azure credentials are needed to load from bucket without a named stage + credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" bucket_path = self.ensure_databricks_abfss_url( bucket_path, staging_credentials.azure_storage_account_name, staging_credentials.azure_account_host, ) + else: + raise LoadJobTerminalException( + self._file_path, + "You need to use Databricks named credential to use google storage." + " Passing explicit Google credentials is not supported by Databricks.", + ) - # always add FROM clause - from_clause = f"FROM '{bucket_path}'" - else: - raise LoadJobTerminalException( - self._file_path, - "Cannot load from local file. Databricks does not support loading from local files." - " Configure staging with an s3, azure or google storage bucket.", + if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: + assert isinstance( + staging_credentials, + (AzureCredentialsWithoutDefaults, AzureServicePrincipalCredentialsWithoutDefaults), + ) + bucket_path = self.ensure_databricks_abfss_url( + bucket_path, + staging_credentials.azure_storage_account_name, + staging_credentials.azure_account_host, ) - # decide on source format, stage_file_path will either be a local file or a bucket path + # always add FROM clause + from_clause = f"FROM '{bucket_path}'" + + return from_clause, credentials_clause, file_name, orig_bucket_path + + def _determine_source_format( + self, file_name: str, orig_bucket_path: str + ) -> tuple[str, str, bool]: if file_name.endswith(".parquet"): - source_format = "PARQUET" # Only parquet is supported + return "PARQUET", "", False + elif file_name.endswith(".jsonl"): if not is_compression_disabled(): raise LoadJobTerminalException( self._file_path, - "Databricks loader does not support gzip compressed JSON files. Please disable" - " compression in the data writer configuration:" + "Databricks loader does not support gzip compressed JSON files. " + "Please disable compression in the data writer configuration:" " https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", ) - source_format = "JSON" + format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')" - # Databricks fails when trying to load empty json files, so we have to check the file size + + # check for an empty JSON file fs, _ = fsspec_from_config(self._staging_config) - file_size = fs.size(orig_bucket_path) - if file_size == 0: # Empty file, do nothing - return + if orig_bucket_path is not None: + file_size = fs.size(orig_bucket_path) + if file_size == 0: + return "JSON", format_options_clause, True + + return "JSON", format_options_clause, False + + raise LoadJobTerminalException( + self._file_path, "Databricks loader only supports .parquet or .jsonl file extensions." + ) - statement = f"""COPY INTO {qualified_table_name} + def _build_copy_into_statement( + self, + qualified_table_name: str, + from_clause: str, + credentials_clause: str, + source_format: str, + format_options_clause: str, + ) -> str: + return f"""COPY INTO {qualified_table_name} {from_clause} {credentials_clause} FILEFORMAT = {source_format} {format_options_clause} - """ - self._sql_client.execute_sql(statement) + """ @staticmethod def ensure_databricks_abfss_url( diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index c36ef08c0b..da60bb1a8c 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -107,8 +107,8 @@ class databricks(Destination[DatabricksClientConfiguration, "DatabricksClient"]) def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = None - caps.supported_loader_file_formats = [] + caps.preferred_loader_file_format = "parquet" + caps.supported_loader_file_formats = ["jsonl", "parquet"] caps.preferred_staging_file_format = "parquet" caps.supported_staging_file_formats = ["jsonl", "parquet"] caps.supported_table_formats = ["delta"] @@ -154,6 +154,7 @@ def __init__( staging_credentials_name: t.Optional[str] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, + staging_volume_name: t.Optional[str] = None, **kwargs: t.Any, ) -> None: """Configure the Databricks destination to use in a pipeline. @@ -173,6 +174,7 @@ def __init__( staging_credentials_name=staging_credentials_name, destination_name=destination_name, environment=environment, + staging_volume_name=staging_volume_name, **kwargs, ) diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index d07973d087..510b013d6a 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -88,7 +88,11 @@ def open_connection(self) -> DatabricksSqlConnection: if self.credentials.client_id and self.credentials.client_secret: conn_params["credentials_provider"] = self._get_oauth_credentials + elif callable(self.credentials.access_token): + # this is w.config.authenticator + conn_params["credentials_provider"] = lambda: self.credentials.access_token else: + # this is access token conn_params["access_token"] = self.credentials.access_token self._conn = databricks_lib.connect( diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md index a28a42f761..d970378cce 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/databricks.md +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -240,6 +240,13 @@ You can find other options for specifying credentials in the [Authentication sec See [Staging support](#staging-support) for authentication options when `dlt` copies files from buckets. +### Using default credentials +If none of auth methods above is configured, `dlt` attempts to get authorization from the Databricks workspace context. The context may +come, for example, from a Notebook (runtime) or via standard set of env variables that Databricks Python sdk recognizes (ie. **DATABRICKS_TOKEN** or **DATABRICKS_HOST**) + +`dlt` is able to set `server_hostname` and `http_path` from available warehouses. We use default warehouse id (**DATABRICKS_WAREHOUSE_ID**) +if set (via env variable), or a first one on warehouse's list. + ## Write disposition All write dispositions are supported. @@ -260,6 +267,56 @@ The JSONL format has some limitations when used with Databricks: 2. The following data types are not supported when using the JSONL format with `databricks`: `decimal`, `json`, `date`, `binary`. Use `parquet` if your data contains these types. 3. The `bigint` data type with precision is not supported with the JSONL format. +## Direct Load (Databricks Managed Volumes) + +`dlt` now supports **Direct Load**, enabling pipelines to run seamlessly from **Databricks Notebooks** without external staging. When executed in a Databricks Notebook, `dlt` uses the notebook context for configuration if not explicitly provided. + +Direct Load also works **outside Databricks**, requiring explicit configuration of `server_hostname`, `http_path`, `catalog`, and authentication (`client_id`/`client_secret` for OAuth or `access_token` for token-based authentication). + +The example below demonstrates how to load data directly from a **Databricks Notebook**. Simply specify the **Databricks catalog** and optionally a **fully qualified volume name** (recommended for production) – the remaining configuration comes from the notebook context: + +```py +import dlt +from dlt.destinations import databricks +from dlt.sources.rest_api import rest_api_source + +# Fully qualified Databricks managed volume (recommended for production) +# - dlt assumes the named volume already exists +staging_volume_name = "dlt_ci.dlt_tests_shared.static_volume" + +bricks = databricks(credentials={"catalog": "dlt_ci"}, staging_volume_name=staging_volume_name) + +pokemon_source = rest_api_source( + { + "client": {"base_url": "https://pokeapi.co/api/v2/"}, + "resource_defaults": {"endpoint": {"params": {"limit": 1000}}}, + "resources": ["pokemon"], + } +) + +pipeline = dlt.pipeline( + pipeline_name="rest_api_example", + dataset_name="rest_api_data", + destination=bricks, +) + +load_info = pipeline.run(pokemon_source) +print(load_info) +print(pipeline.dataset().pokemon.df()) +``` + +- If **no** *staging_volume_name* **is provided**, dlt creates a **default volume** automatically. +- **For production**, explicitly setting *staging_volume_name* is recommended. +- The volume is used as a **temporary location** to store files before loading. + +:::tip:: +You can delete staged files **immediately** after loading by setting the following config option: +```toml +[destination.databricks] +keep_staged_files = false +``` +::: + ## Staging support Databricks supports both Amazon S3, Azure Blob Storage, and Google Cloud Storage as staging locations. `dlt` will upload files in Parquet format to the staging location and will instruct Databricks to load data from there. diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 00a28d652e..8e8618f90f 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -397,13 +397,13 @@ class FinalConfiguration(BaseConfiguration): class FinalConfiguration2(BaseConfiguration): pipeline_name: Final[str] = None - c2 = resolve.resolve_configuration(FinalConfiguration2()) - assert dict(c2) == {"pipeline_name": None} + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(FinalConfiguration2()) c2 = resolve.resolve_configuration( FinalConfiguration2(), explicit_value={"pipeline_name": "exp"} ) - assert c.pipeline_name == "exp" + assert c2.pipeline_name == "exp" with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration(FinalConfiguration2(), explicit_value={"pipeline_name": None}) @@ -435,13 +435,13 @@ class NotResolvedConfiguration(BaseConfiguration): class NotResolvedConfiguration2(BaseConfiguration): pipeline_name: Annotated[str, NotResolved()] = None - c2 = resolve.resolve_configuration(NotResolvedConfiguration2()) - assert dict(c2) == {"pipeline_name": None} + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(NotResolvedConfiguration2()) c2 = resolve.resolve_configuration( NotResolvedConfiguration2(), explicit_value={"pipeline_name": "exp"} ) - assert c.pipeline_name == "exp" + assert c2.pipeline_name == "exp" with pytest.raises(ConfigFieldMissingException): resolve.resolve_configuration( NotResolvedConfiguration2(), explicit_value={"pipeline_name": None} diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index 8b3beed2b3..cc98c47d33 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -1,8 +1,12 @@ import pytest import os +from dlt.common.schema.schema import Schema +from dlt.common.utils import digest128 + pytest.importorskip("databricks") +import dlt from dlt.common.exceptions import TerminalValueError from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob @@ -12,6 +16,7 @@ from dlt.destinations.impl.databricks.configuration import ( DatabricksClientConfiguration, DATABRICKS_APPLICATION_ID, + DatabricksCredentials, ) # mark all tests as essential, do not remove @@ -43,6 +48,11 @@ def test_databricks_credentials_to_connector_params(): assert params["_socket_timeout"] == credentials.socket_timeout assert params["_user_agent_entry"] == DATABRICKS_APPLICATION_ID + displayable_location = str(credentials) + assert displayable_location.startswith( + "databricks://my-databricks.example.com/sql/1.0/warehouses/asdfe/my-catalog" + ) + def test_databricks_configuration() -> None: bricks = databricks() @@ -90,9 +100,110 @@ def test_databricks_abfss_converter() -> None: def test_databricks_auth_invalid() -> None: - with pytest.raises(ConfigurationValueError, match="No valid authentication method detected.*"): + with pytest.raises(ConfigurationValueError, match="Authentication failed:*"): os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = "" bricks = databricks() bricks.configuration(None, accept_partial=True) + + +def test_databricks_missing_config_catalog() -> None: + with pytest.raises( + ConfigurationValueError, match="Configuration error: Missing required parameter 'catalog'*" + ): + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CATALOG"] = "" + bricks = databricks() + bricks.configuration(None, accept_partial=True) + + +def test_databricks_missing_config_http_path() -> None: + with pytest.raises( + ConfigurationValueError, + match="Configuration error: Missing required parameter 'http_path'*", + ): + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__HTTP_PATH"] = "" + bricks = databricks() + bricks.configuration(None, accept_partial=True) + + +def test_databricks_missing_config_server_hostname() -> None: + with pytest.raises( + ConfigurationValueError, + match="Configuration error: Missing required parameter 'server_hostname'*", + ): + os.environ["DESTINATION__DATABRICKS__CREDENTIALS__SERVER_HOSTNAME"] = "" + bricks = databricks() + bricks.configuration(None, accept_partial=True) + + +@pytest.mark.parametrize("auth_type", ("pat", "oauth2")) +def test_default_credentials(auth_type: str) -> None: + # create minimal default env + os.environ["DATABRICKS_HOST"] = dlt.secrets[ + "destination.databricks.credentials.server_hostname" + ] + if auth_type == "pat": + os.environ["DATABRICKS_TOKEN"] = dlt.secrets[ + "destination.databricks.credentials.access_token" + ] + else: + os.environ["DATABRICKS_CLIENT_ID"] = dlt.secrets[ + "destination.databricks.credentials.client_id" + ] + os.environ["DATABRICKS_CLIENT_SECRET"] = dlt.secrets[ + "destination.databricks.credentials.client_secret" + ] + + # will not pick up the credentials from "destination.databricks" + config = resolve_configuration( + DatabricksClientConfiguration( + credentials=DatabricksCredentials(catalog="dlt_ci") + )._bind_dataset_name(dataset_name="my-dataset-1234") + ) + # we pass authenticator that will be used to make connection, that's why callable + assert callable(config.credentials.access_token) + # taken from a warehouse + assert isinstance(config.credentials.http_path, str) + + bricks = databricks(credentials=config.credentials) + # "my-dataset-1234" not present (we check SQL execution) + with bricks.client(Schema("schema"), config) as client: + assert not client.is_storage_initialized() + + # check fingerprint not default + assert config.fingerprint() != digest128("") + + +def test_oauth2_credentials() -> None: + dlt.secrets["destination.databricks.credentials.access_token"] = "" + # we must prime the "destinations" for google secret manager config provider + # because it retrieves catalog as first element and it is not secret. and vault providers + # are secret only + dlt.secrets.get("destination.credentials") + config = resolve_configuration( + DatabricksClientConfiguration()._bind_dataset_name(dataset_name="my-dataset-1234-oauth"), + sections=("destination", "databricks"), + ) + assert config.credentials.access_token == "" + # will resolve to oauth token + bricks = databricks(credentials=config.credentials) + # "my-dataset-1234-oauth" not present (we check SQL execution) + with bricks.client(Schema("schema"), config) as client: + assert not client.is_storage_initialized() + + +def test_default_warehouse() -> None: + os.environ["DATABRICKS_TOKEN"] = dlt.secrets["destination.databricks.credentials.access_token"] + os.environ["DATABRICKS_HOST"] = dlt.secrets[ + "destination.databricks.credentials.server_hostname" + ] + # will force this warehouse + os.environ["DATABRICKS_WAREHOUSE_ID"] = "588dbd71bd802f4d" + + config = resolve_configuration( + DatabricksClientConfiguration( + credentials=DatabricksCredentials(catalog="dlt_ci") + )._bind_dataset_name(dataset_name="my-dataset-1234") + ) + assert config.credentials.http_path == "/sql/1.0/warehouses/588dbd71bd802f4d" diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 078dce3a7f..41791059e5 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -1,6 +1,9 @@ import pytest import os +from pytest_mock import MockerFixture +import dlt + from dlt.common.utils import uniq_id from dlt.destinations import databricks from tests.load.utils import ( @@ -19,7 +22,7 @@ @pytest.mark.parametrize( "destination_config", destinations_configs( - default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + default_staging_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) ), ids=lambda x: x.name, ) @@ -105,7 +108,7 @@ def test_databricks_external_location(destination_config: DestinationTestConfigu @pytest.mark.parametrize( "destination_config", destinations_configs( - default_sql_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + default_staging_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) ), ids=lambda x: x.name, ) @@ -154,19 +157,29 @@ def test_databricks_gcs_external_location(destination_config: DestinationTestCon @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, subset=("databricks",)), + destinations_configs( + default_staging_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + ), ids=lambda x: x.name, ) def test_databricks_auth_oauth(destination_config: DestinationTestConfiguration) -> None: os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = "" - bricks = databricks() + + from dlt.destinations import databricks, filesystem + from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob + + abfss_bucket_url = DatabricksLoadJob.ensure_databricks_abfss_url(AZ_BUCKET, "dltdata") + stage = filesystem(abfss_bucket_url) + + bricks = databricks(is_staging_external_location=False) config = bricks.configuration(None, accept_partial=True) + assert config.credentials.client_id and config.credentials.client_secret assert not config.credentials.access_token dataset_name = "test_databricks_oauth" + uniq_id() pipeline = destination_config.setup_pipeline( - "test_databricks_oauth", dataset_name=dataset_name, destination=bricks + "test_databricks_oauth", dataset_name=dataset_name, destination=bricks, staging=stage ) info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) @@ -179,20 +192,29 @@ def test_databricks_auth_oauth(destination_config: DestinationTestConfiguration) @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, subset=("databricks",)), + destinations_configs( + default_staging_configs=True, bucket_subset=(AZ_BUCKET,), subset=("databricks",) + ), ids=lambda x: x.name, ) def test_databricks_auth_token(destination_config: DestinationTestConfiguration) -> None: os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = "" os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = "" - bricks = databricks() + + from dlt.destinations import databricks, filesystem + from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob + + abfss_bucket_url = DatabricksLoadJob.ensure_databricks_abfss_url(AZ_BUCKET, "dltdata") + stage = filesystem(abfss_bucket_url) + + bricks = databricks(is_staging_external_location=False) config = bricks.configuration(None, accept_partial=True) assert config.credentials.access_token assert not (config.credentials.client_secret and config.credentials.client_id) dataset_name = "test_databricks_token" + uniq_id() pipeline = destination_config.setup_pipeline( - "test_databricks_token", dataset_name=dataset_name, destination=bricks + "test_databricks_token", dataset_name=dataset_name, destination=bricks, staging=stage ) info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) @@ -201,3 +223,58 @@ def test_databricks_auth_token(destination_config: DestinationTestConfiguration) with pipeline.sql_client() as client: rows = client.execute_sql(f"select * from {dataset_name}.digits") assert len(rows) == 3 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=("databricks",)), + ids=lambda x: x.name, +) +def test_databricks_direct_load(destination_config: DestinationTestConfiguration) -> None: + dataset_name = "test_databricks_direct_load" + uniq_id() + pipeline = destination_config.setup_pipeline( + "test_databricks_direct_load", dataset_name=dataset_name + ) + assert pipeline.staging is None + + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) + assert info.has_failed_jobs is False + + with pipeline.sql_client() as client: + rows = client.execute_sql(f"select * from {dataset_name}.digits") + assert len(rows) == 3 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=("databricks",)), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("keep_staged_files", (True, False)) +def test_databricks_direct_load_with_custom_staging_volume_name_and_file_removal( + destination_config: DestinationTestConfiguration, + keep_staged_files: bool, + mocker: MockerFixture, +) -> None: + from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob + + remove_spy = mocker.spy(DatabricksLoadJob, "_handle_staged_file_remove") + custom_staging_volume_name = "dlt_ci.dlt_tests_shared.static_volume" + bricks = databricks( + staging_volume_name=custom_staging_volume_name, keep_staged_files=keep_staged_files + ) + + dataset_name = "test_databricks_direct_load" + uniq_id() + pipeline = destination_config.setup_pipeline( + "test_databricks_direct_load", dataset_name=dataset_name, destination=bricks + ) + + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) + assert info.has_failed_jobs is False + print(info) + + assert remove_spy.call_count == 0 if keep_staged_files else 2 + + with pipeline.sql_client() as client: + rows = client.execute_sql(f"select * from {dataset_name}.digits") + assert len(rows) == 3 diff --git a/tests/load/utils.py b/tests/load/utils.py index 5e61292825..5601fb7272 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -327,8 +327,7 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration(destination_type=destination) for destination in SQL_DESTINATIONS - if destination - not in ("athena", "synapse", "databricks", "dremio", "clickhouse", "sqlalchemy") + if destination not in ("athena", "synapse", "dremio", "clickhouse", "sqlalchemy") ] destination_configs += [ DestinationTestConfiguration(destination_type="duckdb", file_format="parquet"), @@ -364,14 +363,6 @@ def destinations_configs( destination_type="clickhouse", file_format="jsonl", supports_dbt=False ) ] - destination_configs += [ - DestinationTestConfiguration( - destination_type="databricks", - file_format="parquet", - bucket_url=AZ_BUCKET, - extra_info="az-authorization", - ) - ] destination_configs += [ DestinationTestConfiguration( @@ -463,6 +454,13 @@ def destinations_configs( bucket_url=AZ_BUCKET, extra_info="az-authorization", ), + DestinationTestConfiguration( + destination_type="databricks", + staging="filesystem", + file_format="parquet", + bucket_url=AZ_BUCKET, + extra_info="az-authorization", + ), DestinationTestConfiguration( destination_type="databricks", staging="filesystem", @@ -660,7 +658,9 @@ def destinations_configs( destination_configs = [ conf for conf in destination_configs - if conf.destination_type != "filesystem" or conf.bucket_url in bucket_subset + # filter by bucket when (1) filesystem OR (2) specific set of destinations requested + if (conf.destination_type != "filesystem" and not subset) + or conf.bucket_url in bucket_subset ] if exclude: destination_configs = [