Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(js/ai): refactored constrained generation into middleware, simplified json format #1612

Merged
merged 10 commits into from
Jan 27, 2025
6 changes: 5 additions & 1 deletion genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
@@ -115,9 +115,13 @@ export const ModelInfoSchema = z.object({
/** Model can accept messages with role "system". */
systemRole: z.boolean().optional(),
/** Model can output this type of data. */
output: z.array(OutputFormatSchema).optional(),
output: z.array(z.string()).optional(),
/** Model supports output in these content types. */
contentType: z.array(z.string()).optional(),
/** Model can natively support document-based context grounding. */
context: z.boolean().optional(),
/** Model can natively support constrained generation. */
constrained: z.boolean().optional(),
})
.optional(),
});
11 changes: 10 additions & 1 deletion genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
@@ -679,11 +679,20 @@
"output": {
"type": "array",
"items": {
"$ref": "#/$defs/GenerateRequest/properties/output/properties/format"
"type": "string"
}
},
"contentType": {
"type": "array",
"items": {
"type": "string"
}
},
"context": {
"type": "boolean"
},
"constrained": {
"type": "boolean"
}
},
"additionalProperties": false
15 changes: 1 addition & 14 deletions js/ai/src/formats/json.ts
Original file line number Diff line number Diff line change
@@ -24,18 +24,7 @@ export const jsonFormatter: Formatter<unknown, unknown> = {
contentType: 'application/json',
constrained: true,
},
handler: (schema) => {
let instructions: string | undefined;

if (schema) {
instructions = `Output should be in JSON format and conform to the following schema:

\`\`\`
${JSON.stringify(schema)}
\`\`\`
`;
}

handler: () => {
return {
parseChunk: (chunk) => {
return extractJson(chunk.accumulatedText);
@@ -44,8 +33,6 @@ ${JSON.stringify(schema)}
parseMessage: (message) => {
return extractJson(message.text);
},

instructions,
};
},
};
11 changes: 1 addition & 10 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
@@ -346,21 +346,12 @@ export async function generate<
jsonSchema: resolvedOptions.output?.jsonSchema,
});

// If is schema is set but format is not explicitly set, default to `json` format.
if (resolvedOptions.output?.schema && !resolvedOptions.output?.format) {
resolvedOptions.output.format = 'json';
}
const resolvedFormat = await resolveFormat(registry, resolvedOptions.output);
const instructions = resolveInstructions(
resolvedFormat,
resolvedSchema,
resolvedOptions?.output?.instructions
);

const params: z.infer<typeof GenerateUtilParamSchema> = {
model: resolvedModel.modelAction.__action.name,
docs: resolvedOptions.docs,
messages: injectInstructions(messages, instructions),
messages: messages,
tools,
toolChoice: resolvedOptions.toolChoice,
config: {
35 changes: 34 additions & 1 deletion js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
@@ -26,7 +26,11 @@ import { toJsonSchema } from '@genkit-ai/core/schema';
import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing';
import * as clc from 'colorette';
import { DocumentDataSchema } from '../document.js';
import { resolveFormat } from '../formats/index.js';
import {
injectInstructions,
resolveFormat,
resolveInstructions,
} from '../formats/index.js';
import { Formatter } from '../formats/types.js';
import {
GenerateResponse,
@@ -148,10 +152,39 @@ async function generate(

const tools = await resolveTools(registry, options.rawRequest.tools);

const resolvedSchema = toJsonSchema({
jsonSchema: options.rawRequest.output?.jsonSchema,
});

// If is schema is set but format is not explicitly set, default to `json` format.
if (
options.rawRequest.output?.jsonSchema &&
!options.rawRequest.output?.format
) {
options.rawRequest.output.format = 'json';
}
const resolvedFormat = await resolveFormat(
registry,
options.rawRequest.output
);
const instructions = resolveInstructions(
resolvedFormat,
resolvedSchema,
options.rawRequest?.output?.instructions
);
if (resolvedFormat) {
options.rawRequest.messages = injectInstructions(
options.rawRequest.messages,
instructions
);
options.rawRequest.output = {
// use output config from the format
...resolvedFormat.config,
// if anything is set explicitly, use that
...options.rawRequest.output,
};
}

// Create a lookup of tool names with namespaces stripped to original names
const toolMap = tools.reduce<Record<string, ToolAction>>((acc, tool) => {
const name = tool.__action.name;
12 changes: 10 additions & 2 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
@@ -27,7 +27,12 @@ import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { performance } from 'node:perf_hooks';
import { DocumentDataSchema } from './document.js';
import { augmentWithContext, validateSupport } from './model/middleware.js';
import {
augmentWithContext,
simulateConstrainedGeneration,
validateSupport,
} from './model/middleware.js';
export { simulateConstrainedGeneration };

//
// IMPORTANT: Please keep type definitions in sync with
@@ -204,6 +209,8 @@ export const ModelInfoSchema = z.object({
contentType: z.array(z.string()).optional(),
/** Model can natively support document-based context grounding. */
context: z.boolean().optional(),
/** Model can natively support constrained generation. */
constrained: z.boolean().optional(),
/** Model supports controlling tool choice, e.g. forced tool calling. */
toolChoice: z.boolean().optional(),
})
@@ -478,7 +485,8 @@ export function defineModel<
validateSupport(options),
];
if (!options?.supports?.context) middleware.push(augmentWithContext());
// middleware.push(conformOutput(registry));
if (!options?.supports?.constrained)
middleware.push(simulateConstrainedGeneration());
const act = defineAction(
registry,
{
46 changes: 45 additions & 1 deletion js/ai/src/model/middleware.ts
Original file line number Diff line number Diff line change
@@ -15,14 +15,14 @@
*/

import { Document } from '../document.js';
import { injectInstructions } from '../formats/index.js';
import type {
MediaPart,
MessageData,
ModelInfo,
ModelMiddleware,
Part,
} from '../model.js';

/**
* Preprocess a GenerateRequest to download referenced http(s) media URLs and
* inline them as data URIs.
@@ -234,3 +234,47 @@ export function augmentWithContext(
return next(req);
};
}

export interface SimulatedConstrainedGenerationOptions {
instructionsRenderer?: (schema: Record<string, any>) => string;
}

const DEFAULT_CONSTRAINED_GENERATION_INSTRUSCTIONS = (
schema: Record<string, any>
) => `Output should be in JSON format and conform to the following schema:

\`\`\`
${JSON.stringify(schema)}
\`\`\`
`;

/**
* Model middleware that simulates constrained generation by injecting generation
* instructions into the user message.
*/
export function simulateConstrainedGeneration(
options?: SimulatedConstrainedGenerationOptions
): ModelMiddleware {
return (req, next) => {
let instructions: string | undefined;
if (req.output?.constrained && req.output?.schema) {
instructions = (
options?.instructionsRenderer ??
DEFAULT_CONSTRAINED_GENERATION_INSTRUSCTIONS
)(req.output?.schema);

req = {
...req,
messages: injectInstructions(req.messages, instructions),
output: {
...req.output,
// we're simulating it, so to the underlying model it's unconstrained.
constrained: false,
schema: undefined,
},
};
}

return next(req);
};
}
69 changes: 66 additions & 3 deletions js/ai/tests/formats/json_test.ts
Original file line number Diff line number Diff line change
@@ -14,14 +14,30 @@
* limitations under the License.
*/

import * as assert from 'assert';
import { describe, it } from 'node:test';
import { z } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { configureFormats } from '../../src/formats/index.js';
import { jsonFormatter } from '../../src/formats/json.js';
import { GenerateResponseChunk } from '../../src/generate.js';
import { GenerateResponseChunk, generateStream } from '../../src/generate.js';
import { Message } from '../../src/message.js';
import { GenerateResponseChunkData, MessageData } from '../../src/model.js';
import {
ProgrammableModel,
defineProgrammableModel,
runAsync,
} from '../helpers.js';

describe('jsonFormat', () => {
let registry: Registry;
let pm: ProgrammableModel;

beforeEach(() => {
registry = new Registry();
pm = defineProgrammableModel(registry);
});

const streamingTests = [
{
desc: 'parses complete JSON object',
@@ -123,3 +139,50 @@ describe('jsonFormat', () => {
});
}
});

describe('jsonFormat e2e', () => {
let registry: Registry;

beforeEach(() => {
registry = new Registry();
configureFormats(registry);
});

it('injects the instructions into the request', async () => {
let pm = defineProgrammableModel(registry);
pm.handleResponse = async (req, sc) => {
await runAsync(() => sc?.({ content: [{ text: '```\n{' }] }));
await runAsync(() => sc?.({ content: [{ text: '"foo": "b' }] }));
await runAsync(() => sc?.({ content: [{ text: 'ar"' }] }));
await runAsync(() => sc?.({ content: [{ text: '}\n```"' }] }));
return await runAsync(() => ({
message: {
role: 'model',
content: [{ text: '```\n{"foo": "bar"}\n```' }],
},
}));
};

const { response, stream } = await generateStream(registry, {
model: 'programmableModel',
prompt: 'generate json',
output: {
format: 'json',
schema: z.object({
foo: z.string(),
}),
},
});
const chunks: any = [];
for await (const chunk of stream) {
chunks.push(chunk.output);
}
assert.deepEqual((await response).output, { foo: 'bar' });
assert.deepStrictEqual(chunks, [
{},
{ foo: 'b' },
{ foo: 'bar' },
{ foo: 'bar' },
]);
});
});
60 changes: 60 additions & 0 deletions js/ai/tests/helpers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { StreamingCallback } from '@genkit-ai/core';
import { Registry } from '@genkit-ai/core/registry';
import {
GenerateRequest,
GenerateResponseChunkData,
GenerateResponseData,
ModelAction,
ModelInfo,
defineModel,
} from '../src/model';

export async function runAsync<O>(fn: () => O): Promise<O> {
return new Promise((resolve) => {
setTimeout(() => resolve(fn()), 0);
});
}

export type ProgrammableModel = ModelAction & {
handleResponse: (
req: GenerateRequest,
streamingCallback?: StreamingCallback<GenerateResponseChunkData>
) => Promise<GenerateResponseData>;

lastRequest?: GenerateRequest;
};

export function defineProgrammableModel(
registry: Registry,
info?: ModelInfo
): ProgrammableModel {
const pm = defineModel(
registry,
{
...info,
name: 'programmableModel',
},
async (request, streamingCallback) => {
pm.lastRequest = JSON.parse(JSON.stringify(request));
return pm.handleResponse(request, streamingCallback);
}
) as ProgrammableModel;

return pm;
}
Loading