From a608eeb1c76778fa93370ce292f2775b9a49ba00 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Thu, 12 Sep 2024 13:12:00 -0700 Subject: [PATCH] fix(openai): Avoid thrown error on o1 stream calls (#6747) --- libs/langchain-openai/src/chat_models.ts | 14 ++++++++++++++ .../src/tests/chat_models.int.test.ts | 11 +++++++++++ 2 files changed, 25 insertions(+) diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index cf8cdfcd5e05..3a0cccf13d04 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -14,6 +14,7 @@ import { ToolMessageChunk, OpenAIToolCall, isAIMessage, + convertToChunk, } from "@langchain/core/messages"; import { type ChatGeneration, @@ -1185,6 +1186,19 @@ export class ChatOpenAI< options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): AsyncGenerator { + if (this.model.includes("o1-")) { + console.warn( + "[WARNING]: OpenAI o1 models do not yet support token-level streaming. Streaming will yield single chunk." + ); + const result = await this._generate(messages, options, runManager); + const messageChunk = convertToChunk(result.generations[0].message); + yield new ChatGenerationChunk({ + message: messageChunk, + text: + typeof messageChunk.content === "string" ? messageChunk.content : "", + }); + return; + } const messagesMapped: OpenAICompletionParam[] = convertMessagesToOpenAIParams(messages); const params = { diff --git a/libs/langchain-openai/src/tests/chat_models.int.test.ts b/libs/langchain-openai/src/tests/chat_models.int.test.ts index f51b5121a98c..4e0957d25ec2 100644 --- a/libs/langchain-openai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.int.test.ts @@ -940,3 +940,14 @@ test("populates ID field on AIMessage", async () => { expect(finalChunk?.id?.length).toBeGreaterThan(1); expect(finalChunk?.id?.startsWith("chatcmpl-")).toBe(true); }); + +test("Test ChatOpenAI stream method", async () => { + const model = new ChatOpenAI({ model: "o1-mini" }); + const stream = await model.stream("Print hello world."); + const chunks = []; + for await (const chunk of stream) { + console.log(chunk); + chunks.push(chunk); + } + expect(chunks.length).toEqual(1); +});