diff --git a/js/.changeset/slimy-pears-boil.md b/js/.changeset/slimy-pears-boil.md new file mode 100644 index 000000000..a7e32619f --- /dev/null +++ b/js/.changeset/slimy-pears-boil.md @@ -0,0 +1,5 @@ +--- +"@arizeai/openinference-instrumentation-openai": minor +--- + +Add streaming instrumentation for OpenAI Chat completion diff --git a/js/packages/openinference-instrumentation-openai/src/instrumentation.ts b/js/packages/openinference-instrumentation-openai/src/instrumentation.ts index 088693da0..ac4d63e39 100644 --- a/js/packages/openinference-instrumentation-openai/src/instrumentation.ts +++ b/js/packages/openinference-instrumentation-openai/src/instrumentation.ts @@ -13,6 +13,7 @@ import { SpanKind, Attributes, SpanStatusCode, + Span, } from "@opentelemetry/api"; import { VERSION } from "./version"; import { @@ -119,23 +120,27 @@ export class OpenAIInstrumentation extends InstrumentationBase { }, ); const wrappedPromise = execPromise.then((result) => { - if (result) { + if (isChatCompletionResponse(result)) { // Record the results span.setAttributes({ [SemanticConventions.OUTPUT_VALUE]: JSON.stringify(result), [SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.JSON, // Override the model from the value sent by the server - [SemanticConventions.LLM_MODEL_NAME]: isChatCompletionResponse( - result, - ) - ? result.model - : body.model, + [SemanticConventions.LLM_MODEL_NAME]: result.model, ...getLLMOutputMessagesAttributes(result), ...getUsageAttributes(result), }); + span.setStatus({ code: SpanStatusCode.OK }); + span.end(); + } else { + // This is a streaming response + // handle the chunks and add them to the span + // First split the stream via tee + const [leftStream, rightStream] = result.tee(); + consumeChatCompletionStreamChunks(rightStream, span); + result = leftStream; } - span.setStatus({ code: SpanStatusCode.OK }); - span.end(); + return result; }); return context.bind(execContext, wrappedPromise); @@ -254,21 +259,16 @@ function getLLMInputMessagesAttributes( /** * Get Usage attributes */ -function getUsageAttributes( - response: Stream | ChatCompletion, -) { - if (Object.prototype.hasOwnProperty.call(response, "usage")) { - const completion = response as ChatCompletion; - if (completion.usage) { - return { - [SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]: - completion.usage.completion_tokens, - [SemanticConventions.LLM_TOKEN_COUNT_PROMPT]: - completion.usage.prompt_tokens, - [SemanticConventions.LLM_TOKEN_COUNT_TOTAL]: - completion.usage.total_tokens, - }; - } +function getUsageAttributes(completion: ChatCompletion): Attributes { + if (completion.usage) { + return { + [SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]: + completion.usage.completion_tokens, + [SemanticConventions.LLM_TOKEN_COUNT_PROMPT]: + completion.usage.prompt_tokens, + [SemanticConventions.LLM_TOKEN_COUNT_TOTAL]: + completion.usage.total_tokens, + }; } return {}; } @@ -277,26 +277,21 @@ function getUsageAttributes( * Converts the result to LLM output attributes */ function getLLMOutputMessagesAttributes( - response: Stream | ChatCompletion, + completion: ChatCompletion, ): Attributes { - // Handle chat completion - if (Object.prototype.hasOwnProperty.call(response, "choices")) { - const completion = response as ChatCompletion; - // Right now support just the first choice - const choice = completion.choices[0]; - if (!choice) { - return {}; - } - return [choice.message].reduce((acc, message, index) => { - const index_prefix = `${SemanticConventions.LLM_OUTPUT_MESSAGES}.${index}`; - acc[`${index_prefix}.${SemanticConventions.MESSAGE_CONTENT}`] = String( - message.content, - ); - acc[`${index_prefix}.${SemanticConventions.MESSAGE_ROLE}`] = message.role; - return acc; - }, {} as Attributes); + // Right now support just the first choice + const choice = completion.choices[0]; + if (!choice) { + return {}; } - return {}; + return [choice.message].reduce((acc, message, index) => { + const index_prefix = `${SemanticConventions.LLM_OUTPUT_MESSAGES}.${index}`; + acc[`${index_prefix}.${SemanticConventions.MESSAGE_CONTENT}`] = String( + message.content, + ); + acc[`${index_prefix}.${SemanticConventions.MESSAGE_ROLE}`] = message.role; + return acc; + }, {} as Attributes); } /** @@ -338,3 +333,27 @@ function getEmbeddingEmbeddingsAttributes( return acc; }, {} as Attributes); } + +/** + * Consumes the stream chunks and adds them to the span + */ +async function consumeChatCompletionStreamChunks( + stream: Stream, + span: Span, +) { + let streamResponse = ""; + for await (const chunk of stream) { + if (chunk.choices.length > 0 && chunk.choices[0].delta.content) { + streamResponse += chunk.choices[0].delta.content; + } + } + span.setAttributes({ + [SemanticConventions.OUTPUT_VALUE]: streamResponse, + [SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.TEXT, + [`${SemanticConventions.LLM_OUTPUT_MESSAGES}.0.${SemanticConventions.MESSAGE_CONTENT}`]: + streamResponse, + [`${SemanticConventions.LLM_OUTPUT_MESSAGES}.0.${SemanticConventions.MESSAGE_ROLE}`]: + "assistant", + }); + span.end(); +} diff --git a/js/packages/openinference-instrumentation-openai/test/openai.test.ts b/js/packages/openinference-instrumentation-openai/test/openai.test.ts index a768fee77..12d796257 100644 --- a/js/packages/openinference-instrumentation-openai/test/openai.test.ts +++ b/js/packages/openinference-instrumentation-openai/test/openai.test.ts @@ -11,6 +11,7 @@ const instrumentation = new OpenAIInstrumentation(); instrumentation.disable(); import * as OpenAI from "openai"; +import { Stream } from "openai/streaming"; describe("OpenAIInstrumentation", () => { let openai: OpenAI.OpenAI; @@ -27,7 +28,7 @@ describe("OpenAIInstrumentation", () => { beforeAll(() => { instrumentation.enable(); openai = new OpenAI.OpenAI({ - apiKey: `fake-api-key`, + apiKey: "fake-api-key", }); }); afterAll(() => { @@ -136,4 +137,51 @@ describe("OpenAIInstrumentation", () => { } `); }); + it("can handle streaming responses", async () => { + // Mock out the embedding create endpoint + jest.spyOn(openai, "post").mockImplementation( + // @ts-expect-error the response type is not correct - this is just for testing + async (): Promise => { + const iterator = () => + (async function* () { + yield { choices: [{ delta: { content: "This is " } }] }; + yield { choices: [{ delta: { content: "a test." } }] }; + yield { choices: [{ delta: { finish_reason: "stop" } }] }; + })(); + const controller = new AbortController(); + return new Stream(iterator, controller); + }, + ); + const stream = await openai.chat.completions.create({ + messages: [{ role: "user", content: "Say this is a test" }], + model: "gpt-3.5-turbo", + stream: true, + }); + + let response = ""; + for await (const chunk of stream) { + if (chunk.choices[0].delta.content) + response += chunk.choices[0].delta.content; + } + expect(response).toBe("This is a test."); + const spans = memoryExporter.getFinishedSpans(); + expect(spans.length).toBe(1); + const span = spans[0]; + expect(span.name).toBe("OpenAI Chat Completions"); + expect(span.attributes).toMatchInlineSnapshot(` + { + "input.mime_type": "application/json", + "input.value": "{"messages":[{"role":"user","content":"Say this is a test"}],"model":"gpt-3.5-turbo","stream":true}", + "llm.input_messages.0.message.content": "Say this is a test", + "llm.input_messages.0.message.role": "user", + "llm.invocation_parameters": "{"model":"gpt-3.5-turbo","stream":true}", + "llm.model_name": "gpt-3.5-turbo", + "llm.output_messages.0.message.content": "This is a test.", + "llm.output_messages.0.message.role": "assistant", + "openinference.span.kind": "llm", + "output.mime_type": "text/plain", + "output.value": "This is a test.", + } + `); + }); });