diff --git a/package.json b/package.json index a22b171..cf99bc4 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@oconva/qvikchat", - "version": "2.0.0-alpha.2", + "version": "2.0.0-alpha.3", "repository": { "type": "git", "url": "https://github.com/oconva/qvikchat.git" diff --git a/src/agents/chat-agent.ts b/src/agents/chat-agent.ts index 4deb7b9..4182895 100644 --- a/src/agents/chat-agent.ts +++ b/src/agents/chat-agent.ts @@ -17,6 +17,7 @@ import {MessageData} from '@genkit-ai/ai/model'; import {ToolArgument} from '@genkit-ai/ai/tool'; import {Dotprompt} from '@genkit-ai/dotprompt'; import {PromptOutputSchema} from '../prompts/prompts'; +import {dallE3} from 'genkitx-openai'; /** * Represents the type of chat agent. @@ -45,7 +46,6 @@ export type AgentTypeConfig = * @property systemPrompt - The system prompt for the chat agent. * @property chatPrompt - The chat prompt for the chat agent. * @property tools - Tools for the chat agent. - * @property model - The supported model to use for chat completion. * @property modelConfig - The model configuration. * @property responseOutputSchema - The output schema for the response. */ @@ -53,7 +53,6 @@ export type ChatAgentConfig = { systemPrompt?: Dotprompt; chatPrompt?: Dotprompt; tools?: ToolArgument[]; - model?: SupportedModels; modelConfig?: ModelConfig; responseOutputSchema?: OutputSchemaType; } & AgentTypeConfig; @@ -108,7 +107,6 @@ export type GenerateResponseHistoryProps = * @property enableChatHistory - Indicates whether to use chat history. * @property chatHistoryStore - The chat history store. * @property tools - The tool arguments. - * @property model - The supported model. * @property modelConfig - The model configuration. * @property systemPrompt - The system prompt. * @property chatPrompt - The chat prompt. @@ -118,7 +116,6 @@ export type GenerateResponseProps = { context?: string; chatId?: string; tools?: ToolArgument[]; - model?: SupportedModels; modelConfig?: ModelConfig; systemPrompt?: Dotprompt; chatPrompt?: Dotprompt; @@ -149,11 +146,6 @@ export interface ChatAgentMethods { generateResponse: ( props: GenerateResponseProps ) => Promise; - - /** - * Method to get model name that the chat agent is using. - */ - getModelName(): string; } /** @@ -166,7 +158,6 @@ export interface ChatAgentInterface export type GenerateSystemPromptResponseParams = { agentType?: ChatAgentType; prompt: Dotprompt; - model?: string; modelConfig?: ModelConfig; query?: string; context?: string; @@ -183,7 +174,6 @@ export class ChatAgent implements ChatAgentInterface { systemPrompt?: Dotprompt; chatPrompt?: Dotprompt; tools?: ToolArgument[]; - private modelName: string; modelConfig?: ModelConfig; responseOutputSchema?: OutputSchemaType; @@ -196,8 +186,7 @@ export class ChatAgent implements ChatAgentInterface { * @param enableChatHistory - Indicates whether to use chat history. * @param chatHistoryStore - The chat history store. * @param tools - Tools for the chat agent. - * @param model - The supported model. If not provided, will use the default model (e.g. Gemini 1.5 Flash). - * @param modelConfig - The model configuration. + * @param modelConfig - The model configuration. If not provided, will use the default model (e.g. Gemini 1.5 Flash). */ constructor(config: ChatAgentConfig = {}) { this.agentType = config.agentType ?? defaultChatAgentConfig.agentType; @@ -207,9 +196,6 @@ export class ChatAgent implements ChatAgentInterface { this.systemPrompt = config.systemPrompt; this.chatPrompt = config.chatPrompt; this.tools = config.tools; - this.modelName = config.model - ? SupportedModelNames[config.model] - : SupportedModelNames[defaultChatAgentConfig.model]; this.modelConfig = config.modelConfig; this.responseOutputSchema = config.responseOutputSchema; } @@ -323,7 +309,6 @@ export class ChatAgent implements ChatAgentInterface { private static generateSystemPromptResponse({ agentType, prompt, - model, modelConfig, query, context, @@ -333,8 +318,9 @@ export class ChatAgent implements ChatAgentInterface { // generate the response const res = prompt.generate({ // if undefined, will use model defined in the dotprompt - model: model, - config: modelConfig, + model: + SupportedModelNames[modelConfig?.name ?? defaultChatAgentConfig.model], + config: {...modelConfig}, input: ChatAgent.getFormattedInput({agentType, query, context, topic}), tools: tools, }); @@ -350,20 +336,17 @@ export class ChatAgent implements ChatAgentInterface { static getPromptOutputSchema( responseOutputSchema?: OutputSchemaType ): PromptOutputSchema { - if (!responseOutputSchema || responseOutputSchema.responseType === 'text') { + if (!responseOutputSchema || responseOutputSchema.format === 'text') { return {format: 'text'}; - } else if (responseOutputSchema.responseType === 'json') { + } else if (responseOutputSchema.format === 'json') { return { format: 'json', schema: responseOutputSchema.schema, - jsonSchema: responseOutputSchema.jsonSchema, }; - } else if (responseOutputSchema.responseType === 'media') { + } else if (responseOutputSchema.format === 'media') { return {format: 'media'}; } else { - throw new Error( - `Invalid response type ${responseOutputSchema.responseType}` - ); + throw new Error(`Invalid response type ${responseOutputSchema.format}`); } } @@ -382,6 +365,29 @@ export class ChatAgent implements ChatAgentInterface { async generateResponse( params: GenerateResponseProps ): Promise { + // if the model being used is Dall-E3 (e.g., for image generation) + // simply return the response + if ( + params.modelConfig?.name === 'dallE3' || // if model provided in params is Dall-E3 + (!params.modelConfig?.name && this.modelConfig?.name === 'dallE3') // if model not provided in params and default model is Dall-E3 + ) { + // configurations for Dall-E3 model + const dallEConfig = params.modelConfig ?? this.modelConfig; + // return response + return { + res: await generate({ + model: dallE3, + config: { + ...dallEConfig, + }, + prompt: params.query, + tools: params.tools, + output: { + format: 'media', + }, + }), + }; + } // System prompt to use // In order of priority: systemPrompt provided as argument to generateResponse, this.systemPrompt, default system prompt const prompt = @@ -400,9 +406,6 @@ export class ChatAgent implements ChatAgentInterface { res: await ChatAgent.generateSystemPromptResponse({ agentType: this.agentType, prompt, - model: params.model - ? SupportedModelNames[params.model] - : this.modelName, modelConfig: params.modelConfig ?? this.modelConfig, query: params.query, context: params.context, @@ -420,9 +423,6 @@ export class ChatAgent implements ChatAgentInterface { const res = await ChatAgent.generateSystemPromptResponse({ agentType: this.agentType, prompt, - model: params.model - ? SupportedModelNames[params.model] - : this.modelName, modelConfig: params.modelConfig ?? this.modelConfig, query: params.query, context: params.context, @@ -448,7 +448,16 @@ export class ChatAgent implements ChatAgentInterface { throw new Error(`No data found for chat ID ${params.chatId}.`); // generate response for given query (will use chat prompt and any provided chat history, context and tools) const res = await generate({ - model: params.model ?? this.modelName, + model: + SupportedModelNames[ + params.modelConfig?.name ?? + this.modelConfig?.name ?? + defaultChatAgentConfig.model + ], + config: { + ...this.modelConfig, + ...params.modelConfig, + }, prompt: params.query, history: chatHistory, context: params.context @@ -467,11 +476,4 @@ export class ChatAgent implements ChatAgentInterface { res, }; } - - /** - * Method to get model name that the chat agent is using. - */ - getModelName() { - return this.modelName; - } } diff --git a/src/endpoints/endpoints.ts b/src/endpoints/endpoints.ts index 258e6cf..831a8bf 100644 --- a/src/endpoints/endpoints.ts +++ b/src/endpoints/endpoints.ts @@ -23,12 +23,7 @@ import {getDataRetriever} from '../rag/data-retrievers/data-retrievers'; import {ChatHistoryStore} from '../history/chat-history-store'; import {Dotprompt} from '@genkit-ai/dotprompt'; import {ToolArgument} from '@genkit-ai/ai/tool'; -import { - ModelConfig, - OutputSchema, - OutputSchemaType, - SupportedModels, -} from '../models/models'; +import {ModelConfig, OutputSchema, OutputSchemaType} from '../models/models'; import {getSystemPromptText} from '../prompts/system-prompts'; type ChatHistoryParams = @@ -84,14 +79,6 @@ type ChatAgentTypeParams = topic: string; }; -type EndpointChatAgentConfig = { - systemPrompt?: Dotprompt; - chatPrompt?: Dotprompt; - tools?: ToolArgument[]; - model?: SupportedModels; - modelConfig?: ModelConfig; -}; - type VerboseDetails = { usage: GenerationUsage; request?: GenerateRequest; @@ -100,7 +87,10 @@ type VerboseDetails = { export type DefineChatEndpointConfig = { endpoint: string; - chatAgentConfig?: EndpointChatAgentConfig; + systemPrompt?: Dotprompt; + chatPrompt?: Dotprompt; + tools?: ToolArgument[]; + modelConfig?: ModelConfig; verbose?: boolean; outputSchema?: OutputSchemaType; } & ChatAgentTypeParams & @@ -142,10 +132,10 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => z.object({ response: !config.outputSchema || - !config.outputSchema.responseType || - config.outputSchema?.responseType === 'text' + !config.outputSchema.format || + config.outputSchema?.format === 'text' ? z.string() - : config.outputSchema.responseType === 'media' + : config.outputSchema.format === 'media' ? z.object({ contentType: z.string(), url: z.string(), @@ -196,8 +186,8 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => if (query === '') return {response: 'How can I help you today?'}; // set default response type - if (!config.outputSchema || !config.outputSchema.responseType) { - config.outputSchema = {responseType: 'text'}; + if (!config.outputSchema || !config.outputSchema.format) { + config.outputSchema = {format: 'text'}; } // set output schema @@ -206,9 +196,17 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => outputSchema = config.outputSchema; } - // store chat agent + // store chat agent (will be initialized based on the provided agent type or RAG) let chatAgent: ChatAgent; + // shared chat agent configurations + const sharedChatAgentConfig = { + systemPrompt: config.systemPrompt, + chatPrompt: config.chatPrompt, + tools: config.tools, + modelConfig: config.modelConfig, + }; + // Initialize chat agent based on the provided type if (!config.enableRAG) { // check if topic is provided @@ -226,12 +224,12 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => agentType: 'close-ended', topic: config.topic, responseOutputSchema: outputSchema, - ...config.chatAgentConfig, + ...sharedChatAgentConfig, }) : new ChatAgent({ agentType: 'open-ended', responseOutputSchema: outputSchema, - ...config.chatAgentConfig, + ...sharedChatAgentConfig, }); } // If RAG is enabled @@ -241,7 +239,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => agentType: 'rag', topic: config.topic, responseOutputSchema: outputSchema, - ...config.chatAgentConfig, + ...sharedChatAgentConfig, }); } @@ -304,7 +302,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => // also check if response type matches the expected response type if ( cachedQuery.response && - cachedQuery.responseType === outputSchema.responseType + cachedQuery.responseType === outputSchema.format ) { // increment cache hits config.cacheStore.incrementCacheHits(queryHash); @@ -314,7 +312,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => let cachedModelResponse: MessageData; // if expected response type is "text" and cached response type is "text" if ( - outputSchema.responseType === 'text' && + outputSchema.format === 'text' && cachedQuery.responseType === 'text' ) { cachedModelResponse = { @@ -324,7 +322,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => } // else if expected response type is "json" and cached response type is "json" else if ( - outputSchema.responseType === 'json' && + outputSchema.format === 'json' && cachedQuery.responseType === 'json' ) { cachedModelResponse = { @@ -334,7 +332,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => } // else if expected response type is "media" and cached response type is "media" else if ( - outputSchema.responseType === 'media' && + outputSchema.format === 'media' && cachedQuery.responseType === 'media' ) { cachedModelResponse = { @@ -426,7 +424,7 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => // remember to add the query with context config.cacheStore.addQuery( queryWithContext, - outputSchema.responseType ?? 'text', // default to text + outputSchema.format ?? 'text', // default to text queryHash ); } @@ -485,14 +483,14 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => // Not supported for media response type if (config.enableCache && config.cacheStore && cacheThresholdReached) { // cache response based on response type - if (outputSchema.responseType === 'json') { + if (outputSchema.format === 'json') { config.cacheStore.cacheResponse(queryHash, { responseType: 'json', response: JSON.stringify(response.res.output()), }); } // if media - else if (outputSchema.responseType === 'media') { + else if (outputSchema.format === 'media') { const mediaContent = response.res.media(); // if we have valid data if (mediaContent?.contentType && mediaContent?.url) { @@ -516,9 +514,9 @@ export const defineChatEndpoint = (config: DefineChatEndpointConfig) => // return response based on response type let res; - if (outputSchema.responseType === 'json') { + if (outputSchema.format === 'json') { res = response.res.output(); - } else if (outputSchema.responseType === 'media') { + } else if (outputSchema.format === 'media') { const mediaContent = response.res.media(); // if we have valid data if (mediaContent?.contentType && mediaContent?.url) { diff --git a/src/index.ts b/src/index.ts index d8c04c5..cb0ff5c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,4 +1 @@ -console.warn( - `[WARNING]: The root QvikChat entrypoint is empty. Please use a specific entrypoint instead.` -); export {}; diff --git a/src/models/models.ts b/src/models/models.ts index e58643e..db23030 100644 --- a/src/models/models.ts +++ b/src/models/models.ts @@ -12,24 +12,50 @@ import { gpt4Vision, dallE3, } from 'genkitx-openai'; +import {DallE3ConfigSchema} from 'genkitx-openai/lib/dalle'; +import {OpenAiConfigSchema} from 'genkitx-openai/lib/gpt'; import {z} from 'zod'; /** - * Names of supported models. + * List of supported Gemini models. */ -export const SupportedModelNames = { +export const GeminiModels = { gemini10Pro: geminiPro.name, gemini15Pro: gemini15Pro.name, gemini15Flash: gemini15Flash.name, geminiProVision: geminiProVision.name, +} as const; + +/** + * List of supported OpenAI models. + */ +export const OpenAIModels = { gpt35Turbo: gpt35Turbo.name, gpt4o: gpt4o.name, gpt4Turbo: gpt4Turbo.name, gpt4Vision: gpt4Vision.name, gpt4: gpt4.name, +} as const; + +/** + * List of supported DALL-E models. + */ +export const DallEModels = { dallE3: dallE3.name, } as const; +/** + * List of all supported model names. + */ +export const SupportedModelNames = { + ...GeminiModels, + ...OpenAIModels, + ...DallEModels, +}; + +/** + * Get names of all supported models. + */ export const getSupportedModelNames = () => Object.values(SupportedModelNames); /** @@ -40,7 +66,7 @@ export type SupportedModels = keyof typeof SupportedModelNames; /** * Supported configuration options for a model */ -export type ModelConfig = { +export type GeminiModelConfig = { version?: string | undefined; temperature?: number | undefined; maxOutputTokens?: number | undefined; @@ -64,20 +90,33 @@ export type ModelConfig = { | undefined; }; +/** + * Configuration options for a model. + */ +export type ModelConfig = + | ({ + name: keyof typeof GeminiModels; + } & GeminiModelConfig) + | ({ + name: keyof typeof OpenAIModels; + } & z.infer) + | ({ + name: keyof typeof DallEModels; + } & z.infer); + /** * Output schema for model responses. */ export const OutputSchema = z.union([ z.object({ - responseType: z.literal('text').optional(), + format: z.literal('text').optional(), }), z.object({ - responseType: z.literal('json').optional(), - schema: z.any().optional(), - jsonSchema: z.any().optional(), + format: z.literal('json').optional(), + schema: z.custom().optional(), }), z.object({ - responseType: z.literal('media').optional(), + format: z.literal('media').optional(), contentType: z.string(), }), ]); diff --git a/src/tests/unit-tests/endpoint-agent.unit.test.ts b/src/tests/unit-tests/endpoint-agent.unit.test.ts index 402a825..2b783b2 100644 --- a/src/tests/unit-tests/endpoint-agent.unit.test.ts +++ b/src/tests/unit-tests/endpoint-agent.unit.test.ts @@ -1,3 +1,4 @@ +import {z} from 'zod'; import { defineChatEndpoint, getChatEndpointRunner, @@ -21,7 +22,11 @@ describe('Test - Chat Endpoint Agent Config Tests', () => { // Set to true to run the test const Tests = { define_chat_endpoint: true, + define_chat_endpoint_with_dall_e_model: true, + define_chat_endpoint_with_json_output_schema: true, confirm_response_generation: true, + confirm_json_response_generation: true, + confirm_dall_e_image_generation: true, }; // default test timeout @@ -31,8 +36,33 @@ describe('Test - Chat Endpoint Agent Config Tests', () => { test('Define chat endpoint with chat agent config', () => { const endpoint = defineChatEndpoint({ endpoint: 'test-chat-agent', - chatAgentConfig: { - model: 'gemini10Pro', + modelConfig: { + name: 'gemini15Flash', + }, + }); + expect(endpoint).toBeDefined(); + }); + + if (Tests.define_chat_endpoint_with_dall_e_model) + test('Define chat endpoint with chat agent config and DALL-E model', () => { + const endpoint = defineChatEndpoint({ + endpoint: 'test-chat-agent-dall-e', + modelConfig: { + name: 'dallE3', + }, + }); + expect(endpoint).toBeDefined(); + }); + + if (Tests.define_chat_endpoint_with_json_output_schema) + test('Define chat endpoint with chat agent config and json output schema', () => { + const endpoint = defineChatEndpoint({ + endpoint: 'test-chat-agent-json-output-schema', + outputSchema: { + format: 'json', + schema: z.object({ + answer: z.string(), + }), }, }); expect(endpoint).toBeDefined(); @@ -44,12 +74,6 @@ describe('Test - Chat Endpoint Agent Config Tests', () => { async () => { const endpoint = defineChatEndpoint({ endpoint: 'test-chat-agent-response', - chatAgentConfig: { - model: 'gemini10Pro', - modelConfig: { - temperature: 0.9, - }, - }, }); const response = await runEndpoint(endpoint, { query: @@ -79,4 +103,101 @@ describe('Test - Chat Endpoint Agent Config Tests', () => { }, defaultTimeout ); + + if (Tests.confirm_json_response_generation) + test( + 'Confirm JSON response generation for endpoint with chat agent config', + async () => { + const endpoint = defineChatEndpoint({ + endpoint: 'test-chat-agent-json-response', + outputSchema: { + format: 'json', + schema: z.object({ + answer: z.string(), + }), + }, + }); + const response = await runEndpoint(endpoint, { + query: + 'Answer in one sentence: What is Firebase Firestore? Must contain the word "Firestore" in your response.', + }); + expect(response).toBeDefined(); + + // should not contain error + if ('error' in response) { + throw new Error( + `Error in response. Response: ${JSON.stringify(response)}` + ); + } + + // should not be empty + expect(response.response).toBeDefined(); + + // parse response + const parsedResponse = response.response as { + answer: string; + }; + + // should have answer field + if (!('answer' in parsedResponse)) { + throw new Error( + `Invalid response object. Missing 'answer' field. Response: ${JSON.stringify(response)}` + ); + } + + // response should contain the word "Firestore" + expect(parsedResponse.answer.toLowerCase()).toContain('firestore'); + }, + defaultTimeout + ); + + if (Tests.confirm_dall_e_image_generation) + test.skip( + 'Confirm DALL-E image generation for endpoint with chat agent config', + async () => { + const endpoint = defineChatEndpoint({ + endpoint: 'test-chat-agent-dall-e-image', + modelConfig: { + name: 'dallE3', + response_format: 'url', + }, + outputSchema: { + format: 'media', + contentType: 'image/png', + }, + }); + const response = await runEndpoint(endpoint, { + query: 'Generate an image of a cat.', + }); + expect(response).toBeDefined(); + + // should not contain error + if ('error' in response) { + throw new Error( + `Error in response. Response: ${JSON.stringify(response)}` + ); + } + + // should not be empty + expect(response.response).toBeDefined(); + + // parse response + const parsedResponse = response.response as { + url: string; + }; + + // should have URL field + if (!('url' in parsedResponse)) { + throw new Error( + `Invalid response object. Missing 'url' field. Response: ${JSON.stringify(response)}` + ); + } + + // should have URL + expect(parsedResponse.url).toBeDefined(); + // valid URL + expect(parsedResponse.url).toMatch(/^(http|https):\/\/[^ "]+$/); + }, + defaultTimeout * 30 + ); });