Skip to content

Commit

Permalink
Refactoring Openai instrumentation
Browse files Browse the repository at this point in the history
  • Loading branch information
alizenhom committed Sep 5, 2024
1 parent e601f6d commit d04edad
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 252 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ classifiers = [
]
dependencies = [
"opentelemetry-api ~= 1.12",
"opentelemetry-instrumentation == 0.47b0",
"tiktoken>=0.1.1",
"opentelemetry-instrumentation == 0.48b0.dev",
"opentelemetry-semantic-conventions == 0.48b0.dev",
"pydantic>=1.8"

]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.openai.package import _instruments
from opentelemetry.trace import get_tracer
from wrapt import wrap_function_wrapper
from wrapt import wrap_function_wrapper as _W
from .patch import chat_completions_create


Expand All @@ -59,13 +59,10 @@ def _instrument(self, **kwargs):
"""Enable OpenAI instrumentation."""
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, "", tracer_provider)
version = importlib.metadata.version("openai")
wrap_function_wrapper(
"openai.resources.chat.completions",
"Completions.create",
chat_completions_create(
"openai.chat.completions.create", version, tracer
),
_W(
module="openai.resources.chat.completions",
name="Completions.create",
wrapper=chat_completions_create(tracer),
)

def _uninstrument(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import Optional, Union

from opentelemetry import trace
from opentelemetry.trace import SpanKind, Span
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.trace.propagation import set_span_in_context
from openai import NOT_GIVEN
from .span_attributes import LLMSpanAttributes, SpanAttributes

from .utils import silently_fail, extract_content
from .span_attributes import LLMSpanAttributes, SpanAttributes
from opentelemetry.semconv._incubating.attributes import (
gen_ai_attributes as GenAIAttributes,
)
from .utils import (
silently_fail,
extract_content,
get_llm_request_attributes,
is_streaming,
set_span_attribute,
set_event_completion,
extract_tools_prompt,
)
from opentelemetry.trace import Tracer


def chat_completions_create(original_method, version, tracer: Tracer):
def chat_completions_create(tracer: Tracer):
"""Wrap the `create` method of the `ChatCompletion` class to trace it."""

def traced_method(wrapped, instance, args, kwargs):

llm_prompts = []

for item in kwargs.get("messages", []):
tools = get_tool_calls(item)
if tools is not None:
tool_calls = []
for tool_call in tools:
tool_call_dict = {
"id": tool_call.id if hasattr(tool_call, "id") else "",
"type": (
tool_call.type
if hasattr(tool_call, "type")
else ""
),
}
if hasattr(tool_call, "function"):
tool_call_dict["function"] = {
"name": (
tool_call.function.name
if hasattr(tool_call.function, "name")
else ""
),
"arguments": (
tool_call.function.arguments
if hasattr(tool_call.function, "arguments")
else ""
),
}
tool_calls.append(tool_call_dict)
llm_prompts.append(tool_calls)
else:
llm_prompts.append(item)
tools_prompt = extract_tools_prompt(item)
llm_prompts.append(tools_prompt if tools_prompt else item)

span_attributes = {
**get_llm_request_attributes(kwargs, prompts=llm_prompts),
Expand All @@ -74,7 +58,7 @@ def traced_method(wrapped, instance, args, kwargs):
kind=SpanKind.CLIENT,
context=set_span_in_context(trace.get_current_span()),
)
_set_input_attributes(span, kwargs, attributes)
_set_input_attributes(span, attributes)

try:
result = wrapped(*args, **kwargs)
Expand All @@ -86,52 +70,31 @@ def traced_method(wrapped, instance, args, kwargs):
tool_calls=kwargs.get("tools") is not None,
)
else:
_set_response_attributes(span, kwargs, result)
_set_response_attributes(span, result)
span.end()
return result

except Exception as error:
span.set_status(Status(StatusCode.ERROR, str(error)))
span.set_attribute("error.type", error.__class__.__name__)
span.end()
raise

return traced_method


def get_tool_calls(item):
if isinstance(item, dict):
return item.get("tool_calls")
else:
return getattr(item, "tool_calls", None)


@silently_fail
def _set_input_attributes(span, kwargs, attributes: LLMSpanAttributes):
tools = []

if (
kwargs.get("functions") is not None
and kwargs.get("functions") != NOT_GIVEN
):
for function in kwargs.get("functions"):
tools.append(
json.dumps({"type": "function", "function": function})
)

if kwargs.get("tools") is not None and kwargs.get("tools") != NOT_GIVEN:
tools.append(json.dumps(kwargs.get("tools")))

if tools:
set_span_attribute(span, SpanAttributes.LLM_TOOLS, json.dumps(tools))

def _set_input_attributes(span, attributes: LLMSpanAttributes):
for field, value in attributes.model_dump(by_alias=True).items():
set_span_attribute(span, field, value)


@silently_fail
def _set_response_attributes(span, kwargs, result):
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, result.model)
if hasattr(result, "choices") and result.choices is not None:
def _set_response_attributes(span, result):
set_span_attribute(
span, GenAIAttributes.GEN_AI_RESPONSE_MODEL, result.model
)
if getattr(result, "choices", None):
responses = [
{
"role": (
Expand All @@ -154,120 +117,30 @@ def _set_response_attributes(span, kwargs, result):
]
set_event_completion(span, responses)

if (
hasattr(result, "system_fingerprint")
and result.system_fingerprint is not None
and result.system_fingerprint != NOT_GIVEN
):
if getattr(result, "system_fingerprint", None):
set_span_attribute(
span,
SpanAttributes.LLM_SYSTEM_FINGERPRINT,
result.system_fingerprint,
)
# Get the usage
if hasattr(result, "usage") and result.usage is not None:
usage = result.usage
if usage is not None:
set_span_attribute(
span,
SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
result.usage.prompt_tokens,
)
set_span_attribute(
span,
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
result.usage.completion_tokens,
)
set_span_attribute(
span,
SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
result.usage.total_tokens,
)


def set_event_prompt(span: Span, prompt):
span.add_event(
name=SpanAttributes.LLM_CONTENT_PROMPT,
attributes={
SpanAttributes.LLM_PROMPTS: prompt,
},
)


def set_span_attributes(span: Span, attributes: dict):
for field, value in attributes.model_dump(by_alias=True).items():
set_span_attribute(span, field, value)


def set_event_completion(span: Span, result_content):
span.add_event(
name=SpanAttributes.LLM_CONTENT_COMPLETION,
attributes={
SpanAttributes.LLM_COMPLETIONS: json.dumps(result_content),
},
)


def set_span_attribute(span: Span, name, value):
if value is not None:
if value != "" or value != NOT_GIVEN:
if name == SpanAttributes.LLM_PROMPTS:
set_event_prompt(span, value)
else:
span.set_attribute(name, value)
return


def is_streaming(kwargs):
return non_numerical_value_is_set(kwargs.get("stream"))


def non_numerical_value_is_set(value: Optional[Union[bool, str]]):
return bool(value) and value != NOT_GIVEN


def get_llm_request_attributes(
kwargs, prompts=None, model=None, operation_name="chat"
):

user = kwargs.get("user")
if prompts is None:
prompts = (
[{"role": user or "user", "content": kwargs.get("prompt")}]
if "prompt" in kwargs
else None
# Get the usage
if getattr(result, "usage", None):
set_span_attribute(
span,
GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS,
result.usage.prompt_tokens,
)
set_span_attribute(
span,
GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS,
result.usage.completion_tokens,
)
set_span_attribute(
span,
"gen_ai.usage.total_tokens",
result.usage.total_tokens,
)
top_k = (
kwargs.get("n")
or kwargs.get("k")
or kwargs.get("top_k")
or kwargs.get("top_n")
)

top_p = kwargs.get("p") or kwargs.get("top_p")
tools = kwargs.get("tools")
return {
SpanAttributes.LLM_OPERATION_NAME: operation_name,
SpanAttributes.LLM_REQUEST_MODEL: model or kwargs.get("model"),
SpanAttributes.LLM_IS_STREAMING: kwargs.get("stream"),
SpanAttributes.LLM_REQUEST_TEMPERATURE: kwargs.get("temperature"),
SpanAttributes.LLM_TOP_K: top_k,
SpanAttributes.LLM_PROMPTS: json.dumps(prompts) if prompts else None,
SpanAttributes.LLM_USER: user,
SpanAttributes.LLM_REQUEST_TOP_P: top_p,
SpanAttributes.LLM_REQUEST_MAX_TOKENS: kwargs.get("max_tokens"),
SpanAttributes.LLM_SYSTEM_FINGERPRINT: kwargs.get(
"system_fingerprint"
),
SpanAttributes.LLM_PRESENCE_PENALTY: kwargs.get("presence_penalty"),
SpanAttributes.LLM_FREQUENCY_PENALTY: kwargs.get("frequency_penalty"),
SpanAttributes.LLM_REQUEST_SEED: kwargs.get("seed"),
SpanAttributes.LLM_TOOLS: json.dumps(tools) if tools else None,
SpanAttributes.LLM_TOOL_CHOICE: kwargs.get("tool_choice"),
SpanAttributes.LLM_REQUEST_LOGPROPS: kwargs.get("logprobs"),
SpanAttributes.LLM_REQUEST_LOGITBIAS: kwargs.get("logit_bias"),
SpanAttributes.LLM_REQUEST_TOP_LOGPROPS: kwargs.get("top_logprobs"),
}


class StreamWrapper:
Expand All @@ -277,7 +150,7 @@ def __init__(
self,
stream,
span,
prompt_tokens=None,
prompt_tokens=0,
function_call=False,
tool_calls=False,
):
Expand All @@ -299,17 +172,17 @@ def cleanup(self):
if self._span_started:
set_span_attribute(
self.span,
SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS,
self.prompt_tokens,
)
set_span_attribute(
self.span,
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS,
self.completion_tokens,
)
set_span_attribute(
self.span,
SpanAttributes.LLM_USAGE_TOTAL_TOKENS,
"gen_ai.usage.total_tokens",
self.prompt_tokens + self.completion_tokens,
)
set_event_completion(
Expand Down Expand Up @@ -346,14 +219,14 @@ def __next__(self):
raise

def process_chunk(self, chunk):
if hasattr(chunk, "model") and chunk.model is not None:
if getattr(chunk, "model", None):
set_span_attribute(
self.span,
SpanAttributes.LLM_RESPONSE_MODEL,
GenAIAttributes.GEN_AI_RESPONSE_MODEL,
chunk.model,
)

if hasattr(chunk, "choices") and chunk.choices is not None:
if getattr(chunk, "choices", None):
content = []
if not self.function_call and not self.tool_calls:
for choice in chunk.choices:
Expand Down Expand Up @@ -383,12 +256,12 @@ def process_chunk(self, chunk):
if content:
self.result_content.append(content[0])

if hasattr(chunk, "text"):
if getattr(chunk, "text", None):
content = [chunk.text]

if content:
self.result_content.append(content[0])

if getattr(chunk, "usage"):
if getattr(chunk, "usage", None):
self.completion_tokens = chunk.usage.completion_tokens
self.prompt_tokens = chunk.usage.prompt_tokens
Loading

0 comments on commit d04edad

Please sign in to comment.