diff --git a/collection.go b/collection.go index 01c0351..2c3686c 100644 --- a/collection.go +++ b/collection.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "slices" + "sort" "sync" ) @@ -63,6 +65,59 @@ func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddin return c.add(ctx, ids, documents, embeddings, metadatas, concurrency) } +// Performs a nearest neighbors query on a collection specified by UUID. +// +// - queryText: The text to search for. +// - nResults: The number of results to return. Must be > 0. +// - where: Conditional filtering on metadata. Optional. +// - whereDocument: Conditional filtering on documents. Optional. +func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { + c.documentsLock.RLock() + defer c.documentsLock.RUnlock() + if len(c.documents) == 0 { + return nil, nil + } + + if nResults <= 0 { + return nil, errors.New("nResults must be > 0") + } + + // Validate whereDocument operators + for k := range whereDocument { + if !slices.Contains(supportedFilters, k) { + return nil, errors.New("unsupported operator") + } + } + + // Filter docs by metadata and content + filteredDocs := filterDocs(c.documents, where, whereDocument) + + // No need to continue if the filters got rid of all documents + if len(filteredDocs) == 0 { + return nil, nil + } + + queryVectors, err := c.embed(ctx, queryText) + if err != nil { + return nil, fmt.Errorf("couldn't create embedding of query: %w", err) + } + + // For the remaining documents, calculate cosine similarity. + res, err := calcDocSimilarity(ctx, queryVectors, filteredDocs) + if err != nil { + return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err) + } + + // Sort by similarity + sort.Slice(res, func(i, j int) bool { + // The `less` function would usually use `<`, but we want to sort descending. + return res[i].Similarity > res[j].Similarity + }) + + // Return the top nResults + return res[:nResults], nil +} + func (c *Collection) add(ctx context.Context, ids []string, documents []string, embeddings [][]float32, metadatas []map[string]string, concurrency int) error { if len(ids) == 0 || len(documents) == 0 { return errors.New("ids and documents must not be empty") diff --git a/embed_compat.go b/embed_compat.go new file mode 100644 index 0000000..f82e091 --- /dev/null +++ b/embed_compat.go @@ -0,0 +1,68 @@ +package chromem + +const ( + baseURLMistral = "https://api.mistral.ai/v1" + // Currently there's only one. Let's turn this into a pseudo-enum as soon as there are more. + embeddingModelMistral = "mistral-embed" +) + +// NewEmbeddingFuncMistral returns a function that creates embeddings for a document +// using the Mistral API. +func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc { + // The Mistral API docs don't mention the `encoding_format` as optional, + // but it seems to be, just like OpenAI. So we reuse the OpenAI function. + return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral) +} + +const baseURLJina = "https://api.jina.ai/v1" + +type EmbeddingModelJina string + +const ( + EmbeddingModelJina2BaseEN EmbeddingModelJina = "jina-embeddings-v2-base-en" + EmbeddingModelJina2BaseDE EmbeddingModelJina = "jina-embeddings-v2-base-de" + EmbeddingModelJina2BaseCode EmbeddingModelJina = "jina-embeddings-v2-base-code" + EmbeddingModelJina2BaseZH EmbeddingModelJina = "jina-embeddings-v2-base-zh" +) + +// NewEmbeddingFuncJina returns a function that creates embeddings for a document +// using the Jina API. +func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc { + return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model)) +} + +const baseURLMixedbread = "https://api.mixedbread.ai" + +type EmbeddingModelMixedbread string + +const ( + EmbeddingModelMixedbreadUAELargeV1 EmbeddingModelMixedbread = "UAE-Large-V1" + EmbeddingModelMixedbreadBGELargeENV15 EmbeddingModelMixedbread = "bge-large-en-v1.5" + EmbeddingModelMixedbreadGTELarge EmbeddingModelMixedbread = "gte-large" + EmbeddingModelMixedbreadE5LargeV2 EmbeddingModelMixedbread = "e5-large-v2" + EmbeddingModelMixedbreadMultilingualE5Large EmbeddingModelMixedbread = "multilingual-e5-large" + EmbeddingModelMixedbreadMultilingualE5Base EmbeddingModelMixedbread = "multilingual-e5-base" + EmbeddingModelMixedbreadAllMiniLML6V2 EmbeddingModelMixedbread = "all-MiniLM-L6-v2" + EmbeddingModelMixedbreadGTELargeZh EmbeddingModelMixedbread = "gte-large-zh" +) + +// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a document +// using the mixedbread.ai API. +func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc { + return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model)) +} + +const baseURLLocalAI = "http://localhost:8080/v1" + +// NewEmbeddingFuncLocalAI returns a function that creates embeddings for a document +// using the LocalAI API. +// You can start a LocalAI instance like this: +// +// docker run -it -p 127.0.0.1:8080:8080 localai/localai:v2.7.0-ffmpeg-core bert-cpp +// +// And then call this constructor with model "bert-cpp-minilm-v6". +// But other embedding models are supported as well. See the LocalAI documentation +// for details. +func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc { + return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model) +} diff --git a/embed_jina.go b/embed_jina.go deleted file mode 100644 index de77479..0000000 --- a/embed_jina.go +++ /dev/null @@ -1,18 +0,0 @@ -package chromem - -const baseURLJina = "https://api.jina.ai/v1" - -type EmbeddingModelJina string - -const ( - EmbeddingModelJina2BaseEN EmbeddingModelJina = "jina-embeddings-v2-base-en" - EmbeddingModelJina2BaseDE EmbeddingModelJina = "jina-embeddings-v2-base-de" - EmbeddingModelJina2BaseCode EmbeddingModelJina = "jina-embeddings-v2-base-code" - EmbeddingModelJina2BaseZH EmbeddingModelJina = "jina-embeddings-v2-base-zh" -) - -// NewEmbeddingFuncJina returns a function that creates embeddings for a document -// using the Jina API. -func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model)) -} diff --git a/embed_localai.go b/embed_localai.go deleted file mode 100644 index bce28d8..0000000 --- a/embed_localai.go +++ /dev/null @@ -1,16 +0,0 @@ -package chromem - -const baseURLLocalAI = "http://localhost:8080/v1" - -// NewEmbeddingFuncLocalAI returns a function that creates embeddings for a document -// using the LocalAI API. -// You can start a LocalAI instance like this: -// -// docker run -it -p 127.0.0.1:8080:8080 localai/localai:v2.7.0-ffmpeg-core bert-cpp -// -// And then call this constructor with model "bert-cpp-minilm-v6". -// But other embedding models are supported as well. See the LocalAI documentation -// for details. -func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model) -} diff --git a/embed_mistral.go b/embed_mistral.go deleted file mode 100644 index 15cae49..0000000 --- a/embed_mistral.go +++ /dev/null @@ -1,15 +0,0 @@ -package chromem - -const ( - baseURLMistral = "https://api.mistral.ai/v1" - // Currently there's only one. Let's turn this into a pseudo-enum as soon as there are more. - embeddingModelMistral = "mistral-embed" -) - -// NewEmbeddingFuncMistral returns a function that creates embeddings for a document -// using the Mistral API. -func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc { - // The Mistral API docs don't mention the `encoding_format` as optional, - // but it seems to be, just like OpenAI. So we reuse the OpenAI function. - return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral) -} diff --git a/embed_mixedbread.go b/embed_mixedbread.go deleted file mode 100644 index 147b4aa..0000000 --- a/embed_mixedbread.go +++ /dev/null @@ -1,22 +0,0 @@ -package chromem - -const baseURLMixedbread = "https://api.mixedbread.ai" - -type EmbeddingModelMixedbread string - -const ( - EmbeddingModelMixedbreadUAELargeV1 EmbeddingModelMixedbread = "UAE-Large-V1" - EmbeddingModelMixedbreadBGELargeENV15 EmbeddingModelMixedbread = "bge-large-en-v1.5" - EmbeddingModelMixedbreadGTELarge EmbeddingModelMixedbread = "gte-large" - EmbeddingModelMixedbreadE5LargeV2 EmbeddingModelMixedbread = "e5-large-v2" - EmbeddingModelMixedbreadMultilingualE5Large EmbeddingModelMixedbread = "multilingual-e5-large" - EmbeddingModelMixedbreadMultilingualE5Base EmbeddingModelMixedbread = "multilingual-e5-base" - EmbeddingModelMixedbreadAllMiniLML6V2 EmbeddingModelMixedbread = "all-MiniLM-L6-v2" - EmbeddingModelMixedbreadGTELargeZh EmbeddingModelMixedbread = "gte-large-zh" -) - -// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a document -// using the mixedbread.ai API. -func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model)) -} diff --git a/query.go b/query.go index 2285bab..317a729 100644 --- a/query.go +++ b/query.go @@ -2,11 +2,7 @@ package chromem import ( "context" - "errors" - "fmt" "runtime" - "slices" - "sort" "strings" "sync" ) @@ -26,59 +22,6 @@ type Result struct { Similarity float32 } -// Performs a nearest neighbors query on a collection specified by UUID. -// -// - queryText: The text to search for. -// - nResults: The number of results to return. Must be > 0. -// - where: Conditional filtering on metadata. Optional. -// - whereDocument: Conditional filtering on documents. Optional. -func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { - c.documentsLock.RLock() - defer c.documentsLock.RUnlock() - if len(c.documents) == 0 { - return nil, nil - } - - if nResults <= 0 { - return nil, errors.New("nResults must be > 0") - } - - // Validate whereDocument operators - for k := range whereDocument { - if !slices.Contains(supportedFilters, k) { - return nil, errors.New("unsupported operator") - } - } - - // Filter docs by metadata and content - filteredDocs := filterDocs(c.documents, where, whereDocument) - - // No need to continue if the filters got rid of all documents - if len(filteredDocs) == 0 { - return nil, nil - } - - queryVectors, err := c.embed(ctx, queryText) - if err != nil { - return nil, fmt.Errorf("couldn't create embedding of query: %w", err) - } - - // For the remaining documents, calculate cosine similarity. - res, err := calcDocSimilarity(ctx, queryVectors, filteredDocs) - if err != nil { - return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err) - } - - // Sort by similarity - sort.Slice(res, func(i, j int) bool { - // The `less` function would usually use `<`, but we want to sort descending. - return res[i].Similarity > res[j].Similarity - }) - - // Return the top nResults - return res[:nResults], nil -} - // filterDocs filters a map of documents by metadata and content. // It does this concurrently. func filterDocs(docs map[string]*document, where, whereDocument map[string]string) []*document {