Skip to content

Commit

Permalink
fix(js/genkit): correctly handle function prompt output schema (#1395)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Nov 26, 2024
1 parent 7656997 commit 2b85cb8
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
8 changes: 7 additions & 1 deletion js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ import { BaseEvalDataPointSchema } from './evaluator.js';
import { logger } from './logging.js';
import { GenkitPlugin, genkitPlugin } from './plugin.js';
import { Registry } from './registry.js';
import { toJsonSchema } from './schema.js';
import { toToolDefinition } from './tool.js';

/**
Expand Down Expand Up @@ -465,7 +466,12 @@ export class Genkit {
).map(toToolDefinition);
}
if (!response.output && options.output) {
response.output = options.output;
response.output = {
schema: toJsonSchema({
schema: options.output.schema,
jsonSchema: options.output.jsonSchema,
}),
};
}
return response;
}
Expand Down
45 changes: 45 additions & 0 deletions js/genkit/tests/prompts_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -715,13 +715,15 @@ describe('definePrompt', () => {

describe('prompt', () => {
let ai: Genkit;
let pm: ProgrammableModel;

beforeEach(() => {
ai = genkit({
model: 'echoModel',
promptDir: './tests/prompts',
});
defineEchoModel(ai);
pm = defineProgrammableModel(ai);
});

it('loads from from the folder', async () => {
Expand Down Expand Up @@ -771,6 +773,49 @@ describe('prompt', () => {

assert.strictEqual(text, 'Echo: hi banana; config: {"temperature":11}');
});

it('passes in output options to the model', async () => {
const hi = ai.definePrompt(
{
name: 'hi',
model: 'programmableModel',
input: {
schema: z.object({
name: z.string(),
}),
},
output: {
schema: z.object({
message: z.string(),
}),
format: 'json',
},
},
async (input) => {
return {
messages: [{ role: 'user', content: [{ text: `hi ${input.name}` }] }],
config: {
temperature: 11,
},
};
}
);

pm.handleResponse = async (req, sc) => {
return {
message: {
role: 'model',
content: [{ text: '```json\n{"message": "hello"}\n```' }],
},
};
};

const { output } = await hi({
name: 'Pavel',
});

assert.deepStrictEqual(output, { message: 'hello' });
});
});

describe('asTool', () => {
Expand Down

0 comments on commit 2b85cb8

Please sign in to comment.