Skip to content

Commit

Permalink
fix(core): Ensure that cached flag in run extras is only set for cach…
Browse files Browse the repository at this point in the history
…e hits (#7566)
  • Loading branch information
jacoblee93 authored Jan 21, 2025
1 parent b6007bb commit e0fc2a4
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 22 deletions.
6 changes: 4 additions & 2 deletions langchain-core/src/callbacks/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ abstract class BaseCallbackHandlerMethodsClass {
err: Error,
runId: string,
parentRunId?: string,
tags?: string[]
tags?: string[],
extraParams?: Record<string, unknown>
): // eslint-disable-next-line @typescript-eslint/no-explicit-any
Promise<any> | any;

Expand All @@ -108,7 +109,8 @@ abstract class BaseCallbackHandlerMethodsClass {
output: LLMResult,
runId: string,
parentRunId?: string,
tags?: string[]
tags?: string[],
extraParams?: Record<string, unknown>
): // eslint-disable-next-line @typescript-eslint/no-explicit-any
Promise<any> | any;

Expand Down
22 changes: 18 additions & 4 deletions langchain-core/src/callbacks/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,13 @@ export class CallbackManagerForLLMRun
);
}

async handleLLMError(err: Error | unknown): Promise<void> {
async handleLLMError(
err: Error | unknown,
_runId?: string,
_parentRunId?: string,
_tags?: string[],
extraParams?: Record<string, unknown>
): Promise<void> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
Expand All @@ -313,7 +319,8 @@ export class CallbackManagerForLLMRun
err,
this.runId,
this._parentRunId,
this.tags
this.tags,
extraParams
);
} catch (err) {
const logFunction = handler.raiseError
Expand All @@ -332,7 +339,13 @@ export class CallbackManagerForLLMRun
);
}

async handleLLMEnd(output: LLMResult): Promise<void> {
async handleLLMEnd(
output: LLMResult,
_runId?: string,
_parentRunId?: string,
_tags?: string[],
extraParams?: Record<string, unknown>
): Promise<void> {
await Promise.all(
this.handlers.map((handler) =>
consumeCallback(async () => {
Expand All @@ -342,7 +355,8 @@ export class CallbackManagerForLLMRun
output,
this.runId,
this._parentRunId,
this.tags
this.tags,
extraParams
);
} catch (err) {
const logFunction = handler.raiseError
Expand Down
25 changes: 20 additions & 5 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,6 @@ export abstract class BaseChatModel<
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: 1,
cached: true,
};
const runManagers = await callbackManager_?.handleChatModelStart(
this.toJSON(),
Expand Down Expand Up @@ -619,12 +618,28 @@ export abstract class BaseChatModel<
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
return runManager?.handleLLMEnd({
generations: [result],
});
return runManager?.handleLLMEnd(
{
generations: [result],
},
undefined,
undefined,
undefined,
{
cached: true,
}
);
} else {
// status === "rejected"
await runManager?.handleLLMError(promiseResult.reason);
await runManager?.handleLLMError(
promiseResult.reason,
undefined,
undefined,
undefined,
{
cached: true,
}
);
return Promise.reject(promiseResult.reason);
}
})
Expand Down
25 changes: 20 additions & 5 deletions langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ export abstract class BaseLLM<
options: parsedOptions,
invocation_params: this?.invocationParams(parsedOptions),
batch_size: prompts.length,
cached: true,
};
const runManagers = await callbackManager_?.handleLLMStart(
this.toJSON(),
Expand Down Expand Up @@ -426,12 +425,28 @@ export abstract class BaseLLM<
if (result.length) {
await runManager?.handleLLMNewToken(result[0].text);
}
return runManager?.handleLLMEnd({
generations: [result],
});
return runManager?.handleLLMEnd(
{
generations: [result],
},
undefined,
undefined,
undefined,
{
cached: true,
}
);
} else {
// status === "rejected"
await runManager?.handleLLMError(promiseResult.reason);
await runManager?.handleLLMError(
promiseResult.reason,
undefined,
undefined,
undefined,
{
cached: true,
}
);
return Promise.reject(promiseResult.reason);
}
})
Expand Down
15 changes: 13 additions & 2 deletions langchain-core/src/language_models/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { FakeChatModel, FakeListChatModel } from "../../utils/testing/index.js";
import { HumanMessage } from "../../messages/human.js";
import { getBufferString } from "../../messages/utils.js";
import { AIMessage } from "../../messages/ai.js";
import { RunCollectorCallbackHandler } from "../../tracers/run_collector.js";

test("Test ChatModel accepts array shorthand for messages", async () => {
const model = new FakeChatModel({});
Expand Down Expand Up @@ -311,8 +312,13 @@ test("Test ChatModel with cache does not start multiple chat model runs", async
const value = await model.cache.lookup(prompt, llmKey);
expect(value).toBeNull();

const runCollector = new RunCollectorCallbackHandler();

// Invoke model to trigger cache update
const eventStream = model.streamEvents([humanMessage], { version: "v2" });
const eventStream = model.streamEvents([humanMessage], {
version: "v2",
callbacks: [runCollector],
});

expect(await model.cache.lookup(prompt, llmKey)).toBeDefined();

Expand All @@ -323,8 +329,12 @@ test("Test ChatModel with cache does not start multiple chat model runs", async
expect(events.length).toEqual(2);
expect(events[0].event).toEqual("on_chat_model_start");
expect(events[1].event).toEqual("on_chat_model_end");
expect(runCollector.tracedRuns[0].extra?.cached).not.toBe(true);

const eventStream2 = model.streamEvents([humanMessage], { version: "v2" });
const eventStream2 = model.streamEvents([humanMessage], {
version: "v2",
callbacks: [runCollector],
});

const events2 = [];
for await (const event of eventStream2) {
Expand All @@ -333,6 +343,7 @@ test("Test ChatModel with cache does not start multiple chat model runs", async
expect(events2.length).toEqual(2);
expect(events2[0].event).toEqual("on_chat_model_start");
expect(events2[1].event).toEqual("on_chat_model_end");
expect(runCollector.tracedRuns[1].extra?.cached).toBe(true);
});

test("Test ChatModel can emit a custom event", async () => {
Expand Down
15 changes: 13 additions & 2 deletions langchain-core/src/language_models/tests/llms.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import { test, expect } from "@jest/globals";
import { FakeLLM, FakeStreamingLLM } from "../../utils/testing/index.js";
import { HumanMessagePromptTemplate } from "../../prompts/chat.js";
import { RunCollectorCallbackHandler } from "../../tracers/run_collector.js";

test("Test FakeLLM uses callbacks", async () => {
const model = new FakeLLM({});
Expand Down Expand Up @@ -50,8 +51,13 @@ test("Test LLM with cache does not start multiple LLM runs", async () => {
throw new Error("Cache not enabled");
}

const runCollector = new RunCollectorCallbackHandler();

// Invoke model to trigger cache update
const eventStream = model.streamEvents("Hello there!", { version: "v2" });
const eventStream = model.streamEvents("Hello there!", {
version: "v2",
callbacks: [runCollector],
});

const events = [];
for await (const event of eventStream) {
Expand All @@ -60,8 +66,12 @@ test("Test LLM with cache does not start multiple LLM runs", async () => {
expect(events.length).toEqual(2);
expect(events[0].event).toEqual("on_llm_start");
expect(events[1].event).toEqual("on_llm_end");
expect(runCollector.tracedRuns[0].extra?.cached).not.toBe(true);

const eventStream2 = model.streamEvents("Hello there!", { version: "v2" });
const eventStream2 = model.streamEvents("Hello there!", {
version: "v2",
callbacks: [runCollector],
});

const events2 = [];
for await (const event of eventStream2) {
Expand All @@ -70,6 +80,7 @@ test("Test LLM with cache does not start multiple LLM runs", async () => {
expect(events2.length).toEqual(2);
expect(events2[0].event).toEqual("on_llm_start");
expect(events2[1].event).toEqual("on_llm_end");
expect(runCollector.tracedRuns[1].extra?.cached).toBe(true);
});

test("Test FakeStreamingLLM works when streaming through a prompt", async () => {
Expand Down
18 changes: 16 additions & 2 deletions langchain-core/src/tracers/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,13 @@ export abstract class BaseTracer extends BaseCallbackHandler {
return run;
}

async handleLLMEnd(output: LLMResult, runId: string): Promise<Run> {
async handleLLMEnd(
output: LLMResult,
runId: string,
_parentRunId?: string,
_tags?: string[],
extraParams?: Record<string, unknown>
): Promise<Run> {
const run = this.runMap.get(runId);
if (!run || run?.run_type !== "llm") {
throw new Error("No LLM run to end.");
Expand All @@ -309,12 +315,19 @@ export abstract class BaseTracer extends BaseCallbackHandler {
name: "end",
time: new Date(run.end_time).toISOString(),
});
run.extra = { ...run.extra, ...extraParams };
await this.onLLMEnd?.(run);
await this._endTrace(run);
return run;
}

async handleLLMError(error: unknown, runId: string): Promise<Run> {
async handleLLMError(
error: unknown,
runId: string,
_parentRunId?: string,
_tags?: string[],
extraParams?: Record<string, unknown>
): Promise<Run> {
const run = this.runMap.get(runId);
if (!run || run?.run_type !== "llm") {
throw new Error("No LLM run to end.");
Expand All @@ -325,6 +338,7 @@ export abstract class BaseTracer extends BaseCallbackHandler {
name: "error",
time: new Date(run.end_time).toISOString(),
});
run.extra = { ...run.extra, ...extraParams };
await this.onLLMError?.(run);
await this._endTrace(run);
return run;
Expand Down
1 change: 1 addition & 0 deletions langchain-core/src/tracers/tracer_langchain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ export class LangChainTracer
trace_id: run.trace_id,
dotted_order: run.dotted_order,
parent_run_id: run.parent_run_id,
extra: run.extra,
};
await this.client.updateRun(run.id, runUpdate);
}
Expand Down

0 comments on commit e0fc2a4

Please sign in to comment.