Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Typing for unittest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
richvdh committed Apr 1, 2022
1 parent 815b94f commit d1dadf9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ exclude = (?x)
|tests/test_server.py
|tests/test_state.py
|tests/test_terms_auth.py
|tests/unittest.py
|tests/util/caches/test_cached_call.py
|tests/util/caches/test_deferred_cache.py
|tests/util/caches/test_descriptors.py
Expand Down
59 changes: 40 additions & 19 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from typing import (
Any,
AnyStr,
Awaitable,
Callable,
ClassVar,
Dict,
Generic,
Iterable,
List,
Optional,
Expand All @@ -39,6 +41,7 @@
import canonicaljson
import signedjson.key
import unpaddedbase64
from typing_extensions import Protocol

from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
Expand Down Expand Up @@ -84,6 +87,17 @@
setupdb()
setup_logging()

TV = TypeVar("TV")
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)


class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type."""

@property
def value(self) -> _ExcType:
...


def around(target):
"""A CLOS-style 'around' modifier, which wraps the original method of the
Expand Down Expand Up @@ -520,30 +534,36 @@ async def run_bg_updates():

return hs

def pump(self, by=0.0):
def pump(self, by: float = 0.0) -> None:
"""
Pump the reactor enough that Deferreds will fire.
"""
self.reactor.pump([by] * 100)

def get_success(self, d, by=0.0):
deferred: Deferred[TV] = ensureDeferred(d)
def get_success(
self,
d: Awaitable[TV],
by: float = 0.0,
) -> TV:
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
self.pump(by=by)
return self.successResultOf(deferred)

def get_failure(self, d, exc):
def get_failure(
self, d: Awaitable[Any], exc: Type[_ExcType]
) -> _TypedFailure[_ExcType]:
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
deferred: Deferred[Any] = ensureDeferred(d)
deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type]
self.pump()
return self.failureResultOf(deferred, exc)

def get_success_or_raise(self, d, by=0.0):
def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
"""Drive deferred to completion and return result or raise exception
on failure.
"""
deferred: Deferred[TV] = ensureDeferred(d)
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]

results: list = []
deferred.addBoth(results.append)
Expand Down Expand Up @@ -651,11 +671,11 @@ def register_appservice_user(

def login(
self,
username,
password,
device_id=None,
username: str,
password: str,
device_id: Optional[str] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
) -> str:
"""
Log in a user, and get an access token. Requires the Login API be
registered.
Expand All @@ -677,18 +697,22 @@ def login(
return access_token

def create_and_send_event(
self, room_id, user, soft_failed=False, prev_event_ids=None
):
self,
room_id: str,
user: UserID,
soft_failed: bool = False,
prev_event_ids: Optional[List[str]] = None,
) -> str:
"""
Create and send an event.
Args:
soft_failed (bool): Whether to create a soft failed event or not
prev_event_ids (list[str]|None): Explicitly set the prev events,
soft_failed: Whether to create a soft failed event or not
prev_event_ids: Explicitly set the prev events,
or if None just use the default
Returns:
str: The new event's ID.
The new event's ID.
"""
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
Expand Down Expand Up @@ -887,9 +911,6 @@ def decorator(func):
return decorator


TV = TypeVar("TV")


def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
"""A test decorator which will skip the decorated test unless a condition is set
Expand Down

0 comments on commit d1dadf9

Please sign in to comment.