diff --git a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py index 296b4aac..2816ef76 100644 --- a/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py +++ b/google/cloud/pubsublite/cloudpubsub/internal/single_partition_subscriber.py @@ -13,18 +13,19 @@ # limitations under the License. import asyncio -from typing import Callable, Union, List, Dict, NamedTuple -import queue +from typing import Callable, List, Dict, NamedTuple -from google.api_core.exceptions import FailedPrecondition, GoogleAPICallError +from google.api_core.exceptions import GoogleAPICallError from google.cloud.pubsub_v1.subscriber.message import Message from google.pubsub_v1 import PubsubMessage -from google.cloud.pubsublite.internal.wait_ignore_cancelled import wait_ignore_cancelled from google.cloud.pubsublite.internal.wire.permanent_failable import adapt_error -from google.cloud.pubsublite.internal import fast_serialize from google.cloud.pubsublite.types import FlowControlSettings from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker +from google.cloud.pubsublite.cloudpubsub.internal.wrapped_message import ( + AckId, + WrappedMessage, +) from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer from google.cloud.pubsublite.cloudpubsub.nack_handler import NackHandler from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import ( @@ -36,7 +37,6 @@ SubscriberResetHandler, ) from google.cloud.pubsublite_v1 import FlowControlRequest, SequencedMessage -from google.cloud.pubsub_v1.subscriber._protocol import requests class _SizedMessage(NamedTuple): @@ -44,19 +44,6 @@ class _SizedMessage(NamedTuple): size_bytes: int -class _AckId(NamedTuple): - generation: int - offset: int - - def encode(self) -> str: - return fast_serialize.dump([self.generation, self.offset]) - - @staticmethod - def parse(payload: str) -> "_AckId": # pytype: disable=invalid-annotation - loaded = fast_serialize.load(payload) - return _AckId(generation=loaded[0], offset=loaded[1]) - - ResettableSubscriberFactory = Callable[[SubscriberResetHandler], Subscriber] @@ -69,10 +56,10 @@ class SinglePartitionSingleSubscriber( _nack_handler: NackHandler _transformer: MessageTransformer - _queue: queue.Queue _ack_generation_id: int - _messages_by_ack_id: Dict[str, _SizedMessage] - _looper_future: asyncio.Future + _messages_by_ack_id: Dict[AckId, _SizedMessage] + + _loop: asyncio.AbstractEventLoop def __init__( self, @@ -89,7 +76,6 @@ def __init__( self._nack_handler = nack_handler self._transformer = transformer - self._queue = queue.Queue() self._ack_generation_id = 0 self._messages_by_ack_id = {} @@ -104,19 +90,33 @@ def _wrap_message(self, message: SequencedMessage.meta.pb) -> Message: rewrapped._pb = message cps_message = self._transformer.transform(rewrapped) offset = message.cursor.offset - ack_id_str = _AckId(self._ack_generation_id, offset).encode() + ack_id = AckId(self._ack_generation_id, offset) self._ack_set_tracker.track(offset) - self._messages_by_ack_id[ack_id_str] = _SizedMessage( + self._messages_by_ack_id[ack_id] = _SizedMessage( cps_message, message.size_bytes ) - wrapped_message = Message( - cps_message._pb, - ack_id=ack_id_str, - delivery_attempt=0, - request_queue=self._queue, + wrapped_message = WrappedMessage( + pb=cps_message._pb, + ack_id=ack_id, + ack_handler=lambda id, ack: self._on_ack_threadsafe(id, ack), ) return wrapped_message + def _on_ack_threadsafe(self, ack_id: AckId, should_ack: bool) -> None: + """A function called when a message is acked, may happen from any thread.""" + if should_ack: + self._loop.call_soon_threadsafe(lambda: self._handle_ack(ack_id)) + return + try: + sized_message = self._messages_by_ack_id[ack_id] + # Call the threadsafe version on ack since the callback may be called from another thread. + self._nack_handler.on_nack( + sized_message.message, lambda: self._on_ack_threadsafe(ack_id, True) + ) + except Exception as e: + e2 = adapt_error(e) + self._loop.call_soon_threadsafe(lambda: self.fail(e2)) + async def read(self) -> List[Message]: try: latest_batch = await self.await_unless_failed(self._underlying.read()) @@ -126,78 +126,23 @@ async def read(self) -> List[Message]: self.fail(e) raise e - def _handle_ack(self, message: requests.AckRequest): + def _handle_ack(self, ack_id: AckId): flow_control = FlowControlRequest() flow_control._pb.allowed_messages = 1 - flow_control._pb.allowed_bytes = self._messages_by_ack_id[ - message.ack_id - ].size_bytes + flow_control._pb.allowed_bytes = self._messages_by_ack_id[ack_id].size_bytes self._underlying.allow_flow(flow_control) - del self._messages_by_ack_id[message.ack_id] + del self._messages_by_ack_id[ack_id] # Always refill flow control tokens, but do not commit offsets from outdated generations. - ack_id = _AckId.parse(message.ack_id) if ack_id.generation == self._ack_generation_id: try: self._ack_set_tracker.ack(ack_id.offset) except GoogleAPICallError as e: self.fail(e) - def _handle_nack(self, message: requests.NackRequest): - sized_message = self._messages_by_ack_id[message.ack_id] - try: - # Put the ack request back into the queue since the callback may be called from another thread. - self._nack_handler.on_nack( - sized_message.message, - lambda: self._queue.put( - requests.AckRequest( - ack_id=message.ack_id, - byte_size=0, # Ignored - time_to_ack=0, # Ignored - ordering_key="", # Ignored - ) - ), - ) - except GoogleAPICallError as e: - self.fail(e) - - async def _handle_queue_message( - self, - message: Union[ - requests.AckRequest, - requests.DropRequest, - requests.ModAckRequest, - requests.NackRequest, - ], - ): - if isinstance(message, requests.DropRequest) or isinstance( - message, requests.ModAckRequest - ): - self.fail( - FailedPrecondition( - "Called internal method of google.cloud.pubsub_v1.subscriber.message.Message " - f"Pub/Sub Lite does not support: {message}" - ) - ) - elif isinstance(message, requests.AckRequest): - self._handle_ack(message) - else: - self._handle_nack(message) - - async def _looper(self): - while True: - try: - # This is not an asyncio.Queue, and therefore we cannot do `await self._queue.get()`. - # A blocking wait would block the event loop, this needs to be a queue.Queue for - # compatibility with the Cloud Pub/Sub Message's requirements. - queue_message = self._queue.get_nowait() - await self._handle_queue_message(queue_message) - except queue.Empty: - await asyncio.sleep(0.1) - async def __aenter__(self): + self._loop = asyncio.get_event_loop() await self._ack_set_tracker.__aenter__() await self._underlying.__aenter__() - self._looper_future = asyncio.ensure_future(self._looper()) self._underlying.allow_flow( FlowControlRequest( allowed_messages=self._flow_control_settings.messages_outstanding, @@ -207,7 +152,5 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): - self._looper_future.cancel() - await wait_ignore_cancelled(self._looper_future) await self._underlying.__aexit__(exc_type, exc_value, traceback) await self._ack_set_tracker.__aexit__(exc_type, exc_value, traceback) diff --git a/google/cloud/pubsublite/cloudpubsub/internal/wrapped_message.py b/google/cloud/pubsublite/cloudpubsub/internal/wrapped_message.py new file mode 100644 index 00000000..5cfa5102 --- /dev/null +++ b/google/cloud/pubsublite/cloudpubsub/internal/wrapped_message.py @@ -0,0 +1,64 @@ +from concurrent import futures +import logging +from typing import NamedTuple, Callable + +from google.cloud.pubsub_v1.subscriber.message import Message +from google.pubsub_v1 import PubsubMessage +from google.cloud.pubsub_v1.subscriber.exceptions import AcknowledgeStatus + + +class AckId(NamedTuple): + generation: int + offset: int + + def encode(self) -> str: + return str(self.generation) + "," + str(self.offset) + + +_SUCCESS_FUTURE = futures.Future() +_SUCCESS_FUTURE.set_result(AcknowledgeStatus.SUCCESS) + + +class WrappedMessage(Message): + _id: AckId + _ack_handler: Callable[[AckId, bool], None] + + def __init__( + self, + pb: PubsubMessage.meta.pb, + ack_id: AckId, + ack_handler: Callable[[AckId, bool], None], + ): + super().__init__(pb, ack_id.encode(), 1, None) + self._id = ack_id + self._ack_handler = ack_handler + + def ack(self): + self._ack_handler(self._id, True) + + def ack_with_response(self) -> "futures.Future": + self._ack_handler(self._id, True) + return _SUCCESS_FUTURE + + def nack(self): + self._ack_handler(self._id, False) + + def nack_with_response(self) -> "futures.Future": + self._ack_handler(self._id, False) + return _SUCCESS_FUTURE + + def drop(self): + logging.warning( + "Likely incorrect call to drop() on Pub/Sub Lite message. Pub/Sub Lite does not support redelivery in this way." + ) + + def modify_ack_deadline(self, seconds: int): + logging.warning( + "Likely incorrect call to modify_ack_deadline() on Pub/Sub Lite message. Pub/Sub Lite does not support redelivery in this way." + ) + + def modify_ack_deadline_with_response(self, seconds: int) -> "futures.Future": + logging.warning( + "Likely incorrect call to modify_ack_deadline_with_response() on Pub/Sub Lite message. Pub/Sub Lite does not support redelivery in this way." + ) + return _SUCCESS_FUTURE diff --git a/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py b/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py index c007913a..74728045 100644 --- a/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py +++ b/tests/unit/pubsublite/cloudpubsub/internal/single_partition_subscriber_test.py @@ -25,7 +25,7 @@ from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker from google.cloud.pubsublite.cloudpubsub.internal.single_partition_subscriber import ( SinglePartitionSingleSubscriber, - _AckId, + AckId, ) from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer from google.cloud.pubsublite.cloudpubsub.nack_handler import NackHandler @@ -48,7 +48,7 @@ def mock_async_context_manager(cm): def ack_id(generation, offset) -> str: - return _AckId(generation, offset).encode() + return AckId(generation, offset).encode() @pytest.fixture()