Skip to content

Commit

Permalink
[PECO-1857] Use SSL options with HTTPS connection pool (#425)
Browse files Browse the repository at this point in the history
* [PECO-1857] Use SSL options with HTTPS connection pool

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Some cleanup

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Resolve circular dependencies

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Update existing tests

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Fix MyPy issues

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Fix `_tls_no_verify` handling

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

* Add tests

Signed-off-by: Levko Kravets <levko.ne@gmail.com>

---------

Signed-off-by: Levko Kravets <levko.ne@gmail.com>
  • Loading branch information
kravets-levko authored Aug 22, 2024
1 parent 2d2b3c1 commit 1f8cf73
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 159 deletions.
41 changes: 25 additions & 16 deletions src/databricks/sql/auth/thrift_http_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import base64
import logging
import urllib.parse
from typing import Dict, Union
from typing import Dict, Union, Optional

import six
import thrift

logger = logging.getLogger(__name__)

import ssl
import warnings
from http.client import HTTPResponse
Expand All @@ -16,6 +14,9 @@
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
from urllib3.util import make_headers
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
from databricks.sql.types import SSLOptions

logger = logging.getLogger(__name__)


class THttpClient(thrift.transport.THttpClient.THttpClient):
Expand All @@ -25,13 +26,12 @@ def __init__(
uri_or_host,
port=None,
path=None,
cafile=None,
cert_file=None,
key_file=None,
ssl_context=None,
ssl_options: Optional[SSLOptions] = None,
max_connections: int = 1,
retry_policy: Union[DatabricksRetryPolicy, int] = 0,
):
self._ssl_options = ssl_options

if port is not None:
warnings.warn(
"Please use the THttpClient('http{s}://host:port/path') constructor",
Expand All @@ -48,13 +48,11 @@ def __init__(
self.scheme = parsed.scheme
assert self.scheme in ("http", "https")
if self.scheme == "https":
self.certfile = cert_file
self.keyfile = key_file
self.context = (
ssl.create_default_context(cafile=cafile)
if (cafile and not ssl_context)
else ssl_context
)
if self._ssl_options is not None:
# TODO: Not sure if those options are used anywhere - need to double-check
self.certfile = self._ssl_options.tls_client_cert_file
self.keyfile = self._ssl_options.tls_client_cert_key_file
self.context = self._ssl_options.create_ssl_context()
self.port = parsed.port
self.host = parsed.hostname
self.path = parsed.path
Expand Down Expand Up @@ -109,12 +107,23 @@ def startRetryTimer(self):
def open(self):

# self.__pool replaces the self.__http used by the original THttpClient
_pool_kwargs = {"maxsize": self.max_connections}

if self.scheme == "http":
pool_class = HTTPConnectionPool
elif self.scheme == "https":
pool_class = HTTPSConnectionPool

_pool_kwargs = {"maxsize": self.max_connections}
_pool_kwargs.update(
{
"cert_reqs": ssl.CERT_REQUIRED
if self._ssl_options.tls_verify
else ssl.CERT_NONE,
"ca_certs": self._ssl_options.tls_trusted_ca_file,
"cert_file": self._ssl_options.tls_client_cert_file,
"key_file": self._ssl_options.tls_client_cert_key_file,
"key_password": self._ssl_options.tls_client_cert_key_password,
}
)

if self.using_proxy():
proxy_manager = ProxyManager(
Expand Down
18 changes: 16 additions & 2 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)


from databricks.sql.types import Row
from databricks.sql.types import Row, SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence

Expand Down Expand Up @@ -178,8 +178,9 @@ def read(self) -> Optional[OAuthToken]:
# _tls_trusted_ca_file
# Set to the path of the file containing trusted CA certificates for server certificate
# verification. If not provide, uses system truststore.
# _tls_client_cert_file, _tls_client_cert_key_file
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
# Set client SSL certificate.
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
# _retry_stop_after_attempts_count
# The maximum number of attempts during a request retry sequence (defaults to 24)
# _socket_timeout
Expand Down Expand Up @@ -220,12 +221,25 @@ def read(self) -> Optional[OAuthToken]:

base_headers = [("User-Agent", useragent_header)]

self._ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
"_tls_no_verify", False
), # by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.thrift_backend = ThriftBackend(
self.host,
self.port,
http_path,
(http_headers or []) + base_headers,
auth_provider,
ssl_options=self._ssl_options,
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
**kwargs,
)
Expand Down
9 changes: 5 additions & 4 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging

from ssl import SSLContext
from concurrent.futures import ThreadPoolExecutor, Future
from typing import List, Union

Expand All @@ -9,6 +8,8 @@
DownloadableResultSettings,
DownloadedFile,
)
from databricks.sql.types import SSLOptions

from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)
Expand All @@ -20,7 +21,7 @@ def __init__(
links: List[TSparkArrowResultLink],
max_download_threads: int,
lz4_compressed: bool,
ssl_context: SSLContext,
ssl_options: SSLOptions,
):
self._pending_links: List[TSparkArrowResultLink] = []
for link in links:
Expand All @@ -38,7 +39,7 @@ def __init__(
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)

self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self._ssl_context = ssl_context
self._ssl_options = ssl_options

def get_next_downloaded_file(
self, next_row_offset: int
Expand Down Expand Up @@ -95,7 +96,7 @@ def _schedule_downloads(self):
handler = ResultSetDownloadHandler(
settings=self._downloadable_result_settings,
link=link,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
)
task = self._thread_pool.submit(handler.run)
self._download_tasks.append(task)
Expand Down
12 changes: 5 additions & 7 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

import requests
from requests.adapters import HTTPAdapter, Retry
from ssl import SSLContext, CERT_NONE
import lz4.frame
import time

from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

from databricks.sql.exc import Error
from databricks.sql.types import SSLOptions

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,11 +65,11 @@ def __init__(
self,
settings: DownloadableResultSettings,
link: TSparkArrowResultLink,
ssl_context: SSLContext,
ssl_options: SSLOptions,
):
self.settings = settings
self.link = link
self._ssl_context = ssl_context
self._ssl_options = ssl_options

def run(self) -> DownloadedFile:
"""
Expand All @@ -95,14 +94,13 @@ def run(self) -> DownloadedFile:
session.mount("http://", HTTPAdapter(max_retries=retryPolicy))
session.mount("https://", HTTPAdapter(max_retries=retryPolicy))

ssl_verify = self._ssl_context.verify_mode != CERT_NONE

try:
# Get the file via HTTP request
response = session.get(
self.link.fileLink,
timeout=self.settings.download_timeout,
verify=ssl_verify,
verify=self._ssl_options.tls_verify,
# TODO: Pass cert from `self._ssl_options`
)
response.raise_for_status()

Expand Down
43 changes: 6 additions & 37 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
import uuid
import threading
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
from typing import List, Union

import pyarrow
Expand Down Expand Up @@ -36,6 +35,7 @@
convert_decimals_in_arrow_table,
convert_column_based_set_to_arrow_table,
)
from databricks.sql.types import SSLOptions

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(
http_path: str,
http_headers,
auth_provider: AuthProvider,
ssl_options: SSLOptions,
staging_allowed_local_path: Union[None, str, List[str]] = None,
**kwargs,
):
Expand All @@ -93,16 +94,6 @@ def __init__(
# Tag to add to User-Agent header. For use by partners.
# _username, _password
# Username and password Basic authentication (no official support)
# _tls_no_verify
# Set to True (Boolean) to completely disable SSL verification.
# _tls_verify_hostname
# Set to False (Boolean) to disable SSL hostname verification, but check certificate.
# _tls_trusted_ca_file
# Set to the path of the file containing trusted CA certificates for server certificate
# verification. If not provide, uses system truststore.
# _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password
# Set client SSL certificate.
# See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain
# _connection_uri
# Overrides server_hostname and http_path.
# RETRY/ATTEMPT POLICY
Expand Down Expand Up @@ -162,29 +153,7 @@ def __init__(
# Cloud fetch
self.max_download_threads = kwargs.get("max_download_threads", 10)

# Configure tls context
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
if kwargs.get("_tls_no_verify") is True:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_NONE
elif kwargs.get("_tls_verify_hostname") is False:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_REQUIRED
else:
ssl_context.check_hostname = True
ssl_context.verify_mode = CERT_REQUIRED

tls_client_cert_file = kwargs.get("_tls_client_cert_file")
tls_client_cert_key_file = kwargs.get("_tls_client_cert_key_file")
tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password")
if tls_client_cert_file:
ssl_context.load_cert_chain(
certfile=tls_client_cert_file,
keyfile=tls_client_cert_key_file,
password=tls_client_cert_key_password,
)

self._ssl_context = ssl_context
self._ssl_options = ssl_options

self._auth_provider = auth_provider

Expand Down Expand Up @@ -225,7 +194,7 @@ def __init__(
self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
auth_provider=self._auth_provider,
uri_or_host=uri,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
**additional_transport_args, # type: ignore
)

Expand Down Expand Up @@ -776,7 +745,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
)
else:
arrow_queue_opt = None
Expand Down Expand Up @@ -1008,7 +977,7 @@ def fetch_results(
max_download_threads=self.max_download_threads,
lz4_compressed=lz4_compressed,
description=description,
ssl_context=self._ssl_context,
ssl_options=self._ssl_options,
)

return queue, resp.hasMoreRows
Expand Down
48 changes: 48 additions & 0 deletions src/databricks/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,54 @@
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
import datetime
import decimal
from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context


class SSLOptions:
tls_verify: bool
tls_verify_hostname: bool
tls_trusted_ca_file: Optional[str]
tls_client_cert_file: Optional[str]
tls_client_cert_key_file: Optional[str]
tls_client_cert_key_password: Optional[str]

def __init__(
self,
tls_verify: bool = True,
tls_verify_hostname: bool = True,
tls_trusted_ca_file: Optional[str] = None,
tls_client_cert_file: Optional[str] = None,
tls_client_cert_key_file: Optional[str] = None,
tls_client_cert_key_password: Optional[str] = None,
):
self.tls_verify = tls_verify
self.tls_verify_hostname = tls_verify_hostname
self.tls_trusted_ca_file = tls_trusted_ca_file
self.tls_client_cert_file = tls_client_cert_file
self.tls_client_cert_key_file = tls_client_cert_key_file
self.tls_client_cert_key_password = tls_client_cert_key_password

def create_ssl_context(self) -> SSLContext:
ssl_context = create_default_context(cafile=self.tls_trusted_ca_file)

if self.tls_verify is False:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_NONE
elif self.tls_verify_hostname is False:
ssl_context.check_hostname = False
ssl_context.verify_mode = CERT_REQUIRED
else:
ssl_context.check_hostname = True
ssl_context.verify_mode = CERT_REQUIRED

if self.tls_client_cert_file:
ssl_context.load_cert_chain(
certfile=self.tls_client_cert_file,
keyfile=self.tls_client_cert_key_file,
password=self.tls_client_cert_key_password,
)

return ssl_context


class Row(tuple):
Expand Down
Loading

0 comments on commit 1f8cf73

Please sign in to comment.