From e0fc2a41e9b09da301a96075956a44863ff827cc Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Tue, 21 Jan 2025 10:01:22 -0800 Subject: [PATCH] fix(core): Ensure that cached flag in run extras is only set for cache hits (#7566) --- langchain-core/src/callbacks/base.ts | 6 +++-- langchain-core/src/callbacks/manager.ts | 22 +++++++++++++--- .../src/language_models/chat_models.ts | 25 +++++++++++++++---- langchain-core/src/language_models/llms.ts | 25 +++++++++++++++---- .../language_models/tests/chat_models.test.ts | 15 +++++++++-- .../src/language_models/tests/llms.test.ts | 15 +++++++++-- langchain-core/src/tracers/base.ts | 18 +++++++++++-- .../src/tracers/tracer_langchain.ts | 1 + 8 files changed, 105 insertions(+), 22 deletions(-) diff --git a/langchain-core/src/callbacks/base.ts b/langchain-core/src/callbacks/base.ts index 923dde568093..2dad82102fb9 100644 --- a/langchain-core/src/callbacks/base.ts +++ b/langchain-core/src/callbacks/base.ts @@ -97,7 +97,8 @@ abstract class BaseCallbackHandlerMethodsClass { err: Error, runId: string, parentRunId?: string, - tags?: string[] + tags?: string[], + extraParams?: Record ): // eslint-disable-next-line @typescript-eslint/no-explicit-any Promise | any; @@ -108,7 +109,8 @@ abstract class BaseCallbackHandlerMethodsClass { output: LLMResult, runId: string, parentRunId?: string, - tags?: string[] + tags?: string[], + extraParams?: Record ): // eslint-disable-next-line @typescript-eslint/no-explicit-any Promise | any; diff --git a/langchain-core/src/callbacks/manager.ts b/langchain-core/src/callbacks/manager.ts index 55ad0484e480..9ca2cabd39c8 100644 --- a/langchain-core/src/callbacks/manager.ts +++ b/langchain-core/src/callbacks/manager.ts @@ -303,7 +303,13 @@ export class CallbackManagerForLLMRun ); } - async handleLLMError(err: Error | unknown): Promise { + async handleLLMError( + err: Error | unknown, + _runId?: string, + _parentRunId?: string, + _tags?: string[], + extraParams?: Record + ): Promise { await Promise.all( this.handlers.map((handler) => consumeCallback(async () => { @@ -313,7 +319,8 @@ export class CallbackManagerForLLMRun err, this.runId, this._parentRunId, - this.tags + this.tags, + extraParams ); } catch (err) { const logFunction = handler.raiseError @@ -332,7 +339,13 @@ export class CallbackManagerForLLMRun ); } - async handleLLMEnd(output: LLMResult): Promise { + async handleLLMEnd( + output: LLMResult, + _runId?: string, + _parentRunId?: string, + _tags?: string[], + extraParams?: Record + ): Promise { await Promise.all( this.handlers.map((handler) => consumeCallback(async () => { @@ -342,7 +355,8 @@ export class CallbackManagerForLLMRun output, this.runId, this._parentRunId, - this.tags + this.tags, + extraParams ); } catch (err) { const logFunction = handler.raiseError diff --git a/langchain-core/src/language_models/chat_models.ts b/langchain-core/src/language_models/chat_models.ts index 36feee110abe..4e4fdb548c52 100644 --- a/langchain-core/src/language_models/chat_models.ts +++ b/langchain-core/src/language_models/chat_models.ts @@ -550,7 +550,6 @@ export abstract class BaseChatModel< options: parsedOptions, invocation_params: this?.invocationParams(parsedOptions), batch_size: 1, - cached: true, }; const runManagers = await callbackManager_?.handleChatModelStart( this.toJSON(), @@ -619,12 +618,28 @@ export abstract class BaseChatModel< if (result.length) { await runManager?.handleLLMNewToken(result[0].text); } - return runManager?.handleLLMEnd({ - generations: [result], - }); + return runManager?.handleLLMEnd( + { + generations: [result], + }, + undefined, + undefined, + undefined, + { + cached: true, + } + ); } else { // status === "rejected" - await runManager?.handleLLMError(promiseResult.reason); + await runManager?.handleLLMError( + promiseResult.reason, + undefined, + undefined, + undefined, + { + cached: true, + } + ); return Promise.reject(promiseResult.reason); } }) diff --git a/langchain-core/src/language_models/llms.ts b/langchain-core/src/language_models/llms.ts index 63e18cb9a0b3..fcb8228adf6a 100644 --- a/langchain-core/src/language_models/llms.ts +++ b/langchain-core/src/language_models/llms.ts @@ -374,7 +374,6 @@ export abstract class BaseLLM< options: parsedOptions, invocation_params: this?.invocationParams(parsedOptions), batch_size: prompts.length, - cached: true, }; const runManagers = await callbackManager_?.handleLLMStart( this.toJSON(), @@ -426,12 +425,28 @@ export abstract class BaseLLM< if (result.length) { await runManager?.handleLLMNewToken(result[0].text); } - return runManager?.handleLLMEnd({ - generations: [result], - }); + return runManager?.handleLLMEnd( + { + generations: [result], + }, + undefined, + undefined, + undefined, + { + cached: true, + } + ); } else { // status === "rejected" - await runManager?.handleLLMError(promiseResult.reason); + await runManager?.handleLLMError( + promiseResult.reason, + undefined, + undefined, + undefined, + { + cached: true, + } + ); return Promise.reject(promiseResult.reason); } }) diff --git a/langchain-core/src/language_models/tests/chat_models.test.ts b/langchain-core/src/language_models/tests/chat_models.test.ts index 8598d7aa6cd3..3a8e8fc432c3 100644 --- a/langchain-core/src/language_models/tests/chat_models.test.ts +++ b/langchain-core/src/language_models/tests/chat_models.test.ts @@ -7,6 +7,7 @@ import { FakeChatModel, FakeListChatModel } from "../../utils/testing/index.js"; import { HumanMessage } from "../../messages/human.js"; import { getBufferString } from "../../messages/utils.js"; import { AIMessage } from "../../messages/ai.js"; +import { RunCollectorCallbackHandler } from "../../tracers/run_collector.js"; test("Test ChatModel accepts array shorthand for messages", async () => { const model = new FakeChatModel({}); @@ -311,8 +312,13 @@ test("Test ChatModel with cache does not start multiple chat model runs", async const value = await model.cache.lookup(prompt, llmKey); expect(value).toBeNull(); + const runCollector = new RunCollectorCallbackHandler(); + // Invoke model to trigger cache update - const eventStream = model.streamEvents([humanMessage], { version: "v2" }); + const eventStream = model.streamEvents([humanMessage], { + version: "v2", + callbacks: [runCollector], + }); expect(await model.cache.lookup(prompt, llmKey)).toBeDefined(); @@ -323,8 +329,12 @@ test("Test ChatModel with cache does not start multiple chat model runs", async expect(events.length).toEqual(2); expect(events[0].event).toEqual("on_chat_model_start"); expect(events[1].event).toEqual("on_chat_model_end"); + expect(runCollector.tracedRuns[0].extra?.cached).not.toBe(true); - const eventStream2 = model.streamEvents([humanMessage], { version: "v2" }); + const eventStream2 = model.streamEvents([humanMessage], { + version: "v2", + callbacks: [runCollector], + }); const events2 = []; for await (const event of eventStream2) { @@ -333,6 +343,7 @@ test("Test ChatModel with cache does not start multiple chat model runs", async expect(events2.length).toEqual(2); expect(events2[0].event).toEqual("on_chat_model_start"); expect(events2[1].event).toEqual("on_chat_model_end"); + expect(runCollector.tracedRuns[1].extra?.cached).toBe(true); }); test("Test ChatModel can emit a custom event", async () => { diff --git a/langchain-core/src/language_models/tests/llms.test.ts b/langchain-core/src/language_models/tests/llms.test.ts index f1cf453bfc75..d8b3e146d07b 100644 --- a/langchain-core/src/language_models/tests/llms.test.ts +++ b/langchain-core/src/language_models/tests/llms.test.ts @@ -3,6 +3,7 @@ import { test, expect } from "@jest/globals"; import { FakeLLM, FakeStreamingLLM } from "../../utils/testing/index.js"; import { HumanMessagePromptTemplate } from "../../prompts/chat.js"; +import { RunCollectorCallbackHandler } from "../../tracers/run_collector.js"; test("Test FakeLLM uses callbacks", async () => { const model = new FakeLLM({}); @@ -50,8 +51,13 @@ test("Test LLM with cache does not start multiple LLM runs", async () => { throw new Error("Cache not enabled"); } + const runCollector = new RunCollectorCallbackHandler(); + // Invoke model to trigger cache update - const eventStream = model.streamEvents("Hello there!", { version: "v2" }); + const eventStream = model.streamEvents("Hello there!", { + version: "v2", + callbacks: [runCollector], + }); const events = []; for await (const event of eventStream) { @@ -60,8 +66,12 @@ test("Test LLM with cache does not start multiple LLM runs", async () => { expect(events.length).toEqual(2); expect(events[0].event).toEqual("on_llm_start"); expect(events[1].event).toEqual("on_llm_end"); + expect(runCollector.tracedRuns[0].extra?.cached).not.toBe(true); - const eventStream2 = model.streamEvents("Hello there!", { version: "v2" }); + const eventStream2 = model.streamEvents("Hello there!", { + version: "v2", + callbacks: [runCollector], + }); const events2 = []; for await (const event of eventStream2) { @@ -70,6 +80,7 @@ test("Test LLM with cache does not start multiple LLM runs", async () => { expect(events2.length).toEqual(2); expect(events2[0].event).toEqual("on_llm_start"); expect(events2[1].event).toEqual("on_llm_end"); + expect(runCollector.tracedRuns[1].extra?.cached).toBe(true); }); test("Test FakeStreamingLLM works when streaming through a prompt", async () => { diff --git a/langchain-core/src/tracers/base.ts b/langchain-core/src/tracers/base.ts index c9eb56821a6c..d839992efd78 100644 --- a/langchain-core/src/tracers/base.ts +++ b/langchain-core/src/tracers/base.ts @@ -298,7 +298,13 @@ export abstract class BaseTracer extends BaseCallbackHandler { return run; } - async handleLLMEnd(output: LLMResult, runId: string): Promise { + async handleLLMEnd( + output: LLMResult, + runId: string, + _parentRunId?: string, + _tags?: string[], + extraParams?: Record + ): Promise { const run = this.runMap.get(runId); if (!run || run?.run_type !== "llm") { throw new Error("No LLM run to end."); @@ -309,12 +315,19 @@ export abstract class BaseTracer extends BaseCallbackHandler { name: "end", time: new Date(run.end_time).toISOString(), }); + run.extra = { ...run.extra, ...extraParams }; await this.onLLMEnd?.(run); await this._endTrace(run); return run; } - async handleLLMError(error: unknown, runId: string): Promise { + async handleLLMError( + error: unknown, + runId: string, + _parentRunId?: string, + _tags?: string[], + extraParams?: Record + ): Promise { const run = this.runMap.get(runId); if (!run || run?.run_type !== "llm") { throw new Error("No LLM run to end."); @@ -325,6 +338,7 @@ export abstract class BaseTracer extends BaseCallbackHandler { name: "error", time: new Date(run.end_time).toISOString(), }); + run.extra = { ...run.extra, ...extraParams }; await this.onLLMError?.(run); await this._endTrace(run); return run; diff --git a/langchain-core/src/tracers/tracer_langchain.ts b/langchain-core/src/tracers/tracer_langchain.ts index 8a58a8e8b119..71c6ff0fa0dc 100644 --- a/langchain-core/src/tracers/tracer_langchain.ts +++ b/langchain-core/src/tracers/tracer_langchain.ts @@ -104,6 +104,7 @@ export class LangChainTracer trace_id: run.trace_id, dotted_order: run.dotted_order, parent_run_id: run.parent_run_id, + extra: run.extra, }; await this.client.updateRun(run.id, runUpdate); }