Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Rewind seekable streams before retrying #821

Merged
merged 12 commits into from
Nov 15, 2024
54 changes: 42 additions & 12 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import logging
import urllib.parse
from datetime import timedelta
Expand Down Expand Up @@ -130,6 +131,14 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
flattened = dict(flatten_dict(with_fixed_bools))
return flattened

@staticmethod
def _is_seekable_stream(data) -> bool:
renaudhartert-db marked this conversation as resolved.
Show resolved Hide resolved
if data is None:
return False
if not isinstance(data, io.IOBase):
return False
return data.seekable()

def do(self,
method: str,
url: str,
Expand All @@ -144,18 +153,27 @@ def do(self,
if headers is None:
headers = {}
headers['User-Agent'] = self._user_agent_base
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)
response = retryable(self._perform)(method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth)

# Only retry if the request is not a stream or if the stream is seekable and
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
# re-read already read data from the body.
if data is not None and not self._is_seekable_stream(data):
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
call = self._perform
else:
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
renaudhartert-db marked this conversation as resolved.
Show resolved Hide resolved
is_retryable=self._is_retryable,
clock=self._clock)(self._perform)

response = call(method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth)

resp = dict()
for header in response_headers if response_headers else []:
Expand Down Expand Up @@ -226,6 +244,12 @@ def _perform(self,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
# Keep track of the initial position of the stream so that we can rewind it if
# we need to retry the request.
renaudhartert-db marked this conversation as resolved.
Show resolved Hide resolved
initial_data_position = 0
if self._is_seekable_stream(data):
initial_data_position = data.tell()

response = self._session.request(method,
url,
params=self._fix_query_string(query),
Expand All @@ -237,9 +261,15 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)

error = self._error_parser.get_api_error(response)
if error is not None:
# If the request body is a seekable stream, rewind it so that it is ready
# to be read again in case of a retry.
if self._is_seekable_stream(data):
data.seek(initial_data_position)
renaudhartert-db marked this conversation as resolved.
Show resolved Hide resolved
raise error from None

return response

def _record_request_log(self, response: requests.Response, raw: bool = False) -> None:
Expand Down
138 changes: 138 additions & 0 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import random
from http.server import BaseHTTPRequestHandler
from typing import Iterator, List
Expand Down Expand Up @@ -314,3 +315,140 @@ def mock_iter_content(chunk_size):
assert received_data == test_data # all data was received correctly
assert len(content_chunks) == expected_chunks # correct number of chunks
assert all(len(c) <= chunk_size for c in content_chunks) # chunks don't exceed size


def test_is_seekable_stream():
client = _BaseClient()

# Test various input types that are not streams.
assert not client._is_seekable_stream(None) # None
assert not client._is_seekable_stream("string data") # str
renaudhartert-db marked this conversation as resolved.
Show resolved Hide resolved
assert not client._is_seekable_stream(b"binary data") # bytes
assert not client._is_seekable_stream(["list", "data"]) # list
assert not client._is_seekable_stream(42) # int

# Test non-seekable stream.
non_seekable = io.BytesIO(b"test data")
non_seekable.seekable = lambda: False
assert not client._is_seekable_stream(non_seekable)

# Test seekable streams.
assert client._is_seekable_stream(io.BytesIO(b"test data")) # BytesIO
assert client._is_seekable_stream(io.StringIO("test data")) # StringIO

# Test file objects.
with open(__file__, 'rb') as f:
assert client._is_seekable_stream(f) # File object

# Test custom seekable stream.
class CustomSeekableStream(io.IOBase):

def seekable(self):
return True

def seek(self, offset, whence=0):
return 0

def tell(self):
return 0

assert client._is_seekable_stream(CustomSeekableStream())


def test_no_retry_on_non_seekable_stream():
requests = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
requests.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

stream = io.BytesIO(b"test data")
stream.seekable = lambda: False # makes the stream appear non-seekable

with http_fixture_server(inner) as host:
client = _BaseClient()

# Should raise error immediately without retry.
with pytest.raises(DatabricksError):
client.do('POST', f'{host}/foo', data=stream)

# Verify that only one request was made (no retries).
assert len(requests) == 1
assert requests[0] == b"test data"


def test_perform_resets_seekable_stream_on_retry():
received_data = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

stream = io.BytesIO(b"0123456789") # seekable stream

# Read some data from the stream first to verify that the stream is
# reset to the correct position rather than to its beginning.
stream.read(4)
assert stream.tell() == 4

with http_fixture_server(inner) as host:
client = _BaseClient()

# Each call should fail and reset the stream.
with pytest.raises(DatabricksError):
renaudhartert-db marked this conversation as resolved.
Show resolved Hide resolved
client._perform('POST', f'{host}/foo', data=stream)
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)

assert received_data == [b"456789", b"456789", b"456789"]

# Verify stream was reset to initial position.
assert stream.tell() == 4


def test_perform_does_not_reset_nonseekable_stream_on_retry():
received_data = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

stream = io.BytesIO(b"0123456789")
stream.seekable = lambda: False # makes the stream appear non-seekable

# Read some data from the stream first to verify that the stream is
# reset to the correct position rather than to its beginning.
stream.read(4)
assert stream.tell() == 4

with http_fixture_server(inner) as host:
client = _BaseClient()

# Should fail without resetting the stream.
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)

assert received_data == [b"456789"]

# Verify stream was NOT reset to initial position.
assert stream.tell() == 10 # EOF
Loading