Skip to content

Commit

Permalink
Use fake clock for faster unit tests (#533)
Browse files Browse the repository at this point in the history
## Changes
Some unit tests of retry logic are slow because they sleep for a second
or two. This PR changes ApiClient to accept a Clock which supplies the
time for retries and waiting. In unit tests, we use a fake clock that
doesn't tick until sleep() is called, and then ticks by the amount
provided.

## Tests
Existing unit tests.

- [ ] `make test` run locally
- [ ] `make fmt` applied
- [ ] relevant integration tests applied
  • Loading branch information
mgyucht authored Feb 8, 2024
1 parent c20bbff commit 306ed7e
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 17 deletions.
49 changes: 49 additions & 0 deletions databricks/sdk/clock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import abc
import time


class Clock(metaclass=abc.ABCMeta):

@abc.abstractmethod
def time(self) -> float:
"""
Return the current time in seconds since the Epoch.
Fractions of a second may be present if the system clock provides them.
:return: The current time in seconds since the Epoch.
"""

@abc.abstractmethod
def sleep(self, seconds: float) -> None:
"""
Delay execution for a given number of seconds. The argument may be
a floating point number for subsecond precision.
:param seconds: The duration to sleep in seconds.
:return:
"""


class RealClock(Clock):
"""
A real clock that uses the ``time`` module to get the current time and sleep.
"""

def time(self) -> float:
"""
Return the current time in seconds since the Epoch.
Fractions of a second may be present if the system clock provides them.
:return: The current time in seconds since the Epoch.
"""
return time.time()

def sleep(self, seconds: float) -> None:
"""
Delay execution for a given number of seconds. The argument may be
a floating point number for subsecond precision.
:param seconds: The duration to sleep in seconds.
:return:
"""
time.sleep(seconds)
7 changes: 7 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import requests

from .azure import AzureEnvironment
from .clock import Clock, RealClock
from .credentials_provider import CredentialsProvider, DefaultCredentials
from .environments import (ALL_ENVS, DEFAULT_ENVIRONMENT, Cloud,
DatabricksEnvironment)
Expand Down Expand Up @@ -84,13 +85,15 @@ def __init__(self,
credentials_provider: CredentialsProvider = None,
product="unknown",
product_version="0.0.0",
clock: Clock = None,
**kwargs):
self._inner = {}
self._user_agent_other_info = []
self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials()
if 'databricks_environment' in kwargs:
self.databricks_environment = kwargs['databricks_environment']
del kwargs['databricks_environment']
self._clock = clock if clock is not None else RealClock()
try:
self._set_inner_config(kwargs)
self._load_from_env()
Expand Down Expand Up @@ -317,6 +320,10 @@ def sql_http_path(self) -> Optional[str]:
if self.warehouse_id:
return f'/sql/1.0/warehouses/{self.warehouse_id}'

@property
def clock(self) -> Clock:
return self._clock

@classmethod
def attributes(cls) -> Iterable[ConfigAttribute]:
""" Returns a list of Databricks SDK configuration metadata """
Expand Down
3 changes: 2 additions & 1 deletion databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def do(self,
headers = {}
headers['User-Agent'] = self._user_agent_base
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable)
is_retryable=self._is_retryable,
clock=self._cfg.clock)
return retryable(self._perform)(method,
path,
query=query,
Expand Down
14 changes: 9 additions & 5 deletions databricks/sdk/retries.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
import functools
import logging
import time
from datetime import timedelta
from random import random
from typing import Callable, Optional, Sequence, Type

from .clock import Clock, RealClock

logger = logging.getLogger(__name__)


def retried(*,
on: Sequence[Type[BaseException]] = None,
is_retryable: Callable[[BaseException], Optional[str]] = None,
timeout=timedelta(minutes=20)):
timeout=timedelta(minutes=20),
clock: Clock = None):
has_allowlist = on is not None
has_callback = is_retryable is not None
if not (has_allowlist or has_callback) or (has_allowlist and has_callback):
raise SyntaxError('either on=[Exception] or callback=lambda x: .. is required')
if clock is None:
clock = RealClock()

def decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
deadline = time.time() + timeout.total_seconds()
deadline = clock.time() + timeout.total_seconds()
attempt = 1
last_err = None
while time.time() < deadline:
while clock.time() < deadline:
try:
return func(*args, **kwargs)
except Exception as err:
Expand All @@ -50,7 +54,7 @@ def wrapper(*args, **kwargs):
raise err

logger.debug(f'Retrying: {retry_reason} (sleeping ~{sleep}s)')
time.sleep(sleep + random())
clock.sleep(sleep + random())
attempt += 1
raise TimeoutError(f'Timed out after {timeout}') from last_err

Expand Down
16 changes: 16 additions & 0 deletions tests/clock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from databricks.sdk.clock import Clock


class FakeClock(Clock):
"""
A simple clock that can be used to mock time in tests.
"""

def __init__(self, start_time: float = 0.0):
self._start_time = start_time

def time(self) -> float:
return self._start_time

def sleep(self, seconds: float) -> None:
self._start_time += seconds
13 changes: 7 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from databricks.sdk.service.iam import AccessControlRequest
from databricks.sdk.version import __version__

from .clock import FakeClock
from .conftest import noop_credentials


Expand Down Expand Up @@ -422,7 +423,7 @@ def inner(h: BaseHTTPRequestHandler):
requests.append(h.requestline)

with http_fixture_server(inner) as host:
api_client = ApiClient(Config(host=host, token='_'))
api_client = ApiClient(Config(host=host, token='_', clock=FakeClock()))
res = api_client.do('GET', '/foo')
assert 'foo' in res

Expand All @@ -445,7 +446,7 @@ def inner(h: BaseHTTPRequestHandler):
requests.append(h.requestline)

with http_fixture_server(inner) as host:
api_client = ApiClient(Config(host=host, token='_'))
api_client = ApiClient(Config(host=host, token='_', clock=FakeClock()))
res = api_client.do('GET', '/foo')
assert 'foo' in res

Expand All @@ -462,7 +463,7 @@ def inner(h: BaseHTTPRequestHandler):
requests.append(h.requestline)

with http_fixture_server(inner) as host:
api_client = ApiClient(Config(host=host, token='_', retry_timeout_seconds=1))
api_client = ApiClient(Config(host=host, token='_', retry_timeout_seconds=1, clock=FakeClock()))
with pytest.raises(TimeoutError):
api_client.do('GET', '/foo')

Expand All @@ -484,7 +485,7 @@ def inner(h: BaseHTTPRequestHandler):
requests.append(h.requestline)

with http_fixture_server(inner) as host:
api_client = ApiClient(Config(host=host, token='_'))
api_client = ApiClient(Config(host=host, token='_', clock=FakeClock()))
res = api_client.do('GET', '/foo')
assert 'foo' in res

Expand All @@ -502,7 +503,7 @@ def inner(h: BaseHTTPRequestHandler):
requests.append(h.requestline)

with http_fixture_server(inner) as host:
api_client = ApiClient(Config(host=host, token='_'))
api_client = ApiClient(Config(host=host, token='_', clock=FakeClock()))
with pytest.raises(DatabricksError):
api_client.do('GET', '/foo')

Expand All @@ -520,7 +521,7 @@ def inner(h: BaseHTTPRequestHandler):
requests.append(h.requestline)

with http_fixture_server(inner) as host:
api_client = ApiClient(Config(host=host, token='_'))
api_client = ApiClient(Config(host=host, token='_', clock=FakeClock()))
res = api_client.do('GET', '/foo')
assert 'foo' in res

Expand Down
11 changes: 6 additions & 5 deletions tests/test_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from databricks.sdk.errors import NotFound, ResourceDoesNotExist
from databricks.sdk.retries import retried
from tests.clock import FakeClock


def test_match_retry_condition_on_no_qualifier():
Expand All @@ -17,15 +18,15 @@ def foo():
def test_match_retry_condition_on_conflict():
with pytest.raises(SyntaxError):

@retried(on=[IOError], is_retryable=lambda _: 'always')
@retried(on=[IOError], is_retryable=lambda _: 'always', clock=FakeClock())
def foo():
return 1


def test_match_retry_always():
with pytest.raises(TimeoutError):

@retried(is_retryable=lambda _: 'always', timeout=timedelta(seconds=1))
@retried(is_retryable=lambda _: 'always', timeout=timedelta(seconds=1), clock=FakeClock())
def foo():
raise StopIteration()

Expand All @@ -35,7 +36,7 @@ def foo():
def test_match_on_errors():
with pytest.raises(TimeoutError):

@retried(on=[KeyError, AttributeError], timeout=timedelta(seconds=0.5))
@retried(on=[KeyError, AttributeError], timeout=timedelta(seconds=0.5), clock=FakeClock())
def foo():
raise KeyError(1)

Expand All @@ -45,7 +46,7 @@ def foo():
def test_match_on_subclass():
with pytest.raises(TimeoutError):

@retried(on=[NotFound], timeout=timedelta(seconds=0.5))
@retried(on=[NotFound], timeout=timedelta(seconds=0.5), clock=FakeClock())
def foo():
raise ResourceDoesNotExist(...)

Expand All @@ -55,7 +56,7 @@ def foo():
def test_propagates_outside_exception():
with pytest.raises(KeyError):

@retried(on=[AttributeError], timeout=timedelta(seconds=0.5))
@retried(on=[AttributeError], timeout=timedelta(seconds=0.5), clock=FakeClock())
def foo():
raise KeyError(1)

Expand Down

0 comments on commit 306ed7e

Please sign in to comment.