diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 5031cfb64..53523106a 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -16,6 +16,7 @@ import { Action, + config as genkitConfig, GenkitError, runWithStreamingCallback, StreamingCallback, @@ -431,7 +432,7 @@ export interface GenerateOptions< CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, > { /** A model name (e.g. `vertexai/gemini-1.0-pro`) or reference. */ - model: ModelArgument; + model?: ModelArgument; /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ prompt: string | Part | Part[]; /** Retrieved documents to be used as context for this generation. */ @@ -479,9 +480,25 @@ const isValidCandidate = ( }); }; -async function resolveModel( - model: ModelAction | ModelReference | string -): Promise { +async function resolveModel(options: GenerateOptions): Promise { + let model = options.model; + if (!model) { + if (genkitConfig?.options?.defaultModel) { + model = + typeof genkitConfig.options.defaultModel.name === 'string' + ? genkitConfig.options.defaultModel.name + : genkitConfig.options.defaultModel.name.name; + if ( + (!options.config || Object.keys(options.config).length === 0) && + genkitConfig.options.defaultModel.config + ) { + // use configured global config + options.config = genkitConfig.options.defaultModel.config; + } + } else { + throw new Error('Unable to resolve model.'); + } + } if (typeof model === 'string') { return (await lookupAction(`/model/${model}`)) as ModelAction; } else if (model.hasOwnProperty('info')) { @@ -537,7 +554,7 @@ export async function generate< ): Promise>> { const resolvedOptions: GenerateOptions = await Promise.resolve(options); - const model = await resolveModel(resolvedOptions.model); + const model = await resolveModel(resolvedOptions); if (!model) { throw new Error(`Model ${JSON.stringify(resolvedOptions.model)} not found`); } diff --git a/js/core/src/config.ts b/js/core/src/config.ts index 1381af8bf..65391d374 100644 --- a/js/core/src/config.ts +++ b/js/core/src/config.ts @@ -42,6 +42,10 @@ export interface ConfigOptions { logLevel?: 'error' | 'warn' | 'info' | 'debug'; promptDir?: string; telemetry?: TelemetryOptions; + defaultModel?: { + name: string | { name: string }; + config?: Record; + }; } class Config { diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index 12ff7600f..b0fc9c4a0 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -161,14 +161,6 @@ export class Dotprompt implements PromptMetadata { private _generateOptions( options: PromptGenerateOptions ): GenerateOptions { - if (!options.model && !this.model) { - throw new GenkitError({ - source: 'Dotprompt', - message: 'Must supply `model` in prompt metadata or generate options.', - status: 'INVALID_ARGUMENT', - }); - } - const messages = this.renderMessages(options.input); return { model: options.model || this.model!, diff --git a/js/samples/rag/src/index.ts b/js/samples/rag/src/index.ts index ffc5b9882..890756b22 100644 --- a/js/samples/rag/src/index.ts +++ b/js/samples/rag/src/index.ts @@ -77,6 +77,12 @@ export default configureGenkit({ }, ]), ], + defaultModel: { + name: geminiPro, + config: { + temperature: 0.6, + }, + }, flowStateStore: 'firebase', traceStore: 'firebase', enableTracingAndMetrics: true, diff --git a/js/samples/rag/src/prompt.ts b/js/samples/rag/src/prompt.ts index fec34b3ab..c77c28130 100644 --- a/js/samples/rag/src/prompt.ts +++ b/js/samples/rag/src/prompt.ts @@ -15,7 +15,6 @@ */ import { defineDotprompt } from '@genkit-ai/dotprompt'; -import { geminiPro } from '@genkit-ai/vertexai'; import * as z from 'zod'; // Define a prompt that includes the retrieved context documents @@ -23,7 +22,6 @@ import * as z from 'zod'; export const augmentedPrompt = defineDotprompt( { name: 'augmentedPrompt', - model: geminiPro, input: z.object({ context: z.array(z.string()), question: z.string(),