From f3b7b96b1ddaf7194253e3233b9124c73a19840a Mon Sep 17 00:00:00 2001 From: harrisonchu Date: Wed, 2 Oct 2024 19:53:07 -0400 Subject: [PATCH] feat(anthropic): streaming support (#990) --- .../pyproject.toml | 1 + .../instrumentation/anthropic/_stream.py | 483 ++++++++++++++++++ .../instrumentation/anthropic/_utils.py | 79 +++ .../instrumentation/anthropic/_with_span.py | 92 ++++ .../instrumentation/anthropic/_wrappers.py | 146 ++++-- ...mentation_async_completions_streaming.yaml | 27 + ...trumentation_async_messages_streaming.yaml | 78 +++ ...instrumentation_completions_streaming.yaml | 27 + ...ic_instrumentation_messages_streaming.yaml | 77 +++ ...ation_multiple_tool_calling_streaming.yaml | 230 +++++++++ .../anthropic/test_instrumentor.py | 323 ++++++++++++ 11 files changed, 1512 insertions(+), 51 deletions(-) create mode 100644 python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_stream.py create mode 100644 python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_utils.py create mode 100644 python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_with_span.py create mode 100644 python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_async_completions_streaming.yaml create mode 100644 python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_async_messages_streaming.yaml create mode 100644 python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_completions_streaming.yaml create mode 100644 python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_messages_streaming.yaml create mode 100644 python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_multiple_tool_calling_streaming.yaml diff --git a/python/instrumentation/openinference-instrumentation-anthropic/pyproject.toml b/python/instrumentation/openinference-instrumentation-anthropic/pyproject.toml index 7f6308031..f8a926148 100644 --- a/python/instrumentation/openinference-instrumentation-anthropic/pyproject.toml +++ b/python/instrumentation/openinference-instrumentation-anthropic/pyproject.toml @@ -41,6 +41,7 @@ instruments = [ test = [ "anthropic >= 0.25.0", "opentelemetry-sdk", + "pytest-asyncio", "pytest.recording" ] diff --git a/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_stream.py b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_stream.py new file mode 100644 index 000000000..a05254853 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_stream.py @@ -0,0 +1,483 @@ +from collections import defaultdict +from copy import deepcopy +from typing import ( + Any, + AsyncIterator, + Callable, + DefaultDict, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from opentelemetry import trace as trace_api +from opentelemetry.util.types import AttributeValue +from wrapt import ObjectProxy + +from anthropic import Stream +from anthropic.types import ( + Completion, + RawMessageStreamEvent, +) +from anthropic.types.raw_content_block_delta_event import RawContentBlockDeltaEvent +from anthropic.types.raw_content_block_start_event import RawContentBlockStartEvent +from anthropic.types.raw_message_delta_event import RawMessageDeltaEvent +from anthropic.types.raw_message_start_event import RawMessageStartEvent +from anthropic.types.text_block import TextBlock +from anthropic.types.tool_use_block import ToolUseBlock +from openinference.instrumentation import safe_json_dumps +from openinference.instrumentation.anthropic._utils import ( + _as_output_attributes, + _finish_tracing, + _ValueAndType, +) +from openinference.instrumentation.anthropic._with_span import _WithSpan +from openinference.semconv.trace import ( + MessageAttributes, + OpenInferenceMimeTypeValues, + SpanAttributes, + ToolCallAttributes, +) + + +class _Stream(ObjectProxy): # type: ignore + __slots__ = ( + "_response_accumulator", + "_with_span", + "_is_finished", + ) + + def __init__( + self, + stream: Stream[Completion], + with_span: _WithSpan, + ) -> None: + super().__init__(stream) + self._response_accumulator = _ResponseAccumulator() + self._with_span = with_span + + def __iter__(self) -> Iterator[Completion]: + try: + for item in self.__wrapped__: + self._response_accumulator.process_chunk(item) + yield item + except Exception as exception: + status = trace_api.Status( + status_code=trace_api.StatusCode.ERROR, + description=f"{type(exception).__name__}: {exception}", + ) + self._with_span.record_exception(exception) + self._finish_tracing(status=status) + raise + # completed without exception + status = trace_api.Status( + status_code=trace_api.StatusCode.OK, + ) + self._finish_tracing(status=status) + + async def __aiter__(self) -> AsyncIterator[Completion]: + try: + async for item in self.__wrapped__: + self._response_accumulator.process_chunk(item) + yield item + except Exception as exception: + status = trace_api.Status( + status_code=trace_api.StatusCode.ERROR, + description=f"{type(exception).__name__}: {exception}", + ) + self._with_span.record_exception(exception) + self._finish_tracing(status=status) + raise + # completed without exception + status = trace_api.Status( + status_code=trace_api.StatusCode.OK, + ) + self._finish_tracing(status=status) + + def _finish_tracing( + self, + status: Optional[trace_api.Status] = None, + ) -> None: + _finish_tracing( + with_span=self._with_span, + has_attributes=_ResponseExtractor(response_accumulator=self._response_accumulator), + status=status, + ) + + +class _ResponseAccumulator: + __slots__ = ( + "_is_null", + "_values", + ) + + def __init__(self) -> None: + self._is_null = True + self._values = _ValuesAccumulator( + completion=_StringAccumulator(), + stop=_SimpleStringReplace(), + stop_reason=_SimpleStringReplace(), + ) + + def process_chunk(self, chunk: Completion) -> None: + self._is_null = False + values = chunk.model_dump(exclude_unset=True, warnings=False) + self._values += values + + def _result(self) -> Optional[Dict[str, Any]]: + if self._is_null: + return None + return dict(self._values) + + +class _ResponseExtractor: + __slots__ = ("_response_accumulator",) + + def __init__( + self, + response_accumulator: _ResponseAccumulator, + ) -> None: + self._response_accumulator = response_accumulator + + def get_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if not (result := self._response_accumulator._result()): + return + json_string = safe_json_dumps(result) + yield from _as_output_attributes( + _ValueAndType(json_string, OpenInferenceMimeTypeValues.JSON) + ) + + def get_extra_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if not (result := self._response_accumulator._result()): + return + if completion := result.get("completion", ""): + yield SpanAttributes.LLM_OUTPUT_MESSAGES, completion + + +class _MessagesStream(ObjectProxy): # type: ignore + __slots__ = ( + "_response_accumulator", + "_with_span", + "_is_finished", + ) + + def __init__( + self, + stream: Stream[RawMessageStreamEvent], + with_span: _WithSpan, + ) -> None: + super().__init__(stream) + self._response_accumulator = _MessageResponseAccumulator() + self._with_span = with_span + + def __iter__(self) -> Iterator[RawMessageStreamEvent]: + try: + for item in self.__wrapped__: + self._response_accumulator.process_chunk(item) + yield item + except Exception as exception: + status = trace_api.Status( + status_code=trace_api.StatusCode.ERROR, + description=f"{type(exception).__name__}: {exception}", + ) + self._with_span.record_exception(exception) + self._finish_tracing(status=status) + raise + # completed without exception + status = trace_api.Status( + status_code=trace_api.StatusCode.OK, + ) + self._finish_tracing(status=status) + + async def __aiter__(self) -> AsyncIterator[RawMessageStreamEvent]: + try: + async for item in self.__wrapped__: + self._response_accumulator.process_chunk(item) + yield item + except Exception as exception: + status = trace_api.Status( + status_code=trace_api.StatusCode.ERROR, + description=f"{type(exception).__name__}: {exception}", + ) + self._with_span.record_exception(exception) + self._finish_tracing(status=status) + raise + # completed without exception + status = trace_api.Status( + status_code=trace_api.StatusCode.OK, + ) + self._finish_tracing(status=status) + + def _finish_tracing( + self, + status: Optional[trace_api.Status] = None, + ) -> None: + _finish_tracing( + with_span=self._with_span, + has_attributes=_MessageResponseExtractor( + response_accumulator=self._response_accumulator + ), + status=status, + ) + + +class _MessageResponseAccumulator: + __slots__ = ( + "_is_null", + "_values", + "_current_message_idx", + "_current_content_block_type", + ) + + def __init__(self) -> None: + self._is_null = True + self._current_message_idx = -1 + self._current_content_block_type: Union[TextBlock, ToolUseBlock, None] = None + self._values = _ValuesAccumulator( + messages=_IndexedAccumulator( + lambda: _ValuesAccumulator( + role=_SimpleStringReplace(), + content=_IndexedAccumulator( + lambda: _ValuesAccumulator( + type=_SimpleStringReplace(), + text=_StringAccumulator(), + tool_name=_SimpleStringReplace(), + tool_input=_StringAccumulator(), + ), + ), + stop_reason=_SimpleStringReplace(), + input_tokens=_SimpleStringReplace(), + output_tokens=_SimpleStringReplace(), + ), + ), + ) + + def process_chunk(self, chunk: RawContentBlockDeltaEvent) -> None: + self._is_null = False + if isinstance(chunk, RawMessageStartEvent): + self._current_message_idx += 1 + value = { + "messages": { + "index": str(self._current_message_idx), + "role": chunk.message.role, + "input_tokens": str(chunk.message.usage.input_tokens), + } + } + self._values += value + elif isinstance(chunk, RawContentBlockStartEvent): + self._current_content_block_type = chunk.content_block + elif isinstance(chunk, RawContentBlockDeltaEvent): + if isinstance(self._current_content_block_type, TextBlock): + value = { + "messages": { + "index": str(self._current_message_idx), + "content": { + "index": chunk.index, + "type": self._current_content_block_type.type, + "text": chunk.delta.text, # type: ignore + }, + } + } + self._values += value + elif isinstance(self._current_content_block_type, ToolUseBlock): + value = { + "messages": { + "index": str(self._current_message_idx), + "content": { + "index": chunk.index, + "type": self._current_content_block_type.type, + "tool_name": self._current_content_block_type.name, + "tool_input": chunk.delta.partial_json, # type: ignore + }, + } + } + self._values += value + elif isinstance(chunk, RawMessageDeltaEvent): + value = { + "messages": { + "index": str(self._current_message_idx), + "stop_reason": chunk.delta.stop_reason, + "output_tokens": str(chunk.usage.output_tokens), + } + } + self._values += value + + def _result(self) -> Optional[Dict[str, Any]]: + if self._is_null: + return None + return dict(self._values) + + +class _MessageResponseExtractor: + __slots__ = ("_response_accumulator",) + + def __init__( + self, + response_accumulator: _MessageResponseAccumulator, + ) -> None: + self._response_accumulator = response_accumulator + + def get_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if not (result := self._response_accumulator._result()): + return + json_string = safe_json_dumps(result) + yield from _as_output_attributes( + _ValueAndType(json_string, OpenInferenceMimeTypeValues.JSON) + ) + + def get_extra_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if not (result := self._response_accumulator._result()): + return + messages = result.get("messages", []) + idx = 0 + total_completion_token_count = 0 + total_prompt_token_count = 0 + # TODO(harrison): figure out if we should always assume messages is 1. + # The current non streaming implementation assumes the same + for message in messages: + if role := message.get("role"): + yield ( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_ROLE}", + role, + ) + if output_tokens := message.get("output_tokens"): + total_completion_token_count += int(output_tokens) + if input_tokens := message.get("input_tokens"): + total_prompt_token_count += int(input_tokens) + + # TODO(harrison): figure out if we should always assume the first message + # will always be a message output generally this block feels really + # brittle to imitate the current non streaming implementation. + tool_idx = 0 + for content in message.get("content", []): + # this is the current assumption of the non streaming implementation. + if (content_type := content.get("type")) == "text": + yield ( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_CONTENT}", + content.get("text", ""), + ) + elif content_type == "tool_use": + yield ( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_idx}.{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}", + content.get("tool_name", ""), + ) + yield ( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}.{MessageAttributes.MESSAGE_TOOL_CALLS}.{tool_idx}.{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", + content.get("tool_input", "{}"), + ) + tool_idx += 1 + idx += 1 + yield SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, total_completion_token_count + yield SpanAttributes.LLM_TOKEN_COUNT_PROMPT, total_prompt_token_count + yield ( + SpanAttributes.LLM_TOKEN_COUNT_TOTAL, + total_completion_token_count + total_prompt_token_count, + ) + + +class _ValuesAccumulator: + __slots__ = ("_values",) + + def __init__(self, **values: Any) -> None: + self._values: Dict[str, Any] = values + + def __iter__(self) -> Iterator[Tuple[str, Any]]: + for key, value in self._values.items(): + if value is None: + continue + if isinstance(value, _ValuesAccumulator): + if dict_value := dict(value): + yield key, dict_value + elif isinstance(value, _SimpleStringReplace): + if str_value := str(value): + yield key, str_value + elif isinstance(value, _StringAccumulator): + if str_value := str(value): + yield key, str_value + else: + yield key, value + + def __iadd__(self, values: Optional[Mapping[str, Any]]) -> "_ValuesAccumulator": + if not values: + return self + for key in self._values.keys(): + if (value := values.get(key)) is None: + continue + self_value = self._values[key] + if isinstance(self_value, _ValuesAccumulator): + if isinstance(value, Mapping): + self_value += value + elif isinstance(self_value, _StringAccumulator): + if isinstance(value, str): + self_value += value + elif isinstance(self_value, _SimpleStringReplace): + if isinstance(value, str): + self_value += value + elif isinstance(self_value, _IndexedAccumulator): + self_value += value + elif isinstance(self_value, List) and isinstance(value, Iterable): + self_value.extend(value) + else: + self._values[key] = value # replacement + for key in values.keys(): + if key in self._values or (value := values[key]) is None: + continue + value = deepcopy(value) + if isinstance(value, Mapping): + value = _ValuesAccumulator(**value) + self._values[key] = value # new entry + return self + + +class _StringAccumulator: + __slots__ = ("_fragments",) + + def __init__(self) -> None: + self._fragments: List[str] = [] + + def __str__(self) -> str: + return "".join(self._fragments) + + def __iadd__(self, value: Optional[str]) -> "_StringAccumulator": + if not value: + return self + self._fragments.append(value) + return self + + +class _IndexedAccumulator: + __slots__ = ("_indexed",) + + def __init__(self, factory: Callable[[], _ValuesAccumulator]) -> None: + self._indexed: DefaultDict[int, _ValuesAccumulator] = defaultdict(factory) + + def __iter__(self) -> Iterator[Dict[str, Any]]: + for _, values in sorted(self._indexed.items()): + yield dict(values) + + def __iadd__(self, values: Optional[Mapping[str, Any]]) -> "_IndexedAccumulator": + if not values or not hasattr(values, "get") or (index := values.get("index")) is None: + return self + self._indexed[index] += values + return self + + +class _SimpleStringReplace: + __slots__ = ("_str_val",) + + def __init__(self) -> None: + self._str_val: str = "" + + def __str__(self) -> str: + return self._str_val + + def __iadd__(self, value: Optional[str]) -> "_SimpleStringReplace": + if not value: + return self + self._str_val = value + return self diff --git a/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_utils.py b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_utils.py new file mode 100644 index 000000000..2a78b73a5 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_utils.py @@ -0,0 +1,79 @@ +import logging +from typing import Any, Iterator, NamedTuple, Optional, Protocol, Tuple + +from opentelemetry import trace as trace_api +from opentelemetry.util.types import Attributes, AttributeValue + +from openinference.instrumentation import safe_json_dumps +from openinference.instrumentation.anthropic._with_span import _WithSpan +from openinference.semconv.trace import OpenInferenceMimeTypeValues, SpanAttributes + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class _ValueAndType(NamedTuple): + value: str + type: OpenInferenceMimeTypeValues + + +class _HasAttributes(Protocol): + def get_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: ... + + def get_extra_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: ... + + +def _finish_tracing( + with_span: _WithSpan, + has_attributes: _HasAttributes, + status: Optional[trace_api.Status] = None, +) -> None: + try: + attributes: Attributes = dict(has_attributes.get_attributes()) + except Exception: + logger.exception("Failed to get attributes") + attributes = None + try: + extra_attributes: Attributes = dict(has_attributes.get_extra_attributes()) + except Exception: + logger.exception("Failed to get extra attributes") + extra_attributes = None + try: + with_span.finish_tracing( + status=status, + attributes=attributes, + extra_attributes=extra_attributes, + ) + except Exception: + raise + logger.exception("Failed to finish tracing") + + +def _io_value_and_type(obj: Any) -> _ValueAndType: + try: + return _ValueAndType(safe_json_dumps(obj), OpenInferenceMimeTypeValues.JSON) + except Exception: + logger.exception("Failed to get input attributes from request parameters.") + return _ValueAndType(str(obj), OpenInferenceMimeTypeValues.TEXT) + + +def _as_input_attributes( + value_and_type: Optional[_ValueAndType], +) -> Iterator[Tuple[str, AttributeValue]]: + if not value_and_type: + return + yield SpanAttributes.INPUT_VALUE, value_and_type.value + # It's assumed to be TEXT by default, so we can skip to save one attribute. + if value_and_type.type is not OpenInferenceMimeTypeValues.TEXT: + yield SpanAttributes.INPUT_MIME_TYPE, value_and_type.type.value + + +def _as_output_attributes( + value_and_type: Optional[_ValueAndType], +) -> Iterator[Tuple[str, AttributeValue]]: + if not value_and_type: + return + yield SpanAttributes.OUTPUT_VALUE, value_and_type.value + # It's assumed to be TEXT by default, so we can skip to save one attribute. + if value_and_type.type is not OpenInferenceMimeTypeValues.TEXT: + yield SpanAttributes.OUTPUT_MIME_TYPE, value_and_type.type.value diff --git a/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_with_span.py b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_with_span.py new file mode 100644 index 000000000..fb2c73ccc --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_with_span.py @@ -0,0 +1,92 @@ +import logging +from typing import Dict, Optional, Union + +from opentelemetry import trace as trace_api +from opentelemetry.util.types import ( + Attributes, + AttributeValue, +) + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class _WithSpan: + __slots__ = ( + "_span", + "_is_finished", + ) + + def __init__( + self, + span: trace_api.Span, + ) -> None: + self._span = span + try: + self._is_finished = not self._span.is_recording() + except Exception: + logger.exception("Failed to check if span is recording") + self._is_finished = True + + @property + def is_finished(self) -> bool: + return self._is_finished + + def set_attributes(self, attributes: Dict[str, AttributeValue]) -> None: + self._span.set_attributes(attributes) + + def record_exception(self, exception: Exception) -> None: + if self._is_finished: + return + try: + self._span.record_exception(exception) + except Exception: + logger.exception("Failed to record exception on span") + + def set_status(self, status: Union[trace_api.Status, trace_api.StatusCode]) -> None: + if self._is_finished: + return + try: + self._span.set_status(status=status) + except Exception: + logger.exception("Failed to set status on span") + + def add_event(self, name: str) -> None: + if self._is_finished: + return + try: + self._span.add_event(name) + except Exception: + logger.exception("Failed to add event to span") + + def finish_tracing( + self, + status: Optional[trace_api.Status] = None, + attributes: Attributes = None, + extra_attributes: Attributes = None, + ) -> None: + if self._is_finished: + return + for mapping in ( + attributes, + extra_attributes, + ): + if not mapping: + continue + for key, value in mapping.items(): + if value is None: + continue + try: + self._span.set_attribute(key, value) + except Exception: + logger.exception("Failed to set attribute on span") + if status is not None: + try: + self._span.set_status(status=status) + except Exception: + logger.exception("Failed to set status code on span") + try: + self._span.end() + except Exception: + logger.exception("Failed to end span") + self._is_finished = True diff --git a/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py index f378be5e4..f13b95490 100644 --- a/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py +++ b/python/instrumentation/openinference-instrumentation-anthropic/src/openinference/instrumentation/anthropic/_wrappers.py @@ -1,10 +1,17 @@ from abc import ABC +from contextlib import contextmanager from typing import Any, Callable, Dict, Iterator, List, Mapping, Tuple import opentelemetry.context as context_api from opentelemetry import trace as trace_api +from opentelemetry.trace import INVALID_SPAN from openinference.instrumentation import get_attributes_from_context, safe_json_dumps +from openinference.instrumentation.anthropic._stream import ( + _MessagesStream, + _Stream, +) +from openinference.instrumentation.anthropic._with_span import _WithSpan from openinference.semconv.trace import ( DocumentAttributes, EmbeddingAttributes, @@ -25,6 +32,30 @@ def __init__(self, tracer: trace_api.Tracer, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._tracer = tracer + @contextmanager + def _start_as_current_span( + self, + span_name: str, + ) -> Iterator[_WithSpan]: + # Because OTEL has a default limit of 128 attributes, we split our attributes into + # two tiers, where the addition of "extra_attributes" is deferred until the end + # and only after the "attributes" are added. + try: + span = self._tracer.start_span( + name=span_name, record_exception=False, set_status_on_exception=False + ) + except Exception: + span = INVALID_SPAN + with trace_api.use_span( + span, + end_on_exit=False, + record_exception=False, + set_status_on_exception=False, + ) as span: + yield _WithSpan( + span=span, + ) + class _CompletionsWrapper(_WithTracer): """ @@ -32,6 +63,8 @@ class _CompletionsWrapper(_WithTracer): Captures all calls to the pipeline """ + __slots__ = "_response_accumulator" + def __call__( self, wrapped: Callable[..., Any], @@ -47,10 +80,8 @@ def __call__( llm_invocation_parameters = _get_invocation_parameters(arguments) span_name = "Completions" - with self._tracer.start_as_current_span( + with self._start_as_current_span( span_name, - record_exception=False, - set_status_on_exception=False, ) as span: span.set_attributes(dict(get_attributes_from_context())) @@ -71,14 +102,19 @@ def __call__( span.record_exception(exception) raise span.set_status(trace_api.StatusCode.OK) - span.set_attributes( - { - OUTPUT_VALUE: response.model_dump_json(), - OUTPUT_MIME_TYPE: JSON, - } - ) - - return response + streaming = kwargs.get("stream", False) + if streaming: + return _Stream(response, span) + else: + span.set_status(trace_api.StatusCode.OK) + span.set_attributes( + { + OUTPUT_VALUE: response.model_dump_json(), + OUTPUT_MIME_TYPE: JSON, + } + ) + span.finish_tracing() + return response class _AsyncCompletionsWrapper(_WithTracer): @@ -102,10 +138,8 @@ async def __call__( invocation_parameters = _get_invocation_parameters(arguments) span_name = "AsyncCompletions" - with self._tracer.start_as_current_span( + with self._start_as_current_span( span_name, - record_exception=False, - set_status_on_exception=False, ) as span: span.set_attributes(dict(get_attributes_from_context())) @@ -126,13 +160,19 @@ async def __call__( span.record_exception(exception) raise span.set_status(trace_api.StatusCode.OK) - span.set_attributes( - { - OUTPUT_VALUE: response.to_json(indent=None), - OUTPUT_MIME_TYPE: JSON, - } - ) - return response + streaming = kwargs.get("stream", False) + if streaming: + return _Stream(response, span) + else: + span.set_status(trace_api.StatusCode.OK) + span.set_attributes( + { + OUTPUT_VALUE: response.to_json(indent=None), + OUTPUT_MIME_TYPE: JSON, + } + ) + span.finish_tracing() + return response class _MessagesWrapper(_WithTracer): @@ -156,10 +196,8 @@ def __call__( invocation_parameters = _get_invocation_parameters(arguments) span_name = "Messages" - with self._tracer.start_as_current_span( + with self._start_as_current_span( span_name, - record_exception=False, - set_status_on_exception=False, ) as span: span.set_attributes(dict(get_attributes_from_context())) @@ -179,18 +217,22 @@ def __call__( span.set_status(trace_api.Status(trace_api.StatusCode.ERROR, str(exception))) span.record_exception(exception) raise - span.set_status(trace_api.StatusCode.OK) - span.set_attributes( - { - **dict(_get_output_messages(response)), - LLM_TOKEN_COUNT_PROMPT: response.usage.input_tokens, - LLM_TOKEN_COUNT_COMPLETION: response.usage.output_tokens, - OUTPUT_VALUE: response.model_dump_json(), - OUTPUT_MIME_TYPE: JSON, - } - ) - - return response + streaming = kwargs.get("stream", False) + if streaming: + return _MessagesStream(response, span) + else: + span.set_status(trace_api.StatusCode.OK) + span.set_attributes( + { + **dict(_get_output_messages(response)), + LLM_TOKEN_COUNT_PROMPT: response.usage.input_tokens, + LLM_TOKEN_COUNT_COMPLETION: response.usage.output_tokens, + OUTPUT_VALUE: response.model_dump_json(), + OUTPUT_MIME_TYPE: JSON, + } + ) + span.finish_tracing() + return response class _AsyncMessagesWrapper(_WithTracer): @@ -214,10 +256,8 @@ async def __call__( invocation_parameters = _get_invocation_parameters(arguments) span_name = "AsyncMessages" - with self._tracer.start_as_current_span( + with self._start_as_current_span( span_name, - record_exception=False, - set_status_on_exception=False, ) as span: span.set_attributes(dict(get_attributes_from_context())) @@ -237,18 +277,22 @@ async def __call__( span.set_status(trace_api.Status(trace_api.StatusCode.ERROR, str(exception))) span.record_exception(exception) raise - span.set_status(trace_api.StatusCode.OK) - span.set_attributes( - { - **dict(_get_output_messages(response)), - LLM_TOKEN_COUNT_PROMPT: response.usage.input_tokens, - LLM_TOKEN_COUNT_COMPLETION: response.usage.output_tokens, - OUTPUT_VALUE: response.model_dump_json(), - OUTPUT_MIME_TYPE: JSON, - } - ) - - return response + streaming = kwargs.get("stream", False) + if streaming: + return _MessagesStream(response, span) + else: + span.set_status(trace_api.StatusCode.OK) + span.set_attributes( + { + **dict(_get_output_messages(response)), + LLM_TOKEN_COUNT_PROMPT: response.usage.input_tokens, + LLM_TOKEN_COUNT_COMPLETION: response.usage.output_tokens, + OUTPUT_VALUE: response.model_dump_json(), + OUTPUT_MIME_TYPE: JSON, + } + ) + span.finish_tracing() + return response def _get_llm_model(arguments: Mapping[str, Any]) -> Iterator[Tuple[str, Any]]: diff --git a/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_async_completions_streaming.yaml b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_async_completions_streaming.yaml new file mode 100644 index 000000000..de02d5d0d --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_async_completions_streaming.yaml @@ -0,0 +1,27 @@ +interactions: +- request: + body: '{"max_tokens_to_sample": 1000, "model": "claude-2.1", "prompt": "\n\nHuman: + why is the sky blue? respond in five words or less. \n\nAssistant:", "stream": + true}' + headers: {} + method: POST + uri: https://api.anthropic.com/v1/complete + response: + body: + string: "event: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\",\"completion\":\" + Light\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\" + \ }\r\n\r\nevent: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\",\"completion\":\" + sc\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\" + \ }\r\n\r\nevent: ping\r\ndata: {\"type\": \"ping\"}\r\n\r\nevent: + completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\",\"completion\":\"at\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\" + \ }\r\n\r\nevent: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\",\"completion\":\"ters\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\" + \ }\r\n\r\nevent: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\",\"completion\":\" + blue\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\" + \ }\r\n\r\nevent: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\",\"completion\":\".\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\" + \ }\r\n\r\nevent: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\",\"completion\":\"\",\"stop_reason\":\"stop_sequence\",\"model\":\"claude-2.1\",\"stop\":\"\\n\\nHuman:\",\"log_id\":\"compl_01Ho8r6LNPQ9EVEAh3vpiUnQ\" + \ }\r\n\r\n" + headers: {} + status: + code: 200 + message: OK +version: 1 diff --git a/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_async_messages_streaming.yaml b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_async_messages_streaming.yaml new file mode 100644 index 000000000..6e7d45147 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_async_messages_streaming.yaml @@ -0,0 +1,78 @@ +interactions: +- request: + body: '{"max_tokens": 1024, "messages": [{"role": "user", "content": "Why is the + sky blue? Answer in 5 words or less"}], "model": "claude-2.1", "stream": true}' + headers: {} + method: POST + uri: https://api.anthropic.com/v1/messages + response: + body: + string: 'event: message_start + + data: {"type":"message_start","message":{"id":"msg_014xFxzXoZb9xNu5KEuriZ8N","type":"message","role":"assistant","model":"claude-2.1","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":21,"output_tokens":1}} } + + + event: content_block_start + + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + + + event: ping + + data: {"type": "ping"} + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Light"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + sc"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"at"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ters"} + } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + blue"}} + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"."} } + + + event: content_block_stop + + data: {"type":"content_block_stop","index":0 } + + + event: message_delta + + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":10} } + + + event: message_stop + + data: {"type":"message_stop" } + + + ' + headers: {} + status: + code: 200 + message: OK +version: 1 diff --git a/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_completions_streaming.yaml b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_completions_streaming.yaml new file mode 100644 index 000000000..32d00be01 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_completions_streaming.yaml @@ -0,0 +1,27 @@ +interactions: +- request: + body: '{"max_tokens_to_sample": 1000, "model": "claude-2.1", "prompt": "\n\nHuman: + why is the sky blue? respond in five words or less. \n\nAssistant:", "stream": + true}' + headers: {} + method: POST + uri: https://api.anthropic.com/v1/complete + response: + body: + string: "event: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\",\"completion\":\" + Light\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\" + \ }\r\n\r\nevent: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\",\"completion\":\" + sc\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\" + \ }\r\n\r\nevent: ping\r\ndata: {\"type\": \"ping\"}\r\n\r\nevent: completion\r\ndata: + {\"type\":\"completion\",\"id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\",\"completion\":\"at\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\" + \ }\r\n\r\nevent: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\",\"completion\":\"ters\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\" + \ }\r\n\r\nevent: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\",\"completion\":\" + blue\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\"}\r\n\r\nevent: + completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\",\"completion\":\".\",\"stop_reason\":null,\"model\":\"claude-2.1\",\"stop\":null,\"log_id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\" + \ }\r\n\r\nevent: completion\r\ndata: {\"type\":\"completion\",\"id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\",\"completion\":\"\",\"stop_reason\":\"stop_sequence\",\"model\":\"claude-2.1\",\"stop\":\"\\n\\nHuman:\",\"log_id\":\"compl_015dfgyiT7JLszAiMbGMtgeG\" + \ }\r\n\r\n" + headers: {} + status: + code: 200 + message: OK +version: 1 diff --git a/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_messages_streaming.yaml b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_messages_streaming.yaml new file mode 100644 index 000000000..c3b6c7236 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_messages_streaming.yaml @@ -0,0 +1,77 @@ +interactions: +- request: + body: '{"max_tokens": 1024, "messages": [{"role": "user", "content": "Why is the + sky blue? Answer in 5 words or less"}], "model": "claude-2.1", "stream": true}' + headers: {} + method: POST + uri: https://api.anthropic.com/v1/messages + response: + body: + string: 'event: message_start + + data: {"type":"message_start","message":{"id":"msg_011i9EJybqHR6a7pmHLohGnu","type":"message","role":"assistant","model":"claude-2.1","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":21,"output_tokens":1}} } + + + event: content_block_start + + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + + + event: ping + + data: {"type": "ping"} + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Light"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + sc"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"at"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ters"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + blue"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"."} } + + + event: content_block_stop + + data: {"type":"content_block_stop","index":0 } + + + event: message_delta + + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":10} } + + + event: message_stop + + data: {"type":"message_stop" } + + + ' + headers: {} + status: + code: 200 + message: OK +version: 1 diff --git a/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_multiple_tool_calling_streaming.yaml b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_multiple_tool_calling_streaming.yaml new file mode 100644 index 000000000..05625e3a0 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/cassettes/test_instrumentor/test_anthropic_instrumentation_multiple_tool_calling_streaming.yaml @@ -0,0 +1,230 @@ +interactions: +- request: + body: '{"max_tokens": 1024, "messages": [{"role": "user", "content": "What is + the weather like right now in New York? Also what time is it there? Use necessary + tools simultaneously."}], "model": "claude-3-5-sonnet-20240620", "stream": true, + "tools": [{"name": "get_weather", "description": "Get the current weather in + a given location", "input_schema": {"type": "object", "properties": {"location": + {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": + "The unit of temperature, either ''celsius'' or ''fahrenheit''"}}, "required": + ["location"]}}, {"name": "get_time", "description": "Get the current time in + a given time zone", "input_schema": {"type": "object", "properties": {"timezone": + {"type": "string", "description": "The IANA time zone name, e.g. America/Los_Angeles"}}, + "required": ["timezone"]}}]}' + headers: {} + method: POST + uri: https://api.anthropic.com/v1/messages + response: + body: + string: 'event: message_start + + data: {"type":"message_start","message":{"id":"msg_0133FqiHvU9EMrLkhkvgszYM","type":"message","role":"assistant","model":"claude-3-5-sonnet-20240620","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":518,"output_tokens":1}}} + + + event: content_block_start + + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + + + event: ping + + data: {"type": "ping"} + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Certainly"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"! + I"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"''ll + use the"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + available"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + tools to get the current"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + weather in"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + New York and the current"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + time there. Let"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + me fetch"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + that"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + information for you using"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + the necessary"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" + tools simultaneously"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"."} } + + + event: content_block_stop + + data: {"type":"content_block_stop","index":0 } + + + event: content_block_start + + data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01Fbou3N2oQx1Y6yxThmYThV","name":"get_weather","input":{}} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":""}} + + + event: content_block_delta + + data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\"loc"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"ati"} + } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"on\": + \"Ne"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"w + Yo"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"rk, + NY"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"\""} + } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":", + \"unit\""} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":": + \"celsius\"}"} } + + + event: content_block_stop + + data: {"type":"content_block_stop","index":1 } + + + event: content_block_start + + data: {"type":"content_block_start","index":2,"content_block":{"type":"tool_use","id":"toolu_011M5HcDRLheEQBQy5QCSXw6","name":"get_time","input":{}} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":2,"delta":{"type":"input_json_delta","partial_json":""} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":2,"delta":{"type":"input_json_delta","partial_json":"{\"timezone\""} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":2,"delta":{"type":"input_json_delta","partial_json":": + \"America"} } + + + event: content_block_delta + + data: {"type":"content_block_delta","index":2,"delta":{"type":"input_json_delta","partial_json":"/New_York\"}"}} + + + event: content_block_stop + + data: {"type":"content_block_stop","index":2 } + + + event: message_delta + + data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":149} } + + + event: message_stop + + data: {"type":"message_stop" } + + + ' + headers: {} + status: + code: 200 + message: OK +version: 1 diff --git a/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/test_instrumentor.py index c1fe5ffc2..178cbf1e9 100644 --- a/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-anthropic/tests/openinference/anthropic/test_instrumentor.py @@ -113,6 +113,103 @@ def setup_anthropic_instrumentation( AnthropicInstrumentor().uninstrument() +@pytest.mark.vcr( + decode_compressed_response=True, + before_record_request=remove_all_vcr_request_headers, + before_record_response=remove_all_vcr_response_headers, +) +def test_anthropic_instrumentation_completions_streaming( + tracer_provider: TracerProvider, + in_memory_span_exporter: InMemorySpanExporter, + setup_anthropic_instrumentation: Any, +) -> None: + client = Anthropic(api_key="fake") + + prompt = ( + f"{anthropic.HUMAN_PROMPT}" + f" why is the sky blue? respond in five words or less." + f" {anthropic.AI_PROMPT}" + ) + + stream = client.completions.create( + model="claude-2.1", + prompt=prompt, + max_tokens_to_sample=1000, + stream=True, + ) + for event in stream: + print(event.completion) + + spans = in_memory_span_exporter.get_finished_spans() + + assert spans[0].name == "Completions" + attributes = dict(spans[0].attributes or {}) + print(attributes) + + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == "LLM" + assert isinstance(attributes.pop(INPUT_VALUE), str) + assert attributes.pop(INPUT_MIME_TYPE) == JSON + assert isinstance(attributes.pop(OUTPUT_VALUE), str) + assert attributes.pop(OUTPUT_MIME_TYPE) == JSON + + assert attributes.pop(LLM_PROMPTS) == (prompt,) + assert attributes.pop(LLM_MODEL_NAME) == "claude-2.1" + assert isinstance(inv_params := attributes.pop(LLM_INVOCATION_PARAMETERS), str) + + invocation_params = {"model": "claude-2.1", "max_tokens_to_sample": 1000, "stream": True} + assert json.loads(inv_params) == invocation_params + assert attributes.pop(LLM_OUTPUT_MESSAGES) == " Light scatters blue." + + +@pytest.mark.asyncio +@pytest.mark.vcr( + decode_compressed_response=True, + before_record_request=remove_all_vcr_request_headers, + before_record_response=remove_all_vcr_response_headers, +) +async def test_anthropic_instrumentation_async_completions_streaming( + tracer_provider: TracerProvider, + in_memory_span_exporter: InMemorySpanExporter, + setup_anthropic_instrumentation: Any, +) -> None: + client = AsyncAnthropic(api_key="fake") + + prompt = ( + f"{anthropic.HUMAN_PROMPT}" + f" why is the sky blue? respond in five words or less." + f" {anthropic.AI_PROMPT}" + ) + + stream = await client.completions.create( + model="claude-2.1", + prompt=prompt, + max_tokens_to_sample=1000, + stream=True, + ) + async for event in stream: + print(event.completion) + + spans = in_memory_span_exporter.get_finished_spans() + + assert spans[0].name == "AsyncCompletions" + attributes = dict(spans[0].attributes or {}) + print(attributes) + + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == "LLM" + assert isinstance(attributes.pop(INPUT_VALUE), str) + assert attributes.pop(INPUT_MIME_TYPE) == JSON + assert isinstance(attributes.pop(OUTPUT_VALUE), str) + assert attributes.pop(OUTPUT_MIME_TYPE) == JSON + + assert attributes.pop(LLM_PROMPTS) == (prompt,) + assert attributes.pop(LLM_MODEL_NAME) == "claude-2.1" + assert isinstance(inv_params := attributes.pop(LLM_INVOCATION_PARAMETERS), str) + + invocation_params = {"model": "claude-2.1", "max_tokens_to_sample": 1000, "stream": True} + assert json.loads(inv_params) == invocation_params + assert attributes.pop(LLM_OUTPUT_MESSAGES) == " Light scatters blue." + + @pytest.mark.vcr( decode_compressed_response=True, before_record_request=remove_all_vcr_request_headers, @@ -210,6 +307,129 @@ def test_anthropic_instrumentation_messages( assert not attributes +@pytest.mark.vcr( + decode_compressed_response=True, + before_record_request=remove_all_vcr_request_headers, + before_record_response=remove_all_vcr_response_headers, +) +def test_anthropic_instrumentation_messages_streaming( + tracer_provider: TracerProvider, + in_memory_span_exporter: InMemorySpanExporter, + setup_anthropic_instrumentation: Any, +) -> None: + client = Anthropic(api_key="fake") + input_message = "Why is the sky blue? Answer in 5 words or less" + + invocation_params = {"max_tokens": 1024, "model": "claude-2.1", "stream": True} + + stream = client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": input_message, + } + ], + model="claude-2.1", + stream=True, + ) + + for event in stream: + print(event) + + spans = in_memory_span_exporter.get_finished_spans() + + assert spans[0].name == "Messages" + attributes = dict(spans[0].attributes or {}) + + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == "LLM" + assert attributes.pop(f"{LLM_INPUT_MESSAGES}.0.{MESSAGE_CONTENT}") == input_message + assert attributes.pop(f"{LLM_INPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "user" + assert isinstance( + msg_content := attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}"), str + ) + assert "Light scatters blue." in msg_content + assert attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "assistant" + assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == 21 + assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == 10 + assert attributes.pop(LLM_TOKEN_COUNT_TOTAL) == 31 + + assert isinstance(attributes.pop(INPUT_VALUE), str) + assert attributes.pop(INPUT_MIME_TYPE) == JSON + # TODO(harrison): the output here doesn't look properly + # serialized but looks like openai, mistral accumulators do + # the same thing. need to look into why this might be wrong + assert isinstance(attributes.pop(OUTPUT_VALUE), str) + assert attributes.pop(OUTPUT_MIME_TYPE) == JSON + + assert attributes.pop(LLM_MODEL_NAME) == "claude-2.1" + assert isinstance(inv_params := attributes.pop(LLM_INVOCATION_PARAMETERS), str) + assert json.loads(inv_params) == invocation_params + assert not attributes + + +@pytest.mark.asyncio +@pytest.mark.vcr( + decode_compressed_response=True, + before_record_request=remove_all_vcr_request_headers, + before_record_response=remove_all_vcr_response_headers, +) +async def test_anthropic_instrumentation_async_messages_streaming( + tracer_provider: TracerProvider, + in_memory_span_exporter: InMemorySpanExporter, + setup_anthropic_instrumentation: Any, +) -> None: + client = AsyncAnthropic(api_key="fake") + input_message = "Why is the sky blue? Answer in 5 words or less" + + invocation_params = {"max_tokens": 1024, "model": "claude-2.1", "stream": True} + + stream = await client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": input_message, + } + ], + model="claude-2.1", + stream=True, + ) + + async for event in stream: + print(event) + + spans = in_memory_span_exporter.get_finished_spans() + + assert spans[0].name == "AsyncMessages" + attributes = dict(spans[0].attributes or {}) + + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == "LLM" + assert attributes.pop(f"{LLM_INPUT_MESSAGES}.0.{MESSAGE_CONTENT}") == input_message + assert attributes.pop(f"{LLM_INPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "user" + assert isinstance( + msg_content := attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}"), str + ) + assert "Light scatters blue." in msg_content + assert attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "assistant" + assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == 21 + assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == 10 + assert attributes.pop(LLM_TOKEN_COUNT_TOTAL) == 31 + + assert isinstance(attributes.pop(INPUT_VALUE), str) + assert attributes.pop(INPUT_MIME_TYPE) == JSON + # TODO(harrison): the output here doesn't look properly + # serialized but looks like openai, mistral accumulators do + # the same thing. need to look into why this might be wrong + assert isinstance(attributes.pop(OUTPUT_VALUE), str) + assert attributes.pop(OUTPUT_MIME_TYPE) == JSON + + assert attributes.pop(LLM_MODEL_NAME) == "claude-2.1" + assert isinstance(inv_params := attributes.pop(LLM_INVOCATION_PARAMETERS), str) + assert json.loads(inv_params) == invocation_params + assert not attributes + + @pytest.mark.vcr( decode_compressed_response=True, before_record_request=remove_all_vcr_request_headers, @@ -408,6 +628,109 @@ def test_anthropic_instrumentation_multiple_tool_calling( assert not attributes +@pytest.mark.vcr( + decode_compressed_response=True, + before_record_request=remove_all_vcr_request_headers, + before_record_response=remove_all_vcr_response_headers, +) +def test_anthropic_instrumentation_multiple_tool_calling_streaming( + tracer_provider: TracerProvider, + in_memory_span_exporter: InMemorySpanExporter, + setup_anthropic_instrumentation: Any, +) -> None: + client = anthropic.Anthropic(api_key="fake") + + input_message = ( + "What is the weather like right now in New York?" + " Also what time is it there? Use necessary tools simultaneously." + ) + + stream = client.messages.create( + model="claude-3-5-sonnet-20240620", + max_tokens=1024, + tools=[ + { + "name": "get_weather", + "description": "Get the current weather in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature," + " either 'celsius' or 'fahrenheit'", + }, + }, + "required": ["location"], + }, + }, + { + "name": "get_time", + "description": "Get the current time in a given time zone", + "input_schema": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The IANA time zone name, e.g. America/Los_Angeles", + } + }, + "required": ["timezone"], + }, + }, + ], + messages=[{"role": "user", "content": input_message}], + stream=True, + ) + for event in stream: + print(event) + + spans = in_memory_span_exporter.get_finished_spans() + + assert spans[0].name == "Messages" + attributes = dict(spans[0].attributes or {}) + + assert isinstance(attributes.pop(LLM_MODEL_NAME), str) + assert attributes.pop(f"{LLM_INPUT_MESSAGES}.0.{MESSAGE_CONTENT}") == input_message + assert attributes.pop(f"{LLM_INPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "user" + assert isinstance(attributes.pop(LLM_INVOCATION_PARAMETERS), str) + assert isinstance(attributes.pop(INPUT_VALUE), str) + assert attributes.pop(INPUT_MIME_TYPE) == JSON + assert attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "assistant" + assert isinstance(attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_CONTENT}"), str) + assert ( + attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.1.{TOOL_CALL_FUNCTION_NAME}") + == "get_time" + ) + get_time_input_str = attributes.pop( + f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.1.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}" + ) + json.loads(get_time_input_str) == {"timezone": "America/New_York"} # type: ignore + assert ( + attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.0.{TOOL_CALL_FUNCTION_NAME}") + == "get_weather" + ) + get_weather_input_str = attributes.pop( + f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_TOOL_CALLS}.0.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}" + ) + assert json.loads(get_weather_input_str) == {"location": "New York, NY", "unit": "celsius"} # type: ignore + assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == 518 + assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == 149 + assert attributes.pop(LLM_TOKEN_COUNT_TOTAL) == 667 + # TODO(harrison): the output here doesn't look properly + # serialized but looks like openai, mistral accumulators do + # the same thing. need to look into why this might be wrong + assert isinstance(attributes.pop(OUTPUT_VALUE), str) + assert attributes.pop(OUTPUT_MIME_TYPE) == "application/json" + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == "LLM" + assert not attributes + + @pytest.mark.vcr( decode_compressed_response=True, before_record_request=remove_all_vcr_request_headers,