Skip to content

Commit

Permalink
work in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
donotpush committed Jan 15, 2025
1 parent 902c49d commit 1efe565
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 87 deletions.
26 changes: 22 additions & 4 deletions dlt/destinations/impl/databricks/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration
from dlt.common.configuration.exceptions import ConfigurationValueError

from dlt.common import logger

DATABRICKS_APPLICATION_ID = "dltHub_dlt"

Expand All @@ -15,6 +15,7 @@ class DatabricksCredentials(CredentialsConfiguration):
catalog: str = None
server_hostname: str = None
http_path: str = None
is_token_from_context: bool = False
access_token: Optional[TSecretStrValue] = None
client_id: Optional[TSecretStrValue] = None
client_secret: Optional[TSecretStrValue] = None
Expand All @@ -37,11 +38,28 @@ class DatabricksCredentials(CredentialsConfiguration):

def on_resolved(self) -> None:
if not ((self.client_id and self.client_secret) or self.access_token):
raise ConfigurationValueError(
"No valid authentication method detected. Provide either 'client_id' and"
" 'client_secret' for OAuth, or 'access_token' for token-based authentication."
# databricks authentication: attempt context token
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()
dbutils = w.dbutils
self.access_token = (
dbutils.notebook.entry_point.getDbutils()
.notebook()
.getContext()
.apiToken()
.getOrElse(None)
)

if not self.access_token:
raise ConfigurationValueError(
"No valid authentication method detected. Provide either 'client_id' and"
" 'client_secret' for OAuth, or 'access_token' for token-based authentication."
)

self.is_token_from_context = True
logger.info("Authenticating to Databricks using the user's Notebook API token.")

def to_connector_params(self) -> Dict[str, Any]:
conn_params = dict(
catalog=self.catalog,
Expand Down
274 changes: 191 additions & 83 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from dlt.destinations.sql_jobs import SqlMergeFollowupJob
from dlt.destinations.job_impl import ReferenceFollowupJobRequest
from dlt.destinations.utils import is_compression_disabled

from dlt.common import logger

SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS + GCS_PROTOCOLS

Expand All @@ -50,126 +50,234 @@ def __init__(
self._staging_config = staging_config
self._job_client: "DatabricksClient" = None

self._sql_client = None
self._workspace_client = None
self._created_volume = None

def run(self) -> None:
self._sql_client = self._job_client.sql_client

qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name)
staging_credentials = self._staging_config.credentials
# extract and prepare some vars

# Decide if this is a local file or a staged file
is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path)
if is_local_file and self._job_client.config.credentials.is_token_from_context:
# Handle local file by uploading to a temporary volume on Databricks
from_clause, file_name = self._handle_local_file_upload(self._file_path)
credentials_clause = ""
orig_bucket_path = None # not used for local file
else:
# Handle staged file
from_clause, credentials_clause, file_name, orig_bucket_path = (
self._handle_staged_file()
)

# Determine the source format and any additional format options
source_format, format_options_clause, skip_load = self._determine_source_format(
file_name, orig_bucket_path
)

if skip_load:
# If the file is empty or otherwise un-loadable, exit early
self._cleanup_volume() # in case we created a volume
return

# Build and execute the COPY INTO statement
statement = self._build_copy_into_statement(
qualified_table_name,
from_clause,
credentials_clause,
source_format,
format_options_clause,
)

self._sql_client.execute_sql(statement)

self._cleanup_volume()

def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog
import time
import io

w = WorkspaceClient(
host=self._job_client.config.credentials.server_hostname,
token=self._job_client.config.credentials.access_token,
)
self._workspace_client = w

# Create a temporary volume
volume_name = "_dlt_temp_load_volume"
# created_volume = w.volumes.create(
# catalog_name=self._sql_client.database_name,
# schema_name=self._sql_client.dataset_name,
# name=volume_name,
# volume_type=catalog.VolumeType.MANAGED,
# )
# self._created_volume = created_volume # store to delete later

qualified_volume_name = (
f"{self._sql_client.database_name}.{self._sql_client.dataset_name}.{volume_name}"
)
self._sql_client.execute_sql(f"""
CREATE VOLUME IF NOT EXISTS {qualified_volume_name}
""")

logger.info(f"datrabricks volume created {qualified_volume_name}")

# Compute volume paths
volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{volume_name}"
volume_folder = f"file_{time.time_ns()}"
volume_folder_path = f"{volume_path}/{volume_folder}"

file_name = FileStorage.get_file_name_from_file_path(local_file_path)
volume_file_path = f"{volume_folder_path}/{file_name}"

# Upload the file
with open(local_file_path, "rb") as f:
file_bytes = f.read()
binary_data = io.BytesIO(file_bytes)
w.files.upload(volume_file_path, binary_data, overwrite=True)

# Return the FROM clause and file name
from_clause = f"FROM '{volume_folder_path}'"

return from_clause, file_name

def _handle_staged_file(self) -> tuple[str, str, str, str]:
bucket_path = orig_bucket_path = (
ReferenceFollowupJobRequest.resolve_reference(self._file_path)
if ReferenceFollowupJobRequest.is_reference_job(self._file_path)
else ""
)
file_name = (
FileStorage.get_file_name_from_file_path(bucket_path)
if bucket_path
else self._file_name
)
from_clause = ""
credentials_clause = ""
format_options_clause = ""

if bucket_path:
bucket_url = urlparse(bucket_path)
bucket_scheme = bucket_url.scheme
if not bucket_path:
raise LoadJobTerminalException(
self._file_path,
"Cannot load from local file. Databricks does not support loading from local files."
" Configure staging with an s3, azure or google storage bucket.",
)

if bucket_scheme not in SUPPORTED_BLOB_STORAGE_PROTOCOLS:
raise LoadJobTerminalException(
self._file_path,
f"Databricks cannot load data from staging bucket {bucket_path}. Only s3, azure"
" and gcs buckets are supported. Please note that gcs buckets are supported"
" only via named credential",
)
# Extract filename
file_name = FileStorage.get_file_name_from_file_path(bucket_path)

if self._job_client.config.is_staging_external_location:
# just skip the credentials clause for external location
# https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location
pass
elif self._job_client.config.staging_credentials_name:
# add named credentials
credentials_clause = (
f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )"
)
else:
# referencing an staged files via a bucket URL requires explicit AWS credentials
if bucket_scheme == "s3":
assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults)
s3_creds = staging_credentials.to_session_credentials()
credentials_clause = f"""WITH(CREDENTIAL(
staging_credentials = self._staging_config.credentials
bucket_url = urlparse(bucket_path)
bucket_scheme = bucket_url.scheme

# Validate the storage scheme
if bucket_scheme not in SUPPORTED_BLOB_STORAGE_PROTOCOLS:
raise LoadJobTerminalException(
self._file_path,
f"Databricks cannot load data from staging bucket {bucket_path}. "
"Only s3, azure and gcs buckets are supported. "
"Please note that gcs buckets are supported only via named credential.",
)

credentials_clause = ""
# External location vs named credentials vs explicit keys
if self._job_client.config.is_staging_external_location:
# Skip the credentials clause
pass
elif self._job_client.config.staging_credentials_name:
# Named credentials
credentials_clause = (
f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )"
)
else:
# Use explicit keys if needed
if bucket_scheme == "s3":
assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults)
s3_creds = staging_credentials.to_session_credentials()
credentials_clause = f"""WITH(CREDENTIAL(
AWS_ACCESS_KEY='{s3_creds["aws_access_key_id"]}',
AWS_SECRET_KEY='{s3_creds["aws_secret_access_key"]}',
AWS_SESSION_TOKEN='{s3_creds["aws_session_token"]}'
))
"""
elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS:
assert isinstance(
staging_credentials, AzureCredentialsWithoutDefaults
), "AzureCredentialsWithoutDefaults required to pass explicit credential"
# Explicit azure credentials are needed to load from bucket without a named stage
credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))"""
bucket_path = self.ensure_databricks_abfss_url(
bucket_path,
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)
else:
raise LoadJobTerminalException(
self._file_path,
"You need to use Databricks named credential to use google storage."
" Passing explicit Google credentials is not supported by Databricks.",
)

if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS:
))"""
elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS:
assert isinstance(
staging_credentials,
(
AzureCredentialsWithoutDefaults,
AzureServicePrincipalCredentialsWithoutDefaults,
),
)
staging_credentials, AzureCredentialsWithoutDefaults
), "AzureCredentialsWithoutDefaults required to pass explicit credential"
credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))"""
bucket_path = self.ensure_databricks_abfss_url(
bucket_path,
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)
else:
raise LoadJobTerminalException(
self._file_path,
"You need to use Databricks named credential to use google storage."
" Passing explicit Google credentials is not supported by Databricks.",
)

# always add FROM clause
from_clause = f"FROM '{bucket_path}'"
else:
raise LoadJobTerminalException(
self._file_path,
"Cannot load from local file. Databricks does not support loading from local files."
" Configure staging with an s3, azure or google storage bucket.",
if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS:
assert isinstance(
staging_credentials,
(AzureCredentialsWithoutDefaults, AzureServicePrincipalCredentialsWithoutDefaults),
)
bucket_path = self.ensure_databricks_abfss_url(
bucket_path,
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)

# decide on source format, stage_file_path will either be a local file or a bucket path
from_clause = f"FROM '{bucket_path}'"

return from_clause, credentials_clause, file_name, orig_bucket_path

def _determine_source_format(
self, file_name: str, orig_bucket_path: str
) -> tuple[str, str, bool]:
if file_name.endswith(".parquet"):
source_format = "PARQUET" # Only parquet is supported
# Only parquet is supported
return "PARQUET", "", False

elif file_name.endswith(".jsonl"):
if not is_compression_disabled():
raise LoadJobTerminalException(
self._file_path,
"Databricks loader does not support gzip compressed JSON files. Please disable"
" compression in the data writer configuration:"
"Databricks loader does not support gzip compressed JSON files. "
"Please disable compression in the data writer configuration:"
" https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression",
)
source_format = "JSON"
# Databricks can load uncompressed JSON
format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')"
# Databricks fails when trying to load empty json files, so we have to check the file size

# Check for an empty JSON file
fs, _ = fsspec_from_config(self._staging_config)
file_size = fs.size(orig_bucket_path)
if file_size == 0: # Empty file, do nothing
return
if orig_bucket_path is not None:
file_size = fs.size(orig_bucket_path)
if file_size == 0:
return "JSON", format_options_clause, True

statement = f"""COPY INTO {qualified_table_name}
return "JSON", format_options_clause, False

raise LoadJobTerminalException(
self._file_path, "Databricks loader only supports .parquet or .jsonl file extensions."
)

def _build_copy_into_statement(
self,
qualified_table_name: str,
from_clause: str,
credentials_clause: str,
source_format: str,
format_options_clause: str,
) -> str:
return f"""COPY INTO {qualified_table_name}
{from_clause}
{credentials_clause}
FILEFORMAT = {source_format}
{format_options_clause}
"""
self._sql_client.execute_sql(statement)
"""

def _cleanup_volume(self) -> None:
print("lalal")
# if self._workspace_client and self._created_volume:
# self._workspace_client.volumes.delete(name=self._created_volume.full_name)
# logger.info(f"Deleted temporary volume [{self._created_volume.full_name}]")

@staticmethod
def ensure_databricks_abfss_url(
Expand Down
Loading

0 comments on commit 1efe565

Please sign in to comment.