forked from frequenz-floss/frequenz-channels-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_broadcast.py
275 lines (199 loc) · 7.78 KB
/
test_broadcast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# License: MIT
# Copyright © 2022 Frequenz Energy-as-a-Service GmbH
"""Tests for the Broadcast implementation."""
import asyncio
import pytest
from frequenz.channels import (
Broadcast,
ChannelClosedError,
Receiver,
ReceiverInvalidatedError,
ReceiverStoppedError,
Sender,
SenderError,
)
async def test_broadcast() -> None:
"""Ensure sent messages are received by all receivers."""
bcast: Broadcast[int] = Broadcast("meter_5")
num_receivers = 5
num_senders = 5
expected_sum = num_senders * num_receivers * num_receivers * (num_receivers + 1) / 2
# a list of `num_receivers` elements, where each element with get
# incremented by values the corresponding receiver receives. Once the run
# finishes, we will check if their sum equals `expected_sum`.
recv_trackers = [0] * num_receivers
async def send_msg(chan: Sender[int]) -> None:
# send one message for each receiver
for ctr in range(num_receivers):
await chan.send(ctr + 1)
async def update_tracker_on_receive(receiver_id: int, recv: Receiver[int]) -> None:
while True:
try:
msg = await recv.receive()
except ReceiverStoppedError as err:
assert err.receiver is recv
assert isinstance(err.__cause__, ChannelClosedError)
return
recv_trackers[receiver_id] += msg
receivers = []
for ctr in range(num_receivers):
receivers.append(update_tracker_on_receive(ctr, bcast.new_receiver()))
receivers_runs = asyncio.gather(*receivers)
senders = []
for ctr in range(num_senders):
senders.append(send_msg(bcast.new_sender()))
await asyncio.gather(*senders)
await bcast.close()
await receivers_runs
actual_sum = 0
for ctr in recv_trackers:
# ensure all receivers have got messages
assert ctr > 0
actual_sum += ctr
assert actual_sum == expected_sum
async def test_broadcast_none_values() -> None:
"""Ensure None values can be sent and received."""
bcast: Broadcast[int | None] = Broadcast("any_channel")
sender = bcast.new_sender()
receiver = bcast.new_receiver()
await sender.send(5)
assert await receiver.receive() == 5
await sender.send(None)
assert await receiver.receive() is None
await sender.send(10)
assert await receiver.receive() == 10
async def test_broadcast_after_close() -> None:
"""Ensure closed channels can't get new messages."""
bcast: Broadcast[int] = Broadcast("meter_5")
receiver = bcast.new_receiver()
sender = bcast.new_sender()
await bcast.close()
with pytest.raises(SenderError):
await sender.send(5)
with pytest.raises(ReceiverStoppedError) as excinfo:
await receiver.receive()
assert excinfo.value.receiver is receiver
assert isinstance(excinfo.value.__cause__, ChannelClosedError)
assert excinfo.value.__cause__.channel is bcast
async def test_broadcast_overflow() -> None:
"""Ensure messages sent to full broadcast receivers get dropped."""
bcast: Broadcast[int] = Broadcast("meter_5")
big_recv_size = 10
small_recv_size = int(big_recv_size / 2)
sender = bcast.new_sender()
big_receiver = bcast.new_receiver("named-recv", big_recv_size)
small_receiver = bcast.new_receiver(None, small_recv_size)
async def drain_receivers() -> tuple[int, int]:
big_sum = 0
small_sum = 0
while len(big_receiver) > 0:
msg = await big_receiver.receive()
assert msg is not None
big_sum += msg
while len(small_receiver) > 0:
msg = await small_receiver.receive()
assert msg is not None
small_sum += msg
return (big_sum, small_sum)
# we send `big_recv_size` messages first, then drain the receivers, then
# send `big_recv_size` messages again. Then get the sum.
total_messages = 2 * big_recv_size
big_sum = 0
small_sum = 0
for ctr in range(total_messages):
await sender.send(ctr + 1)
if (ctr + 1) % big_recv_size == 0:
big, small = await drain_receivers()
big_sum += big
small_sum += small
assert big_sum == total_messages * (total_messages + 1) / 2
# small_sum should be sum of `small_recv_size+1 .. big_recv_size`, and
# big_sum should be the numbers from `big_recv_size+small_recv_size+1` to
# `2 * big_recv_size`.
assert small_sum == (
small_recv_size * (small_recv_size + big_recv_size + 1) / 2
) + (
small_recv_size
* (2 * big_recv_size + (small_recv_size + big_recv_size + 1))
/ 2
)
async def test_broadcast_resend_latest() -> None:
"""Check if new receivers get the latest value when resend_latest is set."""
bcast: Broadcast[int] = Broadcast("new_recv_test", resend_latest=True)
sender = bcast.new_sender()
old_recv = bcast.new_receiver()
for val in range(0, 10):
await sender.send(val)
new_recv = bcast.new_receiver()
await sender.send(100)
assert await old_recv.receive() == 0
assert await new_recv.receive() == 9
assert await new_recv.receive() == 100
async def test_broadcast_no_resend_latest() -> None:
"""Ensure new receivers don't get the latest value when resend_latest isn't set."""
bcast: Broadcast[int] = Broadcast("new_recv_test", resend_latest=False)
sender = bcast.new_sender()
old_recv = bcast.new_receiver()
for val in range(0, 10):
await sender.send(val)
new_recv = bcast.new_receiver()
await sender.send(100)
assert await old_recv.receive() == 0
assert await new_recv.receive() == 100
async def test_broadcast_peek() -> None:
"""Ensure we are able to peek into broadcast channels."""
bcast: Broadcast[int] = Broadcast("peek-test")
receiver = bcast.new_receiver()
peekable = receiver.into_peekable()
sender = bcast.new_sender()
with pytest.raises(ReceiverInvalidatedError):
await receiver.receive()
assert peekable.peek() is None
for val in range(0, 10):
await sender.send(val)
assert peekable.peek() == 9
assert peekable.peek() == 9
await sender.send(20)
assert peekable.peek() == 20
await bcast.close()
assert peekable.peek() is None
async def test_broadcast_async_iterator() -> None:
"""Check that the broadcast receiver works as an async iterator."""
bcast: Broadcast[int] = Broadcast("iter_test")
sender = bcast.new_sender()
receiver = bcast.new_receiver()
async def send_values() -> None:
for val in range(0, 10):
await sender.send(val)
await bcast.close()
sender_task = asyncio.create_task(send_values())
received = []
async for recv in receiver:
received.append(recv)
assert received == list(range(0, 10))
await sender_task
async def test_broadcast_map() -> None:
"""Ensure map runs on all incoming messages."""
chan = Broadcast[int]("input-chan")
sender = chan.new_sender()
# transform int receiver into bool receiver.
receiver: Receiver[bool] = chan.new_receiver().map(lambda num: num > 10)
await sender.send(8)
await sender.send(12)
assert (await receiver.receive()) is False
assert (await receiver.receive()) is True
async def test_broadcast_receiver_drop() -> None:
"""Ensure deleted receivers get cleaned up."""
chan = Broadcast[int]("input-chan")
sender = chan.new_sender()
receiver1 = chan.new_receiver()
receiver2 = chan.new_receiver()
await sender.send(10)
assert 10 == await receiver1.receive()
assert 10 == await receiver2.receive()
# pylint: disable=protected-access
assert len(chan._receivers) == 2
del receiver2
await sender.send(20)
assert len(chan._receivers) == 1
# pylint: enable=protected-access