Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable SSL verification for CloudFetch links #414

Merged
merged 3 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def read(self) -> Optional[OAuthToken]:
# Which port to connect to
# _skip_routing_headers:
# Don't set routing headers if set to True (for use when connecting directly to server)
# _tls_no_verify
# Set to True (Boolean) to completely disable SSL verification.
# _tls_verify_hostname
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
# _tls_trusted_ca_file
Expand Down
9 changes: 8 additions & 1 deletion src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from ssl import SSLContext
from concurrent.futures import ThreadPoolExecutor, Future
from typing import List, Union

Expand All @@ -19,6 +20,7 @@ def __init__(
links: List[TSparkArrowResultLink],
max_download_threads: int,
lz4_compressed: bool,
ssl_context: SSLContext,
):
self._pending_links: List[TSparkArrowResultLink] = []
for link in links:
Expand All @@ -36,6 +38,7 @@ def __init__(
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)

self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self._ssl_context = ssl_context

def get_next_downloaded_file(
self, next_row_offset: int
Expand Down Expand Up @@ -89,7 +92,11 @@ def _schedule_downloads(self):
logger.debug(
"- start: {}, row count: {}".format(link.startRowOffset, link.rowCount)
)
handler = ResultSetDownloadHandler(self._downloadable_result_settings, link)
handler = ResultSetDownloadHandler(
settings=self._downloadable_result_settings,
link=link,
ssl_context=self._ssl_context,
)
task = self._thread_pool.submit(handler.run)
self._download_tasks.append(task)

Expand Down
9 changes: 8 additions & 1 deletion src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import requests
from requests.adapters import HTTPAdapter, Retry
from ssl import SSLContext, CERT_NONE
import lz4.frame
import time

Expand Down Expand Up @@ -65,9 +66,11 @@ def __init__(
self,
settings: DownloadableResultSettings,
link: TSparkArrowResultLink,
ssl_context: SSLContext,
):
self.settings = settings
self.link = link
self._ssl_context = ssl_context

def run(self) -> DownloadedFile:
"""
Expand All @@ -92,10 +95,14 @@ def run(self) -> DownloadedFile:
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))

ssl_verify = self._ssl_context.verify_mode != CERT_NONE

try:
# Get the file via HTTP request
response = session.get(
self.link.fileLink, timeout=self.settings.download_timeout
self.link.fileLink,
timeout=self.settings.download_timeout,
verify=ssl_verify,
)
response.raise_for_status()

Expand Down
6 changes: 5 additions & 1 deletion src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def __init__(
password=tls_client_cert_key_password,
)

self._ssl_context = ssl_context

self._auth_provider = auth_provider

# Connector version 3 retry approach
Expand Down Expand Up @@ -223,7 +225,7 @@ def __init__(
self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
auth_provider=self._auth_provider,
uri_or_host=uri,
ssl_context=ssl_context,
ssl_context=self._ssl_context,
**additional_transport_args, # type: ignore
)

Expand Down Expand Up @@ -774,6 +776,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_context=self._ssl_context,
)
else:
arrow_queue_opt = None
Expand Down Expand Up @@ -1005,6 +1008,7 @@ def fetch_results(
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_context=self._ssl_context,
)

return queue, resp.hasMoreRows
Expand Down
13 changes: 11 additions & 2 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import re
from ssl import SSLContext

import lz4.frame
import pyarrow
Expand Down Expand Up @@ -47,6 +48,7 @@ def build_queue(
t_row_set: TRowSet,
arrow_schema_bytes: bytes,
max_download_threads: int,
ssl_context: SSLContext,
lz4_compressed: bool = True,
description: Optional[List[List[Any]]] = None,
) -> ResultSetQueue:
Expand All @@ -60,6 +62,7 @@ def build_queue(
lz4_compressed (bool): Whether result data has been lz4 compressed.
description (List[List[Any]]): Hive table schema description.
max_download_threads (int): Maximum number of downloader thread pool threads.
ssl_context (SSLContext): SSLContext object for CloudFetchQueue

Returns:
ResultSetQueue
Expand All @@ -82,12 +85,13 @@ def build_queue(
return ArrowQueue(converted_arrow_table, n_valid_rows)
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
return CloudFetchQueue(
arrow_schema_bytes,
schema_bytes=arrow_schema_bytes,
start_row_offset=t_row_set.startRowOffset,
result_links=t_row_set.resultLinks,
lz4_compressed=lz4_compressed,
description=description,
max_download_threads=max_download_threads,
ssl_context=ssl_context,
)
else:
raise AssertionError("Row set type is not valid")
Expand Down Expand Up @@ -133,6 +137,7 @@ def __init__(
self,
schema_bytes,
max_download_threads: int,
ssl_context: SSLContext,
start_row_offset: int = 0,
result_links: Optional[List[TSparkArrowResultLink]] = None,
lz4_compressed: bool = True,
Expand All @@ -155,6 +160,7 @@ def __init__(
self.result_links = result_links
self.lz4_compressed = lz4_compressed
self.description = description
self._ssl_context = ssl_context

logger.debug(
"Initialize CloudFetch loader, row set start offset: {}, file list:".format(
Expand All @@ -169,7 +175,10 @@ def __init__(
)
)
self.download_manager = ResultFileDownloadManager(
result_links or [], self.max_download_threads, self.lz4_compressed
links=result_links or [],
max_download_threads=self.max_download_threads,
lz4_compressed=self.lz4_compressed,
ssl_context=self._ssl_context,
)

self.table = self._create_next_table()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_cancel_command_calls_the_backend(self):
mock_op_handle = Mock()
cursor.active_op_handle = mock_op_handle
cursor.cancel()
self.assertTrue(mock_thrift_backend.cancel_command.called_with(mock_op_handle))
mock_thrift_backend.cancel_command.assert_called_with(mock_op_handle)

@patch("databricks.sql.client.logger")
def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command(
Expand Down
Loading
Loading