From 2be2a43be981935b0c267f750a78422032fc2d69 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Wed, 27 Nov 2024 10:06:43 -0800 Subject: [PATCH 1/4] fix(js/ai): Fixes use of namespaced tools in model calls. --- js/ai/src/generate/action.ts | 20 ++++++++-- js/ai/src/tool.ts | 10 ++++- js/ai/tests/generate/generate_test.ts | 50 ++++++++++++++++++++++++- js/genkit/src/genkit.ts | 2 +- js/testapps/flow-simple-ai/src/index.ts | 50 +++++++++++++++++++++---- 5 files changed, 116 insertions(+), 16 deletions(-) diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 8a4d017531..6ef7d43a23 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -15,6 +15,7 @@ */ import { + GenkitError, getStreamingCallback, runWithStreamingCallback, z, @@ -116,6 +117,19 @@ async function generate( const tools = await resolveTools(registry, rawRequest.tools); const resolvedFormat = await resolveFormat(registry, rawRequest.output); + // Create a lookup of tool names with namespaces stripped to original names + const toolMap = tools.reduce>((acc, tool) => { + const name = tool.__action.name; + const shortName = name.substring(name.lastIndexOf('/') + 1); + if (acc[shortName]) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Cannot provide two tools with the same name: '${name}' and '${acc[shortName]}'`, + }); + } + acc[shortName] = tool; + return acc; + }, {}); const request = await actionToGenerateRequest( rawRequest, @@ -184,9 +198,7 @@ async function generate( 'Tool request expected but not provided in tool request part' ); } - const tool = tools?.find( - (tool) => tool.__action.name === part.toolRequest?.name - ); + const tool = toolMap[part.toolRequest?.name]; if (!tool) { throw Error(`Tool ${part.toolRequest?.name} not found`); } @@ -238,7 +250,7 @@ async function actionToGenerateRequest( messages: options.messages, config: options.config, docs: options.docs, - tools: resolvedTools?.map((tool) => toToolDefinition(tool)) || [], + tools: resolvedTools?.map((tool) => toToolDefinition(tool, true)) || [], output: { ...(resolvedFormat?.config || {}), schema: toJsonSchema({ diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index bfeb37efd8..4b5888673d 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -134,10 +134,16 @@ export async function lookupToolByName( * Converts a tool action to a definition of the tool to be passed to a model. */ export function toToolDefinition( - tool: Action + tool: Action, + stripNamespace = false ): ToolDefinition { + let name = tool.__action.name; + if (stripNamespace) { + name = name.substring(name.lastIndexOf('/') + 1); + } + return { - name: tool.__action.name, + name, description: tool.__action.description || '', outputSchema: toJsonSchema({ schema: tool.__action.outputSchema ?? z.void(), diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 9c98ae6cc2..82d56f424d 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { z } from '@genkit-ai/core'; +import { PluginProvider, z } from '@genkit-ai/core'; import { Registry } from '@genkit-ai/core/registry'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; @@ -43,6 +43,23 @@ describe('toGenerateRequest', () => { } ); + const namespacedPlugin: PluginProvider = { + name: 'namespaced', + initializer: async () => {}, + }; + registry.registerPluginProvider('namespaced', namespacedPlugin); + + const namespacedTool = defineTool( + registry, + { + name: 'namespaced/add', + description: 'add two numbers together', + inputSchema: z.object({ a: z.number(), b: z.number() }), + outputSchema: z.number(), + }, + async ({ a, b }) => a + b + ); + const testCases = [ { should: 'translate a string prompt correctly', @@ -95,6 +112,37 @@ describe('toGenerateRequest', () => { output: {}, }, }, + { + should: 'strip namespaces from tools when passing to the model', + prompt: { + model: 'vertexai/gemini-1.0-pro', + tools: ['namespaced/add'], + prompt: 'Add 10 and 5.', + }, + expectedOutput: { + messages: [{ role: 'user', content: [{ text: 'Add 10 and 5.' }] }], + config: undefined, + docs: undefined, + tools: [ + { + description: 'add two numbers together', + inputSchema: { + $schema: 'http://json-schema.org/draft-07/schema#', + additionalProperties: true, + properties: { a: { type: 'number' }, b: { type: 'number' } }, + required: ['a', 'b'], + type: 'object', + }, + name: 'namespaced/add', + outputSchema: { + $schema: 'http://json-schema.org/draft-07/schema#', + type: 'number', + }, + }, + ], + output: {}, + }, + }, { should: 'translate a string prompt correctly with tools referenced by their action', diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 9ca8eb2cf3..fa05208c4f 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -463,7 +463,7 @@ export class Genkit { if (!response.tools && options.tools) { response.tools = ( await resolveTools(this.registry, options.tools) - ).map(toToolDefinition); + ).map((t) => toToolDefinition(t)); } if (!response.output && options.output) { response.output = { diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 57ad3b4045..ecc27cfe71 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -28,6 +28,7 @@ import { initializeApp } from 'firebase-admin/app'; import { getFirestore } from 'firebase-admin/firestore'; import { MessageSchema, genkit, run, z } from 'genkit'; import { logger } from 'genkit/logging'; +import { PluginProvider } from 'genkit/plugin'; import { Allow, parse } from 'partial-json'; logger.setLogLevel('debug'); @@ -53,6 +54,32 @@ const ai = genkit({ plugins: [googleAI(), vertexAI()], }); +const math: PluginProvider = { + name: 'math', + initializer: async () => { + ai.defineTool( + { + name: 'math/add', + description: 'add two numbers', + inputSchema: z.object({ a: z.number(), b: z.number() }), + outputSchema: z.number(), + }, + async ({ a, b }) => a + b + ); + + ai.defineTool( + { + name: 'math/subtract', + description: 'subtract two numbers', + inputSchema: z.object({ a: z.number(), b: z.number() }), + outputSchema: z.number(), + }, + async ({ a, b }) => a - b + ); + }, +}; +ai.registry.registerPluginProvider('math', math); + const app = initializeApp(); export const jokeFlow = ai.defineFlow( @@ -538,11 +565,18 @@ export const arrayStreamTester = ai.defineStreamingFlow( } ); -// async function main() { -// const { stream, output } = arrayStreamTester(); -// for await (const chunk of stream) { -// console.log(chunk); -// } -// console.log(await output); -// } -// main(); +ai.defineFlow( + { + name: 'math', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async (query) => { + const { text } = await ai.generate({ + model: gemini15Flash, + prompt: query, + tools: ['math/add', 'math/subtract'], + }); + return text; + } +); From 133444b09115a61cc167750eaf20000fe3646216 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Wed, 27 Nov 2024 10:32:48 -0800 Subject: [PATCH 2/4] fix test --- js/ai/src/generate.ts | 2 +- js/ai/tests/generate/generate_test.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index cb49c6d08e..2fae430d65 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -129,7 +129,7 @@ export async function toGenerateRequest( messages: injectInstructions(messages, instructions), config: options.config, docs: options.docs, - tools: tools?.map((tool) => toToolDefinition(tool)) || [], + tools: tools?.map((tool) => toToolDefinition(tool, true)) || [], output: { ...(resolvedFormat?.config || {}), schema: resolvedSchema, diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 82d56f424d..75ff72661e 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -133,7 +133,7 @@ describe('toGenerateRequest', () => { required: ['a', 'b'], type: 'object', }, - name: 'namespaced/add', + name: 'add', outputSchema: { $schema: 'http://json-schema.org/draft-07/schema#', type: 'number', From 6d3b0a1dcb214ccbc72fda92e171e5ae3a8a5b7a Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Wed, 27 Nov 2024 10:50:34 -0800 Subject: [PATCH 3/4] Adds metadata to ToolDefinition and originalName annotation. --- genkit-tools/common/src/types/model.ts | 4 ++++ js/ai/src/generate.ts | 2 +- js/ai/src/generate/action.ts | 2 +- js/ai/src/model.ts | 4 ++++ js/ai/src/tool.ts | 23 ++++++++++++++++------- js/ai/tests/generate/generate_test.ts | 3 ++- 6 files changed, 28 insertions(+), 10 deletions(-) diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index 547652c2f8..5ce711c178 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -133,6 +133,10 @@ export const ToolDefinitionSchema = z.object({ .record(z.any()) .describe('Valid JSON Schema describing the output of the tool.') .optional(), + metadata: z + .record(z.any()) + .describe('additional metadata for this tool definition') + .optional(), }); export type ToolDefinition = z.infer; diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 2fae430d65..e6d1bb2afb 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -129,7 +129,7 @@ export async function toGenerateRequest( messages: injectInstructions(messages, instructions), config: options.config, docs: options.docs, - tools: tools?.map((tool) => toToolDefinition(tool, true)) || [], + tools: tools?.map(toToolDefinition) || [], output: { ...(resolvedFormat?.config || {}), schema: resolvedSchema, diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 6ef7d43a23..48d6b6e9bb 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -250,7 +250,7 @@ async function actionToGenerateRequest( messages: options.messages, config: options.config, docs: options.docs, - tools: resolvedTools?.map((tool) => toToolDefinition(tool, true)) || [], + tools: resolvedTools?.map(toToolDefinition) || [], output: { ...(resolvedFormat?.config || {}), schema: toJsonSchema({ diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 18609fd8d1..c0eb0b9c8b 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -166,6 +166,10 @@ export const ToolDefinitionSchema = z.object({ .record(z.any()) .describe('Valid JSON Schema describing the output of the tool.') .nullish(), + metadata: z + .record(z.any()) + .describe('additional metadata for this tool definition') + .optional(), }); export type ToolDefinition = z.infer; diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index 4b5888673d..a56f9ad593 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -109,7 +109,10 @@ export async function resolveTools< } else if (typeof (ref as ExecutablePrompt).asTool === 'function') { return await (ref as ExecutablePrompt).asTool(); } else if (ref.name) { - return await lookupToolByName(registry, ref.name); + return await lookupToolByName( + registry, + (ref as ToolDefinition).metadata?.originalName || ref.name + ); } throw new Error('Tools must be strings, tool definitions, or actions.'); }) @@ -134,15 +137,15 @@ export async function lookupToolByName( * Converts a tool action to a definition of the tool to be passed to a model. */ export function toToolDefinition( - tool: Action, - stripNamespace = false + tool: Action ): ToolDefinition { - let name = tool.__action.name; - if (stripNamespace) { - name = name.substring(name.lastIndexOf('/') + 1); + const originalName = tool.__action.name; + let name = originalName; + if (originalName.includes('/')) { + name = originalName.substring(originalName.lastIndexOf('/') + 1); } - return { + const out: ToolDefinition = { name, description: tool.__action.description || '', outputSchema: toJsonSchema({ @@ -154,6 +157,12 @@ export function toToolDefinition( jsonSchema: tool.__action.inputJsonSchema, })!, }; + + if (originalName !== name) { + out.metadata = { originalName }; + } + + return out; } /** diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 75ff72661e..72c3a565ac 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -49,7 +49,7 @@ describe('toGenerateRequest', () => { }; registry.registerPluginProvider('namespaced', namespacedPlugin); - const namespacedTool = defineTool( + defineTool( registry, { name: 'namespaced/add', @@ -138,6 +138,7 @@ describe('toGenerateRequest', () => { $schema: 'http://json-schema.org/draft-07/schema#', type: 'number', }, + metadata: { originalName: 'namespaced/add' }, }, ], output: {}, From d2106789792c96ade46295da459df511ec080184 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Wed, 27 Nov 2024 11:04:25 -0800 Subject: [PATCH 4/4] Export schemas. --- genkit-tools/genkit-schema.json | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 437fe51b00..e1115db3f2 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -839,6 +839,11 @@ "type": "object", "additionalProperties": {}, "description": "Valid JSON Schema describing the output of the tool." + }, + "metadata": { + "type": "object", + "additionalProperties": {}, + "description": "additional metadata for this tool definition" } }, "required": [