From 6133c0349c744745a825ffd22081ed86d1abb035 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 17 Feb 2024 20:31:03 +0100 Subject: [PATCH 1/2] Add DB.DeleteCollection() --- db.go | 8 ++++++++ db_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/db.go b/db.go index 1848193..6e5d532 100644 --- a/db.go +++ b/db.go @@ -100,3 +100,11 @@ func (c *DB) GetCollection(name string) *Collection { embed: orig.embed, } } + +// DeleteCollection deletes the collection with the given name. +// If the collection doesn't exist, this is a no-op. +func (c *DB) DeleteCollection(name string) { + c.collectionsLock.Lock() + defer c.collectionsLock.Unlock() + delete(c.collections, name) +} diff --git a/db_test.go b/db_test.go index 8a6c420..c989311 100644 --- a/db_test.go +++ b/db_test.go @@ -78,3 +78,27 @@ func TestDB_GetCollection(t *testing.T) { // TODO: Same for the EmbeddingFunc // TODO: Check documents map being a copy as soon as we have access to it } + +func TestDB_DeleteCollection(t *testing.T) { + // Values in the collection + name := "test" + metadata := map[string]string{"foo": "bar"} + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{-0.1, 0.1, 0.2}, nil + } + + // Create initial collection + db := chromem.NewDB() + // We ignore the return value. CreateCollection is tested elsewhere. + _ = db.CreateCollection(name, metadata, embeddingFunc) + + // Delete collection + db.DeleteCollection(name) + + // Check expectations + // We don't have access to the documents field, but we can rely on DB.ListCollections() + // because it's tested elsewhere. + if len(db.ListCollections()) != 0 { + t.Error("expected 0 collections, got", len(db.ListCollections())) + } +} From 85b5a30709890f68d0acad6dc2d2d6c08010d981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 17 Feb 2024 20:31:20 +0100 Subject: [PATCH 2/2] Improve Godoc for GetCollection() --- db.go | 1 + 1 file changed, 1 insertion(+) diff --git a/db.go b/db.go index 6e5d532..ab92228 100644 --- a/db.go +++ b/db.go @@ -70,6 +70,7 @@ func (c *DB) ListCollections() map[string]*Collection { // Regarding the EmbeddingFunc it's the original. So if it closes over some state, // this state is shared. But usually an EmbeddingFunc just closes over an API key // or HTTP client, which are safe to share. +// If the collection doesn't exist, this returns nil. func (c *DB) GetCollection(name string) *Collection { c.collectionsLock.RLock() defer c.collectionsLock.RUnlock()