diff --git a/package-lock.json b/package-lock.json index fa13589..a6320fe 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@oconva/qvikchat", - "version": "1.0.8", + "version": "1.0.9", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@oconva/qvikchat", - "version": "1.0.8", + "version": "1.0.9", "license": "MIT", "dependencies": { "@genkit-ai/ai": "^0.5.4", diff --git a/package.json b/package.json index 3eeed2b..29f4485 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@oconva/qvikchat", - "version": "1.0.8", + "version": "1.0.9", "repository": { "type": "git", "url": "https://github.com/oconva/qvikchat.git" diff --git a/src/rag/data-loaders/data-loaders.ts b/src/rag/data-loaders/data-loaders.ts index 5579365..e2ba99d 100644 --- a/src/rag/data-loaders/data-loaders.ts +++ b/src/rag/data-loaders/data-loaders.ts @@ -120,13 +120,19 @@ export function validateDataType(filePath: string): * For more data loaders, see: * @link https://js.langchain.com/v0.1/docs/integrations/document_loaders/ */ -export const getDocs = async ( - dataLoaderType: SupportedDataLoaderTypes, - path: string, - jsonLoaderKeysToInclude?: JSONLoaderKeysToInclude, - csvLoaderOptions?: CSVLoaderOptions, - pdfLoaderOptions?: PDFLoaderOptions -): Promise>[]> => { +export const getDocs = async ({ + dataLoaderType, + path, + jsonLoaderKeysToInclude, + csvLoaderOptions, + pdfLoaderOptions, +}: { + dataLoaderType: SupportedDataLoaderTypes; + path: string; + jsonLoaderKeysToInclude?: JSONLoaderKeysToInclude; + csvLoaderOptions?: CSVLoaderOptions; + pdfLoaderOptions?: PDFLoaderOptions; +}): Promise>[]> => { // store loader let loader; // infer loader to use based on dataLoaderType diff --git a/src/rag/data-retrievers/data-retrievers.ts b/src/rag/data-retrievers/data-retrievers.ts index 2b71584..30dcd89 100644 --- a/src/rag/data-retrievers/data-retrievers.ts +++ b/src/rag/data-retrievers/data-retrievers.ts @@ -39,6 +39,22 @@ export type RetrievalOptions = | Partial> | undefined; +/** + * Configurations for providing information about data to the retriever. + */ +export type RetrieverConfigDataOptions = + | { + filePath: string; + dataType?: SupportedDataLoaderTypes; + } + | { + docs: Document>[]; + dataType: SupportedDataLoaderTypes; + } + | { + splitDocs: Document>[]; + }; + /** * Represents the configuration for the retriever when generating embeddings. * @property {SupportedDataLoaderTypes} dataType - The type of data loader to use. @@ -57,10 +73,6 @@ export type RetrievalOptions = * @property {boolean} generateEmbeddings - Whether to generate embeddings. */ export type RetrieverConfigGeneratingEmbeddings = { - filePath: string; - dataType?: SupportedDataLoaderTypes; - docs?: Document>[]; - splitDocs?: Document>[]; jsonLoaderKeysToInclude?: JSONLoaderKeysToInclude; csvLoaderOptions?: CSVLoaderOptions; pdfLoaderOptions?: PDFLoaderOptions; @@ -71,7 +83,7 @@ export type RetrieverConfigGeneratingEmbeddings = { vectorStore?: VectorStore; embeddingModel?: EmbeddingsInterface; generateEmbeddings: true; -}; +} & RetrieverConfigDataOptions; /** * Represents the configuration for the retriever when not generating embeddings. @@ -144,6 +156,89 @@ export const defaultChunkingConfig: ChunkingConfig = { chunkOverlap: 200, }; +/** + * Method to split the documents based on the data type. + * @param config configuration for the retriever + * @returns array of documents + */ +export const getSplitDocs = async ( + config: RetrieverConfig +): Promise>[]> => { + // if splitDocs available, return them + if ('splitDocs' in config && config.splitDocs) { + return config.splitDocs; + } + + // if docs provided + if ('docs' in config && config.docs) { + // must provide dataType since it can't be inferred from file extension + if (!config.dataType) + throw new Error('Data type must be provided when docs are provided'); + + const {defaultDataSplitterType, defaultSplitterConfig} = + getAppropriateDataSplitter(config.dataType); + + // Split the provided documents into chunks using the data splitter + return await runDataSplitter({ + docs: config.docs, + dataSplitterType: config.dataSplitterType ?? defaultDataSplitterType, + chunkingConfig: config.chunkingConfig ?? defaultChunkingConfig, + splitterConfig: config.splitterConfig ?? defaultSplitterConfig, + }); + } + + // if file path provided + if ('filePath' in config && config.filePath) { + // if generating embeddings, file path must be provided + if (config.filePath === '') { + throw new Error( + 'Invalid file path. File path can not be an empty string.' + ); + } + + // store provided data type + let dataType: SupportedDataLoaderTypes | undefined = config.dataType; + + // if no data type provided, infer it from the file extension + if (!dataType) { + // validate the data type of the file + const result = validateDataType(config.filePath); + // check if the file type is supported + if (result.isSupported) { + dataType = result.dataType; + } else { + throw new Error( + `Unable to load data. Unsupported file type: ${result.unSupportedDataType}` + ); + } + } + + // get documents from the file path using a data loader + const docs = await getDocs({ + dataLoaderType: dataType, + path: config.filePath, + csvLoaderOptions: config.csvLoaderOptions, + jsonLoaderKeysToInclude: config.jsonLoaderKeysToInclude, + pdfLoaderOptions: config.pdfLoaderOptions, + }); + + const {defaultDataSplitterType, defaultSplitterConfig} = + getAppropriateDataSplitter(dataType); + + // Split the retrieved documents into chunks using the data splitter + return await runDataSplitter({ + docs, + dataSplitterType: config.dataSplitterType ?? defaultDataSplitterType, + chunkingConfig: config.chunkingConfig ?? defaultChunkingConfig, + splitterConfig: config.splitterConfig ?? defaultSplitterConfig, + }); + } + + throw new Error( + 'Invalid configuration for the retriever. Must provide docs, splitDocs or file path.' + ); +}; + /** * Method to ingest data, split it into chunks, generate embeddings and store them in a vector store. * If not generating embeddings, simply returns a runnable instance to retrieve docs as string. @@ -167,44 +262,7 @@ export const getDataRetriever = async ( .pipe(formatDocumentsAsString); } - // if generating embeddings, file path must be provided - if (!config.filePath || config.filePath === '') { - throw new Error('Invalid file path. File path must be provided'); - } - - // if data type not provided, infer the data type from file extension using the file path - if (!config.dataType) { - const result = validateDataType(config.filePath); - // check if the file type is supported - if (result.isSupported) { - console.log('/n/n------------------'); - console.log(`Data type: ${result.dataType}`); - config.dataType = result.dataType; - } else { - throw new Error( - `Unable to load data. Unsupported file type: ${result.unSupportedDataType}` - ); - } - } - try { - // Retrieve the documents from the specified file path - const docs: Document>[] = - config.docs ?? (await getDocs(config.dataType, config.filePath)); - - const {defaultDataSplitterType, defaultSplitterConfig} = - getAppropriateDataSplitter(config.dataType); - - // Split the retrieved documents into chunks using the data splitter - const splitDocs: Document>[] = - config.splitDocs ?? - (await runDataSplitter({ - docs, - dataSplitterType: config.dataSplitterType ?? defaultDataSplitterType, - chunkingConfig: config.chunkingConfig ?? defaultChunkingConfig, - splitterConfig: config.splitterConfig ?? defaultSplitterConfig, - })); - // embedding model - if not provided, use the default Google Generative AI Embeddings model const embeddings: EmbeddingsInterface = config.embeddingModel ?? @@ -214,6 +272,9 @@ export const getDataRetriever = async ( taskType: TaskType.RETRIEVAL_DOCUMENT, }); + // get split documents + const splitDocs = await getSplitDocs(config); + // create a vector store instance const vectorStore: VectorStore = config.vectorStore ?? new MemoryVectorStore(embeddings); diff --git a/src/tests/unit-tests/endpoint-rag.unit.test.ts b/src/tests/unit-tests/endpoint-rag.unit.test.ts index 25832d5..8277690 100644 --- a/src/tests/unit-tests/endpoint-rag.unit.test.ts +++ b/src/tests/unit-tests/endpoint-rag.unit.test.ts @@ -4,6 +4,7 @@ import { } from '../../endpoints/endpoints'; import {setupGenkit} from '../../genkit/genkit'; import {getDataRetriever} from '../../rag/data-retrievers/data-retrievers'; +import {CSVLoader} from '@langchain/community/document_loaders/fs/csv'; /** * Test suite for Chat Endpoint RAG Functionality. @@ -22,6 +23,7 @@ describe('Test - Endpoint RAG Tests', () => { // Set to true to run the test const Tests = { test_rag_works: true, + test_rag_works_providing_docs: true, }; // default test timeout @@ -77,4 +79,60 @@ describe('Test - Endpoint RAG Tests', () => { }, defaultTimeout ); + + if (Tests.test_rag_works_providing_docs) + test( + 'Test RAG works when providing docs', + async () => { + // configure data loader + const loader = new CSVLoader('src/tests/test-data/inventory-data.csv'); + // get documents + const docs = await loader.load(); + + // define chat endpoint + const endpoint = defineChatEndpoint({ + endpoint: 'test-chat-open-rag-docs', + topic: 'store inventory', + enableRAG: true, + retriever: await getDataRetriever({ + docs, + dataType: 'csv', // need to specify data type when providing docs + generateEmbeddings: true, + }), + }); + try { + // send test query + const response = await runEndpoint(endpoint, { + query: 'What is the price of Seagate ST1000DX002?', + }); + + // check response is valid and does not contain error + expect(response).toBeDefined(); + expect(response).not.toHaveProperty('error'); + + // confirm response type + if (typeof response === 'string') { + // should not be empty + expect(response.length).toBeGreaterThan(0); + // should contain 68.06 + expect(response).toContain('68.06'); + } else { + expect(response).toHaveProperty('response'); + if ('response' in response) { + // should not be empty + expect(response.response.length).toBeGreaterThan(0); + // should contain 68.06 + expect(response.response).toContain('68.06'); + } else { + throw new Error( + `Response field invalid. Response: ${JSON.stringify(response)}` + ); + } + } + } catch (error) { + throw new Error(`Error in test. Error: ${error}`); + } + }, + defaultTimeout + ); });