Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(js): propagate context to sub actions, expose context in prompts #1663

Merged
merged 9 commits into from
Jan 27, 2025
Merged
6 changes: 3 additions & 3 deletions docs/auth.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ When running with the Genkit Development UI, you can pass the Auth object by
entering JSON in the "Auth JSON" tab: `{"uid": "abc-def"}`.

You can also retrieve the auth context for the flow at any time within the flow
by calling `getFlowAuth()`, including in functions invoked by the flow:
by calling `ai.currentContext()`, including in functions invoked by the flow:

```ts
import { genkit, z } from 'genkit';

const ai = genkit({ ... });;

async function readDatabase(uid: string) {
const auth = ai.getAuthContext();
const auth = ai.currentContext()?.auth;
if (auth?.admin) {
// Do something special if the user is an admin
} else {
Expand Down Expand Up @@ -153,7 +153,7 @@ export const selfSummaryFlow = onFlow(

When using the Firebase Auth plugin, `user` will be returned as a
[DecodedIdToken](https://firebase.google.com/docs/reference/admin/node/firebase-admin.auth.decodedidtoken).
You can always retrieve this object at any time via `getFlowAuth()` as noted
You can always retrieve this object at any time via `ai.currentContext()` as noted
above. When running this flow during development, you would pass the user object
in the same way:

Expand Down
22 changes: 16 additions & 6 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

import {
Action,
ActionContext,
GenkitError,
StreamingCallback,
runWithContext,
runWithStreamingCallback,
sentinelNoopStreamingCallback,
z,
Expand Down Expand Up @@ -113,8 +115,8 @@ export interface GenerateOptions<
* const interrupt = response.interrupts[0];
*
* const resumedResponse = await ai.generate({
* messages: response.messages,
* resume: myInterrupt.reply(interrupt, {note: "this is the reply data"}),
* messages: response.messages,
* resume: myInterrupt.reply(interrupt, {note: "this is the reply data"}),
* });
* ```
*/
Expand All @@ -133,6 +135,8 @@ export interface GenerateOptions<
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
/** Middleware to be used with this model call. */
use?: ModelMiddleware[];
/** Additional context (data, like e.g. auth) to be passed down to tools, prompts and other sub actions. */
context?: ActionContext;
}

function applyResumeOption(
Expand Down Expand Up @@ -376,10 +380,16 @@ export async function generate<
registry,
stripNoop(resolvedOptions.onChunk ?? resolvedOptions.streamingCallback),
async () => {
const response = await generateHelper(registry, {
rawRequest: params,
middleware: resolvedOptions.use,
});
const generateFn = () =>
generateHelper(registry, {
rawRequest: params,
middleware: resolvedOptions.use,
});
const response = await runWithContext(
registry,
resolvedOptions.context,
generateFn
);
const request = await toGenerateRequest(registry, {
...resolvedOptions,
tools,
Expand Down
66 changes: 54 additions & 12 deletions js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import {
Action,
ActionAsyncParams,
ActionContext,
defineActionAsync,
GenkitError,
getContext,
JSONSchema7,
stripUndefinedProps,
z,
Expand Down Expand Up @@ -117,6 +119,7 @@ export interface PromptConfig<
tools?: ToolArgument[];
toolChoice?: ToolChoice;
use?: ModelMiddleware[];
context?: ActionContext;
}

/**
Expand Down Expand Up @@ -179,6 +182,7 @@ export type PartsResolver<I, S = any> = (
input: I,
options: {
state?: S;
context: ActionContext;
}
) => Part[] | Promise<string | Part | Part[]>;

Expand All @@ -187,12 +191,14 @@ export type MessagesResolver<I, S = any> = (
options: {
history?: MessageData[];
state?: S;
context: ActionContext;
}
) => MessageData[] | Promise<MessageData[]>;

export type DocsResolver<I, S = any> = (
input: I,
options: {
context: ActionContext;
state?: S;
}
) => DocumentData[] | Promise<DocumentData[]>;
Expand Down Expand Up @@ -250,7 +256,8 @@ function definePromptAsync<
input,
messages,
resolvedOptions,
promptCache
promptCache,
renderOptions
);
await renderMessages(
registry,
Expand All @@ -267,13 +274,15 @@ function definePromptAsync<
input,
messages,
resolvedOptions,
promptCache
promptCache,
renderOptions
);

let docs: DocumentData[] | undefined;
if (typeof resolvedOptions.docs === 'function') {
docs = await resolvedOptions.docs(input, {
state: session?.state,
context: renderOptions?.context || getContext(registry) || {},
});
} else {
docs = resolvedOptions.docs;
Expand All @@ -287,6 +296,7 @@ function definePromptAsync<
tools: resolvedOptions.tools,
returnToolRequests: resolvedOptions.returnToolRequests,
toolChoice: resolvedOptions.toolChoice,
context: resolvedOptions.context,
output: resolvedOptions.output,
use: resolvedOptions.use,
...stripUndefinedProps(renderOptions),
Expand Down Expand Up @@ -442,13 +452,17 @@ async function renderSystemPrompt<
input: z.infer<I>,
messages: MessageData[],
options: PromptConfig<I, O, CustomOptions>,
promptCache: PromptCache
promptCache: PromptCache,
renderOptions: PromptGenerateOptions<O, CustomOptions> | undefined
) {
if (typeof options.system === 'function') {
messages.push({
role: 'system',
content: normalizeParts(
await options.system(input, { state: session?.state })
await options.system(input, {
state: session?.state,
context: renderOptions?.context || getContext(registry) || {},
})
),
});
} else if (typeof options.system === 'string') {
Expand All @@ -458,7 +472,14 @@ async function renderSystemPrompt<
}
messages.push({
role: 'system',
content: await renderDotpromptToParts(promptCache.system, input, session),
content: await renderDotpromptToParts(
registry,
promptCache.system,
input,
session,
options,
renderOptions
),
});
} else if (options.system) {
messages.push({
Expand Down Expand Up @@ -486,6 +507,7 @@ async function renderMessages<
messages.push(
...(await options.messages(input, {
state: session?.state,
context: renderOptions?.context || getContext(registry) || {},
history: renderOptions?.messages,
}))
);
Expand All @@ -498,7 +520,10 @@ async function renderMessages<
}
const rendered = await promptCache.messages({
input,
context: { state: session?.state },
context: {
...(renderOptions?.context || getContext(registry)),
state: session?.state,
},
messages: renderOptions?.messages?.map((m) =>
Message.parseData(m)
) as DpMessage[],
Expand Down Expand Up @@ -528,13 +553,17 @@ async function renderUserPrompt<
input: z.infer<I>,
messages: MessageData[],
options: PromptConfig<I, O, CustomOptions>,
promptCache: PromptCache
promptCache: PromptCache,
renderOptions: PromptGenerateOptions<O, CustomOptions> | undefined
) {
if (typeof options.prompt === 'function') {
messages.push({
role: 'user',
content: normalizeParts(
await options.prompt(input, { state: session?.state })
await options.prompt(input, {
state: session?.state,
context: renderOptions?.context || getContext(registry) || {},
})
),
});
} else if (typeof options.prompt === 'string') {
Expand All @@ -545,9 +574,12 @@ async function renderUserPrompt<
messages.push({
role: 'user',
content: await renderDotpromptToParts(
registry,
promptCache.userPrompt,
input,
session
session,
options,
renderOptions
),
});
} else if (options.prompt) {
Expand Down Expand Up @@ -585,14 +617,24 @@ function normalizeParts(parts: string | Part | Part[]): Part[] {
return [parts as Part];
}

async function renderDotpromptToParts(
async function renderDotpromptToParts<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
promptFn: PromptFunction,
input: any,
session?: Session
session: Session | undefined,
options: PromptConfig<I, O, CustomOptions>,
renderOptions: PromptGenerateOptions<O, CustomOptions> | undefined
): Promise<Part[]> {
const renderred = await promptFn({
input,
context: { state: session?.state },
context: {
...(renderOptions?.context || getContext(registry)),
state: session?.state,
},
});
if (renderred.messages.length !== 1) {
throw new Error('parts tempate must produce only one message');
Expand Down
8 changes: 7 additions & 1 deletion js/ai/src/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import {
Action,
ActionContext,
defineAction,
JSONSchema7,
stripUndefinedProps,
Expand Down Expand Up @@ -188,6 +189,8 @@ export interface ToolFnOptions {
* getting interrupted (immediately) and tool request returned to the upstream caller.
*/
interrupt: (metadata?: Record<string, any>) => never;

context: ActionContext;
}

export type ToolFn<I extends z.ZodTypeAny, O extends z.ZodTypeAny> = (
Expand All @@ -212,9 +215,12 @@ export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
actionType: 'tool',
metadata: { ...(config.metadata || {}), type: 'tool' },
},
(i) =>
(i, { context }) =>
fn(i, {
interrupt: interruptTool,
context: {
...context,
},
})
);
(a as ToolAction<I, O>).reply = (interrupt, replyData, options) => {
Expand Down
Loading
Loading