Skip to content

Commit

Permalink
observability: PDML + some batch write spans (#1274)
Browse files Browse the repository at this point in the history
* observability: PDML + some batch write spans

This change adds spans for Partitioned DML and making
updates for Batch.

Carved out from PR #1241.

* Add more system tests

* Account for lack of OpenTelemetry on Python-3.7

* Update tests

* Fix more test assertions

* Updates from code review

* Update tests with code review suggestions

* Remove return per code review nit
  • Loading branch information
odeke-em authored Jan 10, 2025
1 parent 0887eb4 commit 592047f
Show file tree
Hide file tree
Showing 11 changed files with 370 additions and 194 deletions.
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
)
observability_options = getattr(database, "observability_options", None)
with trace_call(
"CloudSpanner.BatchWrite",
"CloudSpanner.batch_write",
self._session,
trace_attributes,
observability_options=observability_options,
Expand Down
222 changes: 129 additions & 93 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,38 +699,43 @@ def execute_partitioned_dml(
)

def execute_pdml():
with SessionCheckout(self._pool) as session:
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)
with trace_call(
"CloudSpanner.Database.execute_partitioned_pdml",
observability_options=self.observability_options,
) as span:
with SessionCheckout(self._pool) as session:
add_span_event(span, "Starting BeginTransaction")
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)

txn_selector = TransactionSelector(id=txn.id)
txn_selector = TransactionSelector(id=txn.id)

request = ExecuteSqlRequest(
session=session.name,
sql=dml,
params=params_pb,
param_types=param_types,
query_options=query_options,
request_options=request_options,
)
method = functools.partial(
api.execute_streaming_sql,
metadata=metadata,
)
request = ExecuteSqlRequest(
session=session.name,
sql=dml,
params=params_pb,
param_types=param_types,
query_options=query_options,
request_options=request_options,
)
method = functools.partial(
api.execute_streaming_sql,
metadata=metadata,
)

iterator = _restart_on_unavailable(
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
)
iterator = _restart_on_unavailable(
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
)

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials
result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials

return result_set.stats.row_count_lower_bound
return result_set.stats.row_count_lower_bound

return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()

Expand Down Expand Up @@ -1357,6 +1362,10 @@ def to_dict(self):
"transaction_id": snapshot._transaction_id,
}

@property
def observability_options(self):
return getattr(self._database, "observability_options", {})

def _get_session(self):
"""Create session as needed.
Expand Down Expand Up @@ -1476,27 +1485,32 @@ def generate_read_batches(
mappings of information used perform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_read(
table=table,
columns=columns,
keyset=keyset,
index=index,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.generate_read_batches",
extra_attributes=dict(table=table, columns=columns),
observability_options=self.observability_options,
):
partitions = self._get_snapshot().partition_read(
table=table,
columns=columns,
keyset=keyset,
index=index,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

read_info = {
"table": table,
"columns": columns,
"keyset": keyset._to_dict(),
"index": index,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
for partition in partitions:
yield {"partition": partition, "read": read_info.copy()}
read_info = {
"table": table,
"columns": columns,
"keyset": keyset._to_dict(),
"index": index,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
for partition in partitions:
yield {"partition": partition, "read": read_info.copy()}

def process_read_batch(
self,
Expand All @@ -1522,12 +1536,17 @@ def process_read_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
kwargs = copy.deepcopy(batch["read"])
keyset_dict = kwargs.pop("keyset")
kwargs["keyset"] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch["partition"], **kwargs, retry=retry, timeout=timeout
)
observability_options = self.observability_options
with trace_call(
f"CloudSpanner.{type(self).__name__}.process_read_batch",
observability_options=observability_options,
):
kwargs = copy.deepcopy(batch["read"])
keyset_dict = kwargs.pop("keyset")
kwargs["keyset"] = KeySet._from_dict(keyset_dict)
return self._get_snapshot().read(
partition=batch["partition"], **kwargs, retry=retry, timeout=timeout
)

def generate_query_batches(
self,
Expand Down Expand Up @@ -1602,34 +1621,39 @@ def generate_query_batches(
mappings of information used perform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_query(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.generate_query_batches",
extra_attributes=dict(sql=sql),
observability_options=self.observability_options,
):
partitions = self._get_snapshot().partition_query(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = self._database._instance._client._query_options
query_info["query_options"] = _merge_query_options(
default_query_options, query_options
)
query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = self._database._instance._client._query_options
query_info["query_options"] = _merge_query_options(
default_query_options, query_options
)

for partition in partitions:
yield {"partition": partition, "query": query_info}
for partition in partitions:
yield {"partition": partition, "query": query_info}

def process_query_batch(
self,
Expand All @@ -1654,9 +1678,16 @@ def process_query_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
return self._get_snapshot().execute_sql(
partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.process_query_batch",
observability_options=self.observability_options,
):
return self._get_snapshot().execute_sql(
partition=batch["partition"],
**batch["query"],
retry=retry,
timeout=timeout,
)

def run_partitioned_query(
self,
Expand Down Expand Up @@ -1711,18 +1742,23 @@ def run_partitioned_query(
:rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
partitions = list(
self.generate_query_batches(
sql,
params,
param_types,
partition_size_bytes,
max_partitions,
query_options,
data_boost_enabled,
with trace_call(
f"CloudSpanner.${type(self).__name__}.run_partitioned_query",
extra_attributes=dict(sql=sql),
observability_options=self.observability_options,
):
partitions = list(
self.generate_query_batches(
sql,
params,
param_types,
partition_size_bytes,
max_partitions,
query_options,
data_boost_enabled,
)
)
)
return MergedResultSet(self, partitions, 0)
return MergedResultSet(self, partitions, 0)

def process(self, batch):
"""Process a single, partitioned query or read.
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/spanner_v1/merged_result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import Any, TYPE_CHECKING
from threading import Lock, Event

from google.cloud.spanner_v1._opentelemetry_tracing import trace_call

if TYPE_CHECKING:
from google.cloud.spanner_v1.database import BatchSnapshot

Expand All @@ -37,6 +39,16 @@ def __init__(self, batch_snapshot, partition_id, merged_result_set):
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue

def run(self):
observability_options = getattr(
self._batch_snapshot, "observability_options", {}
)
with trace_call(
"CloudSpanner.PartitionExecutor.run",
observability_options=observability_options,
):
self.__run()

def __run(self):
results = None
try:
results = self._batch_snapshot.process_query_batch(self._partition_id)
Expand Down
29 changes: 9 additions & 20 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,11 @@ def bind(self, database):
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
created_session_count = 0
self._database_role = self._database_role or self._database.database_role

request = BatchCreateSessionsRequest(
database=database.name,
session_count=self.size - created_session_count,
session_count=self.size,
session_template=Session(creator_role=self.database_role),
)

Expand All @@ -549,38 +548,28 @@ def bind(self, database):
span_event_attributes,
)

if created_session_count >= self.size:
add_span_event(
current_span,
"Created no new sessions as sessionPool is full",
span_event_attributes,
)
return

add_span_event(
current_span,
f"Creating {request.session_count} sessions",
span_event_attributes,
)

observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.PingingPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
returned_session_count = 0
while created_session_count < self.size:
while returned_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)

add_span_event(
span,
f"Created {len(resp.session)} sessions",
)

for session_pb in resp.session:
session = self._new_session()
returned_session_count += 1
session._session_id = session_pb.name.split("/")[-1]
self.put(session)
returned_session_count += 1

created_session_count += len(resp.session)

add_span_event(
span,
Expand Down
Loading

0 comments on commit 592047f

Please sign in to comment.