Skip to content

Commit

Permalink
Fix for the issue where filePath was still required when providing …
Browse files Browse the repository at this point in the history
…`docs` in retriever config #33 (#34)

* fix for #33 - refactored code for data retriever

* Bumped package version #33
  • Loading branch information
pranav-kural authored Jul 15, 2024
1 parent 9a58767 commit 987cbdb
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 52 deletions.
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@oconva/qvikchat",
"version": "1.0.8",
"version": "1.0.9",
"repository": {
"type": "git",
"url": "https://github.com/oconva/qvikchat.git"
Expand Down
20 changes: 13 additions & 7 deletions src/rag/data-loaders/data-loaders.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,19 @@ export function validateDataType(filePath: string):
* For more data loaders, see:
* @link https://js.langchain.com/v0.1/docs/integrations/document_loaders/
*/
export const getDocs = async (
dataLoaderType: SupportedDataLoaderTypes,
path: string,
jsonLoaderKeysToInclude?: JSONLoaderKeysToInclude,
csvLoaderOptions?: CSVLoaderOptions,
pdfLoaderOptions?: PDFLoaderOptions
): Promise<Document<Record<string, string>>[]> => {
export const getDocs = async ({
dataLoaderType,
path,
jsonLoaderKeysToInclude,
csvLoaderOptions,
pdfLoaderOptions,
}: {
dataLoaderType: SupportedDataLoaderTypes;
path: string;
jsonLoaderKeysToInclude?: JSONLoaderKeysToInclude;
csvLoaderOptions?: CSVLoaderOptions;
pdfLoaderOptions?: PDFLoaderOptions;
}): Promise<Document<Record<string, string>>[]> => {
// store loader
let loader;
// infer loader to use based on dataLoaderType
Expand Down
145 changes: 103 additions & 42 deletions src/rag/data-retrievers/data-retrievers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ export type RetrievalOptions =
| Partial<VectorStoreRetrieverInput<any>>
| undefined;

/**
* Configurations for providing information about data to the retriever.
*/
export type RetrieverConfigDataOptions =
| {
filePath: string;
dataType?: SupportedDataLoaderTypes;
}
| {
docs: Document<Record<string, string>>[];
dataType: SupportedDataLoaderTypes;
}
| {
splitDocs: Document<Record<string, unknown>>[];
};

/**
* Represents the configuration for the retriever when generating embeddings.
* @property {SupportedDataLoaderTypes} dataType - The type of data loader to use.
Expand All @@ -57,10 +73,6 @@ export type RetrievalOptions =
* @property {boolean} generateEmbeddings - Whether to generate embeddings.
*/
export type RetrieverConfigGeneratingEmbeddings = {
filePath: string;
dataType?: SupportedDataLoaderTypes;
docs?: Document<Record<string, string>>[];
splitDocs?: Document<Record<string, unknown>>[];
jsonLoaderKeysToInclude?: JSONLoaderKeysToInclude;
csvLoaderOptions?: CSVLoaderOptions;
pdfLoaderOptions?: PDFLoaderOptions;
Expand All @@ -71,7 +83,7 @@ export type RetrieverConfigGeneratingEmbeddings = {
vectorStore?: VectorStore;
embeddingModel?: EmbeddingsInterface;
generateEmbeddings: true;
};
} & RetrieverConfigDataOptions;

/**
* Represents the configuration for the retriever when not generating embeddings.
Expand Down Expand Up @@ -144,6 +156,89 @@ export const defaultChunkingConfig: ChunkingConfig = {
chunkOverlap: 200,
};

/**
* Method to split the documents based on the data type.
* @param config configuration for the retriever
* @returns array of documents
*/
export const getSplitDocs = async (
config: RetrieverConfig
): Promise<Document<Record<string, unknown>>[]> => {
// if splitDocs available, return them
if ('splitDocs' in config && config.splitDocs) {
return config.splitDocs;
}

// if docs provided
if ('docs' in config && config.docs) {
// must provide dataType since it can't be inferred from file extension
if (!config.dataType)
throw new Error('Data type must be provided when docs are provided');

const {defaultDataSplitterType, defaultSplitterConfig} =
getAppropriateDataSplitter(config.dataType);

// Split the provided documents into chunks using the data splitter
return await runDataSplitter({
docs: config.docs,
dataSplitterType: config.dataSplitterType ?? defaultDataSplitterType,
chunkingConfig: config.chunkingConfig ?? defaultChunkingConfig,
splitterConfig: config.splitterConfig ?? defaultSplitterConfig,
});
}

// if file path provided
if ('filePath' in config && config.filePath) {
// if generating embeddings, file path must be provided
if (config.filePath === '') {
throw new Error(
'Invalid file path. File path can not be an empty string.'
);
}

// store provided data type
let dataType: SupportedDataLoaderTypes | undefined = config.dataType;

// if no data type provided, infer it from the file extension
if (!dataType) {
// validate the data type of the file
const result = validateDataType(config.filePath);
// check if the file type is supported
if (result.isSupported) {
dataType = result.dataType;
} else {
throw new Error(
`Unable to load data. Unsupported file type: ${result.unSupportedDataType}`
);
}
}

// get documents from the file path using a data loader
const docs = await getDocs({
dataLoaderType: dataType,
path: config.filePath,
csvLoaderOptions: config.csvLoaderOptions,
jsonLoaderKeysToInclude: config.jsonLoaderKeysToInclude,
pdfLoaderOptions: config.pdfLoaderOptions,
});

const {defaultDataSplitterType, defaultSplitterConfig} =
getAppropriateDataSplitter(dataType);

// Split the retrieved documents into chunks using the data splitter
return await runDataSplitter({
docs,
dataSplitterType: config.dataSplitterType ?? defaultDataSplitterType,
chunkingConfig: config.chunkingConfig ?? defaultChunkingConfig,
splitterConfig: config.splitterConfig ?? defaultSplitterConfig,
});
}

throw new Error(
'Invalid configuration for the retriever. Must provide docs, splitDocs or file path.'
);
};

/**
* Method to ingest data, split it into chunks, generate embeddings and store them in a vector store.
* If not generating embeddings, simply returns a runnable instance to retrieve docs as string.
Expand All @@ -167,44 +262,7 @@ export const getDataRetriever = async (
.pipe(formatDocumentsAsString);
}

// if generating embeddings, file path must be provided
if (!config.filePath || config.filePath === '') {
throw new Error('Invalid file path. File path must be provided');
}

// if data type not provided, infer the data type from file extension using the file path
if (!config.dataType) {
const result = validateDataType(config.filePath);
// check if the file type is supported
if (result.isSupported) {
console.log('/n/n------------------');
console.log(`Data type: ${result.dataType}`);
config.dataType = result.dataType;
} else {
throw new Error(
`Unable to load data. Unsupported file type: ${result.unSupportedDataType}`
);
}
}

try {
// Retrieve the documents from the specified file path
const docs: Document<Record<string, string>>[] =
config.docs ?? (await getDocs(config.dataType, config.filePath));

const {defaultDataSplitterType, defaultSplitterConfig} =
getAppropriateDataSplitter(config.dataType);

// Split the retrieved documents into chunks using the data splitter
const splitDocs: Document<Record<string, unknown>>[] =
config.splitDocs ??
(await runDataSplitter({
docs,
dataSplitterType: config.dataSplitterType ?? defaultDataSplitterType,
chunkingConfig: config.chunkingConfig ?? defaultChunkingConfig,
splitterConfig: config.splitterConfig ?? defaultSplitterConfig,
}));

// embedding model - if not provided, use the default Google Generative AI Embeddings model
const embeddings: EmbeddingsInterface =
config.embeddingModel ??
Expand All @@ -214,6 +272,9 @@ export const getDataRetriever = async (
taskType: TaskType.RETRIEVAL_DOCUMENT,
});

// get split documents
const splitDocs = await getSplitDocs(config);

// create a vector store instance
const vectorStore: VectorStore =
config.vectorStore ?? new MemoryVectorStore(embeddings);
Expand Down
58 changes: 58 additions & 0 deletions src/tests/unit-tests/endpoint-rag.unit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
} from '../../endpoints/endpoints';
import {setupGenkit} from '../../genkit/genkit';
import {getDataRetriever} from '../../rag/data-retrievers/data-retrievers';
import {CSVLoader} from '@langchain/community/document_loaders/fs/csv';

/**
* Test suite for Chat Endpoint RAG Functionality.
Expand All @@ -22,6 +23,7 @@ describe('Test - Endpoint RAG Tests', () => {
// Set to true to run the test
const Tests = {
test_rag_works: true,
test_rag_works_providing_docs: true,
};

// default test timeout
Expand Down Expand Up @@ -77,4 +79,60 @@ describe('Test - Endpoint RAG Tests', () => {
},
defaultTimeout
);

if (Tests.test_rag_works_providing_docs)
test(
'Test RAG works when providing docs',
async () => {
// configure data loader
const loader = new CSVLoader('src/tests/test-data/inventory-data.csv');
// get documents
const docs = await loader.load();

// define chat endpoint
const endpoint = defineChatEndpoint({
endpoint: 'test-chat-open-rag-docs',
topic: 'store inventory',
enableRAG: true,
retriever: await getDataRetriever({
docs,
dataType: 'csv', // need to specify data type when providing docs
generateEmbeddings: true,
}),
});
try {
// send test query
const response = await runEndpoint(endpoint, {
query: 'What is the price of Seagate ST1000DX002?',
});

// check response is valid and does not contain error
expect(response).toBeDefined();
expect(response).not.toHaveProperty('error');

// confirm response type
if (typeof response === 'string') {
// should not be empty
expect(response.length).toBeGreaterThan(0);
// should contain 68.06
expect(response).toContain('68.06');
} else {
expect(response).toHaveProperty('response');
if ('response' in response) {
// should not be empty
expect(response.response.length).toBeGreaterThan(0);
// should contain 68.06
expect(response.response).toContain('68.06');
} else {
throw new Error(
`Response field invalid. Response: ${JSON.stringify(response)}`
);
}
}
} catch (error) {
throw new Error(`Error in test. Error: ${error}`);
}
},
defaultTimeout
);
});

0 comments on commit 987cbdb

Please sign in to comment.