Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core,openai): Add support for disable_streaming, set for o1 #7503

Merged
merged 1 commit into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,18 @@ export type SerializedLLM = {
/**
* Represents the parameters for a base chat model.
*/
export type BaseChatModelParams = BaseLanguageModelParams;
export type BaseChatModelParams = BaseLanguageModelParams & {
/**
* Whether to disable streaming.
*
* If streaming is bypassed, then `stream()` will defer to
* `invoke()`.
*
* - If true, will always bypass streaming case.
* - If false (default), will always use streaming case if available.
*/
disableStreaming?: boolean;
};

/**
* Represents the call options for a base chat model.
Expand Down Expand Up @@ -152,6 +163,8 @@ export abstract class BaseChatModel<
// Only ever instantiated in main LangChain
lc_namespace = ["langchain", "chat_models", this._llmType()];

disableStreaming = false;

constructor(fields: BaseChatModelParams) {
super(fields);
}
Expand Down Expand Up @@ -220,7 +233,8 @@ export abstract class BaseChatModel<
// Subclass check required to avoid double callbacks with default implementation
if (
this._streamResponseChunks ===
BaseChatModel.prototype._streamResponseChunks
BaseChatModel.prototype._streamResponseChunks ||
this.disableStreaming
) {
yield this.invoke(input, options);
} else {
Expand Down
4 changes: 4 additions & 0 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,10 @@ export class ChatOpenAI<
this.streamUsage = false;
}

if (this.model === "o1") {
this.disableStreaming = true;
}

this.streaming = fields?.streaming ?? false;
this.streamUsage = fields?.streamUsage ?? this.streamUsage;

Expand Down
29 changes: 28 additions & 1 deletion libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ describe("Audio output", () => {
});
});

test("Can stream o1 requests", async () => {
test("Can stream o1-mini requests", async () => {
const model = new ChatOpenAI({
model: "o1-mini",
});
Expand All @@ -1192,6 +1192,33 @@ test("Can stream o1 requests", async () => {
expect(numChunks).toBeGreaterThan(3);
});

test("Doesn't stream o1 requests", async () => {
const model = new ChatOpenAI({
model: "o1",
});
const stream = await model.stream(
"Write me a very simple hello world program in Python. Ensure it is wrapped in a function called 'hello_world' and has descriptive comments."
);
let finalMsg: AIMessageChunk | undefined;
let numChunks = 0;
for await (const chunk of stream) {
finalMsg = finalMsg ? concat(finalMsg, chunk) : chunk;
numChunks += 1;
}

expect(finalMsg).toBeTruthy();
if (!finalMsg) {
throw new Error("No final message found");
}
if (typeof finalMsg.content === "string") {
expect(finalMsg.content.length).toBeGreaterThan(10);
} else {
expect(finalMsg.content.length).toBeGreaterThanOrEqual(1);
}

expect(numChunks).toBe(1);
});

test("Allows developer messages with o1", async () => {
const model = new ChatOpenAI({
model: "o1",
Expand Down
Loading