Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BatchRetryConfig allowing for mid-batch retry of failed objects #1010

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion integration/test_batch_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import weaviate
from integration.conftest import _sanitize_collection_name
from weaviate.collections.classes.batch import Shard
from weaviate.collections.classes.batch import Shard, BatchRetryConfig
from weaviate.collections.classes.config import (
Configure,
DataType,
Expand Down Expand Up @@ -535,3 +535,19 @@ def test_error_reset(client_factory: ClientFactory) -> None:
assert len(errs) == 1
assert errs[0].object_.properties is not None
assert errs[0].object_.properties["name"] == 1


def test_error_retrying(client_factory: ClientFactory) -> None:
client, name = client_factory()
config = BatchRetryConfig(
retry_on_error_message_contains=[
"invalid text property 'name' on class 'Test_error_retrying'"
]
)
with client.batch.dynamic(retry_config=config) as batch:
batch.add_object(properties={"name": 1}, collection=name)
batch.add_object(properties={"name": "correct"}, collection=name)

errs = client.batch.failed_objects
assert len(errs) == 1
assert errs[0].object_.retry_count == config.max_retries
3 changes: 2 additions & 1 deletion weaviate/classes/batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from weaviate.collections.classes.batch import Shard
from weaviate.collections.classes.batch import BatchRetryConfig, Shard

__all__ = [
"BatchRetryConfig",
"Shard",
]
67 changes: 48 additions & 19 deletions weaviate/collections/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BatchObjectReturn,
BatchReferenceReturn,
Shard,
BatchRetryConfig,
)
from weaviate.collections.classes.config import ConsistencyLevel
from weaviate.collections.classes.internal import (
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
batch_mode: _BatchMode,
objects_: Optional[ObjectsBatchRequest] = None,
references: Optional[ReferencesBatchRequest] = None,
retry_config: Optional[BatchRetryConfig] = None,
) -> None:
self.__batch_objects = objects_ or ObjectsBatchRequest()
self.__batch_references = references or ReferencesBatchRequest()
Expand All @@ -184,6 +186,8 @@ def __init__(
self.__batching_mode: _BatchMode = batch_mode
self.__max_batch_size: int = 1000

self.__retry_config = retry_config

if isinstance(self.__batching_mode, _FixedSizeBatching):
self.__recommended_num_objects = self.__batching_mode.batch_size
self.__concurrent_requests = self.__batching_mode.concurrent_requests
Expand Down Expand Up @@ -446,7 +450,7 @@ async def __send_batch_async(

readded_uuids = set()
if readd_rate_limit:
readded_objects = []
readded_objects: List[int] = []
highest_retry_count = 0
for i, err in response_obj.errors.items():
if ("support@cohere.com" in err.message and "rate limit" in err.message) or (
Expand Down Expand Up @@ -482,30 +486,41 @@ async def __send_batch_async(

self.__batch_objects.prepend(readd_objects)

new_errors = {
i: err for i, err in response_obj.errors.items() if i not in readded_objects
}
response_obj = BatchObjectReturn(
uuids={
i: uid
for i, uid in response_obj.uuids.items()
if i not in readded_objects
},
errors=new_errors,
has_errors=len(new_errors) > 0,
all_responses=[
err
for i, err in enumerate(response_obj.all_responses)
if i not in readded_objects
],
elapsed_seconds=response_obj.elapsed_seconds,
)
response_obj = self.__alter_errors_after_retry(response_obj, readded_objects)
self.__time_stamp_last_request = (
time.time() + self.__fix_rate_batching_base_time * (highest_retry_count + 1)
) # skip a full minute to recover from the rate limit
self.__fix_rate_batching_base_time += (
1 # increase the base time as the current one is too low
)

if self.__retry_config is not None:
readded_objects = []
for i, err in response_obj.errors.items():
if any(
msg in err.message
for msg in self.__retry_config.retry_on_error_message_contains
):
if err.object_.retry_count > self.__retry_config.max_retries:
highest_retry_count = err.object_.retry_count

if err.object_.retry_count >= self.__retry_config.max_retries:
continue # too many retries, give up
err.object_.retry_count += 1
readded_objects.append(i)

readd_objects = [
err.object_ for i, err in response_obj.errors.items() if i in readded_objects
]
readded_uuids = readded_uuids.union({obj.uuid for obj in readd_objects})

await asyncio.sleep(
self.__retry_config.retry_wait_time
) # wait before adding the objects again
self.__batch_objects.prepend(readd_objects)

response_obj = self.__alter_errors_after_retry(response_obj, readded_objects)

self.__uuid_lookup_lock.acquire()
self.__uuid_lookup.difference_update(
obj.uuid for obj in objs if obj.uuid not in readded_uuids
Expand Down Expand Up @@ -541,6 +556,20 @@ async def __send_batch_async(
self.__active_requests -= 1
self.__active_requests_lock.release()

def __alter_errors_after_retry(
self, response_obj: BatchObjectReturn, readded_objects: List[int]
) -> BatchObjectReturn:
new_errors = {i: err for i, err in response_obj.errors.items() if i not in readded_objects}
return BatchObjectReturn(
uuids={i: uid for i, uid in response_obj.uuids.items() if i not in readded_objects},
errors=new_errors,
has_errors=len(new_errors) > 0,
all_responses=[
err for i, err in enumerate(response_obj.all_responses) if i not in readded_objects
],
elapsed_seconds=response_obj.elapsed_seconds,
)

def flush(self) -> None:
"""Flush the batch queue and wait for all requests to be finished."""
# bg thread is sending objs+refs automatically, so simply wait for everything to be done
Expand Down
10 changes: 8 additions & 2 deletions weaviate/collections/batch/batch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
_DynamicBatching,
_BatchMode,
)
from weaviate.collections.classes.batch import BatchResult, ErrorObject, ErrorReference, Shard
from weaviate.collections.classes.batch import (
BatchRetryConfig,
BatchResult,
ErrorObject,
ErrorReference,
Shard,
)
from weaviate.collections.classes.config import ConsistencyLevel
from weaviate.connect import ConnectionV4
from weaviate.util import _capitalize_first_letter, _decode_json_response_list
Expand All @@ -22,7 +28,7 @@ def __init__(
self._current_batch: Optional[_BatchBase] = None
# config options
self._batch_mode: _BatchMode = _DynamicBatching()

self._batch_retry_config: Optional[BatchRetryConfig] = None
self._batch_data = _BatchDataWrapper()

def wait_for_vector_indexing(
Expand Down
21 changes: 19 additions & 2 deletions weaviate/collections/batch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_BatchMode,
_ContextManagerWrapper,
)
from weaviate.collections.classes.batch import BatchRetryConfig
from weaviate.collections.classes.config import ConsistencyLevel
from weaviate.collections.classes.internal import ReferenceInput, ReferenceInputs
from weaviate.collections.classes.tenants import Tenant
Expand Down Expand Up @@ -117,11 +118,14 @@ def __create_batch_and_reset(self) -> _ContextManagerWrapper[_BatchClient]:
consistency_level=self._consistency_level,
results=self._batch_data,
batch_mode=self._batch_mode,
retry_config=self._batch_retry_config,
)
)

def dynamic(
self, consistency_level: Optional[ConsistencyLevel] = None
self,
consistency_level: Optional[ConsistencyLevel] = None,
retry_config: Optional[BatchRetryConfig] = None,
) -> _ContextManagerWrapper[_BatchClient]:
"""Configure dynamic batching.

Expand All @@ -130,8 +134,11 @@ def dynamic(
Arguments:
`consistency_level`
The consistency level to be used to send batches. If not provided, the default value is `None`.
`retry_config`
Configuration for retrying failed objects during the batching algorithm. If not provided, the default value is `None`.
"""
self._batch_mode: _BatchMode = _DynamicBatching()
self._batch_retry_config = retry_config
self._consistency_level = consistency_level
return self.__create_batch_and_reset()

Expand All @@ -140,6 +147,7 @@ def fixed_size(
batch_size: int = 100,
concurrent_requests: int = 2,
consistency_level: Optional[ConsistencyLevel] = None,
retry_config: Optional[BatchRetryConfig] = None,
) -> _ContextManagerWrapper[_BatchClient]:
"""Configure fixed size batches. Note that the default is dynamic batching.

Expand All @@ -153,14 +161,20 @@ def fixed_size(
made to Weaviate and not the speed of batch creation within Python.
`consistency_level`
The consistency level to be used to send batches. If not provided, the default value is `None`.
`retry_config`
Configuration for retrying failed objects during the batching algorithm. If not provided, the default value is `None`.

"""
self._batch_mode = _FixedSizeBatching(batch_size, concurrent_requests)
self._batch_retry_config = retry_config
self._consistency_level = consistency_level
return self.__create_batch_and_reset()

def rate_limit(
self, requests_per_minute: int, consistency_level: Optional[ConsistencyLevel] = None
self,
requests_per_minute: int,
consistency_level: Optional[ConsistencyLevel] = None,
retry_config: Optional[BatchRetryConfig] = None,
) -> _ContextManagerWrapper[_BatchClient]:
"""Configure batches with a rate limited vectorizer.

Expand All @@ -171,7 +185,10 @@ def rate_limit(
The number of requests that the vectorizer can process per minute.
`consistency_level`
The consistency level to be used to send batches. If not provided, the default value is `None`.
`retry_config`
Configuration for retrying failed objects during the batching algorithm. If not provided, the default value is `None`.
"""
self._batch_mode = _RateLimitedBatching(requests_per_minute)
self._batch_retry_config = retry_config
self._consistency_level = consistency_level
return self.__create_batch_and_reset()
8 changes: 8 additions & 0 deletions weaviate/collections/classes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,11 @@ class DeleteManyReturn(Generic[T]):
matches: int
objects: T
successful: int


class BatchRetryConfig(BaseModel):
"""Configuration for retrying failed batch operations."""

max_retries: int = 3
retry_on_error_message_contains: List[str] = Field(default_factory=list)
retry_wait_time: int = 0
Loading