Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Jul 2, 2024
1 parent c40f394 commit 3ecfdfa
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 153 deletions.
169 changes: 101 additions & 68 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class EmbeddingStateMachineStates:
collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")


class EmbeddingStateMachine(RuleBasedStateMachine):
class EmbeddingStateMachineBase(RuleBasedStateMachine):
collection: Collection
embedding_ids: Bundle[ID] = Bundle("embedding_ids")

Expand All @@ -94,37 +94,6 @@ def initialize(self, collection: strategies.Collection):
self.record_set_state = strategies.StateMachineRecordSet(
ids=[], metadatas=[], documents=[], embeddings=[]
)
if self.__class__.__name__ == "EmbeddingStateMachine":
print("[test_embeddings] Reset")
self.log_operation_count = 0
self.collection_version = self.collection.get_model()["version"]

@precondition(
lambda self: not NOT_CLUSTER_ONLY
and self.log_operation_count > 10
and self.__class__.__name__ == "EmbeddingStateMachine"
)
@rule()
def wait_for_compaction(self) -> None:
current_version = get_collection_version(self.api, self.collection.name)
assert current_version >= self.collection_version
# This means that there was a compaction from the last time this was
# invoked. Ok to start all over again.
if current_version > self.collection_version:
print(
"[test_embeddings][wait_for_compaction] collection version has changed, so reset to 0"
)
self.collection_version = current_version
# This is fine even if the log has some records right now
self.log_operation_count = 0
else:
print("[test_embeddings][wait_for_compaction] wait for version to increase")
new_version = wait_for_version_increase(
self.api, self.collection.name, current_version, additional_time=240
)
# Everything got compacted.
self.log_operation_count = 0
self.collection_version = new_version

@rule(
target=embedding_ids,
Expand Down Expand Up @@ -160,29 +129,12 @@ def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID
if normalized_record_set["embeddings"]
else None,
}
# TODO(Sanket): Why is this the full list and not only the non-overlapping ones
self.collection.add(**normalized_record_set)
print(
"[test_embeddings][add] Intersection ids ",
normalized_record_set["ids"],
" len ",
len(normalized_record_set["ids"]),
)
if self.__class__.__name__ == "EmbeddingStateMachine":
self.log_operation_count += len(normalized_record_set["ids"])
self._upsert_embeddings(cast(strategies.RecordSet, filtered_record_set))
return multiple(*filtered_record_set["ids"])

else:
self.collection.add(**normalized_record_set)
print(
"[test_embeddings][add] Non Intersection ids ",
normalized_record_set["ids"],
" len ",
len(normalized_record_set["ids"]),
)
if self.__class__.__name__ == "EmbeddingStateMachine":
self.log_operation_count += len(normalized_record_set["ids"])
self._upsert_embeddings(cast(strategies.RecordSet, normalized_record_set))
return multiple(*normalized_record_set["ids"])

Expand All @@ -193,9 +145,6 @@ def delete_by_ids(self, ids: IDs) -> None:
indices_to_remove = [self.record_set_state["ids"].index(id) for id in ids]

self.collection.delete(ids=ids)
print("[test_embeddings][delete] ids ", ids, " len ", len(ids))
if self.__class__.__name__ == "EmbeddingStateMachine":
self.log_operation_count += len(ids)
self._remove_embeddings(set(indices_to_remove))

# Removing the precondition causes the tests to frequently fail as "unsatisfiable"
Expand All @@ -214,14 +163,6 @@ def update_embeddings(self, record_set: strategies.RecordSet) -> None:
self.on_state_change(EmbeddingStateMachineStates.update_embeddings)

self.collection.update(**record_set)
print(
"[test_embeddings][update] ids ",
record_set["ids"],
" len ",
len(invariants.wrap(record_set["ids"])),
)
if self.__class__.__name__ == "EmbeddingStateMachine":
self.log_operation_count += len(invariants.wrap(record_set["ids"]))
self._upsert_embeddings(record_set)

# Using a value < 3 causes more retries and lowers the number of valid samples
Expand All @@ -239,14 +180,6 @@ def upsert_embeddings(self, record_set: strategies.RecordSet) -> None:
self.on_state_change(EmbeddingStateMachineStates.upsert_embeddings)

self.collection.upsert(**record_set)
print(
"[test_embeddings][upsert] ids ",
record_set["ids"],
" len ",
len(invariants.wrap(record_set["ids"])),
)
if self.__class__.__name__ == "EmbeddingStateMachine":
self.log_operation_count += len(invariants.wrap(record_set["ids"]))
self._upsert_embeddings(record_set)

@invariant()
Expand Down Expand Up @@ -362,6 +295,106 @@ def on_state_change(self, new_state: str) -> None:
pass


class EmbeddingStateMachine(EmbeddingStateMachineBase):
def __init__(self, api: ServerAPI):
super().__init__(api)

@initialize(collection=collection_st) # type: ignore
def initialize(self, collection: strategies.Collection):
super().initialize(collection)
print("[test_embeddings] Reset")
self.log_operation_count = 0
self.collection_version = self.collection.get_model()["version"]

@precondition(lambda self: not NOT_CLUSTER_ONLY and self.log_operation_count > 10)
@rule()
def wait_for_compaction(self) -> None:
current_version = get_collection_version(self.api, self.collection.name)
assert current_version >= self.collection_version
# This means that there was a compaction from the last time this was
# invoked. Ok to start all over again.
if current_version > self.collection_version:
print(
"[test_embeddings][wait_for_compaction] collection version has changed, so reset to 0"
)
self.collection_version = current_version
# This is fine even if the log has some records right now
self.log_operation_count = 0
else:
print("[test_embeddings][wait_for_compaction] wait for version to increase")
new_version = wait_for_version_increase(
self.api, self.collection.name, current_version, additional_time=240
)
# Everything got compacted.
self.log_operation_count = 0
self.collection_version = new_version

@rule(
target=embedding_ids,
record_set=strategies.recordsets(collection_st),
)
def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID]:
res = super().add_embeddings(record_set)
normalized_record_set: strategies.NormalizedRecordSet = invariants.wrap_all(
record_set
)
print(
"[test_embeddings][add] Non Intersection ids ",
normalized_record_set["ids"],
" len ",
len(normalized_record_set["ids"]),
)
self.log_operation_count += len(normalized_record_set["ids"])
return res

@rule(ids=st.lists(consumes(embedding_ids), min_size=1))
def delete_by_ids(self, ids: IDs) -> None:
super().delete_by_ids(ids)
print("[test_embeddings][delete] ids ", ids, " len ", len(ids))
self.log_operation_count += len(ids)

# Removing the precondition causes the tests to frequently fail as "unsatisfiable"
# Using a value < 5 causes retries and lowers the number of valid samples
@precondition(lambda self: len(self.record_set_state["ids"]) >= 5)
@rule(
record_set=strategies.recordsets(
collection_strategy=collection_st,
id_strategy=embedding_ids,
min_size=1,
max_size=5,
),
)
def update_embeddings(self, record_set: strategies.RecordSet) -> None:
super().update_embeddings(record_set)
print(
"[test_embeddings][update] ids ",
record_set["ids"],
" len ",
len(invariants.wrap(record_set["ids"])),
)
self.log_operation_count += len(invariants.wrap(record_set["ids"]))

# Using a value < 3 causes more retries and lowers the number of valid samples
@precondition(lambda self: len(self.record_set_state["ids"]) >= 3)
@rule(
record_set=strategies.recordsets(
collection_strategy=collection_st,
id_strategy=st.one_of(embedding_ids, strategies.safe_text),
min_size=1,
max_size=5,
)
)
def upsert_embeddings(self, record_set: strategies.RecordSet) -> None:
super().upsert_embeddings(record_set)
print(
"[test_embeddings][upsert] ids ",
record_set["ids"],
" len ",
len(invariants.wrap(record_set["ids"])),
)
self.log_operation_count += len(invariants.wrap(record_set["ids"]))


def test_embeddings_state(caplog: pytest.LogCaptureFixture, api: ServerAPI) -> None:
caplog.set_level(logging.ERROR)
run_state_machine_as_test(
Expand Down
3 changes: 2 additions & 1 deletion chromadb/test/property/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
EmbeddingStateMachineStates,
collection_st as embedding_collection_st,
trace,
EmbeddingStateMachineBase,
)
from hypothesis.stateful import (
run_state_machine_as_test,
Expand Down Expand Up @@ -168,7 +169,7 @@ class PersistEmbeddingsStateMachineStates(EmbeddingStateMachineStates):
MIN_STATE_CHANGES_BEFORE_PERSIST = 5


class PersistEmbeddingsStateMachine(EmbeddingStateMachine):
class PersistEmbeddingsStateMachine(EmbeddingStateMachineBase):
def __init__(self, api: ClientAPI, settings: Settings):
self.api = api
self.settings = settings
Expand Down
Loading

0 comments on commit 3ecfdfa

Please sign in to comment.