-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Expand file tree
/
Copy pathtest_base_protocol.py
More file actions
304 lines (236 loc) · 8.04 KB
/
test_base_protocol.py
File metadata and controls
304 lines (236 loc) · 8.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import asyncio
from contextlib import suppress
from unittest import mock
import pytest
from aiohttp.base_protocol import BaseProtocol
async def test_loop() -> None:
loop = asyncio.get_event_loop()
asyncio.set_event_loop(None)
pr = BaseProtocol(loop)
assert pr._loop is loop
async def test_pause_writing() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop)
assert not pr._paused
assert pr.writing_paused is False
pr.pause_writing()
assert pr._paused
assert pr.writing_paused is True # type: ignore[unreachable]
async def test_pause_reading_no_transport() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop)
assert not pr._reading_paused
pr.pause_reading()
assert not pr._reading_paused
async def test_pause_reading_stub_transport() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop)
tr = asyncio.Transport()
pr.transport = tr
assert not pr._reading_paused
pr.pause_reading()
assert pr._reading_paused
async def test_resume_reading_no_transport() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop)
pr._reading_paused = True
pr.resume_reading()
assert pr._reading_paused
async def test_resume_reading_stub_transport() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop)
tr = asyncio.Transport()
pr.transport = tr
pr._reading_paused = True
pr.resume_reading()
assert not pr._reading_paused
async def test_resume_writing_no_waiters() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
pr.pause_writing()
assert pr._paused
pr.resume_writing()
assert not pr._paused
async def test_resume_writing_waiter_done() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
waiter = mock.Mock(done=mock.Mock(return_value=True))
pr._drain_waiter = waiter
pr._paused = True
pr.resume_writing()
assert not pr._paused
assert waiter.mock_calls == [mock.call.done()]
async def test_connection_made() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
assert pr.transport is None
pr.connection_made(tr)
assert pr.transport is not None
async def test_connection_lost_not_paused() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert pr.connected
pr.connection_lost(None)
assert pr.transport is None
assert not pr.connected
async def test_connection_lost_paused_without_waiter() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert pr.connected
pr.pause_writing()
pr.connection_lost(None)
assert pr.transport is None
assert not pr.connected
async def test_connection_lost_waiter_done() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
pr._paused = True
waiter = mock.Mock(done=mock.Mock(return_value=True))
pr._drain_waiter = waiter
pr.connection_lost(None)
assert pr._drain_waiter is None
assert waiter.mock_calls == [mock.call.done()] # type: ignore[unreachable]
async def test_drain_lost() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.connection_lost(None)
with pytest.raises(ConnectionResetError):
await pr._drain_helper()
async def test_drain_not_paused() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert pr._drain_waiter is None
await pr._drain_helper()
assert pr._drain_waiter is None
async def test_resume_drain_waited() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()
t = loop.create_task(pr._drain_helper())
await asyncio.sleep(0)
assert pr._drain_waiter is not None
pr.resume_writing()
await t
assert pr._drain_waiter is None
async def test_lost_drain_waited_ok() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()
t = loop.create_task(pr._drain_helper())
await asyncio.sleep(0)
assert pr._drain_waiter is not None
pr.connection_lost(None)
await t
assert pr._drain_waiter is None
async def test_lost_drain_waited_exception() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()
t = loop.create_task(pr._drain_helper())
await asyncio.sleep(0)
assert pr._drain_waiter is not None
exc = RuntimeError()
pr.connection_lost(exc)
with pytest.raises(ConnectionError, match=r"^Connection lost$") as cm:
await t
assert cm.value.__cause__ is exc
assert pr._drain_waiter is None
async def test_lost_drain_cancelled() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()
fut = loop.create_future()
async def wait() -> None:
fut.set_result(None)
await pr._drain_helper()
t = loop.create_task(wait())
await fut
t.cancel()
assert pr._drain_waiter is not None
pr.connection_lost(None)
with suppress(asyncio.CancelledError):
await t
assert pr._drain_waiter is None
async def test_resume_drain_cancelled() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()
fut = loop.create_future()
async def wait() -> None:
fut.set_result(None)
await pr._drain_helper()
t = loop.create_task(wait())
await fut
t.cancel()
assert pr._drain_waiter is not None
pr.resume_writing()
with suppress(asyncio.CancelledError):
await t
assert pr._drain_waiter is None
async def test_cancelled_drain_no_unhandled_future_warning() -> None:
"""Cancelling a task during backpressure must not leave an orphaned future.
When the handler task is cancelled while awaiting _drain_helper and
connection_lost fires with an exception afterward, the waiter should
already be done (cancelled) so set_exception is skipped. No "Future
exception was never retrieved" warning should appear.
Regression test for https://github.com/aio-libs/aiohttp/issues/12281
"""
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()
fut = loop.create_future()
async def wait() -> None:
fut.set_result(None)
await pr._drain_helper()
t = loop.create_task(wait())
await fut
t.cancel()
with suppress(asyncio.CancelledError):
await t
# After cancellation the waiter should be done (cancelled), so
# connection_lost with an exception must not call set_exception.
assert pr._drain_waiter is not None
waiter = pr._drain_waiter
assert waiter.done(), "waiter must be cancelled when task is cancelled"
# This previously left an orphaned future with an unhandled exception
# because asyncio.shield kept the original waiter alive and uncancelled.
exc = RuntimeError("connection died")
pr.connection_lost(exc)
assert pr._drain_waiter is None
# Verify the waiter is cancelled, not set with an exception.
assert waiter.cancelled() # type: ignore[unreachable]
async def test_parallel_drain_race_condition() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
pr.pause_writing()
ts = [loop.create_task(pr._drain_helper()) for _ in range(5)]
assert not (await asyncio.wait(ts, timeout=0.5))[
0
], "All draining tasks must be pending"
assert pr._drain_waiter is not None
pr.resume_writing()
await asyncio.gather(*ts)
assert pr._drain_waiter is None