diff --git a/google/cloud/bigtable_v2/services/bigtable/async_client.py b/google/cloud/bigtable_v2/services/bigtable/async_client.py index 12432dda7..1ed7a4740 100644 --- a/google/cloud/bigtable_v2/services/bigtable/async_client.py +++ b/google/cloud/bigtable_v2/services/bigtable/async_client.py @@ -340,11 +340,13 @@ def read_rows( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("table_name", request.table_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("table_name", request.table_name),) + ), + ) # Validate the universe domain. self._client._validate_universe_domain() @@ -441,11 +443,13 @@ def sample_row_keys( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("table_name", request.table_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("table_name", request.table_name),) + ), + ) # Validate the universe domain. self._client._validate_universe_domain() @@ -563,11 +567,13 @@ async def mutate_row( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("table_name", request.table_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("table_name", request.table_name),) + ), + ) # Validate the universe domain. self._client._validate_universe_domain() @@ -679,11 +685,13 @@ def mutate_rows( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("table_name", request.table_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("table_name", request.table_name),) + ), + ) # Validate the universe domain. self._client._validate_universe_domain() @@ -838,11 +846,13 @@ async def check_and_mutate_row( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("table_name", request.table_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("table_name", request.table_name),) + ), + ) # Validate the universe domain. self._client._validate_universe_domain() @@ -936,9 +946,11 @@ async def ping_and_warm( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) # Validate the universe domain. self._client._validate_universe_domain() @@ -1062,11 +1074,13 @@ async def read_modify_write_row( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("table_name", request.table_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("table_name", request.table_name),) + ), + ) # Validate the universe domain. self._client._validate_universe_domain() @@ -1172,11 +1186,13 @@ def generate_initial_change_stream_partitions( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("table_name", request.table_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("table_name", request.table_name),) + ), + ) # Validate the universe domain. self._client._validate_universe_domain() @@ -1274,11 +1290,13 @@ def read_change_stream( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("table_name", request.table_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("table_name", request.table_name),) + ), + ) # Validate the universe domain. self._client._validate_universe_domain() @@ -1377,11 +1395,13 @@ def execute_query( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("instance_name", request.instance_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("instance_name", request.instance_name),) + ), + ) # Validate the universe domain. self._client._validate_universe_domain() diff --git a/google/cloud/bigtable_v2/services/bigtable/client.py b/google/cloud/bigtable_v2/services/bigtable/client.py index 0937c90fe..4a3f19ce6 100644 --- a/google/cloud/bigtable_v2/services/bigtable/client.py +++ b/google/cloud/bigtable_v2/services/bigtable/client.py @@ -817,9 +817,9 @@ def read_rows( ) if header_params: - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(header_params), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += (gapic_v1.routing_header.to_grpc_metadata(header_params),) # Validate the universe domain. self._validate_universe_domain() @@ -933,9 +933,9 @@ def sample_row_keys( ) if header_params: - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(header_params), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += (gapic_v1.routing_header.to_grpc_metadata(header_params),) # Validate the universe domain. self._validate_universe_domain() @@ -1070,9 +1070,9 @@ def mutate_row( ) if header_params: - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(header_params), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += (gapic_v1.routing_header.to_grpc_metadata(header_params),) # Validate the universe domain. self._validate_universe_domain() @@ -1201,9 +1201,9 @@ def mutate_rows( ) if header_params: - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(header_params), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += (gapic_v1.routing_header.to_grpc_metadata(header_params),) # Validate the universe domain. self._validate_universe_domain() @@ -1375,9 +1375,9 @@ def check_and_mutate_row( ) if header_params: - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(header_params), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += (gapic_v1.routing_header.to_grpc_metadata(header_params),) # Validate the universe domain. self._validate_universe_domain() @@ -1477,9 +1477,9 @@ def ping_and_warm( header_params["app_profile_id"] = request.app_profile_id if header_params: - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(header_params), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += (gapic_v1.routing_header.to_grpc_metadata(header_params),) # Validate the universe domain. self._validate_universe_domain() @@ -1620,9 +1620,9 @@ def read_modify_write_row( ) if header_params: - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(header_params), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += (gapic_v1.routing_header.to_grpc_metadata(header_params),) # Validate the universe domain. self._validate_universe_domain() @@ -1725,11 +1725,13 @@ def generate_initial_change_stream_partitions( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("table_name", request.table_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("table_name", request.table_name),) + ), + ) # Validate the universe domain. self._validate_universe_domain() @@ -1824,11 +1826,13 @@ def read_change_stream( # Certain fields should be provided within the metadata header; # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("table_name", request.table_name),) - ), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += ( + gapic_v1.routing_header.to_grpc_metadata( + (("table_name", request.table_name),) + ), + ) # Validate the universe domain. self._validate_universe_domain() @@ -1933,9 +1937,9 @@ def execute_query( header_params["app_profile_id"] = request.app_profile_id if header_params: - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(header_params), - ) + metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += (gapic_v1.routing_header.to_grpc_metadata(header_params),) # Validate the universe domain. self._validate_universe_domain() diff --git a/owlbot.py b/owlbot.py index 84aa3d61b..090f7ee93 100644 --- a/owlbot.py +++ b/owlbot.py @@ -143,6 +143,17 @@ def insert(file, before_line, insert_line, after_line, escape=None): escape='"' ) +# ---------------------------------------------------------------------------- +# Patch duplicate routing header: https://github.com/googleapis/gapic-generator-python/issues/2078 +# ---------------------------------------------------------------------------- +for file in ["client.py", "async_client.py"]: + s.replace( + f"google/cloud/bigtable_v2/services/bigtable/{file}", + "metadata \= tuple\(metadata\) \+ \(", + """metadata = tuple(metadata) + if all(m[0] != gapic_v1.routing_header.ROUTING_METADATA_KEY for m in metadata): + metadata += (""" + ) # ---------------------------------------------------------------------------- # Samples templates diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 9ebc403ce..6c49ca0da 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1277,7 +1277,7 @@ async def test_customizable_retryable_errors( ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), ("row_exists", (b"row_key",), "read_rows"), ("sample_row_keys", (), "sample_row_keys"), - ("mutate_row", (b"row_key", [mock.Mock()]), "mutate_row"), + ("mutate_row", (b"row_key", [mutations.DeleteAllFromRow()]), "mutate_row"), ( "bulk_mutate_rows", ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), @@ -1286,7 +1286,7 @@ async def test_customizable_retryable_errors( ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), ( "read_modify_write_row", - (b"row_key", mock.Mock()), + (b"row_key", IncrementRule("f", "q")), "read_modify_write_row", ), ], @@ -1298,31 +1298,34 @@ async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_ from google.cloud.bigtable.data import TableAsync profile = "profile" if include_app_profile else None - with mock.patch( - f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}", mock.AsyncMock() - ) as gapic_mock: - gapic_mock.side_effect = RuntimeError("stop early") - async with _make_client() as client: - table = TableAsync(client, "instance-id", "table-id", profile) - try: - test_fn = table.__getattribute__(fn_name) - maybe_stream = await test_fn(*fn_args) - [i async for i in maybe_stream] - except Exception: - # we expect an exception from attempting to call the mock - pass - kwargs = gapic_mock.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata + client = _make_client() + # create mock for rpc stub + transport_mock = mock.MagicMock() + rpc_mock = mock.AsyncMock() + transport_mock._wrapped_methods.__getitem__.return_value = rpc_mock + client._gapic_client._client._transport = transport_mock + client._gapic_client._client._is_universe_domain_valid = True + table = TableAsync(client, "instance-id", "table-id", profile) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = await test_fn(*fn_args) + [i async for i in maybe_stream] + except Exception: + # we expect an exception from attempting to call the mock + pass + assert rpc_mock.call_count == 1 + kwargs = rpc_mock.call_args_list[0].kwargs + metadata = kwargs["metadata"] + # expect single metadata entry + assert len(metadata) == 1 + # expect x-goog-request-params tag + assert metadata[0][0] == "x-goog-request-params" + routing_str = metadata[0][1] + assert "table_name=" + table.table_name in routing_str + if include_app_profile: + assert "app_profile_id=profile" in routing_str + else: + assert "app_profile_id=" not in routing_str class TestReadRows: