diff --git a/libs/langchain-azure-cosmosdb/src/caches/caches_mongodb.ts b/libs/langchain-azure-cosmosdb/src/caches/caches_mongodb.ts new file mode 100644 index 000000000000..b33589a4095d --- /dev/null +++ b/libs/langchain-azure-cosmosdb/src/caches/caches_mongodb.ts @@ -0,0 +1,178 @@ +import { + BaseCache, + deserializeStoredGeneration, + getCacheKey, + serializeGeneration, +} from "@langchain/core/caches"; +import { Generation } from "@langchain/core/outputs"; +import { Document } from "@langchain/core/documents"; +import { EmbeddingsInterface } from "@langchain/core/embeddings"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { MongoClient } from "mongodb"; +import { + AzureCosmosDBMongoDBConfig, + AzureCosmosDBMongoDBVectorStore, + AzureCosmosDBMongoDBSimilarityType, +} from "../azure_cosmosdb_mongodb.js"; + +/** + * Represents a Semantic Cache that uses CosmosDB MongoDB backend as the underlying + * storage system. + * + * @example + * ```typescript + * const embeddings = new OpenAIEmbeddings(); + * const cache = new AzureCosmosDBMongoDBSemanticCache(embeddings, { + * client?: MongoClient + * }); + * const model = new ChatOpenAI({cache}); + * + * // Invoke the model to perform an action + * const response = await model.invoke("Do something random!"); + * console.log(response); + * ``` + */ +export class AzureCosmosDBMongoDBSemanticCache extends BaseCache { + private embeddings: EmbeddingsInterface; + + private config: AzureCosmosDBMongoDBConfig; + + private similarityScoreThreshold: number; + + private cacheDict: { [key: string]: AzureCosmosDBMongoDBVectorStore } = {}; + + private readonly client: MongoClient | undefined; + + private vectorDistanceFunction: string; + + constructor( + embeddings: EmbeddingsInterface, + dbConfig: AzureCosmosDBMongoDBConfig, + similarityScoreThreshold: number = 0.6 + ) { + super(); + + const connectionString = + dbConfig.connectionString ?? + getEnvironmentVariable("AZURE_COSMOSDB_MONGODB_CONNECTION_STRING"); + + if (!dbConfig.client && !connectionString) { + throw new Error( + "AzureCosmosDBMongoDBSemanticCache client or connection string must be set." + ); + } + + if (!dbConfig.client) { + this.client = new MongoClient(connectionString!, { + appName: "langchainjs", + }); + } else { + this.client = dbConfig.client; + } + + this.config = { + ...dbConfig, + client: this.client, + collectionName: dbConfig.collectionName ?? "semanticCacheContainer", + }; + + this.similarityScoreThreshold = similarityScoreThreshold; + this.embeddings = embeddings; + this.vectorDistanceFunction = + dbConfig?.indexOptions?.similarity ?? + AzureCosmosDBMongoDBSimilarityType.COS; + } + + private getLlmCache(llmKey: string) { + const key = getCacheKey(llmKey); + if (!this.cacheDict[key]) { + this.cacheDict[key] = new AzureCosmosDBMongoDBVectorStore( + this.embeddings, + this.config + ); + } + return this.cacheDict[key]; + } + + /** + * Retrieves data from the cache. + * + * @param prompt The prompt for lookup. + * @param llmKey The LLM key used to construct the cache key. + * @returns An array of Generations if found, null otherwise. + */ + async lookup(prompt: string, llmKey: string): Promise { + const llmCache = this.getLlmCache(llmKey); + + const queryEmbedding = await this.embeddings.embedQuery(prompt); + const results = await llmCache.similaritySearchVectorWithScore( + queryEmbedding, + 1, + this.config.indexOptions?.indexType + ); + if (!results.length) return null; + + const generations = results + .flatMap(([document, score]) => { + const isSimilar = + (this.vectorDistanceFunction === + AzureCosmosDBMongoDBSimilarityType.L2 && + score <= this.similarityScoreThreshold) || + (this.vectorDistanceFunction !== + AzureCosmosDBMongoDBSimilarityType.L2 && + score >= this.similarityScoreThreshold); + + if (!isSimilar) return undefined; + + return document.metadata.return_value.map((gen: string) => + deserializeStoredGeneration(JSON.parse(gen)) + ); + }) + .filter((gen) => gen !== undefined); + + return generations.length > 0 ? generations : null; + } + + /** + * Updates the cache with new data. + * + * @param prompt The prompt for update. + * @param llmKey The LLM key used to construct the cache key. + * @param value The value to be stored in the cache. + */ + public async update( + prompt: string, + llmKey: string, + returnValue: Generation[] + ): Promise { + const serializedGenerations = returnValue.map((generation) => + JSON.stringify(serializeGeneration(generation)) + ); + + const llmCache = this.getLlmCache(llmKey); + + const metadata = { + llm_string: llmKey, + prompt, + return_value: serializedGenerations, + }; + + const doc = new Document({ + pageContent: prompt, + metadata, + }); + + await llmCache.addDocuments([doc]); + } + + /** + * deletes the semantic cache for a given llmKey + * @param llmKey + */ + public async clear(llmKey: string) { + const key = getCacheKey(llmKey); + if (this.cacheDict[key]) { + await this.cacheDict[key].delete(); + } + } +} diff --git a/libs/langchain-azure-cosmosdb/src/caches.ts b/libs/langchain-azure-cosmosdb/src/caches/caches_nosql.ts similarity index 99% rename from libs/langchain-azure-cosmosdb/src/caches.ts rename to libs/langchain-azure-cosmosdb/src/caches/caches_nosql.ts index da7619c5ff96..e110f848702b 100644 --- a/libs/langchain-azure-cosmosdb/src/caches.ts +++ b/libs/langchain-azure-cosmosdb/src/caches/caches_nosql.ts @@ -13,7 +13,7 @@ import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { AzureCosmosDBNoSQLConfig, AzureCosmosDBNoSQLVectorStore, -} from "./azure_cosmosdb_nosql.js"; +} from "../azure_cosmosdb_nosql.js"; const USER_AGENT_SUFFIX = "langchainjs-cdbnosql-semanticcache-javascript"; const DEFAULT_CONTAINER_NAME = "semanticCacheContainer"; diff --git a/libs/langchain-azure-cosmosdb/src/index.ts b/libs/langchain-azure-cosmosdb/src/index.ts index 883710c842ff..2778d2c55417 100644 --- a/libs/langchain-azure-cosmosdb/src/index.ts +++ b/libs/langchain-azure-cosmosdb/src/index.ts @@ -1,5 +1,6 @@ export * from "./azure_cosmosdb_mongodb.js"; export * from "./azure_cosmosdb_nosql.js"; -export * from "./caches.js"; +export * from "./caches/caches_nosql.js"; +export * from "./caches/caches_mongodb.js"; export * from "./chat_histories/nosql.js"; export * from "./chat_histories/mongodb.js"; diff --git a/libs/langchain-azure-cosmosdb/src/tests/caches/caches_mongodb.int.test.ts b/libs/langchain-azure-cosmosdb/src/tests/caches/caches_mongodb.int.test.ts new file mode 100644 index 000000000000..7902bdc06e79 --- /dev/null +++ b/libs/langchain-azure-cosmosdb/src/tests/caches/caches_mongodb.int.test.ts @@ -0,0 +1,136 @@ +/* eslint-disable no-nested-ternary */ +/* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable no-process-env */ +import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai"; +import { MongoClient } from "mongodb"; +import { AzureCosmosDBMongoDBSemanticCache } from "../../caches/caches_mongodb.js"; +import { + AzureCosmosDBMongoDBIndexOptions, + AzureCosmosDBMongoDBSimilarityType, +} from "../../azure_cosmosdb_mongodb.js"; + +const DATABASE_NAME = "langchain"; +const COLLECTION_NAME = "test"; + +async function initializeCache( + indexType: any, + distanceFunction: any, + similarityThreshold: number = 0.6 +): Promise { + const embeddingModel = new OpenAIEmbeddings(); + const testEmbedding = await embeddingModel.embedDocuments(["sample text"]); + const dimension = testEmbedding[0].length; + + const indexOptions: AzureCosmosDBMongoDBIndexOptions = { + indexType, + // eslint-disable-next-line no-nested-ternary + similarity: + distanceFunction === "cosine" + ? AzureCosmosDBMongoDBSimilarityType.COS + : distanceFunction === "euclidean" + ? AzureCosmosDBMongoDBSimilarityType.L2 + : AzureCosmosDBMongoDBSimilarityType.IP, + dimensions: dimension, + }; + + let cache: AzureCosmosDBMongoDBSemanticCache; + + const connectionString = process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING; + if (connectionString) { + cache = new AzureCosmosDBMongoDBSemanticCache( + embeddingModel, + { + databaseName: DATABASE_NAME, + collectionName: COLLECTION_NAME, + connectionString, + indexOptions, + }, + similarityThreshold + ); + } else { + throw new Error( + "Please set the environment variable AZURE_COSMOSDB_MONGODB_CONNECTION_STRING" + ); + } + + return cache; +} + +describe("AzureCosmosDBMongoDBSemanticCache", () => { + beforeEach(async () => { + const connectionString = + process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING; + const client = new MongoClient(connectionString!); + + try { + await client.db(DATABASE_NAME).collection(COLLECTION_NAME).drop(); + } catch (error) { + throw new Error("Please set collection name here"); + } + }); + + it("should store and retrieve cache using cosine similarity with ivf index", async () => { + const cache = await initializeCache("ivf", "cosine"); + const model = new ChatOpenAI({ cache }); + const llmString = JSON.stringify(model._identifyingParams); + await cache.update("foo", llmString, [{ text: "fizz" }]); + + let cacheOutput = await cache.lookup("foo", llmString); + expect(cacheOutput).toEqual([{ text: "fizz" }]); + + cacheOutput = await cache.lookup("bar", llmString); + expect(cacheOutput).toEqual(null); + + await cache.clear(llmString); + }); + + it("should store and retrieve cache using euclidean similarity with hnsw index", async () => { + const cache = await initializeCache("hnsw", "euclidean"); + const model = new ChatOpenAI({ cache }); + const llmString = JSON.stringify(model._identifyingParams); + await cache.update("foo", llmString, [{ text: "fizz" }]); + + let cacheOutput = await cache.lookup("foo", llmString); + expect(cacheOutput).toEqual([{ text: "fizz" }]); + + cacheOutput = await cache.lookup("bar", llmString); + expect(cacheOutput).toEqual(null); + + await cache.clear(llmString); + }); + + it("should return null if similarity score is below threshold (cosine similarity with ivf index)", async () => { + const cache = await initializeCache("ivf", "cosine", 0.8); + const model = new ChatOpenAI({ cache }); + const llmString = JSON.stringify(model._identifyingParams); + await cache.update("foo", llmString, [{ text: "fizz" }]); + + const cacheOutput = await cache.lookup("foo", llmString); + expect(cacheOutput).toEqual([{ text: "fizz" }]); + + const resultBelowThreshold = await cache.lookup("bar", llmString); + expect(resultBelowThreshold).toEqual(null); + + await cache.clear(llmString); + }); + + it("should handle a variety of cache updates and lookups", async () => { + const cache = await initializeCache("ivf", "cosine", 0.7); + const model = new ChatOpenAI({ cache }); + const llmString = JSON.stringify(model._identifyingParams); + + await cache.update("test1", llmString, [{ text: "response 1" }]); + await cache.update("test2", llmString, [{ text: "response 2" }]); + + let cacheOutput = await cache.lookup("test1", llmString); + expect(cacheOutput).toEqual([{ text: "response 1" }]); + + cacheOutput = await cache.lookup("test2", llmString); + expect(cacheOutput).toEqual([{ text: "response 2" }]); + + cacheOutput = await cache.lookup("test3", llmString); + expect(cacheOutput).toEqual(null); + + await cache.clear(llmString); + }); +}); diff --git a/libs/langchain-azure-cosmosdb/src/tests/caches/caches_mongodb.test.ts b/libs/langchain-azure-cosmosdb/src/tests/caches/caches_mongodb.test.ts new file mode 100644 index 000000000000..4bd9cf0ce996 --- /dev/null +++ b/libs/langchain-azure-cosmosdb/src/tests/caches/caches_mongodb.test.ts @@ -0,0 +1,72 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +// eslint-disable-next-line import/no-extraneous-dependencies +import { jest } from "@jest/globals"; +import { FakeEmbeddings, FakeLLM } from "@langchain/core/utils/testing"; +import { Document } from "@langchain/core/documents"; +import { MongoClient } from "mongodb"; +import { AzureCosmosDBMongoDBSemanticCache } from "../../index.js"; + +const createMockClient = () => ({ + db: jest.fn().mockReturnValue({ + collectionName: "documents", + collection: jest.fn().mockReturnValue({ + listIndexes: jest.fn().mockReturnValue({ + toArray: jest.fn().mockReturnValue([ + { + name: "vectorSearchIndex", + }, + ]), + }), + findOne: jest.fn().mockReturnValue({ + metadata: { + return_value: ['{"text": "fizz"}'], + }, + similarityScore: 0.8, + }), + insertMany: jest.fn().mockImplementation((docs: any) => ({ + insertedIds: docs.map((_: any, i: any) => `id${i}`), + })), + aggregate: jest.fn().mockReturnValue({ + map: jest.fn().mockReturnValue({ + toArray: jest.fn().mockReturnValue([ + [ + new Document({ + pageContent: "test", + metadata: { return_value: ['{"text": "fizz"}'] }, + }), + 0.8, + ], + ]), + }), + }), + }), + command: jest.fn(), + }), + connect: jest.fn(), + close: jest.fn(), +}); + +describe("AzureCosmosDBMongoDBSemanticCache", () => { + it("should store, retrieve, and clear cache in MongoDB", async () => { + const mockClient = createMockClient() as any; + const embeddings = new FakeEmbeddings(); + const cache = new AzureCosmosDBMongoDBSemanticCache( + embeddings, + { + client: mockClient as MongoClient, + }, + 0.8 + ); + + expect(cache).toBeDefined(); + + const llm = new FakeLLM({}); + const llmString = JSON.stringify(llm._identifyingParams()); + + await cache.update("foo", llmString, [{ text: "fizz" }]); + expect(mockClient.db().collection().insertMany).toHaveBeenCalled(); + + const result = await cache.lookup("foo", llmString); + expect(result).toEqual([{ text: "fizz" }]); + }); +}); diff --git a/libs/langchain-azure-cosmosdb/src/tests/caches.int.test.ts b/libs/langchain-azure-cosmosdb/src/tests/caches/caches_nosql.int.test.ts similarity index 99% rename from libs/langchain-azure-cosmosdb/src/tests/caches.int.test.ts rename to libs/langchain-azure-cosmosdb/src/tests/caches/caches_nosql.int.test.ts index c7acb92f7c86..4c301469bbcf 100644 --- a/libs/langchain-azure-cosmosdb/src/tests/caches.int.test.ts +++ b/libs/langchain-azure-cosmosdb/src/tests/caches/caches_nosql.int.test.ts @@ -9,7 +9,7 @@ import { } from "@azure/cosmos"; import { DefaultAzureCredential } from "@azure/identity"; import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai"; -import { AzureCosmosDBNoSQLSemanticCache } from "../caches.js"; +import { AzureCosmosDBNoSQLSemanticCache } from "../../caches/caches_nosql.js"; const DATABASE_NAME = "langchainTestCacheDB"; const CONTAINER_NAME = "testContainer"; diff --git a/libs/langchain-azure-cosmosdb/src/tests/caches.test.ts b/libs/langchain-azure-cosmosdb/src/tests/caches/caches_nosql.test.ts similarity index 96% rename from libs/langchain-azure-cosmosdb/src/tests/caches.test.ts rename to libs/langchain-azure-cosmosdb/src/tests/caches/caches_nosql.test.ts index 9de3f507acc0..3a7a253f22bc 100644 --- a/libs/langchain-azure-cosmosdb/src/tests/caches.test.ts +++ b/libs/langchain-azure-cosmosdb/src/tests/caches/caches_nosql.test.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { jest } from "@jest/globals"; import { FakeEmbeddings, FakeLLM } from "@langchain/core/utils/testing"; -import { AzureCosmosDBNoSQLSemanticCache } from "../index.js"; +import { AzureCosmosDBNoSQLSemanticCache } from "../../index.js"; // Create the mock Cosmos DB client const createMockClient = () => {