diff --git a/langchain-core/src/prompts/chat.ts b/langchain-core/src/prompts/chat.ts index 0fcb0330489b..0d32610fbef0 100644 --- a/langchain-core/src/prompts/chat.ts +++ b/langchain-core/src/prompts/chat.ts @@ -12,6 +12,7 @@ import { coerceMessageLikeToMessage, isBaseMessage, MessageContent, + MessageContentComplex, } from "../messages/index.js"; import { type ChatPromptValueInterface, @@ -493,7 +494,14 @@ class _StringImageMessagePromptTemplate< } else if (typeof item.text === "string") { text = item.text ?? ""; } - prompt.push(PromptTemplate.fromTemplate(text, additionalOptions)); + + const options = { + ...additionalOptions, + ...(typeof item !== "string" + ? { additionalContentFields: item } + : {}), + }; + prompt.push(PromptTemplate.fromTemplate(text, options)); } else if (typeof item === "object" && "image_url" in item) { let imgTemplate = item.image_url ?? ""; let imgTemplateObject: ImagePromptTemplate; @@ -526,6 +534,7 @@ class _StringImageMessagePromptTemplate< template: imgTemplate, inputVariables, templateFormat: additionalOptions?.templateFormat, + additionalContentFields: item, }); } else if (typeof imgTemplate === "object") { if ("url" in imgTemplate) { @@ -546,6 +555,7 @@ class _StringImageMessagePromptTemplate< template: imgTemplate, inputVariables, templateFormat: additionalOptions?.templateFormat, + additionalContentFields: item, }); } else { throw new Error("Invalid image template"); @@ -583,17 +593,34 @@ class _StringImageMessagePromptTemplate< const formatted = await prompt.format( inputs as TypedPromptInputValues ); - content.push({ type: "text", text: formatted }); + let additionalContentFields: MessageContentComplex | undefined; + if ("additionalContentFields" in prompt) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + additionalContentFields = prompt.additionalContentFields as any; + } + content.push({ + ...additionalContentFields, + type: "text", + text: formatted, + }); /** @TODO replace this */ // eslint-disable-next-line no-instanceof/no-instanceof } else if (prompt instanceof ImagePromptTemplate) { const formatted = await prompt.format( inputs as TypedPromptInputValues ); - content.push({ type: "image_url", image_url: formatted }); + let additionalContentFields: MessageContentComplex | undefined; + if ("additionalContentFields" in prompt) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + additionalContentFields = prompt.additionalContentFields as any; + } + content.push({ + ...additionalContentFields, + type: "image_url", + image_url: formatted, + }); } } - return this.createMessage(content); } } @@ -769,9 +796,9 @@ function _coerceMessagePromptTemplateLike< // Assuming message.content is an array of complex objects, transform it. templateData = message.content.map((item) => { if ("text" in item) { - return { text: item.text }; + return { ...item, text: item.text }; } else if ("image_url" in item) { - return { image_url: item.image_url }; + return { ...item, image_url: item.image_url }; } else { return item; } diff --git a/langchain-core/src/prompts/image.ts b/langchain-core/src/prompts/image.ts index 397b36bb0867..9f92f5c1630e 100644 --- a/langchain-core/src/prompts/image.ts +++ b/langchain-core/src/prompts/image.ts @@ -1,4 +1,4 @@ -import { MessageContent } from "../messages/index.js"; +import { MessageContent, MessageContentComplex } from "../messages/index.js"; import { ImagePromptValue, ImageContent } from "../prompt_values.js"; import type { InputValues, PartialValues } from "../utils/types/index.js"; import { @@ -40,6 +40,14 @@ export interface ImagePromptTemplateInput< * @defaultValue `true` */ validateTemplate?: boolean; + + /** + * Additional fields which should be included inside + * the message content array if using a complex message + * content. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + additionalContentFields?: MessageContentComplex; } /** @@ -63,11 +71,20 @@ export class ImagePromptTemplate< validateTemplate = true; + /** + * Additional fields which should be included inside + * the message content array if using a complex message + * content. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + additionalContentFields?: MessageContentComplex; + constructor(input: ImagePromptTemplateInput) { super(input); this.template = input.template; this.templateFormat = input.templateFormat ?? this.templateFormat; this.validateTemplate = input.validateTemplate ?? this.validateTemplate; + this.additionalContentFields = input.additionalContentFields; if (this.validateTemplate) { let totalInputVariables: string[] = this.inputVariables; diff --git a/langchain-core/src/prompts/prompt.ts b/langchain-core/src/prompts/prompt.ts index 58d3982b6852..998715895141 100644 --- a/langchain-core/src/prompts/prompt.ts +++ b/langchain-core/src/prompts/prompt.ts @@ -14,7 +14,7 @@ import { } from "./template.js"; import type { SerializedPromptTemplate } from "./serde.js"; import type { InputValues, PartialValues } from "../utils/types/index.js"; -import { MessageContent } from "../messages/index.js"; +import { MessageContent, MessageContentComplex } from "../messages/index.js"; /** * Inputs to create a {@link PromptTemplate} @@ -43,6 +43,14 @@ export interface PromptTemplateInput< * @defaultValue `true` */ validateTemplate?: boolean; + + /** + * Additional fields which should be included inside + * the message content array if using a complex message + * content. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + additionalContentFields?: MessageContentComplex; } type NonAlphanumeric = @@ -120,6 +128,14 @@ export class PromptTemplate< validateTemplate = true; + /** + * Additional fields which should be included inside + * the message content array if using a complex message + * content. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + additionalContentFields?: MessageContentComplex; + constructor(input: PromptTemplateInput) { super(input); // If input is mustache and validateTemplate is not defined, set it to false @@ -251,6 +267,7 @@ export class PromptTemplate< names.add(node.name); } }); + return new PromptTemplate({ // Rely on extracted types // eslint-disable-next-line @typescript-eslint/no-explicit-any diff --git a/langchain-core/src/prompts/tests/chat.test.ts b/langchain-core/src/prompts/tests/chat.test.ts index 3f5125861a73..b3fc38958ddc 100644 --- a/langchain-core/src/prompts/tests/chat.test.ts +++ b/langchain-core/src/prompts/tests/chat.test.ts @@ -622,3 +622,75 @@ test("Multi-modal, multi part chat prompt works with instances of BaseMessage", }); expect(messages).toMatchSnapshot(); }); + +test("Format complex messages and keep additional fields", async () => { + const examplePrompt = ChatPromptTemplate.fromMessages([ + [ + "human", + [ + { + type: "text", + text: "{input}", + cache_control: { type: "ephemeral" }, + }, + ], + ], + [ + "ai", + [ + { + type: "text", + text: "{output}", + cache_control: { type: "ephemeral" }, + }, + ], + ], + ]); + const formatted = await examplePrompt.formatMessages({ + input: "hello", + output: "ciao", + }); + + expect(formatted).toHaveLength(2); + + expect(formatted[0]._getType()).toBe("human"); + expect(formatted[0].content[0]).toHaveProperty("cache_control"); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + expect((formatted[0].content[0] as any).cache_control).toEqual({ + type: "ephemeral", + }); + + expect(formatted[1]._getType()).toBe("ai"); + expect(formatted[1].content[0]).toHaveProperty("cache_control"); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + expect((formatted[1].content[0] as any).cache_control).toEqual({ + type: "ephemeral", + }); +}); + +test("Format image content messages and keep additional fields", async () => { + const examplePrompt = ChatPromptTemplate.fromMessages([ + [ + "human", + [ + { + type: "image_url", + image_url: "{image_url}", + cache_control: { type: "ephemeral" }, + }, + ], + ], + ]); + const formatted = await examplePrompt.formatMessages({ + image_url: "image_url", + }); + + expect(formatted).toHaveLength(1); + + expect(formatted[0]._getType()).toBe("human"); + expect(formatted[0].content[0]).toHaveProperty("cache_control"); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + expect((formatted[0].content[0] as any).cache_control).toEqual({ + type: "ephemeral", + }); +});