diff --git a/databricks/sdk/clock.py b/databricks/sdk/clock.py new file mode 100644 index 000000000..6eef7ee49 --- /dev/null +++ b/databricks/sdk/clock.py @@ -0,0 +1,49 @@ +import abc +import time + + +class Clock(metaclass=abc.ABCMeta): + + @abc.abstractmethod + def time(self) -> float: + """ + Return the current time in seconds since the Epoch. + Fractions of a second may be present if the system clock provides them. + + :return: The current time in seconds since the Epoch. + """ + + @abc.abstractmethod + def sleep(self, seconds: float) -> None: + """ + Delay execution for a given number of seconds. The argument may be + a floating point number for subsecond precision. + + :param seconds: The duration to sleep in seconds. + :return: + """ + + +class RealClock(Clock): + """ + A real clock that uses the ``time`` module to get the current time and sleep. + """ + + def time(self) -> float: + """ + Return the current time in seconds since the Epoch. + Fractions of a second may be present if the system clock provides them. + + :return: The current time in seconds since the Epoch. + """ + return time.time() + + def sleep(self, seconds: float) -> None: + """ + Delay execution for a given number of seconds. The argument may be + a floating point number for subsecond precision. + + :param seconds: The duration to sleep in seconds. + :return: + """ + time.sleep(seconds) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 0dadfc927..6f653e16a 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -11,6 +11,7 @@ import requests from .azure import AzureEnvironment +from .clock import Clock, RealClock from .credentials_provider import CredentialsProvider, DefaultCredentials from .environments import (ALL_ENVS, DEFAULT_ENVIRONMENT, Cloud, DatabricksEnvironment) @@ -84,6 +85,7 @@ def __init__(self, credentials_provider: CredentialsProvider = None, product="unknown", product_version="0.0.0", + clock: Clock = None, **kwargs): self._inner = {} self._user_agent_other_info = [] @@ -91,6 +93,7 @@ def __init__(self, if 'databricks_environment' in kwargs: self.databricks_environment = kwargs['databricks_environment'] del kwargs['databricks_environment'] + self._clock = clock if clock is not None else RealClock() try: self._set_inner_config(kwargs) self._load_from_env() @@ -317,6 +320,10 @@ def sql_http_path(self) -> Optional[str]: if self.warehouse_id: return f'/sql/1.0/warehouses/{self.warehouse_id}' + @property + def clock(self) -> Clock: + return self._clock + @classmethod def attributes(cls) -> Iterable[ConfigAttribute]: """ Returns a list of Databricks SDK configuration metadata """ diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 2b7442708..9b56d0f6d 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -123,7 +123,8 @@ def do(self, headers = {} headers['User-Agent'] = self._user_agent_base retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), - is_retryable=self._is_retryable) + is_retryable=self._is_retryable, + clock=self._cfg.clock) return retryable(self._perform)(method, path, query=query, diff --git a/databricks/sdk/retries.py b/databricks/sdk/retries.py index a91467c4a..b98c54281 100644 --- a/databricks/sdk/retries.py +++ b/databricks/sdk/retries.py @@ -1,30 +1,34 @@ import functools import logging -import time from datetime import timedelta from random import random from typing import Callable, Optional, Sequence, Type +from .clock import Clock, RealClock + logger = logging.getLogger(__name__) def retried(*, on: Sequence[Type[BaseException]] = None, is_retryable: Callable[[BaseException], Optional[str]] = None, - timeout=timedelta(minutes=20)): + timeout=timedelta(minutes=20), + clock: Clock = None): has_allowlist = on is not None has_callback = is_retryable is not None if not (has_allowlist or has_callback) or (has_allowlist and has_callback): raise SyntaxError('either on=[Exception] or callback=lambda x: .. is required') + if clock is None: + clock = RealClock() def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - deadline = time.time() + timeout.total_seconds() + deadline = clock.time() + timeout.total_seconds() attempt = 1 last_err = None - while time.time() < deadline: + while clock.time() < deadline: try: return func(*args, **kwargs) except Exception as err: @@ -50,7 +54,7 @@ def wrapper(*args, **kwargs): raise err logger.debug(f'Retrying: {retry_reason} (sleeping ~{sleep}s)') - time.sleep(sleep + random()) + clock.sleep(sleep + random()) attempt += 1 raise TimeoutError(f'Timed out after {timeout}') from last_err diff --git a/tests/clock.py b/tests/clock.py new file mode 100644 index 000000000..bb0595d10 --- /dev/null +++ b/tests/clock.py @@ -0,0 +1,16 @@ +from databricks.sdk.clock import Clock + + +class FakeClock(Clock): + """ + A simple clock that can be used to mock time in tests. + """ + + def __init__(self, start_time: float = 0.0): + self._start_time = start_time + + def time(self) -> float: + return self._start_time + + def sleep(self, seconds: float) -> None: + self._start_time += seconds diff --git a/tests/test_core.py b/tests/test_core.py index ca2eaac31..aed659b3b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -26,6 +26,7 @@ from databricks.sdk.service.iam import AccessControlRequest from databricks.sdk.version import __version__ +from .clock import FakeClock from .conftest import noop_credentials @@ -422,7 +423,7 @@ def inner(h: BaseHTTPRequestHandler): requests.append(h.requestline) with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_')) + api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) res = api_client.do('GET', '/foo') assert 'foo' in res @@ -445,7 +446,7 @@ def inner(h: BaseHTTPRequestHandler): requests.append(h.requestline) with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_')) + api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) res = api_client.do('GET', '/foo') assert 'foo' in res @@ -462,7 +463,7 @@ def inner(h: BaseHTTPRequestHandler): requests.append(h.requestline) with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', retry_timeout_seconds=1)) + api_client = ApiClient(Config(host=host, token='_', retry_timeout_seconds=1, clock=FakeClock())) with pytest.raises(TimeoutError): api_client.do('GET', '/foo') @@ -484,7 +485,7 @@ def inner(h: BaseHTTPRequestHandler): requests.append(h.requestline) with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_')) + api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) res = api_client.do('GET', '/foo') assert 'foo' in res @@ -502,7 +503,7 @@ def inner(h: BaseHTTPRequestHandler): requests.append(h.requestline) with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_')) + api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) with pytest.raises(DatabricksError): api_client.do('GET', '/foo') @@ -520,7 +521,7 @@ def inner(h: BaseHTTPRequestHandler): requests.append(h.requestline) with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_')) + api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) res = api_client.do('GET', '/foo') assert 'foo' in res diff --git a/tests/test_retries.py b/tests/test_retries.py index 734c9ccc8..65dfa4d70 100644 --- a/tests/test_retries.py +++ b/tests/test_retries.py @@ -4,6 +4,7 @@ from databricks.sdk.errors import NotFound, ResourceDoesNotExist from databricks.sdk.retries import retried +from tests.clock import FakeClock def test_match_retry_condition_on_no_qualifier(): @@ -17,7 +18,7 @@ def foo(): def test_match_retry_condition_on_conflict(): with pytest.raises(SyntaxError): - @retried(on=[IOError], is_retryable=lambda _: 'always') + @retried(on=[IOError], is_retryable=lambda _: 'always', clock=FakeClock()) def foo(): return 1 @@ -25,7 +26,7 @@ def foo(): def test_match_retry_always(): with pytest.raises(TimeoutError): - @retried(is_retryable=lambda _: 'always', timeout=timedelta(seconds=1)) + @retried(is_retryable=lambda _: 'always', timeout=timedelta(seconds=1), clock=FakeClock()) def foo(): raise StopIteration() @@ -35,7 +36,7 @@ def foo(): def test_match_on_errors(): with pytest.raises(TimeoutError): - @retried(on=[KeyError, AttributeError], timeout=timedelta(seconds=0.5)) + @retried(on=[KeyError, AttributeError], timeout=timedelta(seconds=0.5), clock=FakeClock()) def foo(): raise KeyError(1) @@ -45,7 +46,7 @@ def foo(): def test_match_on_subclass(): with pytest.raises(TimeoutError): - @retried(on=[NotFound], timeout=timedelta(seconds=0.5)) + @retried(on=[NotFound], timeout=timedelta(seconds=0.5), clock=FakeClock()) def foo(): raise ResourceDoesNotExist(...) @@ -55,7 +56,7 @@ def foo(): def test_propagates_outside_exception(): with pytest.raises(KeyError): - @retried(on=[AttributeError], timeout=timedelta(seconds=0.5)) + @retried(on=[AttributeError], timeout=timedelta(seconds=0.5), clock=FakeClock()) def foo(): raise KeyError(1)