Skip to content

Commit

Permalink
feat: set the X-Server-Timeout header when timeout is set (googleapis…
Browse files Browse the repository at this point in the history
…#927)

Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
- [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery/issues/new/choose) before writing your code!  That way we can discuss the change, evaluate designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)

Fixes googleapis#919 🦕
  • Loading branch information
jimfulton authored and abdelmegahedgoogle committed Apr 17, 2023
1 parent 55913a0 commit 1b78c41
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 9 deletions.
27 changes: 26 additions & 1 deletion google/cloud/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@
# https://github.com/googleapis/python-bigquery/issues/781#issuecomment-883497414
_PYARROW_BAD_VERSIONS = frozenset([packaging.version.Version("2.0.0")])

TIMEOUT_HEADER = "X-Server-Timeout"


class Project(object):
"""Wrapper for resource describing a BigQuery project.
Expand Down Expand Up @@ -742,16 +744,26 @@ def create_table(
return self.get_table(table.reference, retry=retry)

def _call_api(
self, retry, span_name=None, span_attributes=None, job_ref=None, **kwargs
self,
retry,
span_name=None,
span_attributes=None,
job_ref=None,
headers: Optional[Dict[str, str]] = None,
**kwargs,
):
kwargs = _add_server_timeout_header(headers, kwargs)
call = functools.partial(self._connection.api_request, **kwargs)

if retry:
call = retry(call)

if span_name is not None:
with create_span(
name=span_name, attributes=span_attributes, client=self, job_ref=job_ref
):
return call()

return call()

def get_dataset(
Expand Down Expand Up @@ -4045,3 +4057,16 @@ def _get_upload_headers(user_agent):
"User-Agent": user_agent,
"content-type": "application/json",
}


def _add_server_timeout_header(headers: Optional[Dict[str, str]], kwargs):
timeout = kwargs.get("timeout")
if timeout is not None:
if headers is None:
headers = {}
headers[TIMEOUT_HEADER] = str(timeout)

if headers:
kwargs["headers"] = headers

return kwargs
19 changes: 19 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import mock
import pytest

from .helpers import make_client
Expand All @@ -35,3 +36,21 @@ def DS_ID():
@pytest.fixture
def LOCATION():
yield "us-central"


def noop_add_server_timeout_header(headers, kwargs):
if headers:
kwargs["headers"] = headers
return kwargs


@pytest.fixture(autouse=True)
def disable_add_server_timeout_header(request):
if "enable_add_server_timeout_header" in request.keywords:
yield
else:
with mock.patch(
"google.cloud.bigquery.client._add_server_timeout_header",
noop_add_server_timeout_header,
):
yield
27 changes: 19 additions & 8 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,7 +1806,6 @@ def test_update_dataset(self):
"access": ACCESS,
},
path="/" + PATH,
headers=None,
timeout=7.5,
)
self.assertEqual(ds2.description, ds.description)
Expand Down Expand Up @@ -1850,7 +1849,6 @@ def test_update_dataset_w_custom_property(self):
method="PATCH",
data={"newAlphaProperty": "unreleased property"},
path=path,
headers=None,
timeout=DEFAULT_TIMEOUT,
)

Expand Down Expand Up @@ -1909,7 +1907,7 @@ def test_update_model(self):
"labels": {"x": "y"},
}
conn.api_request.assert_called_once_with(
method="PATCH", data=sent, path="/" + path, headers=None, timeout=7.5
method="PATCH", data=sent, path="/" + path, timeout=7.5
)
self.assertEqual(updated_model.model_id, model.model_id)
self.assertEqual(updated_model.description, model.description)
Expand Down Expand Up @@ -1982,7 +1980,6 @@ def test_update_routine(self):
method="PUT",
data=sent,
path="/projects/routines-project/datasets/test_routines/routines/updated_routine",
headers=None,
timeout=7.5,
)
self.assertEqual(actual_routine.arguments, routine.arguments)
Expand Down Expand Up @@ -2090,7 +2087,7 @@ def test_update_table(self):
"labels": {"x": "y"},
}
conn.api_request.assert_called_once_with(
method="PATCH", data=sent, path="/" + path, headers=None, timeout=7.5
method="PATCH", data=sent, path="/" + path, timeout=7.5
)
self.assertEqual(updated_table.description, table.description)
self.assertEqual(updated_table.friendly_name, table.friendly_name)
Expand Down Expand Up @@ -2140,7 +2137,6 @@ def test_update_table_w_custom_property(self):
method="PATCH",
path="/%s" % path,
data={"newAlphaProperty": "unreleased property"},
headers=None,
timeout=DEFAULT_TIMEOUT,
)
self.assertEqual(
Expand Down Expand Up @@ -2175,7 +2171,6 @@ def test_update_table_only_use_legacy_sql(self):
method="PATCH",
path="/%s" % path,
data={"view": {"useLegacySql": True}},
headers=None,
timeout=DEFAULT_TIMEOUT,
)
self.assertEqual(updated_table.view_use_legacy_sql, table.view_use_legacy_sql)
Expand Down Expand Up @@ -2273,7 +2268,6 @@ def test_update_table_w_query(self):
"expirationTime": str(_millis(exp_time)),
"schema": schema_resource,
},
headers=None,
timeout=DEFAULT_TIMEOUT,
)

Expand Down Expand Up @@ -8173,3 +8167,20 @@ def transmit_next_chunk(transport):

chunk_size = RU.call_args_list[0][0][1]
assert chunk_size == 100 * (1 << 20)


@pytest.mark.enable_add_server_timeout_header
@pytest.mark.parametrize("headers", [None, {}])
def test__call_api_add_server_timeout_w_timeout(client, headers):
client._connection = make_connection({})
client._call_api(None, method="GET", path="/", headers=headers, timeout=42)
client._connection.api_request.assert_called_with(
method="GET", path="/", timeout=42, headers={"X-Server-Timeout": "42"}
)


@pytest.mark.enable_add_server_timeout_header
def test__call_api_no_add_server_timeout_wo_timeout(client):
client._connection = make_connection({})
client._call_api(None, method="GET", path="/")
client._connection.api_request.assert_called_with(method="GET", path="/")

0 comments on commit 1b78c41

Please sign in to comment.