Skip to content

Commit

Permalink
Add Datadog propagator
Browse files Browse the repository at this point in the history
  • Loading branch information
majorgreys committed May 18, 2020
1 parent 5913d4a commit f0b691e
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import typing

from opentelemetry import propagators, trace
from opentelemetry.context import Context
from opentelemetry.trace.propagation import (
get_span_from_context,
set_span_in_context,
)
from opentelemetry.trace.propagation.httptextformat import (
Getter,
HTTPTextFormat,
HTTPTextFormatT,
Setter,
)
from opentelemetry.trace.propagation.tracecontexthttptextformat import (
TraceContextHTTPTextFormat,
)


class DatadogFormat(HTTPTextFormat):
"""Propagator for the Datadog HTTP header format.
"""

TRACE_ID_KEY = "x-datadog-trace-id"
PARENT_ID_KEY = "x-datadog-parent-id"
SAMPLING_PRIORITY_KEY = "x-datadog-sampling-priority"
ORIGIN_KEY = "x-datadog-origin"

def extract(
self,
get_from_carrier: Getter[HTTPTextFormatT],
carrier: HTTPTextFormatT,
context: typing.Optional[Context] = None,
) -> Context:
trace_id = (
extract_first_element(get_from_carrier(carrier, self.TRACE_ID_KEY))
)

span_id = (
extract_first_element(get_from_carrier(carrier, self.PARENT_ID_KEY))
)

# TODO: add sampling
# TODO: add origin

if trace_id is None or span_id is None:
return set_span_in_context(trace.INVALID_SPAN, context)

span_context = trace.SpanContext(
trace_id=int(trace_id), span_id=int(span_id), is_remote=True,
)

return set_span_in_context(trace.DefaultSpan(span_context), context)

def inject(
self,
set_in_carrier: Setter[HTTPTextFormatT],
carrier: HTTPTextFormatT,
context: typing.Optional[Context] = None,
) -> None:
span = get_span_from_context(context=context)
set_in_carrier(
carrier, self.TRACE_ID_KEY, format_trace_id(span.context.trace_id),
)
set_in_carrier(
carrier, self.PARENT_ID_KEY, format_span_id(span.context.span_id)
)


def format_trace_id(trace_id: int) -> str:
"""Format the trace id according to b3 specification."""
return str(trace_id & 0xFFFFFFFFFFFFFFFF)


def format_span_id(span_id: int) -> str:
"""Format the span id according to b3 specification."""
return str(span_id)


def extract_first_element(
items: typing.Iterable[HTTPTextFormatT],
) -> typing.Optional[HTTPTextFormatT]:
if items is None:
return None
return next(iter(items), None)
115 changes: 115 additions & 0 deletions ext/opentelemetry-ext-datadog/tests/test_datadog_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing
import unittest

from opentelemetry import trace as trace_api
from opentelemetry.ext.datadog import propagator
from opentelemetry.sdk import trace as trace
from opentelemetry.trace.propagation import (
get_span_from_context,
set_span_in_context,
)

FORMAT = propagator.DatadogFormat()


def get_as_list(dict_object, key):
value = dict_object.get(key)
return [value] if value is not None else []


class TestDatadogFormat(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.serialized_trace_id = propagator.format_trace_id(
trace.generate_trace_id()
)
cls.serialized_parent_id = propagator.format_span_id(
trace.generate_span_id()
)

def test_malformed_headers(self):
"""Test with no Datadog headers"""
malformed_trace_id_key = FORMAT.TRACE_ID_KEY + "-x"
malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x"
context = get_span_from_context(
FORMAT.extract(
get_as_list,
{
malformed_trace_id_key: self.serialized_trace_id,
malformed_parent_id_key: self.serialized_parent_id,
},
)
).get_context()

self.assertNotEqual(context.trace_id, int(self.serialized_trace_id))
self.assertNotEqual(context.span_id, int(self.serialized_parent_id))
self.assertFalse(context.is_remote)

def test_missing_trace_id(self):
"""If a trace id is missing, populate an invalid trace id."""
carrier = {
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
}

ctx = FORMAT.extract(get_as_list, carrier)
span_context = get_span_from_context(ctx).get_context()
self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)

def test_missing_parent_id(self):
"""If a parent id is missing, populate an invalid trace id."""
carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
}

ctx = FORMAT.extract(get_as_list, carrier)
span_context = get_span_from_context(ctx).get_context()
self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID)

def test_context_propagation(self):
"""Test the propagation of Datadog headers."""
parent_context = get_span_from_context(
FORMAT.extract(
get_as_list,
{
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
},
)
).get_context()

self.assertEqual(parent_context.trace_id, int(self.serialized_trace_id))
self.assertEqual(parent_context.span_id, int(self.serialized_parent_id))
self.assertTrue(parent_context.is_remote)

child = trace.Span(
"child",
trace_api.SpanContext(
parent_context.trace_id,
trace.generate_span_id(),
is_remote=False,
trace_flags=parent_context.trace_flags,
trace_state=parent_context.trace_state,
),
parent=parent_context,
)

child_carrier = {}
child_context = set_span_in_context(child)
FORMAT.inject(dict.__setitem__, child_carrier, context=child_context)

self.assertEqual(child_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id)
self.assertEqual(child_carrier[FORMAT.PARENT_ID_KEY], str(child.context.span_id))

0 comments on commit f0b691e

Please sign in to comment.