diff --git a/aioreactive/combine.py b/aioreactive/combine.py index f78896b..356b062 100644 --- a/aioreactive/combine.py +++ b/aioreactive/combine.py @@ -8,9 +8,9 @@ MailboxProcessor, Nothing, Option, - TailCallResult, Some, TailCall, + TailCallResult, match, pipe, tailrec_async, @@ -58,7 +58,7 @@ async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable: key=Key(0), ) - async def worker(inbox: MailboxProcessor[Msg]) -> None: + async def worker(inbox: MailboxProcessor[Msg[TSource]]) -> None: def obv(key: Key) -> AsyncObserver[TSource]: async def asend(value: TSource) -> None: await safe_obv.asend(value) @@ -71,19 +71,19 @@ async def aclose() -> None: return AsyncAnonymousObserver(asend, athrow, aclose) - async def update(msg: Msg, model: Model[TSource]) -> Model[TSource]: + async def update(msg: Msg[TSource], model: Model[TSource]) -> Model[TSource]: # log.debug("update: %s, model: %s", msg, model) - with match(msg) as m: - for xs in InnerObservableMsg.case(m): + with match(msg) as case: + for xs in case(InnerObservableMsg[TSource]): if max_concurrent == 0 or len(model.subscriptions) < max_concurrent: inner = await xs.subscribe_async(obv(model.key)) return model.replace( subscriptions=model.subscriptions.add(model.key, inner), key=Key(model.key + 1), ) - - return model.replace(queue=model.queue.append(xs)) - for key in InnerCompletedMsg.case(m): + lst = FrozenList.singleton(xs) + return model.replace(queue=model.queue.append(lst)) + for key in case(InnerCompletedMsg): subscriptions = model.subscriptions.remove(key) if len(model.queue): xs = model.queue[0] @@ -100,14 +100,14 @@ async def update(msg: Msg, model: Model[TSource]) -> Model[TSource]: if model.is_stopped: await safe_obv.aclose() return model.replace(subscriptions=map.empty) - while CompletedMsg.case(m): + while case(CompletedMsg): if not model.subscriptions: log.debug("merge_inner: closing!") await safe_obv.aclose() return model.replace(is_stopped=True) - while m.default(): + while case.default(): for dispose in model.subscriptions.values(): await dispose.dispose_async() @@ -179,7 +179,7 @@ def _combine_latest(source: AsyncObservable[TSource]) -> AsyncObservable[Tuple[T async def subscribe_async(aobv: AsyncObserver[Tuple[TSource, TOther]]) -> AsyncDisposable: safe_obv, auto_detach = auto_detach_observer(aobv) - async def worker(inbox: MailboxProcessor[Msg]) -> None: + async def worker(inbox: MailboxProcessor[Msg[TSource]]) -> None: @tailrec_async async def message_loop( source_value: Option[TSource], other_value: Option[TOther] @@ -188,24 +188,24 @@ async def message_loop( async def get_value(n: Notification[Any]) -> Option[Any]: with match(n) as m: - for value in OnNext.case(m): + for value in case(OnNext[TSource]): return Some(value) - for err in OnError.case(m): + for err in case(OnError): await safe_obv.athrow(err) while m.default(): await safe_obv.aclose() return Nothing - m = match(cn) - for value in SourceMsg.case(m): - source_value = await get_value(value) - break + with match(cn) as case: + for value in case(SourceMsg[TSource]): + source_value = await get_value(value) + break - for value in OtherMsg.case(m): - other_value = await get_value(value) - break + for value in case(OtherMsg[TOther]): + other_value = await get_value(value) + break def binder(s: TSource) -> Option[Tuple[TSource, TOther]]: def mapper(o: TOther) -> Tuple[TSource, TOther]: @@ -260,20 +260,20 @@ def _with_latest_from(source: AsyncObservable[TSource]) -> AsyncObservable[Tuple async def subscribe_async(aobv: AsyncObserver[Tuple[TSource, TOther]]) -> AsyncDisposable: safe_obv, auto_detach = auto_detach_observer(aobv) - async def worker(inbox: MailboxProcessor[Msg]) -> None: + async def worker(inbox: MailboxProcessor[Msg[TSource]]) -> None: @tailrec_async async def message_loop(latest: Option[TOther]) -> TailCallResult[NoReturn]: cn = await inbox.receive() async def get_value(n: Notification[Any]) -> Option[Any]: - with match(n) as m: - for value in OnNext.case(m): + with match(n) as case: + for value in case(OnNext[TSource]): return Some(value) - for err in OnError.case(m): + for err in case(OnError[TSource]): await safe_obv.athrow(err) - while m.default(): + while case.default(): await safe_obv.aclose() return Nothing diff --git a/aioreactive/create.py b/aioreactive/create.py index b1e5b3f..7f6ef60 100644 --- a/aioreactive/create.py +++ b/aioreactive/create.py @@ -1,7 +1,7 @@ import asyncio import logging from asyncio import Future -from typing import AsyncIterable, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar from expression.core import TailCallResult, aiotools, tailrec_async from expression.core.fn import TailCall @@ -131,7 +131,7 @@ async def subscribe_async(_: AsyncObserver[TSource]) -> AsyncDisposable: return AsyncAnonymousObservable(subscribe_async) -def fail(error: Exception) -> AsyncObservable[TSource]: +def fail(error: Exception) -> AsyncObservable[Any]: """Returns the observable sequence that terminates exceptionally with the specified exception.""" diff --git a/aioreactive/filtering.py b/aioreactive/filtering.py index b47f186..376b1c9 100644 --- a/aioreactive/filtering.py +++ b/aioreactive/filtering.py @@ -4,8 +4,8 @@ from expression.core import ( MailboxProcessor, Option, - TailCallResult, TailCall, + TailCallResult, aiotools, compose, fst, @@ -126,8 +126,8 @@ async def message_loop(latest: Notification[TSource]) -> TailCallResult[NoReturn n = await inbox.receive() async def get_latest() -> Notification[TSource]: - with match(n) as m: - for x in OnNext.case(m): + with match(n) as case: + for x in case(OnNext[TSource]): if n == latest: break try: @@ -135,10 +135,10 @@ async def get_latest() -> Notification[TSource]: except Exception as ex: await safe_obv.athrow(ex) break - for err in OnError.case(m): + for err in case(OnError[TSource]): await safe_obv.athrow(err) break - while m.case(OnCompleted): + while case(OnCompleted): await safe_obv.aclose() break diff --git a/aioreactive/msg.py b/aioreactive/msg.py index d29ee4f..a4a3709 100644 --- a/aioreactive/msg.py +++ b/aioreactive/msg.py @@ -1,10 +1,10 @@ """Internal messages used by mailbox processors. Do not import or use. """ -from abc import abstractclassmethod +from abc import ABC from dataclasses import dataclass -from typing import Any, Generic, Iterable, NewType, TypeVar +from typing import Any, Iterable, NewType, Type, TypeVar, get_origin -from expression.core import Matcher +from expression.core import SupportsMatch from expression.system import AsyncDisposable from .notification import Notification @@ -16,123 +16,124 @@ Key = NewType("Key", int) -class Msg: - """Message base class. - - Contains overloads for pattern matching to avoid any type casting - later. - """ - - @abstractclassmethod - def case(cls, matcher: Matcher) -> Any: - raise NotImplementedError +class Msg(SupportsMatch[TSource], ABC): + """Message base class.""" @dataclass -class SourceMsg(Msg, Generic[TSource]): +class SourceMsg(Msg[Notification[TSource]], SupportsMatch[TSource]): value: Notification[TSource] - @classmethod - def case(cls, matcher: Matcher) -> Iterable[Notification[TSource]]: - """Helper to cast the match result to correct type.""" - return matcher.case(cls) - def __match__(self, pattern: Any) -> Iterable[Notification[TSource]]: - if isinstance(self, pattern): - return [self.value] + origin: Any = get_origin(pattern) + try: + if isinstance(self, origin or pattern): + return [self.value] + except TypeError: + pass return [] @dataclass -class OtherMsg(Msg, Generic[TOther]): +class OtherMsg(Msg[Notification[TOther]], SupportsMatch[TOther]): value: Notification[TOther] - @classmethod - def case(cls, matcher: Matcher) -> Iterable[Notification[TOther]]: - """Helper to cast the match result to correct type.""" - - return matcher.case(cls) - def __match__(self, pattern: Any) -> Iterable[Notification[TOther]]: - if isinstance(self, pattern): - return [self.value] + origin: Any = get_origin(pattern) + try: + if isinstance(self, origin or pattern): + return [self.value] + except TypeError: + pass return [] @dataclass -class DisposableMsg(Msg): +class DisposableMsg(Msg[AsyncDisposable], SupportsMatch[AsyncDisposable]): """Message containing a diposable.""" disposable: AsyncDisposable - @classmethod - def case(cls, matcher: Matcher) -> Iterable[Notification[AsyncDisposable]]: - """Helper to cast the match result to correct type.""" - - return matcher.case(cls) - def __match__(self, pattern: Any) -> Iterable[AsyncDisposable]: - if isinstance(self, pattern): - return [self.disposable] + try: + if isinstance(self, pattern): + return [self.disposable] + except TypeError: + pass return [] @dataclass -class InnerObservableMsg(Msg, Generic[TSource]): +class InnerObservableMsg(Msg[AsyncObservable[TSource]], SupportsMatch[AsyncObservable[TSource]]): """Message containing an inner observable.""" inner_observable: AsyncObservable[TSource] - @classmethod - def case(cls, matcher: Matcher) -> Iterable[AsyncObservable[TSource]]: - """Helper to cast the match result to correct type.""" - - return matcher.case(cls) - def __match__(self, pattern: Any) -> Iterable[AsyncObservable[TSource]]: - if isinstance(self, pattern): - return [self.inner_observable] + origin: Any = get_origin(pattern) + try: + if isinstance(self, origin or pattern): + return [self.inner_observable] + except TypeError: + pass return [] @dataclass -class InnerCompletedMsg(Msg): +class InnerCompletedMsg(Msg[TSource]): """Message notifying that the inner observable completed.""" key: Key - @classmethod - def case(cls, matcher: Matcher) -> Iterable[Key]: - """Helper to cast the match result to correct type.""" - - return matcher.case(cls) - def __match__(self, pattern: Any) -> Iterable[Key]: - if isinstance(self, pattern): - return [self.key] + origin: Any = get_origin(pattern) + try: + if isinstance(self, origin or pattern): + return [self.key] + except TypeError: + pass return [] -class CompletedMsg(Msg): +class CompletedMsg_(Msg[Any]): """Message notifying that the observable sequence completed.""" - @classmethod - def case(cls, matcher: Matcher) -> Iterable[bool]: - """Helper to cast the match result to correct type.""" + def __match__(self, pattern: Any) -> Iterable[bool]: + if self is pattern: + return [True] - return matcher.case(cls) + origin: Any = get_origin(pattern) + try: + if isinstance(self, origin or pattern): + return [True] + except TypeError: + pass + + return [] -CompletedMsg_ = CompletedMsg() # Singleton +CompletedMsg = CompletedMsg_() # Singleton -class DisposeMsg(Msg): +class DisposeMsg_(Msg[None]): """Message notifying that the operator got disposed.""" - pass + def __match__(self, pattern: Any) -> Iterable[bool]: + + if self is pattern: + return [True] + + origin: Any = get_origin(pattern) + try: + if isinstance(self, origin or pattern): + return [True] + except TypeError: + pass + + return [] -DisposeMsg_ = DisposeMsg() # Singleton +DisposeMsg = DisposeMsg_() # Singleton __all__ = ["Msg", "DisposeMsg", "CompletedMsg", "InnerCompletedMsg", "InnerObservableMsg", "DisposableMsg"] diff --git a/aioreactive/notification.py b/aioreactive/notification.py index 1d23b74..67560ad 100644 --- a/aioreactive/notification.py +++ b/aioreactive/notification.py @@ -1,8 +1,8 @@ -from abc import ABC, abstractclassmethod, abstractmethod +from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Awaitable, Callable, Generic, Iterable, TypeVar +from typing import Any, Awaitable, Callable, Generic, Iterable, TypeVar, get_origin -from expression.core import Matcher +from expression.core import SupportsMatch from .types import AsyncObserver @@ -16,7 +16,7 @@ class MsgKind(Enum): ON_COMPLETED = 3 -class Notification(ABC, Generic[TSource]): +class Notification(Generic[TSource], ABC): """Represents a message to a mailbox processor.""" def __init__(self, kind: MsgKind): @@ -35,17 +35,11 @@ async def accept( async def accept_observer(self, obv: AsyncObserver[TSource]) -> None: raise NotImplementedError - @abstractclassmethod - def case(cls, matcher: Matcher) -> Iterable[Any]: - """Helper to cast the match result to correct type.""" - - raise NotImplementedError - def __repr__(self) -> str: return str(self) -class OnNext(Notification[TSource]): +class OnNext(Notification[TSource], SupportsMatch[TSource]): """Represents an OnNext notification to an observer.""" def __init__(self, value: TSource): @@ -64,15 +58,13 @@ async def accept( async def accept_observer(self, obv: AsyncObserver[TSource]) -> None: await obv.asend(self.value) - @classmethod - def case(cls, matcher: Matcher) -> Iterable[TSource]: - """Helper to cast the match result to correct type.""" - - return matcher.case(cls) - def __match__(self, pattern: Any) -> Iterable[TSource]: - if isinstance(self, pattern): - return [self.value] + origin: Any = get_origin(pattern) + try: + if isinstance(self, origin or pattern): + return [self.value] + except TypeError: + pass return [] def __eq__(self, other: Any) -> bool: @@ -84,7 +76,7 @@ def __str__(self) -> str: return f"OnNext({self.value})" -class OnError(Notification[TSource]): +class OnError(Notification[TSource], SupportsMatch[Exception]): """Represents an OnError notification to an observer.""" def __init__(self, exception: Exception): @@ -103,15 +95,13 @@ async def accept( async def accept_observer(self, obv: AsyncObserver[TSource]): await obv.athrow(self.exception) - @classmethod - def case(cls, matcher: Matcher) -> Iterable[Exception]: - """Helper to cast the match result to correct type.""" - - return matcher.case(cls) - def __match__(self, pattern: Any) -> Iterable[Exception]: - if isinstance(self, pattern): - return [self.exception] + origin: Any = get_origin(pattern) + try: + if isinstance(self, origin or pattern): + return [self.exception] + except TypeError: + pass return [] def __eq__(self, other: Any) -> bool: @@ -123,7 +113,7 @@ def __str__(self) -> str: return f"OnError({self.exception})" -class _OnCompleted(Notification[TSource]): +class _OnCompleted(Notification[TSource], SupportsMatch[bool]): """Represents an OnCompleted notification to an observer. Note: Do not use. Use the singleton `OnCompleted` instance instead. @@ -145,11 +135,17 @@ async def accept( async def accept_observer(self, obv: AsyncObserver[TSource]): await obv.aclose() - @classmethod - def case(cls, matcher: Matcher) -> Iterable[bool]: - """Helper to cast the match result to correct type.""" + def __match__(self, pattern: Any) -> Iterable[bool]: + if self is pattern: + return [True] - return matcher.case(cls) + origin: Any = get_origin(pattern) + try: + if isinstance(self, origin or pattern): + return [True] + except TypeError: + pass + return [] def __eq__(self, other: Any) -> bool: if isinstance(other, _OnCompleted): diff --git a/aioreactive/observers.py b/aioreactive/observers.py index f10e98e..09dc610 100644 --- a/aioreactive/observers.py +++ b/aioreactive/observers.py @@ -200,7 +200,7 @@ def auto_detach_observer( cts = CancellationTokenSource() token = cts.token - async def worker(inbox: MailboxProcessor[Msg]): + async def worker(inbox: MailboxProcessor[Msg[TSource]]): @tailrec_async async def message_loop(disposables: List[AsyncDisposable]): if token.is_cancellation_requested: diff --git a/aioreactive/testing/observer.py b/aioreactive/testing/observer.py index 5912cfd..55dd9cf 100644 --- a/aioreactive/testing/observer.py +++ b/aioreactive/testing/observer.py @@ -1,5 +1,5 @@ import logging -from typing import Awaitable, Callable, List, Tuple, TypeVar, cast +from typing import Awaitable, Callable, List, Tuple, TypeVar from aioreactive import AsyncAwaitableObserver from aioreactive.notification import Notification, OnCompleted, OnError, OnNext @@ -33,7 +33,7 @@ def __init__( ): super().__init__(asend, athrow, aclose) - self._values = cast(List[Tuple[float, Notification[TSource]]], []) # FIXME: Pylance confusion? + self._values: List[Tuple[float, Notification[TSource]]] = [] self._send = asend self._throw = athrow diff --git a/aioreactive/timeshift.py b/aioreactive/timeshift.py index e0917bd..ff48c0e 100644 --- a/aioreactive/timeshift.py +++ b/aioreactive/timeshift.py @@ -4,16 +4,7 @@ from typing import Iterable, NoReturn, Tuple, TypeVar from expression.collections import seq -from expression.core import ( - MailboxProcessor, - TailCall, - aiotools, - match, - pipe, - tailrec_async, - snd, - TailCallResult, -) +from expression.core import MailboxProcessor, TailCall, TailCallResult, aiotools, match, pipe, snd, tailrec_async from expression.system import CancellationTokenSource from .combine import with_latest_from @@ -60,14 +51,14 @@ async def loop() -> TailCallResult[None]: await asyncio.sleep(seconds) async def matcher() -> None: - with match(ns) as m: - for x in OnNext.case(m): + with match(ns) as case: + for x in case(OnNext[TSource]): await aobv.asend(x) return - for err in OnError.case(m): + for err in case(OnError[TSource]): await aobv.athrow(err) return - while m.case(OnCompleted): + for x in case(OnCompleted): await aobv.aclose() return @@ -108,19 +99,19 @@ async def worker(inbox: MailboxProcessor[Tuple[Notification[TSource], int]]) -> async def message_loop(current_index: int) -> TailCallResult[NoReturn]: n, index = await inbox.receive() - with match(n) as m: + with match(n) as case: log.debug("debounce: %s, %d, %d", n, index, current_index) - for x in OnNext.case(m): + for x in case(OnNext[TSource]): if index == current_index: await safe_obv.asend(x) current_index = index elif index > current_index: current_index = index - for err in OnError.case(m): + for err in case(OnError[TSource]): await safe_obv.athrow(err) - while m.case(OnCompleted): + while case(OnCompleted): await safe_obv.aclose() return TailCall(current_index) diff --git a/aioreactive/transform.py b/aioreactive/transform.py index d25270d..a75a602 100644 --- a/aioreactive/transform.py +++ b/aioreactive/transform.py @@ -5,9 +5,9 @@ MailboxProcessor, Nothing, Option, - TailCallResult, Some, TailCall, + TailCallResult, compose, match, pipe, @@ -243,7 +243,7 @@ def switch_latest(source: AsyncObservable[AsyncObservable[TSource]]) -> AsyncObs async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable: safe_obv, auto_detach = auto_detach_observer(aobv) - def obv(mb: MailboxProcessor[Msg], id: int): + def obv(mb: MailboxProcessor[Msg[TSource]], id: int): async def asend(value: TSource) -> None: await safe_obv.asend(value) @@ -255,38 +255,31 @@ async def aclose() -> None: return AsyncAnonymousObserver(asend, athrow, aclose) - async def worker(inbox: MailboxProcessor[Msg]) -> None: + async def worker(inbox: MailboxProcessor[Msg[TSource]]) -> None: @tailrec_async async def message_loop( current: Option[AsyncDisposable], is_stopped: bool, current_id: int ) -> TailCallResult[None]: cmd = await inbox.receive() - with match(cmd) as m: - for xs in InnerObservableMsg.case(m): - next_id = current_id + 1 - for disp in current.to_list(): - await disp.dispose_async() - inner = await xs.subscribe_async(obv(inbox, next_id)) - current, current_id = Some(inner), next_id - break - for xs in InnerObservableMsg.case(m): + with match(cmd) as case: + for xs in case(InnerObservableMsg[TSource]): next_id = current_id + 1 for disp in current.to_list(): await disp.dispose_async() inner = await xs.subscribe_async(obv(inbox, next_id)) current, current_id = Some(inner), next_id break - for idx in InnerCompletedMsg.case(m): + for idx in case(InnerCompletedMsg): if is_stopped and idx == current_id: await safe_obv.aclose() current, is_stopped = Nothing, True break - while m.case(CompletedMsg): + while case(CompletedMsg): if current.is_none(): await safe_obv.aclose() break - while m.case(DisposeMsg): + while case(DisposeMsg): if current.is_some(): await current.value.dispose_async() current, is_stopped = Nothing, True @@ -358,7 +351,7 @@ def flat_map_latest(mapper: Callable[[TSource], AsyncObservable[TResult]]) -> St return compose(map(mapper), switch_latest) -def catch(handler: Callable[[Exception], AsyncObservable[TSource]]) -> Stream[TSource, TResult]: +def catch(handler: Callable[[Exception], AsyncObservable[TSource]]) -> Stream[TSource, TSource]: """Catch Exception. Returns an observable sequence containing the first sequence's diff --git a/test/test_single.py b/test/test_single.py index 818d3f1..508f061 100644 --- a/test/test_single.py +++ b/test/test_single.py @@ -121,7 +121,6 @@ async def test_unit_future_cancel(): obv = AsyncTestObserver() async with await xs.subscribe_async(obv): await asyncio.sleep(1) - print("cancelling") fut.cancel() with pytest.raises(asyncio.CancelledError): await obv