diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index 6f3997069e..e80ddc97ee 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -117,7 +117,10 @@ def trace_call(name, session=None, extra_attributes=None, observability_options= # invoke .record_exception on our own else we shall have 2 exceptions. raise else: - if (not span._status) or span._status.status_code == StatusCode.UNSET: + # All spans still have set_status available even if for example + # NonRecordingSpan doesn't have "_status". + absent_span_status = getattr(span, "_status", None) is None + if absent_span_status or span._status.status_code == StatusCode.UNSET: # OpenTelemetry-Python only allows a status change # if the current code is UNSET or ERROR. At the end # of the generator's consumption, only set it to OK diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index dc28644d6c..f9edbe96fa 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -583,7 +583,7 @@ def _get_streamed_result_set( iterator = _restart_on_unavailable( restart, request, - f"CloudSpanner.{type(self).__name__}.execute_streaming_sql", + f"CloudSpanner.{type(self).__name__}.execute_sql", self._session, trace_attributes, transaction=self, diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index a8aef7f470..cc59789248 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -242,39 +242,7 @@ def commit( :returns: timestamp of the committed changes. :raises ValueError: if there are no mutations to commit. """ - self._check_state() - if self._transaction_id is None and len(self._mutations) > 0: - self.begin() - elif self._transaction_id is None and len(self._mutations) == 0: - raise ValueError("Transaction is not begun") - database = self._session._database - api = database.spanner_api - metadata = _metadata_with_prefix(database.name) - if database._route_to_leader_enabled: - metadata.append( - _metadata_with_leader_aware_routing(database._route_to_leader_enabled) - ) - - if request_options is None: - request_options = RequestOptions() - elif type(request_options) is dict: - request_options = RequestOptions(request_options) - if self.transaction_tag is not None: - request_options.transaction_tag = self.transaction_tag - - # Request tags are not supported for commit requests. - request_options.request_tag = None - - request = CommitRequest( - session=self._session.name, - mutations=self._mutations, - transaction_id=self._transaction_id, - return_commit_stats=return_commit_stats, - max_commit_delay=max_commit_delay, - request_options=request_options, - ) - trace_attributes = {"num_mutations": len(self._mutations)} observability_options = getattr(database, "observability_options", None) with trace_call( @@ -283,6 +251,40 @@ def commit( trace_attributes, observability_options, ) as span: + self._check_state() + if self._transaction_id is None and len(self._mutations) > 0: + self.begin() + elif self._transaction_id is None and len(self._mutations) == 0: + raise ValueError("Transaction is not begun") + + api = database.spanner_api + metadata = _metadata_with_prefix(database.name) + if database._route_to_leader_enabled: + metadata.append( + _metadata_with_leader_aware_routing( + database._route_to_leader_enabled + ) + ) + + if request_options is None: + request_options = RequestOptions() + elif type(request_options) is dict: + request_options = RequestOptions(request_options) + if self.transaction_tag is not None: + request_options.transaction_tag = self.transaction_tag + + # Request tags are not supported for commit requests. + request_options.request_tag = None + + request = CommitRequest( + session=self._session.name, + mutations=self._mutations, + transaction_id=self._transaction_id, + return_commit_stats=return_commit_stats, + max_commit_delay=max_commit_delay, + request_options=request_options, + ) + add_span_event(span, "Starting Commit") method = functools.partial( diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index a91955496f..d40b34f800 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -111,7 +111,7 @@ def test_propagation(enable_extended_tracing): gotNames = [span.name for span in from_inject_spans] wantNames = [ "CloudSpanner.CreateSession", - "CloudSpanner.Snapshot.execute_streaming_sql", + "CloudSpanner.Snapshot.execute_sql", ] assert gotNames == wantNames @@ -239,8 +239,8 @@ def select_in_txn(txn): ("CloudSpanner.Database.run_in_transaction", codes.OK, None), ("CloudSpanner.CreateSession", codes.OK, None), ("CloudSpanner.Session.run_in_transaction", codes.OK, None), - ("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None), - ("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None), + ("CloudSpanner.Transaction.execute_sql", codes.OK, None), + ("CloudSpanner.Transaction.execute_sql", codes.OK, None), ("CloudSpanner.Transaction.commit", codes.OK, None), ] assert got_statuses == want_statuses @@ -273,6 +273,116 @@ def finished_spans_statuses(trace_exporter): return got_statuses, got_events +@pytest.mark.skipif( + not _helpers.USE_EMULATOR, + reason="Emulator needed to run this tests", +) +@pytest.mark.skipif( + not HAS_OTEL_INSTALLED, + reason="Tracing requires OpenTelemetry", +) +def test_transaction_update_implicit_begin_nested_inside_commit(): + # Tests to ensure that transaction.commit() without a began transaction + # has transaction.begin() inlined and nested under the commit span. + from google.auth.credentials import AnonymousCredentials + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.sampling import ALWAYS_ON + + PROJECT = _helpers.EMULATOR_PROJECT + CONFIGURATION_NAME = "config-name" + INSTANCE_ID = _helpers.INSTANCE_ID + DISPLAY_NAME = "display-name" + DATABASE_ID = _helpers.unique_id("temp_db") + NODE_COUNT = 5 + LABELS = {"test": "true"} + + def tx_update(txn): + txn.insert( + "Singers", + columns=["SingerId", "FirstName"], + values=[["1", "Bryan"], ["2", "Slash"]], + ) + + tracer_provider = TracerProvider(sampler=ALWAYS_ON) + trace_exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter)) + observability_options = dict( + tracer_provider=tracer_provider, + enable_extended_tracing=True, + ) + + client = Client( + project=PROJECT, + observability_options=observability_options, + credentials=AnonymousCredentials(), + ) + + instance = client.instance( + INSTANCE_ID, + CONFIGURATION_NAME, + display_name=DISPLAY_NAME, + node_count=NODE_COUNT, + labels=LABELS, + ) + + try: + instance.create() + except Exception: + pass + + db = instance.database(DATABASE_ID) + try: + db._ddl_statements = [ + """CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX), + FullName STRING(2048) AS ( + ARRAY_TO_STRING([FirstName, LastName], " ") + ) STORED + ) PRIMARY KEY (SingerId)""", + """CREATE TABLE Albums ( + SingerId INT64 NOT NULL, + AlbumId INT64 NOT NULL, + AlbumTitle STRING(MAX), + MarketingBudget INT64, + ) PRIMARY KEY (SingerId, AlbumId), + INTERLEAVE IN PARENT Singers ON DELETE CASCADE""", + ] + db.create() + except Exception: + pass + + try: + db.run_in_transaction(tx_update) + except Exception: + pass + + span_list = trace_exporter.get_finished_spans() + # Sort the spans by their start time in the hierarchy. + span_list = sorted(span_list, key=lambda span: span.start_time) + got_span_names = [span.name for span in span_list] + want_span_names = [ + "CloudSpanner.Database.run_in_transaction", + "CloudSpanner.CreateSession", + "CloudSpanner.Session.run_in_transaction", + "CloudSpanner.Transaction.commit", + "CloudSpanner.Transaction.begin", + ] + + assert got_span_names == want_span_names + + # Our object is to ensure that .begin() is a child of .commit() + span_tx_begin = span_list[-1] + span_tx_commit = span_list[-2] + assert span_tx_begin.parent.span_id == span_tx_commit.context.span_id + + @pytest.mark.skipif( not _helpers.USE_EMULATOR, reason="Emulator needed to run this test", diff --git a/tests/unit/test__opentelemetry_tracing.py b/tests/unit/test__opentelemetry_tracing.py index 1150ce7778..884928a279 100644 --- a/tests/unit/test__opentelemetry_tracing.py +++ b/tests/unit/test__opentelemetry_tracing.py @@ -159,7 +159,7 @@ def test_trace_codeless_error(self): span = span_list[0] self.assertEqual(span.status.status_code, StatusCode.ERROR) - def test_trace_call_terminal_span_status(self): + def test_trace_call_terminal_span_status_ALWAYS_ON_sampler(self): # Verify that we don't unconditionally set the terminal span status to # SpanStatus.OK per https://github.com/googleapis/python-spanner/issues/1246 from opentelemetry.sdk.trace.export import SimpleSpanProcessor @@ -195,3 +195,32 @@ def test_trace_call_terminal_span_status(self): ("VerifyTerminalSpanStatus", StatusCode.ERROR, "Our error exhibit"), ] assert got_statuses == want_statuses + + def test_trace_call_terminal_span_status_ALWAYS_OFF_sampler(self): + # Verify that we get the correct status even when using the ALWAYS_OFF + # sampler which produces the NonRecordingSpan per + # https://github.com/googleapis/python-spanner/issues/1286 + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.sampling import ALWAYS_OFF + + tracer_provider = TracerProvider(sampler=ALWAYS_OFF) + trace_exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter)) + observability_options = dict(tracer_provider=tracer_provider) + + session = _make_session() + used_span = None + with _opentelemetry_tracing.trace_call( + "VerifyWithNonRecordingSpan", + session, + observability_options=observability_options, + ) as span: + used_span = span + + assert type(used_span).__name__ == "NonRecordingSpan" + span_list = list(trace_exporter.get_finished_spans()) + assert span_list == [] diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 099bd31bea..02cc35e017 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -868,7 +868,7 @@ def test_execute_sql_other_error(self): self.assertEqual(derived._execute_sql_count, 1) self.assertSpanAttributes( - "CloudSpanner._Derived.execute_streaming_sql", + "CloudSpanner._Derived.execute_sql", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) @@ -1024,7 +1024,7 @@ def _execute_sql_helper( self.assertEqual(derived._execute_sql_count, sql_count + 1) self.assertSpanAttributes( - "CloudSpanner._Derived.execute_streaming_sql", + "CloudSpanner._Derived.execute_sql", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), ) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index d3d7035854..9707632421 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -22,6 +22,7 @@ from google.api_core import gapic_v1 from tests._helpers import ( + HAS_OPENTELEMETRY_INSTALLED, OpenTelemetryBase, StatusCode, enrich_with_otel_scope, @@ -226,7 +227,7 @@ def test_rollback_not_begun(self): transaction.rollback() self.assertTrue(transaction.rolled_back) - # Since there was no transaction to be rolled back, rollbacl rpc is not called. + # Since there was no transaction to be rolled back, rollback rpc is not called. api.rollback.assert_not_called() self.assertNoSpans() @@ -309,7 +310,27 @@ def test_commit_not_begun(self): with self.assertRaises(ValueError): transaction.commit() - self.assertNoSpans() + if not HAS_OPENTELEMETRY_INSTALLED: + return + + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Transaction.commit"] + assert got_span_names == want_span_names + + got_span_events_statuses = self.finished_spans_events_statuses() + want_span_events_statuses = [ + ( + "exception", + { + "exception.type": "ValueError", + "exception.message": "Transaction is not begun", + "exception.stacktrace": "EPHEMERAL", + "exception.escaped": "False", + }, + ) + ] + assert got_span_events_statuses == want_span_events_statuses def test_commit_already_committed(self): session = _Session() @@ -319,7 +340,27 @@ def test_commit_already_committed(self): with self.assertRaises(ValueError): transaction.commit() - self.assertNoSpans() + if not HAS_OPENTELEMETRY_INSTALLED: + return + + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Transaction.commit"] + assert got_span_names == want_span_names + + got_span_events_statuses = self.finished_spans_events_statuses() + want_span_events_statuses = [ + ( + "exception", + { + "exception.type": "ValueError", + "exception.message": "Transaction is already committed", + "exception.stacktrace": "EPHEMERAL", + "exception.escaped": "False", + }, + ) + ] + assert got_span_events_statuses == want_span_events_statuses def test_commit_already_rolled_back(self): session = _Session() @@ -329,7 +370,27 @@ def test_commit_already_rolled_back(self): with self.assertRaises(ValueError): transaction.commit() - self.assertNoSpans() + if not HAS_OPENTELEMETRY_INSTALLED: + return + + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Transaction.commit"] + assert got_span_names == want_span_names + + got_span_events_statuses = self.finished_spans_events_statuses() + want_span_events_statuses = [ + ( + "exception", + { + "exception.type": "ValueError", + "exception.message": "Transaction is already rolled back", + "exception.stacktrace": "EPHEMERAL", + "exception.escaped": "False", + }, + ) + ] + assert got_span_events_statuses == want_span_events_statuses def test_commit_w_other_error(self): database = _Database() @@ -435,6 +496,18 @@ def _commit_helper( ), ) + if not HAS_OPENTELEMETRY_INSTALLED: + return + + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Transaction.commit"] + assert got_span_names == want_span_names + + got_span_events_statuses = self.finished_spans_events_statuses() + want_span_events_statuses = [("Starting Commit", {}), ("Commit Done", {})] + assert got_span_events_statuses == want_span_events_statuses + def test_commit_no_mutations(self): self._commit_helper(mutate=False) @@ -586,6 +659,13 @@ def _execute_update_helper( ) self.assertEqual(transaction._execute_sql_count, count + 1) + want_span_attributes = dict(TestTransaction.BASE_ATTRIBUTES) + want_span_attributes["db.statement"] = DML_QUERY_WITH_PARAM + self.assertSpanAttributes( + "CloudSpanner.Transaction.execute_update", + status=StatusCode.OK, + attributes=want_span_attributes, + ) def test_execute_update_new_transaction(self): self._execute_update_helper()