Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): Ensure that cached flag in run extras is only set for cache hits #7566

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions langchain-core/src/callbacks/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ abstract class BaseCallbackHandlerMethodsClass {
err: Error,
runId: string,
parentRunId?: string,
tags?: string[]
tags?: string[],
extraParams?: Record<string, unknown>
): // eslint-disable-next-line @typescript-eslint/no-explicit-any
Promise<any> | any;

Expand All @@ -108,7 +109,8 @@ abstract class BaseCallbackHandlerMethodsClass {
output: LLMResult,
runId: string,
parentRunId?: string,
tags?: string[]
tags?: string[],
extraParams?: Record<string, unknown>
): // eslint-disable-next-line @typescript-eslint/no-explicit-any
Promise<any> | any;

Expand Down
22 changes: 18 additions & 4 deletions langchain-core/src/callbacks/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,13 @@ export class CallbackManagerForLLMRun
);
}

async handleLLMError(err: Error | unknown): Promise<void> {
async handleLLMError(
err: Error | unknown,
_runId?: string,
_parentRunId?: string,
_tags?: string[],
extraParams?: Record<string, unknown>
): Promise<void> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
Expand All @@ -313,7 +319,8 @@ export class CallbackManagerForLLMRun
err,
this.runId,
this._parentRunId,
this.tags
this.tags,
extraParams
);
} catch (err) {
const logFunction = handler.raiseError
Expand All @@ -332,7 +339,13 @@ export class CallbackManagerForLLMRun
);
}

async handleLLMEnd(output: LLMResult): Promise<void> {
async handleLLMEnd(
output: LLMResult,
_runId?: string,
_parentRunId?: string,
_tags?: string[],
extraParams?: Record<string, unknown>
): Promise<void> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
Expand All @@ -342,7 +355,8 @@ export class CallbackManagerForLLMRun
output,
this.runId,
this._parentRunId,
this.tags
this.tags,
extraParams
);
} catch (err) {
const logFunction = handler.raiseError
Expand Down
25 changes: 20 additions & 5 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,6 @@ export abstract class BaseChatModel<
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: 1,
cached: true,
};
const runManagers = await callbackManager_?.handleChatModelStart(
this.toJSON(),
Expand Down Expand Up @@ -619,12 +618,28 @@ export abstract class BaseChatModel<
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
return runManager?.handleLLMEnd({
generations: [result],
});
return runManager?.handleLLMEnd(
{
generations: [result],
},
undefined,
undefined,
undefined,
{
cached: true,
}
);
} else {
// status === "rejected"
await runManager?.handleLLMError(promiseResult.reason);
await runManager?.handleLLMError(
promiseResult.reason,
undefined,
undefined,
undefined,
{
cached: true,
}
);
return Promise.reject(promiseResult.reason);
}
})
Expand Down
25 changes: 20 additions & 5 deletions langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ export abstract class BaseLLM<
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: prompts.length,
cached: true,
};
const runManagers = await callbackManager_?.handleLLMStart(
this.toJSON(),
Expand Down Expand Up @@ -426,12 +425,28 @@ export abstract class BaseLLM<
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
return runManager?.handleLLMEnd({
generations: [result],
});
return runManager?.handleLLMEnd(
{
generations: [result],
},
undefined,
undefined,
undefined,
{
cached: true,
}
);
} else {
// status === "rejected"
await runManager?.handleLLMError(promiseResult.reason);
await runManager?.handleLLMError(
promiseResult.reason,
undefined,
undefined,
undefined,
{
cached: true,
}
);
return Promise.reject(promiseResult.reason);
}
})
Expand Down
15 changes: 13 additions & 2 deletions langchain-core/src/language_models/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { FakeChatModel, FakeListChatModel } from "../../utils/testing/index.js";
import { HumanMessage } from "../../messages/human.js";
import { getBufferString } from "../../messages/utils.js";
import { AIMessage } from "../../messages/ai.js";
import { RunCollectorCallbackHandler } from "../../tracers/run_collector.js";

test("Test ChatModel accepts array shorthand for messages", async () => {
const model = new FakeChatModel({});
Expand Down Expand Up @@ -311,8 +312,13 @@ test("Test ChatModel with cache does not start multiple chat model runs", async
const value = await model.cache.lookup(prompt, llmKey);
expect(value).toBeNull();

const runCollector = new RunCollectorCallbackHandler();

// Invoke model to trigger cache update
const eventStream = model.streamEvents([humanMessage], { version: "v2" });
const eventStream = model.streamEvents([humanMessage], {
version: "v2",
callbacks: [runCollector],
});

expect(await model.cache.lookup(prompt, llmKey)).toBeDefined();

Expand All @@ -323,8 +329,12 @@ test("Test ChatModel with cache does not start multiple chat model runs", async
expect(events.length).toEqual(2);
expect(events[0].event).toEqual("on_chat_model_start");
expect(events[1].event).toEqual("on_chat_model_end");
expect(runCollector.tracedRuns[0].extra?.cached).not.toBe(true);

const eventStream2 = model.streamEvents([humanMessage], { version: "v2" });
const eventStream2 = model.streamEvents([humanMessage], {
version: "v2",
callbacks: [runCollector],
});

const events2 = [];
for await (const event of eventStream2) {
Expand All @@ -333,6 +343,7 @@ test("Test ChatModel with cache does not start multiple chat model runs", async
expect(events2.length).toEqual(2);
expect(events2[0].event).toEqual("on_chat_model_start");
expect(events2[1].event).toEqual("on_chat_model_end");
expect(runCollector.tracedRuns[1].extra?.cached).toBe(true);
});

test("Test ChatModel can emit a custom event", async () => {
Expand Down
15 changes: 13 additions & 2 deletions langchain-core/src/language_models/tests/llms.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import { test, expect } from "@jest/globals";
import { FakeLLM, FakeStreamingLLM } from "../../utils/testing/index.js";
import { HumanMessagePromptTemplate } from "../../prompts/chat.js";
import { RunCollectorCallbackHandler } from "../../tracers/run_collector.js";

test("Test FakeLLM uses callbacks", async () => {
const model = new FakeLLM({});
Expand Down Expand Up @@ -50,8 +51,13 @@ test("Test LLM with cache does not start multiple LLM runs", async () => {
throw new Error("Cache not enabled");
}

const runCollector = new RunCollectorCallbackHandler();

// Invoke model to trigger cache update
const eventStream = model.streamEvents("Hello there!", { version: "v2" });
const eventStream = model.streamEvents("Hello there!", {
version: "v2",
callbacks: [runCollector],
});

const events = [];
for await (const event of eventStream) {
Expand All @@ -60,8 +66,12 @@ test("Test LLM with cache does not start multiple LLM runs", async () => {
expect(events.length).toEqual(2);
expect(events[0].event).toEqual("on_llm_start");
expect(events[1].event).toEqual("on_llm_end");
expect(runCollector.tracedRuns[0].extra?.cached).not.toBe(true);

const eventStream2 = model.streamEvents("Hello there!", { version: "v2" });
const eventStream2 = model.streamEvents("Hello there!", {
version: "v2",
callbacks: [runCollector],
});

const events2 = [];
for await (const event of eventStream2) {
Expand All @@ -70,6 +80,7 @@ test("Test LLM with cache does not start multiple LLM runs", async () => {
expect(events2.length).toEqual(2);
expect(events2[0].event).toEqual("on_llm_start");
expect(events2[1].event).toEqual("on_llm_end");
expect(runCollector.tracedRuns[1].extra?.cached).toBe(true);
});

test("Test FakeStreamingLLM works when streaming through a prompt", async () => {
Expand Down
18 changes: 16 additions & 2 deletions langchain-core/src/tracers/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,13 @@ export abstract class BaseTracer extends BaseCallbackHandler {
return run;
}

async handleLLMEnd(output: LLMResult, runId: string): Promise<Run> {
async handleLLMEnd(
output: LLMResult,
runId: string,
_parentRunId?: string,
_tags?: string[],
extraParams?: Record<string, unknown>
): Promise<Run> {
const run = this.runMap.get(runId);
if (!run || run?.run_type !== "llm") {
throw new Error("No LLM run to end.");
Expand All @@ -309,12 +315,19 @@ export abstract class BaseTracer extends BaseCallbackHandler {
name: "end",
time: new Date(run.end_time).toISOString(),
});
run.extra = { ...run.extra, ...extraParams };
await this.onLLMEnd?.(run);
await this._endTrace(run);
return run;
}

async handleLLMError(error: unknown, runId: string): Promise<Run> {
async handleLLMError(
error: unknown,
runId: string,
_parentRunId?: string,
_tags?: string[],
extraParams?: Record<string, unknown>
): Promise<Run> {
const run = this.runMap.get(runId);
if (!run || run?.run_type !== "llm") {
throw new Error("No LLM run to end.");
Expand All @@ -325,6 +338,7 @@ export abstract class BaseTracer extends BaseCallbackHandler {
name: "error",
time: new Date(run.end_time).toISOString(),
});
run.extra = { ...run.extra, ...extraParams };
await this.onLLMError?.(run);
await this._endTrace(run);
return run;
Expand Down
1 change: 1 addition & 0 deletions langchain-core/src/tracers/tracer_langchain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ export class LangChainTracer
trace_id: run.trace_id,
dotted_order: run.dotted_order,
parent_run_id: run.parent_run_id,
extra: run.extra,
};
await this.client.updateRun(run.id, runUpdate);
}
Expand Down
Loading