Skip to content

Commit

Permalink
core[minor]: Allow for additional fields to be passed with prompt tem…
Browse files Browse the repository at this point in the history
…plates (#6559)

* core[minor]: Allow for additional fields to be passed with prompt templates

* add support for image prompt templates
  • Loading branch information
bracesproul authored Aug 16, 2024
1 parent a29b2d6 commit 3fc7125
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 8 deletions.
39 changes: 33 additions & 6 deletions langchain-core/src/prompts/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
coerceMessageLikeToMessage,
isBaseMessage,
MessageContent,
MessageContentComplex,
} from "../messages/index.js";
import {
type ChatPromptValueInterface,
Expand Down Expand Up @@ -493,7 +494,14 @@ class _StringImageMessagePromptTemplate<
} else if (typeof item.text === "string") {
text = item.text ?? "";
}
prompt.push(PromptTemplate.fromTemplate(text, additionalOptions));

const options = {
...additionalOptions,
...(typeof item !== "string"
? { additionalContentFields: item }
: {}),
};
prompt.push(PromptTemplate.fromTemplate(text, options));
} else if (typeof item === "object" && "image_url" in item) {
let imgTemplate = item.image_url ?? "";
let imgTemplateObject: ImagePromptTemplate<InputValues>;
Expand Down Expand Up @@ -526,6 +534,7 @@ class _StringImageMessagePromptTemplate<
template: imgTemplate,
inputVariables,
templateFormat: additionalOptions?.templateFormat,
additionalContentFields: item,
});
} else if (typeof imgTemplate === "object") {
if ("url" in imgTemplate) {
Expand All @@ -546,6 +555,7 @@ class _StringImageMessagePromptTemplate<
template: imgTemplate,
inputVariables,
templateFormat: additionalOptions?.templateFormat,
additionalContentFields: item,
});
} else {
throw new Error("Invalid image template");
Expand Down Expand Up @@ -583,17 +593,34 @@ class _StringImageMessagePromptTemplate<
const formatted = await prompt.format(
inputs as TypedPromptInputValues<RunInput>
);
content.push({ type: "text", text: formatted });
let additionalContentFields: MessageContentComplex | undefined;
if ("additionalContentFields" in prompt) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
additionalContentFields = prompt.additionalContentFields as any;
}
content.push({
...additionalContentFields,
type: "text",
text: formatted,
});
/** @TODO replace this */
// eslint-disable-next-line no-instanceof/no-instanceof
} else if (prompt instanceof ImagePromptTemplate) {
const formatted = await prompt.format(
inputs as TypedPromptInputValues<RunInput>
);
content.push({ type: "image_url", image_url: formatted });
let additionalContentFields: MessageContentComplex | undefined;
if ("additionalContentFields" in prompt) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
additionalContentFields = prompt.additionalContentFields as any;
}
content.push({
...additionalContentFields,
type: "image_url",
image_url: formatted,
});
}
}

return this.createMessage(content);
}
}
Expand Down Expand Up @@ -769,9 +796,9 @@ function _coerceMessagePromptTemplateLike<
// Assuming message.content is an array of complex objects, transform it.
templateData = message.content.map((item) => {
if ("text" in item) {
return { text: item.text };
return { ...item, text: item.text };
} else if ("image_url" in item) {
return { image_url: item.image_url };
return { ...item, image_url: item.image_url };
} else {
return item;
}
Expand Down
19 changes: 18 additions & 1 deletion langchain-core/src/prompts/image.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { MessageContent } from "../messages/index.js";
import { MessageContent, MessageContentComplex } from "../messages/index.js";
import { ImagePromptValue, ImageContent } from "../prompt_values.js";
import type { InputValues, PartialValues } from "../utils/types/index.js";
import {
Expand Down Expand Up @@ -40,6 +40,14 @@ export interface ImagePromptTemplateInput<
* @defaultValue `true`
*/
validateTemplate?: boolean;

/**
* Additional fields which should be included inside
* the message content array if using a complex message
* content.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
additionalContentFields?: MessageContentComplex;
}

/**
Expand All @@ -63,11 +71,20 @@ export class ImagePromptTemplate<

validateTemplate = true;

/**
* Additional fields which should be included inside
* the message content array if using a complex message
* content.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
additionalContentFields?: MessageContentComplex;

constructor(input: ImagePromptTemplateInput<RunInput, PartialVariableName>) {
super(input);
this.template = input.template;
this.templateFormat = input.templateFormat ?? this.templateFormat;
this.validateTemplate = input.validateTemplate ?? this.validateTemplate;
this.additionalContentFields = input.additionalContentFields;

if (this.validateTemplate) {
let totalInputVariables: string[] = this.inputVariables;
Expand Down
19 changes: 18 additions & 1 deletion langchain-core/src/prompts/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
} from "./template.js";
import type { SerializedPromptTemplate } from "./serde.js";
import type { InputValues, PartialValues } from "../utils/types/index.js";
import { MessageContent } from "../messages/index.js";
import { MessageContent, MessageContentComplex } from "../messages/index.js";

/**
* Inputs to create a {@link PromptTemplate}
Expand Down Expand Up @@ -43,6 +43,14 @@ export interface PromptTemplateInput<
* @defaultValue `true`
*/
validateTemplate?: boolean;

/**
* Additional fields which should be included inside
* the message content array if using a complex message
* content.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
additionalContentFields?: MessageContentComplex;
}

type NonAlphanumeric =
Expand Down Expand Up @@ -120,6 +128,14 @@ export class PromptTemplate<

validateTemplate = true;

/**
* Additional fields which should be included inside
* the message content array if using a complex message
* content.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
additionalContentFields?: MessageContentComplex;

constructor(input: PromptTemplateInput<RunInput, PartialVariableName>) {
super(input);
// If input is mustache and validateTemplate is not defined, set it to false
Expand Down Expand Up @@ -251,6 +267,7 @@ export class PromptTemplate<
names.add(node.name);
}
});

return new PromptTemplate({
// Rely on extracted types
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down
72 changes: 72 additions & 0 deletions langchain-core/src/prompts/tests/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,75 @@ test("Multi-modal, multi part chat prompt works with instances of BaseMessage",
});
expect(messages).toMatchSnapshot();
});

test("Format complex messages and keep additional fields", async () => {
const examplePrompt = ChatPromptTemplate.fromMessages([
[
"human",
[
{
type: "text",
text: "{input}",
cache_control: { type: "ephemeral" },
},
],
],
[
"ai",
[
{
type: "text",
text: "{output}",
cache_control: { type: "ephemeral" },
},
],
],
]);
const formatted = await examplePrompt.formatMessages({
input: "hello",
output: "ciao",
});

expect(formatted).toHaveLength(2);

expect(formatted[0]._getType()).toBe("human");
expect(formatted[0].content[0]).toHaveProperty("cache_control");
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((formatted[0].content[0] as any).cache_control).toEqual({
type: "ephemeral",
});

expect(formatted[1]._getType()).toBe("ai");
expect(formatted[1].content[0]).toHaveProperty("cache_control");
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((formatted[1].content[0] as any).cache_control).toEqual({
type: "ephemeral",
});
});

test("Format image content messages and keep additional fields", async () => {
const examplePrompt = ChatPromptTemplate.fromMessages([
[
"human",
[
{
type: "image_url",
image_url: "{image_url}",
cache_control: { type: "ephemeral" },
},
],
],
]);
const formatted = await examplePrompt.formatMessages({
image_url: "image_url",
});

expect(formatted).toHaveLength(1);

expect(formatted[0]._getType()).toBe("human");
expect(formatted[0].content[0]).toHaveProperty("cache_control");
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((formatted[0].content[0] as any).cache_control).toEqual({
type: "ephemeral",
});
});

0 comments on commit 3fc7125

Please sign in to comment.