Skip to content

Commit

Permalink
Add type annotation on test_websockets.py (#1880)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Mar 5, 2023
1 parent 2a94a96 commit 4a5f3dd
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 41 deletions.
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ follow_imports = silent
# NOTE: If you are seing this, feel free to create a PR to cover the below files.
exclude = (?x)
^(
|tests/protocols/test_websocket.py
|tests/supervisors/test_reload.py
|tests/protocols/test_http.py
|tests/test_auto_detection.py
Expand Down
86 changes: 46 additions & 40 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
import websockets.client
import websockets.exceptions
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
from websockets.typing import Subprotocol

from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol

ONLY_WEBSOCKETS_PROTOCOL = [WebSocketProtocol]
WS_PROTOCOLS = [WSProtocol, WebSocketProtocol]
except ImportError: # pragma: nocover
websockets = None
WebSocketProtocol = None
ClientPerMessageDeflateFactory = None

ONLY_WEBSOCKETPROTOCOL = [p for p in [WebSocketProtocol] if p is not None]
ONLY_WS_PROTOCOL = [p for p in [WSProtocol] if p is not None]
WS_PROTOCOLS = [p for p in [WSProtocol, WebSocketProtocol] if p is not None]
pytestmark = pytest.mark.skipif(
websockets is None, reason="This test needs the websockets module"
)
ONLY_WEBSOCKETS_PROTOCOL = []
WS_PROTOCOLS = [WSProtocol]

pytestmark = pytest.mark.skip(reason="This test needs the websockets module")

ONLY_WS_PROTOCOL = [WSProtocol]


class WebSocketResponse:
Expand Down Expand Up @@ -99,7 +99,7 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
return websocket.open

config = Config(
Expand All @@ -126,7 +126,9 @@ async def websocket_connect(self, message):

async def open_connection(url):
extension_factories = [ClientPerMessageDeflateFactory()]
async with websockets.connect(url, extensions=extension_factories) as websocket:
async with websockets.client.connect(
url, extensions=extension_factories
) as websocket:
return [extension.name for extension in websocket.extensions]

config = Config(
Expand Down Expand Up @@ -155,7 +157,9 @@ async def open_connection(url):
# enable per-message deflate on the client, so that we can check the server
# won't support it when it's disabled.
extension_factories = [ClientPerMessageDeflateFactory()]
async with websockets.connect(url, extensions=extension_factories) as websocket:
async with websockets.client.connect(
url, extensions=extension_factories
) as websocket:
return [extension.name for extension in websocket.extensions]

config = Config(
Expand Down Expand Up @@ -183,7 +187,7 @@ async def websocket_connect(self, message):

async def open_connection(url):
try:
await websockets.connect(url)
await websockets.client.connect(url)
except websockets.exceptions.InvalidHandshake:
return False
return True # pragma: no cover
Expand Down Expand Up @@ -213,7 +217,7 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(
async with websockets.client.connect(
url, extra_headers=[("username", "abraão")]
) as websocket:
return websocket.open
Expand Down Expand Up @@ -241,7 +245,7 @@ async def websocket_connect(self, message):
)

async def open_connection(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
return websocket.response_headers

config = Config(
Expand Down Expand Up @@ -271,7 +275,7 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
return websocket.open

config = Config(
Expand All @@ -298,7 +302,7 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.send", "text": "123"})

async def get_data(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
return await websocket.recv()

config = Config(
Expand All @@ -325,7 +329,7 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.send", "bytes": b"123"})

async def get_data(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
return await websocket.recv()

config = Config(
Expand Down Expand Up @@ -353,7 +357,7 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.close"})

async def get_data(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
data = await websocket.recv()
is_open = True
try:
Expand Down Expand Up @@ -390,7 +394,7 @@ async def websocket_receive(self, message):
await self.send({"type": "websocket.send", "text": _text})

async def send_text(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
await websocket.send("abc")
return await websocket.recv()

Expand Down Expand Up @@ -421,7 +425,7 @@ async def websocket_receive(self, message):
await self.send({"type": "websocket.send", "bytes": _bytes})

async def send_text(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
await websocket.send(b"abc")
return await websocket.recv()

Expand Down Expand Up @@ -452,7 +456,7 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.send", "text": "123"})

async def get_data(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
data = await websocket.recv()
is_open = True
try:
Expand Down Expand Up @@ -484,7 +488,7 @@ async def app(app, receive, send):
pass

async def connect(url):
await websockets.connect(url)
await websockets.client.connect(url)

config = Config(
app=app,
Expand All @@ -509,7 +513,7 @@ async def app(scope, receive, send):
await send({"type": "websocket.send", "text": "123"})

async def connect(url):
await websockets.connect(url)
await websockets.client.connect(url)

config = Config(
app=app,
Expand All @@ -535,7 +539,7 @@ async def app(scope, receive, send):
await send({"type": "websocket.accept"})

async def connect(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
_ = await websocket.recv()

config = Config(
Expand Down Expand Up @@ -567,7 +571,7 @@ async def app(scope, receive, send):
return 123

async def connect(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
_ = await websocket.recv()

config = Config(
Expand Down Expand Up @@ -614,7 +618,7 @@ async def app(scope, receive, send):
break

async def websocket_session(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
await websocket.ping()
await websocket.send("abc")
await websocket.recv()
Expand Down Expand Up @@ -648,7 +652,7 @@ async def app(scope, receive, send):
break

async def websocket_session(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
await websocket.ping()
await websocket.send("abc")

Expand Down Expand Up @@ -814,8 +818,8 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.accept", "subprotocol": subprotocol})

async def get_subprotocol(url):
async with websockets.connect(
url, subprotocols=["proto1", "proto2"]
async with websockets.client.connect(
url, subprotocols=[Subprotocol("proto1"), Subprotocol("proto2")]
) as websocket:
return websocket.subprotocol

Expand All @@ -838,7 +842,7 @@ async def get_subprotocol(url):


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", ONLY_WEBSOCKETPROTOCOL)
@pytest.mark.parametrize("ws_protocol_cls", ONLY_WEBSOCKETS_PROTOCOL)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
@pytest.mark.parametrize(
"client_size_sent, server_size_max, expected_result",
Expand Down Expand Up @@ -872,7 +876,9 @@ async def websocket_receive(self, message):
await self.send({"type": "websocket.send", "bytes": _bytes})

async def send_text(url):
async with websockets.connect(url, max_size=client_size_sent) as websocket:
async with websockets.client.connect(
url, max_size=client_size_sent
) as websocket:
await websocket.send(b"\x01" * client_size_sent)
return await websocket.recv()

Expand All @@ -889,7 +895,7 @@ async def send_text(url):
data = await send_text(f"ws://127.0.0.1:{unused_tcp_port}")
assert data == b"\x01" * client_size_sent
else:
with pytest.raises(websockets.ConnectionClosedError) as e:
with pytest.raises(websockets.exceptions.ConnectionClosedError) as e:
data = await send_text(f"ws://127.0.0.1:{unused_tcp_port}")
assert e.value.code == expected_result

Expand Down Expand Up @@ -918,7 +924,7 @@ async def app(scope, receive, send):

async def websocket_session(url):
try:
async with websockets.connect(url):
async with websockets.client.connect(url):
pass # pragma: no cover
except Exception:
pass
Expand Down Expand Up @@ -959,7 +965,7 @@ async def websocket_receive(self, message):
frames.append(message.get("bytes"))

async def send_text(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
await websocket.send(b"abc")
await websocket.send(b"abc")
await websocket.send(b"abc")
Expand Down Expand Up @@ -989,7 +995,7 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
return websocket.response_headers

config = Config(
Expand All @@ -1015,7 +1021,7 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
return websocket.response_headers

config = Config(
Expand All @@ -1040,7 +1046,7 @@ async def websocket_connect(self, message):
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
return websocket.response_headers

config = Config(
Expand Down Expand Up @@ -1075,7 +1081,7 @@ async def websocket_connect(self, message):
)

async def open_connection(url):
async with websockets.connect(url) as websocket:
async with websockets.client.connect(url) as websocket:
return websocket.response_headers

config = Config(
Expand Down

0 comments on commit 4a5f3dd

Please sign in to comment.