From 78df19be90a26e8b814ffea803b157836921e6f4 Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Tue, 2 Jul 2024 16:18:28 -0700 Subject: [PATCH] Review comments --- chromadb/test/property/test_embeddings.py | 169 +++++++++++------- chromadb/test/property/test_persist.py | 3 +- rust/worker/src/blockstore/arrow/blockfile.rs | 146 ++++++++++----- .../src/blockstore/arrow/sparse_index.rs | 63 +++---- .../src/execution/operators/hnsw_knn.rs | 2 + rust/worker/src/lib.rs | 1 + rust/worker/src/segment/metadata_segment.rs | 2 + 7 files changed, 233 insertions(+), 153 deletions(-) diff --git a/chromadb/test/property/test_embeddings.py b/chromadb/test/property/test_embeddings.py index 58639159774..b68d7d9dbac 100644 --- a/chromadb/test/property/test_embeddings.py +++ b/chromadb/test/property/test_embeddings.py @@ -70,7 +70,7 @@ class EmbeddingStateMachineStates: collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") -class EmbeddingStateMachine(RuleBasedStateMachine): +class EmbeddingStateMachineBase(RuleBasedStateMachine): collection: Collection embedding_ids: Bundle[ID] = Bundle("embedding_ids") @@ -94,37 +94,6 @@ def initialize(self, collection: strategies.Collection): self.record_set_state = strategies.StateMachineRecordSet( ids=[], metadatas=[], documents=[], embeddings=[] ) - if self.__class__.__name__ == "EmbeddingStateMachine": - print("[test_embeddings] Reset") - self.log_operation_count = 0 - self.collection_version = self.collection.get_model()["version"] - - @precondition( - lambda self: not NOT_CLUSTER_ONLY - and self.log_operation_count > 10 - and self.__class__.__name__ == "EmbeddingStateMachine" - ) - @rule() - def wait_for_compaction(self) -> None: - current_version = get_collection_version(self.api, self.collection.name) - assert current_version >= self.collection_version - # This means that there was a compaction from the last time this was - # invoked. Ok to start all over again. - if current_version > self.collection_version: - print( - "[test_embeddings][wait_for_compaction] collection version has changed, so reset to 0" - ) - self.collection_version = current_version - # This is fine even if the log has some records right now - self.log_operation_count = 0 - else: - print("[test_embeddings][wait_for_compaction] wait for version to increase") - new_version = wait_for_version_increase( - self.api, self.collection.name, current_version, additional_time=240 - ) - # Everything got compacted. - self.log_operation_count = 0 - self.collection_version = new_version @rule( target=embedding_ids, @@ -160,29 +129,12 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID if normalized_record_set["embeddings"] else None, } - # TODO(Sanket): Why is this the full list and not only the non-overlapping ones self.collection.add(**normalized_record_set) - print( - "[test_embeddings][add] Intersection ids ", - normalized_record_set["ids"], - " len ", - len(normalized_record_set["ids"]), - ) - if self.__class__.__name__ == "EmbeddingStateMachine": - self.log_operation_count += len(normalized_record_set["ids"]) self._upsert_embeddings(cast(strategies.RecordSet, filtered_record_set)) return multiple(*filtered_record_set["ids"]) else: self.collection.add(**normalized_record_set) - print( - "[test_embeddings][add] Non Intersection ids ", - normalized_record_set["ids"], - " len ", - len(normalized_record_set["ids"]), - ) - if self.__class__.__name__ == "EmbeddingStateMachine": - self.log_operation_count += len(normalized_record_set["ids"]) self._upsert_embeddings(cast(strategies.RecordSet, normalized_record_set)) return multiple(*normalized_record_set["ids"]) @@ -193,9 +145,6 @@ def delete_by_ids(self, ids: IDs) -> None: indices_to_remove = [self.record_set_state["ids"].index(id) for id in ids] self.collection.delete(ids=ids) - print("[test_embeddings][delete] ids ", ids, " len ", len(ids)) - if self.__class__.__name__ == "EmbeddingStateMachine": - self.log_operation_count += len(ids) self._remove_embeddings(set(indices_to_remove)) # Removing the precondition causes the tests to frequently fail as "unsatisfiable" @@ -214,14 +163,6 @@ def update_embeddings(self, record_set: strategies.RecordSet) -> None: self.on_state_change(EmbeddingStateMachineStates.update_embeddings) self.collection.update(**record_set) - print( - "[test_embeddings][update] ids ", - record_set["ids"], - " len ", - len(invariants.wrap(record_set["ids"])), - ) - if self.__class__.__name__ == "EmbeddingStateMachine": - self.log_operation_count += len(invariants.wrap(record_set["ids"])) self._upsert_embeddings(record_set) # Using a value < 3 causes more retries and lowers the number of valid samples @@ -239,14 +180,6 @@ def upsert_embeddings(self, record_set: strategies.RecordSet) -> None: self.on_state_change(EmbeddingStateMachineStates.upsert_embeddings) self.collection.upsert(**record_set) - print( - "[test_embeddings][upsert] ids ", - record_set["ids"], - " len ", - len(invariants.wrap(record_set["ids"])), - ) - if self.__class__.__name__ == "EmbeddingStateMachine": - self.log_operation_count += len(invariants.wrap(record_set["ids"])) self._upsert_embeddings(record_set) @invariant() @@ -362,6 +295,106 @@ def on_state_change(self, new_state: str) -> None: pass +class EmbeddingStateMachine(EmbeddingStateMachineBase): + def __init__(self, api: ServerAPI): + super().__init__(api) + + @initialize(collection=collection_st) # type: ignore + def initialize(self, collection: strategies.Collection): + super().initialize(collection) + print("[test_embeddings] Reset") + self.log_operation_count = 0 + self.collection_version = self.collection.get_model()["version"] + + @precondition(lambda self: not NOT_CLUSTER_ONLY and self.log_operation_count > 10) + @rule() + def wait_for_compaction(self) -> None: + current_version = get_collection_version(self.api, self.collection.name) + assert current_version >= self.collection_version + # This means that there was a compaction from the last time this was + # invoked. Ok to start all over again. + if current_version > self.collection_version: + print( + "[test_embeddings][wait_for_compaction] collection version has changed, so reset to 0" + ) + self.collection_version = current_version + # This is fine even if the log has some records right now + self.log_operation_count = 0 + else: + print("[test_embeddings][wait_for_compaction] wait for version to increase") + new_version = wait_for_version_increase( + self.api, self.collection.name, current_version, additional_time=240 + ) + # Everything got compacted. + self.log_operation_count = 0 + self.collection_version = new_version + + @rule( + target=embedding_ids, + record_set=strategies.recordsets(collection_st), + ) + def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID]: + res = super().add_embeddings(record_set) + normalized_record_set: strategies.NormalizedRecordSet = invariants.wrap_all( + record_set + ) + print( + "[test_embeddings][add] Non Intersection ids ", + normalized_record_set["ids"], + " len ", + len(normalized_record_set["ids"]), + ) + self.log_operation_count += len(normalized_record_set["ids"]) + return res + + @rule(ids=st.lists(consumes(embedding_ids), min_size=1)) + def delete_by_ids(self, ids: IDs) -> None: + super().delete_by_ids(ids) + print("[test_embeddings][delete] ids ", ids, " len ", len(ids)) + self.log_operation_count += len(ids) + + # Removing the precondition causes the tests to frequently fail as "unsatisfiable" + # Using a value < 5 causes retries and lowers the number of valid samples + @precondition(lambda self: len(self.record_set_state["ids"]) >= 5) + @rule( + record_set=strategies.recordsets( + collection_strategy=collection_st, + id_strategy=embedding_ids, + min_size=1, + max_size=5, + ), + ) + def update_embeddings(self, record_set: strategies.RecordSet) -> None: + super().update_embeddings(record_set) + print( + "[test_embeddings][update] ids ", + record_set["ids"], + " len ", + len(invariants.wrap(record_set["ids"])), + ) + self.log_operation_count += len(invariants.wrap(record_set["ids"])) + + # Using a value < 3 causes more retries and lowers the number of valid samples + @precondition(lambda self: len(self.record_set_state["ids"]) >= 3) + @rule( + record_set=strategies.recordsets( + collection_strategy=collection_st, + id_strategy=st.one_of(embedding_ids, strategies.safe_text), + min_size=1, + max_size=5, + ) + ) + def upsert_embeddings(self, record_set: strategies.RecordSet) -> None: + super().upsert_embeddings(record_set) + print( + "[test_embeddings][upsert] ids ", + record_set["ids"], + " len ", + len(invariants.wrap(record_set["ids"])), + ) + self.log_operation_count += len(invariants.wrap(record_set["ids"])) + + def test_embeddings_state(caplog: pytest.LogCaptureFixture, api: ServerAPI) -> None: caplog.set_level(logging.ERROR) run_state_machine_as_test( diff --git a/chromadb/test/property/test_persist.py b/chromadb/test/property/test_persist.py index c2f1b275abe..0b7b0d7716e 100644 --- a/chromadb/test/property/test_persist.py +++ b/chromadb/test/property/test_persist.py @@ -15,6 +15,7 @@ EmbeddingStateMachineStates, collection_st as embedding_collection_st, trace, + EmbeddingStateMachineBase, ) from hypothesis.stateful import ( run_state_machine_as_test, @@ -168,7 +169,7 @@ class PersistEmbeddingsStateMachineStates(EmbeddingStateMachineStates): MIN_STATE_CHANGES_BEFORE_PERSIST = 5 -class PersistEmbeddingsStateMachine(EmbeddingStateMachine): +class PersistEmbeddingsStateMachine(EmbeddingStateMachineBase): def __init__(self, api: ClientAPI, settings: Settings): self.api = api self.settings = settings diff --git a/rust/worker/src/blockstore/arrow/blockfile.rs b/rust/worker/src/blockstore/arrow/blockfile.rs index df31741dee7..b7e0b7b258c 100644 --- a/rust/worker/src/blockstore/arrow/blockfile.rs +++ b/rust/worker/src/blockstore/arrow/blockfile.rs @@ -97,41 +97,18 @@ impl ArrowBlockfileWriter { ) -> Result> { let mut delta_ids = HashSet::new(); for delta in self.block_deltas.lock().values() { + let mut removed = false; // Skip empty blocks. Also, remove from sparse index. if delta.len() == 0 { tracing::info!("Delta with id {:?} is empty", delta.id); - self.sparse_index.remove_block(&delta.id); - continue; + removed = self.sparse_index.remove_block(&delta.id); } - // TODO: might these error? - self.block_manager.commit::(delta); - delta_ids.insert(delta.id); - } - // We commit and flush an empty dummy block if the blockfile is empty. - // It can happen that other indexes of the segment are not empty. In this case, - // our segment open() logic breaks down since we only handle either - // all indexes initialized or none at all but not other combinations. - // We could argue that we should fix the readers to handle these cases - // but this is simpler, easier and less error prone to do. - if self.sparse_index.len() == 0 { - if !delta_ids.is_empty() { - panic!("Invariant violation. Expected delta ids to be empty"); + if !removed { + // TODO: might these error? + self.block_manager.commit::(delta); + delta_ids.insert(delta.id); } - // dummy block. - tracing::info!("Adding dummy block since index is empty"); - let initial_block = self.block_manager.create::(); - self.sparse_index.add_initial_block(initial_block.id); - self.block_manager.commit::(&initial_block); - delta_ids.insert(initial_block.id); } - // It can happen that the sparse index does not contain - // the start key after this sequence of operations, - // for e.g. consider the following: - // sparse_index: {start_key: block_id1, some_key: block_id2, some_other_key: block_id3} - // If we delete block_id1 from the sparse index then it becomes - // {some_key: block_id2, some_other_key: block_id3} - // This should be changed to {start_key: block_id2, some_other_key: block_id3} - self.sparse_index.correct_start_key(); // Should be non-empty. self.sparse_index_manager.commit(self.sparse_index.clone()); @@ -182,18 +159,8 @@ impl ArrowBlockfileWriter { let new_delta = self.block_manager.fork::(&block.id); let new_id = new_delta.id; // Blocks can be empty. - if new_delta.len() == 0 { - self.sparse_index - .replace_lone_block(target_block_id, new_delta.id); - } else { - self.sparse_index.replace_block( - target_block_id, - new_delta.id, - new_delta - .get_min_key() - .expect("Block should never be empty when forked"), - ); - } + self.sparse_index + .replace_block(target_block_id, new_delta.id); { let mut deltas = self.block_deltas.lock(); deltas.insert(new_id, new_delta.clone()); @@ -243,13 +210,8 @@ impl ArrowBlockfileWriter { let block = self.block_manager.get(&target_block_id).await.unwrap(); let new_delta = self.block_manager.fork::(&block.id); let new_id = new_delta.id; - self.sparse_index.replace_block( - target_block_id, - new_delta.id, - new_delta - .get_min_key() - .expect("Block should never be empty when forked"), - ); + self.sparse_index + .replace_block(target_block_id, new_delta.id); { let mut deltas = self.block_deltas.lock(); deltas.insert(new_id, new_delta.clone()); @@ -576,7 +538,11 @@ impl<'me, K: ArrowReadableKey<'me> + Into, V: ArrowReadableValue<'me #[cfg(test)] mod tests { use crate::{ - blockstore::arrow::{config::TEST_MAX_BLOCK_SIZE_BYTES, provider::ArrowBlockfileProvider}, + blockstore::{ + arrow::{config::TEST_MAX_BLOCK_SIZE_BYTES, provider::ArrowBlockfileProvider}, + BlockfileError, + }, + log::config::{self, GrpcLogConfig}, segment::DataRecord, storage::{local::LocalStorage, Storage}, types::MetadataValue, @@ -1222,4 +1188,86 @@ mod tests { assert_eq!(res.2, expected_value); } } + + #[tokio::test] + async fn test_first_block_removal() { + let tmp_dir = tempfile::tempdir().unwrap(); + let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let blockfile_provider = ArrowBlockfileProvider::new(storage); + let writer = blockfile_provider.create::<&str, &Int32Array>().unwrap(); + let id_1 = writer.id(); + + // Add the larger keys first then smaller. + let n = 1200; + for i in n..n * 2 { + let key = format!("{:04}", i); + let value = Int32Array::from(vec![i]); + writer.set("key", key.as_str(), &value).await.unwrap(); + } + for i in 0..n { + let key = format!("{:04}", i); + let value = Int32Array::from(vec![i]); + writer.set("key", key.as_str(), &value).await.unwrap(); + } + writer.commit::<&str, &Int32Array>().unwrap(); + // Create another writer. + let writer = blockfile_provider + .fork::<&str, &Int32Array>(&id_1) + .await + .expect("BlockfileWriter fork unsuccessful"); + // Delete everything but the last 10 keys. + let delete_end = n * 2 - 10; + for i in 0..delete_end { + let key = format!("{:04}", i); + writer + .delete::<&str, &Int32Array>("key", key.as_str()) + .await + .expect("Delete failed"); + } + let flusher = writer.commit::<&str, &Int32Array>().unwrap(); + let id_2 = flusher.id(); + + let reader = blockfile_provider + .open::<&str, Int32Array>(&id_2) + .await + .unwrap(); + + for i in 0..delete_end { + let key = format!("{:04}", i); + assert_eq!(reader.contains("key", &key).await, false); + } + + for i in delete_end..n * 2 { + let key = format!("{:04}", i); + let value = reader.get("key", &key).await.unwrap(); + assert_eq!(value.values(), &[i]); + } + + let writer = blockfile_provider + .fork::<&str, &Int32Array>(&id_1) + .await + .expect("BlockfileWriter fork unsuccessful"); + // Add everything back. + for i in 0..delete_end { + let key = format!("{:04}", i); + let value = Int32Array::from(vec![i]); + writer + .set::<&str, &Int32Array>("key", key.as_str(), &value) + .await + .expect("Delete failed"); + } + let flusher = writer.commit::<&str, &Int32Array>().unwrap(); + let id_3 = flusher.id(); + + let reader = blockfile_provider + .open::<&str, Int32Array>(&id_3) + .await + .unwrap(); + + for i in 0..n * 2 { + let key = format!("{:04}", i); + let value = reader.get("key", &key).await.unwrap(); + assert_eq!(value.values(), &[i]); + } + } } diff --git a/rust/worker/src/blockstore/arrow/sparse_index.rs b/rust/worker/src/blockstore/arrow/sparse_index.rs index c94fa5c8b6b..fe28514dbbc 100644 --- a/rust/worker/src/blockstore/arrow/sparse_index.rs +++ b/rust/worker/src/blockstore/arrow/sparse_index.rs @@ -328,43 +328,17 @@ impl SparseIndex { .insert(block_id, SparseIndexDelimiter::Key(start_key)); } - pub(super) fn replace_block( - &self, - old_block_id: Uuid, - new_block_id: Uuid, - new_start_key: CompositeKey, - ) { - let mut forward = self.forward.lock(); - let mut reverse = self.reverse.lock(); - if let Some(old_start_key) = reverse.remove(&old_block_id) { - forward.remove(&old_start_key); - if old_start_key == SparseIndexDelimiter::Start { - forward.insert(SparseIndexDelimiter::Start, new_block_id); - reverse.insert(new_block_id, SparseIndexDelimiter::Start); - } else { - forward.insert( - SparseIndexDelimiter::Key(new_start_key.clone()), - new_block_id, - ); - reverse.insert(new_block_id, SparseIndexDelimiter::Key(new_start_key)); - } - } - } - - pub(super) fn replace_lone_block(&self, old_block_id: Uuid, new_block_id: Uuid) { + pub(super) fn replace_block(&self, old_block_id: Uuid, new_block_id: Uuid) { let mut forward = self.forward.lock(); let mut reverse = self.reverse.lock(); if let Some(old_start_key) = reverse.remove(&old_block_id) { forward.remove(&old_start_key); - if old_start_key != SparseIndexDelimiter::Start { - panic!("Invariant violation. The lone block should have SparseIndexDelimiter::Start as start key"); - } - forward.insert(SparseIndexDelimiter::Start, new_block_id); - reverse.insert(new_block_id, SparseIndexDelimiter::Start); + forward.insert(old_start_key.clone(), new_block_id); + reverse.insert(new_block_id, old_start_key); } } - pub(super) fn correct_start_key(&self) { + fn correct_start_key(&self) { if self.len() == 0 { return; } @@ -388,12 +362,31 @@ impl SparseIndex { } } - pub(super) fn remove_block(&self, block_id: &Uuid) { - let mut forward = self.forward.lock(); - let mut reverse = self.reverse.lock(); - if let Some(start_key) = reverse.remove(block_id) { - forward.remove(&start_key); + pub(super) fn remove_block(&self, block_id: &Uuid) -> bool { + // We commit and flush an empty dummy block if the blockfile is empty. + // It can happen that other indexes of the segment are not empty. In this case, + // our segment open() logic breaks down since we only handle either + // all indexes initialized or none at all but not other combinations. + // We could argue that we should fix the readers to handle these cases + // but this is simpler, easier and less error prone to do. + let mut removed = false; + if self.len() > 1 { + let mut forward = self.forward.lock(); + let mut reverse = self.reverse.lock(); + if let Some(start_key) = reverse.remove(block_id) { + forward.remove(&start_key); + } + removed = true; } + // It can happen that the sparse index does not contain + // the start key after this sequence of operations, + // for e.g. consider the following: + // sparse_index: {start_key: block_id1, some_key: block_id2, some_other_key: block_id3} + // If we delete block_id1 from the sparse index then it becomes + // {some_key: block_id2, some_other_key: block_id3} + // This should be changed to {start_key: block_id2, some_other_key: block_id3} + self.correct_start_key(); + removed } pub(super) fn len(&self) -> usize { diff --git a/rust/worker/src/execution/operators/hnsw_knn.rs b/rust/worker/src/execution/operators/hnsw_knn.rs index e78eff3229c..c937bdf6314 100644 --- a/rust/worker/src/execution/operators/hnsw_knn.rs +++ b/rust/worker/src/execution/operators/hnsw_knn.rs @@ -68,6 +68,8 @@ impl HnswKnnOperator { let mut disallowed_ids = Vec::new(); for item in logs.iter() { let log = item.0; + // This means that even if an embedding is not updated on the log, + // we brute force it. Can use the HNSW index also. if log.final_operation == Operation::Delete || log.final_operation == Operation::Update { let offset_id = record_segment_reader diff --git a/rust/worker/src/lib.rs b/rust/worker/src/lib.rs index 074dee97aa7..76f22264e9b 100644 --- a/rust/worker/src/lib.rs +++ b/rust/worker/src/lib.rs @@ -22,6 +22,7 @@ use memberlist::MemberlistProvider; use tokio::select; use tokio::signal::unix::{signal, SignalKind}; +use ::tracing::Span; const CONFIG_PATH_ENV_VAR: &str = "CONFIG_PATH"; diff --git a/rust/worker/src/segment/metadata_segment.rs b/rust/worker/src/segment/metadata_segment.rs index dbde970cbb1..1d1fa6c4639 100644 --- a/rust/worker/src/segment/metadata_segment.rs +++ b/rust/worker/src/segment/metadata_segment.rs @@ -600,6 +600,8 @@ impl<'log_records> SegmentWriter<'log_records> for MetadataSegmentWriter<'_> { return Err(ApplyMaterializedLogError::FTSDocumentDeleteError); } }, + // The record that is to be deleted might not have + // a document, it is fine and should not be an error. None => {} }; }