Skip to content

Commit

Permalink
add cloud trace propagator (#819)
Browse files Browse the repository at this point in the history
Adding initial cloud trace propagator

Co-authored-by: Aaron Abbott <aaronabbott@google.com>
Co-authored-by: Diego Hurtado <ocelotl@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 17, 2020
1 parent a284367 commit 9d74918
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
from opentelemetry.sdk.trace import Event
from opentelemetry.sdk.trace.export import Span, SpanExporter, SpanExportResult
from opentelemetry.sdk.util import BoundedDict
from opentelemetry.trace.span import (
get_hexadecimal_span_id,
get_hexadecimal_trace_id,
)
from opentelemetry.util import types

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -123,15 +127,15 @@ def _translate_to_cloud_trace(

for span in spans:
ctx = span.get_context()
trace_id = _get_hexadecimal_trace_id(ctx.trace_id)
span_id = _get_hexadecimal_span_id(ctx.span_id)
trace_id = get_hexadecimal_trace_id(ctx.trace_id)
span_id = get_hexadecimal_span_id(ctx.span_id)
span_name = "projects/{}/traces/{}/spans/{}".format(
self.project_id, trace_id, span_id
)

parent_id = None
if span.parent:
parent_id = _get_hexadecimal_span_id(span.parent.span_id)
parent_id = get_hexadecimal_span_id(span.parent.span_id)

start_time = _get_time_from_ns(span.start_time)
end_time = _get_time_from_ns(span.end_time)
Expand Down Expand Up @@ -169,14 +173,6 @@ def shutdown(self):
pass


def _get_hexadecimal_trace_id(trace_id: int) -> str:
return "{:032x}".format(trace_id)


def _get_hexadecimal_span_id(span_id: int) -> str:
return "{:016x}".format(span_id)


def _get_time_from_ns(nanoseconds: int) -> Dict:
"""Given epoch nanoseconds, split into epoch milliseconds and remaining
nanoseconds"""
Expand Down Expand Up @@ -234,8 +230,8 @@ def _extract_links(links: Sequence[trace_api.Link]) -> ProtoSpan.Links:
"Link has more then %s attributes, some will be truncated",
MAX_LINK_ATTRS,
)
trace_id = _get_hexadecimal_trace_id(link.context.trace_id)
span_id = _get_hexadecimal_span_id(link.context.span_id)
trace_id = get_hexadecimal_trace_id(link.context.trace_id)
span_id = get_hexadecimal_span_id(link.context.span_id)
extracted_links.append(
{
"trace_id": trace_id,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import re
import typing

import opentelemetry.trace as trace
from opentelemetry.context.context import Context
from opentelemetry.trace.propagation import httptextformat
from opentelemetry.trace.span import (
SpanContext,
TraceFlags,
get_hexadecimal_trace_id,
)

_TRACE_CONTEXT_HEADER_NAME = "X-Cloud-Trace-Context"
_TRACE_CONTEXT_HEADER_FORMAT = r"(?P<trace_id>[0-9a-f]{32})\/(?P<span_id>[\d]{1,20});o=(?P<trace_flags>\d+)"
_TRACE_CONTEXT_HEADER_RE = re.compile(_TRACE_CONTEXT_HEADER_FORMAT)


class CloudTraceFormatPropagator(httptextformat.HTTPTextFormat):
"""This class is for injecting into a carrier the SpanContext in Google
Cloud format, or extracting the SpanContext from a carrier using Google
Cloud format.
"""

def extract(
self,
get_from_carrier: httptextformat.Getter[
httptextformat.HTTPTextFormatT
],
carrier: httptextformat.HTTPTextFormatT,
context: typing.Optional[Context] = None,
) -> Context:
header = get_from_carrier(carrier, _TRACE_CONTEXT_HEADER_NAME)

if not header:
return trace.set_span_in_context(trace.INVALID_SPAN, context)

match = re.fullmatch(_TRACE_CONTEXT_HEADER_RE, header[0])
if match is None:
return trace.set_span_in_context(trace.INVALID_SPAN, context)

trace_id = match.group("trace_id")
span_id = match.group("span_id")
trace_options = match.group("trace_flags")

if trace_id == "0" * 32 or int(span_id) == 0:
return trace.set_span_in_context(trace.INVALID_SPAN, context)

span_context = SpanContext(
trace_id=int(trace_id, 16),
span_id=int(span_id),
is_remote=True,
trace_flags=TraceFlags(trace_options),
)
return trace.set_span_in_context(
trace.DefaultSpan(span_context), context
)

def inject(
self,
set_in_carrier: httptextformat.Setter[httptextformat.HTTPTextFormatT],
carrier: httptextformat.HTTPTextFormatT,
context: typing.Optional[Context] = None,
) -> None:
span = trace.get_current_span(context)
if span is None:
return
span_context = span.get_context()
if span_context == trace.INVALID_SPAN_CONTEXT:
return

header = "{}/{};o={}".format(
get_hexadecimal_trace_id(span_context.trace_id),
span_context.span_id,
int(span_context.trace_flags.sampled),
)
set_in_carrier(carrier, _TRACE_CONTEXT_HEADER_NAME, header)
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import typing
import unittest

import opentelemetry.trace as trace
from opentelemetry.context import get_current
from opentelemetry.exporter.cloud_trace.cloud_trace_propagator import (
_TRACE_CONTEXT_HEADER_NAME,
CloudTraceFormatPropagator,
)
from opentelemetry.trace.span import (
INVALID_SPAN_ID,
INVALID_TRACE_ID,
SpanContext,
TraceFlags,
get_hexadecimal_trace_id,
)


def get_dict_value(dict_object: typing.Dict[str, str], key: str) -> str:
return dict_object.get(key, "")


class TestCloudTraceFormatPropagator(unittest.TestCase):
def setUp(self):
self.propagator = CloudTraceFormatPropagator()
self.valid_trace_id = 281017822499060589596062859815111849546
self.valid_span_id = 17725314949316355921
self.too_long_id = 111111111111111111111111111111111111111111111

def _extract(self, header_value):
"""Test helper"""
header = {_TRACE_CONTEXT_HEADER_NAME: [header_value]}
new_context = self.propagator.extract(get_dict_value, header)
return trace.get_current_span(new_context).get_context()

def _inject(self, span=None):
"""Test helper"""
ctx = get_current()
if span is not None:
ctx = trace.set_span_in_context(span, ctx)
output = {}
self.propagator.inject(dict.__setitem__, output, context=ctx)
return output.get(_TRACE_CONTEXT_HEADER_NAME)

def test_no_context_header(self):
header = {}
new_context = self.propagator.extract(get_dict_value, header)
self.assertEqual(
trace.get_current_span(new_context).get_context(),
trace.INVALID_SPAN.get_context(),
)

def test_empty_context_header(self):
header = ""
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

def test_valid_header(self):
header = "{}/{};o=1".format(
get_hexadecimal_trace_id(self.valid_trace_id), self.valid_span_id
)
new_span_context = self._extract(header)
self.assertEqual(new_span_context.trace_id, self.valid_trace_id)
self.assertEqual(new_span_context.span_id, self.valid_span_id)
self.assertEqual(new_span_context.trace_flags, TraceFlags(1))
self.assertTrue(new_span_context.is_remote)

header = "{}/{};o=10".format(
get_hexadecimal_trace_id(self.valid_trace_id), self.valid_span_id
)
new_span_context = self._extract(header)
self.assertEqual(new_span_context.trace_id, self.valid_trace_id)
self.assertEqual(new_span_context.span_id, self.valid_span_id)
self.assertEqual(new_span_context.trace_flags, TraceFlags(10))
self.assertTrue(new_span_context.is_remote)

header = "{}/{};o=0".format(
get_hexadecimal_trace_id(self.valid_trace_id), self.valid_span_id
)
new_span_context = self._extract(header)
self.assertEqual(new_span_context.trace_id, self.valid_trace_id)
self.assertEqual(new_span_context.span_id, self.valid_span_id)
self.assertEqual(new_span_context.trace_flags, TraceFlags(0))
self.assertTrue(new_span_context.is_remote)

header = "{}/{};o=0".format(
get_hexadecimal_trace_id(self.valid_trace_id), 345
)
new_span_context = self._extract(header)
self.assertEqual(new_span_context.trace_id, self.valid_trace_id)
self.assertEqual(new_span_context.span_id, 345)
self.assertEqual(new_span_context.trace_flags, TraceFlags(0))
self.assertTrue(new_span_context.is_remote)

def test_invalid_header_format(self):
header = "invalid_header"
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "{}/{};o=".format(
get_hexadecimal_trace_id(self.valid_trace_id), self.valid_span_id
)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "extra_chars/{}/{};o=1".format(
get_hexadecimal_trace_id(self.valid_trace_id), self.valid_span_id
)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "{}/{}extra_chars;o=1".format(
get_hexadecimal_trace_id(self.valid_trace_id), self.valid_span_id
)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "{}/{};o=1extra_chars".format(
get_hexadecimal_trace_id(self.valid_trace_id), self.valid_span_id
)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "{}/;o=1".format(
get_hexadecimal_trace_id(self.valid_trace_id)
)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "/{};o=1".format(self.valid_span_id)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "{}/{};o={}".format("123", "34", "4")
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

def test_invalid_trace_id(self):
header = "{}/{};o={}".format(INVALID_TRACE_ID, self.valid_span_id, 1)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)
header = "{}/{};o={}".format("0" * 32, self.valid_span_id, 1)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "0/{};o={}".format(self.valid_span_id, 1)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "234/{};o={}".format(self.valid_span_id, 1)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "{}/{};o={}".format(self.too_long_id, self.valid_span_id, 1)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

def test_invalid_span_id(self):
header = "{}/{};o={}".format(
get_hexadecimal_trace_id(self.valid_trace_id), INVALID_SPAN_ID, 1
)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "{}/{};o={}".format(
get_hexadecimal_trace_id(self.valid_trace_id), "0" * 16, 1
)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "{}/{};o={}".format(
get_hexadecimal_trace_id(self.valid_trace_id), "0", 1
)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

header = "{}/{};o={}".format(
get_hexadecimal_trace_id(self.valid_trace_id), self.too_long_id, 1
)
self.assertEqual(
self._extract(header), trace.INVALID_SPAN.get_context()
)

def test_inject_with_no_context(self):
output = self._inject()
self.assertIsNone(output)

def test_inject_with_invalid_context(self):
output = self._inject(trace.INVALID_SPAN)
self.assertIsNone(output)

def test_inject_with_valid_context(self):
span_context = SpanContext(
trace_id=self.valid_trace_id,
span_id=self.valid_span_id,
is_remote=True,
trace_flags=TraceFlags(1),
)
output = self._inject(trace.DefaultSpan(span_context))
self.assertEqual(
output,
"{}/{};o={}".format(
get_hexadecimal_trace_id(self.valid_trace_id),
self.valid_span_id,
1,
),
)
8 changes: 8 additions & 0 deletions opentelemetry-api/src/opentelemetry/trace/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,11 @@ def format_trace_id(trace_id: int) -> str:

def format_span_id(span_id: int) -> str:
return "0x{:016x}".format(span_id)


def get_hexadecimal_trace_id(trace_id: int) -> str:
return "{:032x}".format(trace_id)


def get_hexadecimal_span_id(span_id: int) -> str:
return "{:016x}".format(span_id)

0 comments on commit 9d74918

Please sign in to comment.