Skip to content

Commit

Permalink
Merge branch 'main' into support-inline-begin-in-mock-server
Browse files Browse the repository at this point in the history
  • Loading branch information
olavloite authored Dec 20, 2024
2 parents d3ce683 + f2483e1 commit 3dff625
Show file tree
Hide file tree
Showing 19 changed files with 765 additions and 183 deletions.
44 changes: 44 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import math
import time
import base64
import threading

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
Expand All @@ -30,6 +31,7 @@
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1.request_id_header import with_request_id

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
Expand Down Expand Up @@ -525,3 +527,45 @@ def _metadata_with_leader_aware_routing(value, **kw):
List[Tuple[str, str]]: RPC metadata with leader aware routing header
"""
return ("x-goog-spanner-route-to-leader", str(value).lower())


class AtomicCounter:
def __init__(self, start_value=0):
self.__lock = threading.Lock()
self.__value = start_value

@property
def value(self):
with self.__lock:
return self.__value

def increment(self, n=1):
with self.__lock:
self.__value += n
return self.__value

def __iadd__(self, n):
"""
Defines the inplace += operator result.
"""
with self.__lock:
self.__value += n
return self

def __add__(self, n):
"""
Defines the result of invoking: value = AtomicCounter + addable
"""
with self.__lock:
n += self.__value
return n

def __radd__(self, n):
"""
Defines the result of invoking: value = addable + AtomicCounter
"""
return self.__add__(n)


def _metadata_with_request_id(*args, **kwargs):
return with_request_id(*args, **kwargs)
12 changes: 8 additions & 4 deletions google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def get_tracer(tracer_provider=None):


@contextmanager
def trace_call(name, session, extra_attributes=None, observability_options=None):
def trace_call(name, session=None, extra_attributes=None, observability_options=None):
if session:
session._last_use_time = datetime.now()

if not HAS_OPENTELEMETRY_INSTALLED or not session:
if not (HAS_OPENTELEMETRY_INSTALLED and name):
# Empty context manager. Users will have to check if the generated value is None or a span
yield None
return
Expand All @@ -72,20 +72,24 @@ def trace_call(name, session, extra_attributes=None, observability_options=None)
# on by default.
enable_extended_tracing = True

db_name = ""
if session and getattr(session, "_database", None):
db_name = session._database.name

if isinstance(observability_options, dict): # Avoid false positives with mock.Mock
tracer_provider = observability_options.get("tracer_provider", None)
enable_extended_tracing = observability_options.get(
"enable_extended_tracing", enable_extended_tracing
)
db_name = observability_options.get("db_name", db_name)

tracer = get_tracer(tracer_provider)

# Set base attributes that we know for every trace created
db = session._database
attributes = {
"db.type": "spanner",
"db.url": SpannerClient.DEFAULT_ENDPOINT,
"db.instance": "" if not db else db.name,
"db.instance": db_name,
"net.host.name": SpannerClient.DEFAULT_ENDPOINT,
OTEL_SCOPE_NAME: TRACER_NAME,
OTEL_SCOPE_VERSION: TRACER_VERSION,
Expand Down
12 changes: 11 additions & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def insert(self, table, columns, values):
:param values: Values to be modified.
"""
self._mutations.append(Mutation(insert=_make_write_pb(table, columns, values)))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def update(self, table, columns, values):
"""Update one or more existing table rows.
Expand All @@ -84,6 +86,8 @@ def update(self, table, columns, values):
:param values: Values to be modified.
"""
self._mutations.append(Mutation(update=_make_write_pb(table, columns, values)))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def insert_or_update(self, table, columns, values):
"""Insert/update one or more table rows.
Expand All @@ -100,6 +104,8 @@ def insert_or_update(self, table, columns, values):
self._mutations.append(
Mutation(insert_or_update=_make_write_pb(table, columns, values))
)
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def replace(self, table, columns, values):
"""Replace one or more table rows.
Expand All @@ -114,6 +120,8 @@ def replace(self, table, columns, values):
:param values: Values to be modified.
"""
self._mutations.append(Mutation(replace=_make_write_pb(table, columns, values)))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def delete(self, table, keyset):
"""Delete one or more table rows.
Expand All @@ -126,6 +134,8 @@ def delete(self, table, keyset):
"""
delete = Mutation.Delete(table=table, key_set=keyset._to_pb())
self._mutations.append(Mutation(delete=delete))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269


class Batch(_BatchBase):
Expand Down Expand Up @@ -207,7 +217,7 @@ def commit(
)
observability_options = getattr(database, "observability_options", None)
with trace_call(
"CloudSpanner.Commit",
f"CloudSpanner.{type(self).__name__}.commit",
self._session,
trace_attributes,
observability_options=observability_options,
Expand Down
42 changes: 27 additions & 15 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from google.cloud.spanner_v1._opentelemetry_tracing import (
add_span_event,
get_current_span,
trace_call,
)


Expand Down Expand Up @@ -720,6 +721,7 @@ def execute_pdml():

iterator = _restart_on_unavailable(
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
Expand Down Expand Up @@ -881,20 +883,25 @@ def run_in_transaction(self, func, *args, **kw):
:raises Exception:
reraises any non-ABORT exceptions raised by ``func``.
"""
# Sanity check: Is there a transaction already running?
# If there is, then raise a red flag. Otherwise, mark that this one
# is running.
if getattr(self._local, "transaction_running", False):
raise RuntimeError("Spanner does not support nested transactions.")
self._local.transaction_running = True

# Check out a session and run the function in a transaction; once
# done, flip the sanity check bit back.
try:
with SessionCheckout(self._pool) as session:
return session.run_in_transaction(func, *args, **kw)
finally:
self._local.transaction_running = False
observability_options = getattr(self, "observability_options", None)
with trace_call(
"CloudSpanner.Database.run_in_transaction",
observability_options=observability_options,
):
# Sanity check: Is there a transaction already running?
# If there is, then raise a red flag. Otherwise, mark that this one
# is running.
if getattr(self._local, "transaction_running", False):
raise RuntimeError("Spanner does not support nested transactions.")
self._local.transaction_running = True

# Check out a session and run the function in a transaction; once
# done, flip the sanity check bit back.
try:
with SessionCheckout(self._pool) as session:
return session.run_in_transaction(func, *args, **kw)
finally:
self._local.transaction_running = False

def restore(self, source):
"""Restore from a backup to this database.
Expand Down Expand Up @@ -1120,7 +1127,12 @@ def observability_options(self):
if not (self._instance and self._instance._client):
return None

return getattr(self._instance._client, "observability_options", None)
opts = getattr(self._instance._client, "observability_options", None)
if not opts:
opts = dict()

opts["db_name"] = self.name
return opts


class BatchCheckout(object):
Expand Down
90 changes: 54 additions & 36 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.cloud.spanner_v1._opentelemetry_tracing import (
add_span_event,
get_current_span,
trace_call,
)
from warnings import warn

Expand Down Expand Up @@ -237,29 +238,41 @@ def bind(self, database):
session_template=Session(creator_role=self.database_role),
)

returned_session_count = 0
while not self._sessions.full():
request.session_count = requested_session_count - self._sessions.qsize()
observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.FixedPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
returned_session_count = 0
while not self._sessions.full():
request.session_count = requested_session_count - self._sessions.qsize()
add_span_event(
span,
f"Creating {request.session_count} sessions",
span_event_attributes,
)
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)

add_span_event(
span,
"Created sessions",
dict(count=len(resp.session)),
)

for session_pb in resp.session:
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
self._sessions.put(session)
returned_session_count += 1

add_span_event(
span,
f"Creating {request.session_count} sessions",
f"Requested for {requested_session_count} sessions, returned {returned_session_count}",
span_event_attributes,
)
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)
for session_pb in resp.session:
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
self._sessions.put(session)
returned_session_count += 1

add_span_event(
span,
f"Requested for {requested_session_count} sessions, returned {returned_session_count}",
span_event_attributes,
)

def get(self, timeout=None):
"""Check a session out from the pool.
Expand Down Expand Up @@ -550,25 +563,30 @@ def bind(self, database):
span_event_attributes,
)

returned_session_count = 0
while created_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)
for session_pb in resp.session:
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
self.put(session)
returned_session_count += 1
observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.PingingPool.BatchCreateSessions",
observability_options=observability_options,
) as span:
returned_session_count = 0
while created_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
)
for session_pb in resp.session:
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
self.put(session)
returned_session_count += 1

created_session_count += len(resp.session)
created_session_count += len(resp.session)

add_span_event(
current_span,
f"Requested for {requested_session_count} sessions, return {returned_session_count}",
span_event_attributes,
)
add_span_event(
span,
f"Requested for {requested_session_count} sessions, returned {returned_session_count}",
span_event_attributes,
)

def get(self, timeout=None):
"""Check a session out from the pool.
Expand Down
42 changes: 42 additions & 0 deletions google/cloud/spanner_v1/request_id_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

REQ_ID_VERSION = 1 # The version of the x-goog-spanner-request-id spec.
REQ_ID_HEADER_KEY = "x-goog-spanner-request-id"


def generate_rand_uint64():
b = os.urandom(8)
return (
b[7] & 0xFF
| (b[6] & 0xFF) << 8
| (b[5] & 0xFF) << 16
| (b[4] & 0xFF) << 24
| (b[3] & 0xFF) << 32
| (b[2] & 0xFF) << 36
| (b[1] & 0xFF) << 48
| (b[0] & 0xFF) << 56
)


REQ_RAND_PROCESS_ID = generate_rand_uint64()


def with_request_id(client_id, channel_id, nth_request, attempt, other_metadata=[]):
req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}"
all_metadata = other_metadata.copy()
all_metadata.append((REQ_ID_HEADER_KEY, req_id))
return all_metadata
Loading

0 comments on commit 3dff625

Please sign in to comment.