Skip to content

Commit

Permalink
Merge pull request #200 from aiokitchen/featuire/fix-contextvars
Browse files Browse the repository at this point in the history
Featuire/fix contextvars
  • Loading branch information
mosquito authored Mar 7, 2024
2 parents 12bfae8 + 16bae46 commit b4acf57
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 14 deletions.
8 changes: 4 additions & 4 deletions aiomisc/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import inspect
import logging
from asyncio import CancelledError, Event, Future, Lock, wait_for
from dataclasses import dataclass
from inspect import Parameter
from typing import (
Any, Awaitable, Callable, Iterable, List, NamedTuple, Optional, Union,
)
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Union

from .compat import EventLoopMixin
from .counters import Statistic
Expand All @@ -14,7 +13,8 @@
log = logging.getLogger(__name__)


class Arg(NamedTuple):
@dataclass(frozen=True)
class Arg:
value: Any
future: Future

Expand Down
12 changes: 10 additions & 2 deletions aiomisc/counters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import Counter
from dataclasses import dataclass
from typing import (
Any, Dict, FrozenSet, Generator, MutableMapping, MutableSet, NamedTuple,
Any, Dict, FrozenSet, Generator, Iterator, MutableMapping, MutableSet,
Optional, Set, Tuple, Type, Union,
)
from weakref import WeakSet
Expand Down Expand Up @@ -107,12 +108,19 @@ def __init__(self, name: Optional[str] = None) -> None:
self.__instances__.add(self)


class StatisticResult(NamedTuple):
@dataclass(frozen=True)
class StatisticResult:
kind: Type[AbstractStatistic]
name: Optional[str]
metric: str
value: Union[int, float]

def __iter__(self) -> Iterator:
yield self.kind
yield self.name
yield self.metric
yield self.value


# noinspection PyProtectedMember
def get_statistics(
Expand Down
6 changes: 4 additions & 2 deletions aiomisc/service/cron.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import logging
from asyncio import iscoroutinefunction
from typing import Any, Callable, NamedTuple, Optional, Set, Tuple, Type
from dataclasses import dataclass
from typing import Any, Callable, Optional, Set, Tuple, Type

from croniter import croniter

Expand All @@ -13,7 +14,8 @@
ExceptionsType = Tuple[Type[Exception], ...]


class StoreItem(NamedTuple):
@dataclass(frozen=True)
class StoreItem:
callback: CronCallback
spec: str
shield: bool
Expand Down
21 changes: 15 additions & 6 deletions aiomisc/thread_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import time
import warnings
from concurrent.futures import ThreadPoolExecutor as ThreadPoolExecutorBase
from dataclasses import dataclass
from functools import partial, wraps
from multiprocessing import cpu_count
from queue import SimpleQueue
from types import MappingProxyType
from typing import (
Any, Awaitable, Callable, Coroutine, Dict, FrozenSet, NamedTuple, Optional,
Set, Tuple, TypeVar,
Any, Awaitable, Callable, Coroutine, Dict, FrozenSet, Optional, Set, Tuple,
TypeVar,
)

from ._context_vars import EVENT_LOOP
Expand All @@ -31,6 +32,10 @@ def context_partial(
func: F, *args: Any,
**kwargs: Any,
) -> Any:
warnings.warn(
"context_partial has been deprecated and will be removed",
DeprecationWarning,
)
context = contextvars.copy_context()
return partial(context.run, func, *args, **kwargs)

Expand All @@ -39,12 +44,14 @@ class ThreadPoolException(RuntimeError):
pass


class WorkItemBase(NamedTuple):
@dataclass(frozen=True)
class WorkItemBase:
func: Callable[..., Any]
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
future: asyncio.Future
loop: asyncio.AbstractEventLoop
context: contextvars.Context


class ThreadPoolStatistic(Statistic):
Expand Down Expand Up @@ -74,10 +81,11 @@ def __call__(self, statistic: ThreadPoolStatistic) -> None:
return

result, exception = None, None

delta = -time.monotonic()
try:
result = self.func(*self.args, **self.kwargs)
result = self.context.run(
self.func, *self.args, **self.kwargs,
)
statistic.success += 1
except BaseException as e:
statistic.error += 1
Expand Down Expand Up @@ -193,6 +201,7 @@ def submit( # type: ignore
args=args,
kwargs=kwargs,
future=future,
context=contextvars.copy_context(),
loop=loop,
),
)
Expand Down Expand Up @@ -230,7 +239,7 @@ def run_in_executor(
try:
loop = asyncio.get_running_loop()
return loop.run_in_executor(
executor, context_partial(func, *args, **kwargs),
executor, partial(func, *args, **kwargs),
)
except RuntimeError:
# In case the event loop is not running right now is
Expand Down

0 comments on commit b4acf57

Please sign in to comment.