From 0cd439b6ff0770841b84f47a6a76ec3283fcd988 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Mon, 15 Jul 2024 22:15:21 +0300 Subject: [PATCH] Use existing `_tls_no_verify` option in CloudFetch downloader Signed-off-by: Levko Kravets --- src/databricks/sql/client.py | 2 ++ src/databricks/sql/cloudfetch/download_manager.py | 9 ++++++++- src/databricks/sql/cloudfetch/downloader.py | 7 ++++++- src/databricks/sql/thrift_backend.py | 6 +++++- src/databricks/sql/utils.py | 13 +++++++++++-- 5 files changed, 32 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e56d22f6..084c42df 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -171,6 +171,8 @@ def read(self) -> Optional[OAuthToken]: # Which port to connect to # _skip_routing_headers: # Don't set routing headers if set to True (for use when connecting directly to server) + # _tls_no_verify + # Set to True (Boolean) to completely disable SSL verification. # _tls_verify_hostname # Set to False (Boolean) to disable SSL hostname verification, but check certificate. # _tls_trusted_ca_file diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 93b6f623..e30adcd6 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,5 +1,6 @@ import logging +from ssl import SSLContext from concurrent.futures import ThreadPoolExecutor, Future from typing import List, Union @@ -19,6 +20,7 @@ def __init__( links: List[TSparkArrowResultLink], max_download_threads: int, lz4_compressed: bool, + ssl_context: SSLContext, ): self._pending_links: List[TSparkArrowResultLink] = [] for link in links: @@ -36,6 +38,7 @@ def __init__( self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads) self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) + self._ssl_context = ssl_context def get_next_downloaded_file( self, next_row_offset: int @@ -89,7 +92,11 @@ def _schedule_downloads(self): logger.debug( "- start: {}, row count: {}".format(link.startRowOffset, link.rowCount) ) - handler = ResultSetDownloadHandler(self._downloadable_result_settings, link) + handler = ResultSetDownloadHandler( + settings=self._downloadable_result_settings, + link=link, + ssl_context=self._ssl_context, + ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 6663db7d..00ffecd0 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -3,6 +3,7 @@ import requests from requests.adapters import HTTPAdapter, Retry +from ssl import SSLContext, CERT_NONE import lz4.frame import time @@ -65,9 +66,11 @@ def __init__( self, settings: DownloadableResultSettings, link: TSparkArrowResultLink, + ssl_context: SSLContext, ): self.settings = settings self.link = link + self._ssl_context = ssl_context def run(self) -> DownloadedFile: """ @@ -92,12 +95,14 @@ def run(self) -> DownloadedFile: session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) + ssl_verify = self._ssl_context.verify_mode != CERT_NONE + try: # Get the file via HTTP request response = session.get( self.link.fileLink, timeout=self.settings.download_timeout, - verify=False, + verify=ssl_verify, ) response.raise_for_status() diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 79293e85..56412fce 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -184,6 +184,8 @@ def __init__( password=tls_client_cert_key_password, ) + self._ssl_context = ssl_context + self._auth_provider = auth_provider # Connector version 3 retry approach @@ -223,7 +225,7 @@ def __init__( self._transport = databricks.sql.auth.thrift_http_client.THttpClient( auth_provider=self._auth_provider, uri_or_host=uri, - ssl_context=ssl_context, + ssl_context=self._ssl_context, **additional_transport_args, # type: ignore ) @@ -774,6 +776,7 @@ def _results_message_to_execute_response(self, resp, operation_state): max_download_threads=self.max_download_threads, lz4_compressed=lz4_compressed, description=description, + ssl_context=self._ssl_context, ) else: arrow_queue_opt = None @@ -1005,6 +1008,7 @@ def fetch_results( max_download_threads=self.max_download_threads, lz4_compressed=lz4_compressed, description=description, + ssl_context=self._ssl_context, ) return queue, resp.hasMoreRows diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 4a770079..c22688bb 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -9,6 +9,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union import re +from ssl import SSLContext import lz4.frame import pyarrow @@ -47,6 +48,7 @@ def build_queue( t_row_set: TRowSet, arrow_schema_bytes: bytes, max_download_threads: int, + ssl_context: SSLContext, lz4_compressed: bool = True, description: Optional[List[List[Any]]] = None, ) -> ResultSetQueue: @@ -60,6 +62,7 @@ def build_queue( lz4_compressed (bool): Whether result data has been lz4 compressed. description (List[List[Any]]): Hive table schema description. max_download_threads (int): Maximum number of downloader thread pool threads. + ssl_context (SSLContext): SSLContext object for CloudFetchQueue Returns: ResultSetQueue @@ -82,12 +85,13 @@ def build_queue( return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.URL_BASED_SET: return CloudFetchQueue( - arrow_schema_bytes, + schema_bytes=arrow_schema_bytes, start_row_offset=t_row_set.startRowOffset, result_links=t_row_set.resultLinks, lz4_compressed=lz4_compressed, description=description, max_download_threads=max_download_threads, + ssl_context=ssl_context, ) else: raise AssertionError("Row set type is not valid") @@ -133,6 +137,7 @@ def __init__( self, schema_bytes, max_download_threads: int, + ssl_context: SSLContext, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, @@ -155,6 +160,7 @@ def __init__( self.result_links = result_links self.lz4_compressed = lz4_compressed self.description = description + self._ssl_context = ssl_context logger.debug( "Initialize CloudFetch loader, row set start offset: {}, file list:".format( @@ -169,7 +175,10 @@ def __init__( ) ) self.download_manager = ResultFileDownloadManager( - result_links or [], self.max_download_threads, self.lz4_compressed + links=result_links or [], + max_download_threads=self.max_download_threads, + lz4_compressed=self.lz4_compressed, + ssl_context=self._ssl_context, ) self.table = self._create_next_table()