Skip to content

Commit

Permalink
feat(js): opanai chat completion streaming support (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeldking authored Jan 10, 2024
1 parent ec83fb8 commit 82c5d83
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 42 deletions.
5 changes: 5 additions & 0 deletions js/.changeset/slimy-pears-boil.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@arizeai/openinference-instrumentation-openai": minor
---

Add streaming instrumentation for OpenAI Chat completion
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
SpanKind,
Attributes,
SpanStatusCode,
Span,
} from "@opentelemetry/api";
import { VERSION } from "./version";
import {
Expand Down Expand Up @@ -119,23 +120,27 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
},
);
const wrappedPromise = execPromise.then((result) => {
if (result) {
if (isChatCompletionResponse(result)) {
// Record the results
span.setAttributes({
[SemanticConventions.OUTPUT_VALUE]: JSON.stringify(result),
[SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.JSON,
// Override the model from the value sent by the server
[SemanticConventions.LLM_MODEL_NAME]: isChatCompletionResponse(
result,
)
? result.model
: body.model,
[SemanticConventions.LLM_MODEL_NAME]: result.model,
...getLLMOutputMessagesAttributes(result),
...getUsageAttributes(result),
});
span.setStatus({ code: SpanStatusCode.OK });
span.end();
} else {
// This is a streaming response
// handle the chunks and add them to the span
// First split the stream via tee
const [leftStream, rightStream] = result.tee();
consumeChatCompletionStreamChunks(rightStream, span);
result = leftStream;
}
span.setStatus({ code: SpanStatusCode.OK });
span.end();

return result;
});
return context.bind(execContext, wrappedPromise);
Expand Down Expand Up @@ -254,21 +259,16 @@ function getLLMInputMessagesAttributes(
/**
* Get Usage attributes
*/
function getUsageAttributes(
response: Stream<ChatCompletionChunk> | ChatCompletion,
) {
if (Object.prototype.hasOwnProperty.call(response, "usage")) {
const completion = response as ChatCompletion;
if (completion.usage) {
return {
[SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]:
completion.usage.completion_tokens,
[SemanticConventions.LLM_TOKEN_COUNT_PROMPT]:
completion.usage.prompt_tokens,
[SemanticConventions.LLM_TOKEN_COUNT_TOTAL]:
completion.usage.total_tokens,
};
}
function getUsageAttributes(completion: ChatCompletion): Attributes {
if (completion.usage) {
return {
[SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]:
completion.usage.completion_tokens,
[SemanticConventions.LLM_TOKEN_COUNT_PROMPT]:
completion.usage.prompt_tokens,
[SemanticConventions.LLM_TOKEN_COUNT_TOTAL]:
completion.usage.total_tokens,
};
}
return {};
}
Expand All @@ -277,26 +277,21 @@ function getUsageAttributes(
* Converts the result to LLM output attributes
*/
function getLLMOutputMessagesAttributes(
response: Stream<ChatCompletionChunk> | ChatCompletion,
completion: ChatCompletion,
): Attributes {
// Handle chat completion
if (Object.prototype.hasOwnProperty.call(response, "choices")) {
const completion = response as ChatCompletion;
// Right now support just the first choice
const choice = completion.choices[0];
if (!choice) {
return {};
}
return [choice.message].reduce((acc, message, index) => {
const index_prefix = `${SemanticConventions.LLM_OUTPUT_MESSAGES}.${index}`;
acc[`${index_prefix}.${SemanticConventions.MESSAGE_CONTENT}`] = String(
message.content,
);
acc[`${index_prefix}.${SemanticConventions.MESSAGE_ROLE}`] = message.role;
return acc;
}, {} as Attributes);
// Right now support just the first choice
const choice = completion.choices[0];
if (!choice) {
return {};
}
return {};
return [choice.message].reduce((acc, message, index) => {
const index_prefix = `${SemanticConventions.LLM_OUTPUT_MESSAGES}.${index}`;
acc[`${index_prefix}.${SemanticConventions.MESSAGE_CONTENT}`] = String(
message.content,
);
acc[`${index_prefix}.${SemanticConventions.MESSAGE_ROLE}`] = message.role;
return acc;
}, {} as Attributes);
}

/**
Expand Down Expand Up @@ -338,3 +333,27 @@ function getEmbeddingEmbeddingsAttributes(
return acc;
}, {} as Attributes);
}

/**
* Consumes the stream chunks and adds them to the span
*/
async function consumeChatCompletionStreamChunks(
stream: Stream<ChatCompletionChunk>,
span: Span,
) {
let streamResponse = "";
for await (const chunk of stream) {
if (chunk.choices.length > 0 && chunk.choices[0].delta.content) {
streamResponse += chunk.choices[0].delta.content;
}
}
span.setAttributes({
[SemanticConventions.OUTPUT_VALUE]: streamResponse,
[SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.TEXT,
[`${SemanticConventions.LLM_OUTPUT_MESSAGES}.0.${SemanticConventions.MESSAGE_CONTENT}`]:
streamResponse,
[`${SemanticConventions.LLM_OUTPUT_MESSAGES}.0.${SemanticConventions.MESSAGE_ROLE}`]:
"assistant",
});
span.end();
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const instrumentation = new OpenAIInstrumentation();
instrumentation.disable();

import * as OpenAI from "openai";
import { Stream } from "openai/streaming";

describe("OpenAIInstrumentation", () => {
let openai: OpenAI.OpenAI;
Expand All @@ -27,7 +28,7 @@ describe("OpenAIInstrumentation", () => {
beforeAll(() => {
instrumentation.enable();
openai = new OpenAI.OpenAI({
apiKey: `fake-api-key`,
apiKey: "fake-api-key",
});
});
afterAll(() => {
Expand Down Expand Up @@ -136,4 +137,51 @@ describe("OpenAIInstrumentation", () => {
}
`);
});
it("can handle streaming responses", async () => {
// Mock out the embedding create endpoint
jest.spyOn(openai, "post").mockImplementation(
// @ts-expect-error the response type is not correct - this is just for testing
async (): Promise<unknown> => {
const iterator = () =>
(async function* () {
yield { choices: [{ delta: { content: "This is " } }] };
yield { choices: [{ delta: { content: "a test." } }] };
yield { choices: [{ delta: { finish_reason: "stop" } }] };
})();
const controller = new AbortController();
return new Stream(iterator, controller);
},
);
const stream = await openai.chat.completions.create({
messages: [{ role: "user", content: "Say this is a test" }],
model: "gpt-3.5-turbo",
stream: true,
});

let response = "";
for await (const chunk of stream) {
if (chunk.choices[0].delta.content)
response += chunk.choices[0].delta.content;
}
expect(response).toBe("This is a test.");
const spans = memoryExporter.getFinishedSpans();
expect(spans.length).toBe(1);
const span = spans[0];
expect(span.name).toBe("OpenAI Chat Completions");
expect(span.attributes).toMatchInlineSnapshot(`
{
"input.mime_type": "application/json",
"input.value": "{"messages":[{"role":"user","content":"Say this is a test"}],"model":"gpt-3.5-turbo","stream":true}",
"llm.input_messages.0.message.content": "Say this is a test",
"llm.input_messages.0.message.role": "user",
"llm.invocation_parameters": "{"model":"gpt-3.5-turbo","stream":true}",
"llm.model_name": "gpt-3.5-turbo",
"llm.output_messages.0.message.content": "This is a test.",
"llm.output_messages.0.message.role": "assistant",
"openinference.span.kind": "llm",
"output.mime_type": "text/plain",
"output.value": "This is a test.",
}
`);
});
});

0 comments on commit 82c5d83

Please sign in to comment.