Skip to content

Commit

Permalink
Merge pull request #21 from philippgille/reorganize-code
Browse files Browse the repository at this point in the history
Reorganize code
  • Loading branch information
philippgille authored Feb 18, 2024
2 parents 38898e0 + ca027d5 commit cf08535
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 128 deletions.
55 changes: 55 additions & 0 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
"slices"
"sort"
"sync"
)

Expand Down Expand Up @@ -63,6 +65,59 @@ func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddin
return c.add(ctx, ids, documents, embeddings, metadatas, concurrency)
}

// Performs a nearest neighbors query on a collection specified by UUID.
//
// - queryText: The text to search for.
// - nResults: The number of results to return. Must be > 0.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
c.documentsLock.RLock()
defer c.documentsLock.RUnlock()
if len(c.documents) == 0 {
return nil, nil
}

if nResults <= 0 {
return nil, errors.New("nResults must be > 0")
}

// Validate whereDocument operators
for k := range whereDocument {
if !slices.Contains(supportedFilters, k) {
return nil, errors.New("unsupported operator")
}
}

// Filter docs by metadata and content
filteredDocs := filterDocs(c.documents, where, whereDocument)

// No need to continue if the filters got rid of all documents
if len(filteredDocs) == 0 {
return nil, nil
}

queryVectors, err := c.embed(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
}

// For the remaining documents, calculate cosine similarity.
res, err := calcDocSimilarity(ctx, queryVectors, filteredDocs)
if err != nil {
return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err)
}

// Sort by similarity
sort.Slice(res, func(i, j int) bool {
// The `less` function would usually use `<`, but we want to sort descending.
return res[i].Similarity > res[j].Similarity
})

// Return the top nResults
return res[:nResults], nil
}

func (c *Collection) add(ctx context.Context, ids []string, documents []string, embeddings [][]float32, metadatas []map[string]string, concurrency int) error {
if len(ids) == 0 || len(documents) == 0 {
return errors.New("ids and documents must not be empty")
Expand Down
68 changes: 68 additions & 0 deletions embed_compat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package chromem

const (
baseURLMistral = "https://api.mistral.ai/v1"
// Currently there's only one. Let's turn this into a pseudo-enum as soon as there are more.
embeddingModelMistral = "mistral-embed"
)

// NewEmbeddingFuncMistral returns a function that creates embeddings for a document
// using the Mistral API.
func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc {
// The Mistral API docs don't mention the `encoding_format` as optional,
// but it seems to be, just like OpenAI. So we reuse the OpenAI function.
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral)
}

const baseURLJina = "https://api.jina.ai/v1"

type EmbeddingModelJina string

const (
EmbeddingModelJina2BaseEN EmbeddingModelJina = "jina-embeddings-v2-base-en"
EmbeddingModelJina2BaseDE EmbeddingModelJina = "jina-embeddings-v2-base-de"
EmbeddingModelJina2BaseCode EmbeddingModelJina = "jina-embeddings-v2-base-code"
EmbeddingModelJina2BaseZH EmbeddingModelJina = "jina-embeddings-v2-base-zh"
)

// NewEmbeddingFuncJina returns a function that creates embeddings for a document
// using the Jina API.
func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model))
}

const baseURLMixedbread = "https://api.mixedbread.ai"

type EmbeddingModelMixedbread string

const (
EmbeddingModelMixedbreadUAELargeV1 EmbeddingModelMixedbread = "UAE-Large-V1"
EmbeddingModelMixedbreadBGELargeENV15 EmbeddingModelMixedbread = "bge-large-en-v1.5"
EmbeddingModelMixedbreadGTELarge EmbeddingModelMixedbread = "gte-large"
EmbeddingModelMixedbreadE5LargeV2 EmbeddingModelMixedbread = "e5-large-v2"
EmbeddingModelMixedbreadMultilingualE5Large EmbeddingModelMixedbread = "multilingual-e5-large"
EmbeddingModelMixedbreadMultilingualE5Base EmbeddingModelMixedbread = "multilingual-e5-base"
EmbeddingModelMixedbreadAllMiniLML6V2 EmbeddingModelMixedbread = "all-MiniLM-L6-v2"
EmbeddingModelMixedbreadGTELargeZh EmbeddingModelMixedbread = "gte-large-zh"
)

// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a document
// using the mixedbread.ai API.
func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model))
}

const baseURLLocalAI = "http://localhost:8080/v1"

// NewEmbeddingFuncLocalAI returns a function that creates embeddings for a document
// using the LocalAI API.
// You can start a LocalAI instance like this:
//
// docker run -it -p 127.0.0.1:8080:8080 localai/localai:v2.7.0-ffmpeg-core bert-cpp
//
// And then call this constructor with model "bert-cpp-minilm-v6".
// But other embedding models are supported as well. See the LocalAI documentation
// for details.
func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model)
}
18 changes: 0 additions & 18 deletions embed_jina.go

This file was deleted.

16 changes: 0 additions & 16 deletions embed_localai.go

This file was deleted.

15 changes: 0 additions & 15 deletions embed_mistral.go

This file was deleted.

22 changes: 0 additions & 22 deletions embed_mixedbread.go

This file was deleted.

57 changes: 0 additions & 57 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@ package chromem

import (
"context"
"errors"
"fmt"
"runtime"
"slices"
"sort"
"strings"
"sync"
)
Expand All @@ -26,59 +22,6 @@ type Result struct {
Similarity float32
}

// Performs a nearest neighbors query on a collection specified by UUID.
//
// - queryText: The text to search for.
// - nResults: The number of results to return. Must be > 0.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
c.documentsLock.RLock()
defer c.documentsLock.RUnlock()
if len(c.documents) == 0 {
return nil, nil
}

if nResults <= 0 {
return nil, errors.New("nResults must be > 0")
}

// Validate whereDocument operators
for k := range whereDocument {
if !slices.Contains(supportedFilters, k) {
return nil, errors.New("unsupported operator")
}
}

// Filter docs by metadata and content
filteredDocs := filterDocs(c.documents, where, whereDocument)

// No need to continue if the filters got rid of all documents
if len(filteredDocs) == 0 {
return nil, nil
}

queryVectors, err := c.embed(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
}

// For the remaining documents, calculate cosine similarity.
res, err := calcDocSimilarity(ctx, queryVectors, filteredDocs)
if err != nil {
return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err)
}

// Sort by similarity
sort.Slice(res, func(i, j int) bool {
// The `less` function would usually use `<`, but we want to sort descending.
return res[i].Similarity > res[j].Similarity
})

// Return the top nResults
return res[:nResults], nil
}

// filterDocs filters a map of documents by metadata and content.
// It does this concurrently.
func filterDocs(docs map[string]*document, where, whereDocument map[string]string) []*document {
Expand Down

0 comments on commit cf08535

Please sign in to comment.