Skip to content

Commit f5fd6d7

Browse files
authored
Support strict_exception_groups=True (#188)
Fixes #132 and #187 * Changes `open_websocket` to only raise a single exception, even when running under `strict_exception_groups=True` * [ ] Should maybe introduce special handling for `KeyboardInterrupt`s * If multiple non-Cancelled-exceptions are encountered, then it will raise `TrioWebSocketInternalError` with the exceptiongroup as its `__cause__`. This should only be possible if the background task and the user context both raise exceptions. This would previously raise a `MultiError` with both Exceptions. * other alternatives could include throwing out the exception from the background task, raising an ExceptionGroup with both errors, or trying to do something fancy with `__cause__` or `__context__`. * `WebSocketServer.run` and `WebSocketServer._handle_connection` are the other two places that opens a nursery. I've opted not to change these, since I don't think user code should expect any special exceptions from it, and it seems less obscure that it might contain an internal nursery. * [ ] Update docstrings to mention existence of internal nursery.
1 parent 3294748 commit f5fd6d7

File tree

6 files changed

+238
-19
lines changed

6 files changed

+238
-19
lines changed

Diff for: .github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ jobs:
7474
strategy:
7575
matrix:
7676
os: [ubuntu-latest]
77-
python-version: ['3.12']
77+
python-version: ['3.13-dev']
7878
steps:
7979
- uses: actions/checkout@v3
8080
- name: Setup Python

Diff for: requirements-dev-full.txt

+2-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ build==1.2.1
2121
# via pip-tools
2222
certifi==2024.6.2
2323
# via requests
24-
cffi==1.16.0
24+
cffi==1.17.0
2525
# via cryptography
2626
charset-normalizer==3.3.2
2727
# via requests
@@ -194,9 +194,8 @@ tomli==2.0.1
194194
# pytest
195195
tomlkit==0.12.5
196196
# via pylint
197-
trio==0.24.0
197+
trio==0.25.1
198198
# via
199-
# -r requirements-dev.in
200199
# pytest-trio
201200
# trio-websocket (setup.py)
202201
trustme==1.1.0

Diff for: requirements-dev.in

-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,4 @@ pip-tools>=5.5.0
44
pytest>=4.6
55
pytest-cov
66
pytest-trio>=0.5.0
7-
trio<0.25
87
trustme

Diff for: requirements-dev.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,8 @@ tomli==2.0.1
7171
# coverage
7272
# pip-tools
7373
# pytest
74-
trio==0.24.0
74+
trio==0.25.1
7575
# via
76-
# -r requirements-dev.in
7776
# pytest-trio
7877
# trio-websocket (setup.py)
7978
trustme==1.1.0

Diff for: tests/test_connection.py

+116-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
from __future__ import annotations
3333

3434
from functools import partial, wraps
35+
import re
3536
import ssl
37+
import sys
3638
from unittest.mock import patch
3739

3840
import attr
@@ -48,6 +50,13 @@
4850
except ImportError:
4951
from trio.hazmat import current_task # type: ignore # pylint: disable=ungrouped-imports
5052

53+
54+
# only available on trio>=0.25, we don't use it when testing lower versions
55+
try:
56+
from trio.testing import RaisesGroup
57+
except ImportError:
58+
pass
59+
5160
from trio_websocket import (
5261
connect_websocket,
5362
connect_websocket_url,
@@ -60,12 +69,18 @@
6069
open_websocket,
6170
open_websocket_url,
6271
serve_websocket,
72+
WebSocketConnection,
6373
WebSocketServer,
6474
WebSocketRequest,
6575
wrap_client_stream,
6676
wrap_server_stream
6777
)
6878

79+
from trio_websocket._impl import _TRIO_EXC_GROUP_TYPE
80+
81+
if sys.version_info < (3, 11):
82+
from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin
83+
6984
WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.')))
7085

7186
HOST = '127.0.0.1'
@@ -428,6 +443,92 @@ async def handler(request):
428443
assert header_value == b'My test header'
429444

430445

446+
447+
448+
@fail_after(5)
449+
async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock):
450+
"""_reader_task._handle_ping_event triggers KeyboardInterrupt.
451+
user code also raises exception.
452+
Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup
453+
"""
454+
async def ki_raising_ping_handler(*args, **kwargs) -> None:
455+
print("raising ki")
456+
raise KeyboardInterrupt
457+
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler)
458+
async def handler(request):
459+
server_ws = await request.accept()
460+
await server_ws.ping(b"a")
461+
462+
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
463+
with pytest.raises(KeyboardInterrupt) as exc_info:
464+
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
465+
with trio.fail_after(1) as cs:
466+
cs.shield = True
467+
await trio.sleep(2)
468+
469+
e_cause = exc_info.value.__cause__
470+
assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE)
471+
assert any(isinstance(e, trio.TooSlowError) for e in e_cause.exceptions)
472+
473+
@fail_after(5)
474+
async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock):
475+
"""_reader_task._handle_ping_event triggers ValueError.
476+
user code also raises exception.
477+
internal exception is in __cause__ exceptiongroup and user exc is delivered
478+
"""
479+
my_value_error = ValueError()
480+
async def raising_ping_event(*args, **kwargs) -> None:
481+
raise my_value_error
482+
483+
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event)
484+
async def handler(request):
485+
server_ws = await request.accept()
486+
await server_ws.ping(b"a")
487+
488+
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
489+
with pytest.raises(trio.TooSlowError) as exc_info:
490+
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
491+
with trio.fail_after(1) as cs:
492+
cs.shield = True
493+
await trio.sleep(2)
494+
495+
e_cause = exc_info.value.__cause__
496+
assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE)
497+
assert my_value_error in e_cause.exceptions
498+
499+
@fail_after(5)
500+
async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock):
501+
"""Both user code and _reader_task raise Cancellation.
502+
Check that open_websocket reraises the one from user code for traceback reasons.
503+
"""
504+
505+
506+
async def sleeping_ping_event(*args, **kwargs) -> None:
507+
await trio.sleep_forever()
508+
509+
# We monkeypatch WebSocketConnection._handle_ping_event to ensure it will actually
510+
# raise Cancelled upon being cancelled. For some reason it doesn't otherwise.
511+
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", sleeping_ping_event)
512+
async def handler(request):
513+
server_ws = await request.accept()
514+
await server_ws.ping(b"a")
515+
user_cancelled = None
516+
517+
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
518+
with trio.move_on_after(2):
519+
with pytest.raises(trio.Cancelled) as exc_info:
520+
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
521+
try:
522+
await trio.sleep_forever()
523+
except trio.Cancelled as e:
524+
user_cancelled = e
525+
raise
526+
assert exc_info.value is user_cancelled
527+
528+
def _trio_default_non_strict_exception_groups() -> bool:
529+
assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme"
530+
return int(trio.__version__[2:4]) < 25
531+
431532
@fail_after(1)
432533
async def test_handshake_exception_before_accept() -> None:
433534
''' In #107, a request handler that throws an exception before finishing the
@@ -436,14 +537,28 @@ async def test_handshake_exception_before_accept() -> None:
436537
async def handler(request):
437538
raise ValueError()
438539

439-
with pytest.raises(ValueError):
540+
# pylint fails to resolve that BaseExceptionGroup will always be available
541+
with pytest.raises((BaseExceptionGroup, ValueError)) as exc: # pylint: disable=possibly-used-before-assignment
440542
async with trio.open_nursery() as nursery:
441543
server = await nursery.start(serve_websocket, handler, HOST, 0,
442544
None)
443545
async with open_websocket(HOST, server.port, RESOURCE,
444546
use_ssl=False):
445547
pass
446548

549+
if _trio_default_non_strict_exception_groups():
550+
assert isinstance(exc.value, ValueError)
551+
else:
552+
# there's 4 levels of nurseries opened, leading to 4 nested groups:
553+
# 1. this test
554+
# 2. WebSocketServer.run
555+
# 3. trio.serve_listeners
556+
# 4. WebSocketServer._handle_connection
557+
assert RaisesGroup(
558+
RaisesGroup(
559+
RaisesGroup(
560+
RaisesGroup(ValueError)))).matches(exc.value)
561+
447562

448563
@fail_after(1)
449564
async def test_reject_handshake(nursery):

Diff for: trio_websocket/_impl.py

+118-11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import urllib.parse
1414
from typing import Iterable, List, Optional, Union
1515

16+
import outcome
1617
import trio
1718
import trio.abc
1819
from wsproto import ConnectionType, WSConnection
@@ -35,7 +36,12 @@
3536
# pylint doesn't care about the version_info check, so need to ignore the warning
3637
from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin
3738

38-
_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22)
39+
_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22)
40+
41+
if _IS_TRIO_MULTI_ERROR:
42+
_TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member
43+
else:
44+
_TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment
3945

4046
CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds
4147
MESSAGE_QUEUE_SIZE = 1
@@ -44,6 +50,13 @@
4450
logger = logging.getLogger('trio-websocket')
4551

4652

53+
class TrioWebsocketInternalError(Exception):
54+
"""Raised as a fallback when open_websocket is unable to unwind an exceptiongroup
55+
into a single preferred exception. This should never happen, if it does then
56+
underlying assumptions about the internal code are incorrect.
57+
"""
58+
59+
4760
def _ignore_cancel(exc):
4861
return None if isinstance(exc, trio.Cancelled) else exc
4962

@@ -70,7 +83,7 @@ def __exit__(self, ty, value, tb):
7083
if value is None or not self._armed:
7184
return False
7285

73-
if _TRIO_MULTI_ERROR: # pragma: no cover
86+
if _IS_TRIO_MULTI_ERROR: # pragma: no cover
7487
filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member
7588
elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment
7689
filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled))
@@ -125,10 +138,33 @@ async def open_websocket(
125138
client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`),
126139
or server rejection (:exc:`ConnectionRejected`) during handshakes.
127140
'''
128-
async with trio.open_nursery() as new_nursery:
141+
142+
# This context manager tries very very hard not to raise an exceptiongroup
143+
# in order to be as transparent as possible for the end user.
144+
# In the trivial case, this means that if user code inside the cm raises
145+
# we make sure that it doesn't get wrapped.
146+
147+
# If opening the connection fails, then we will raise that exception. User
148+
# code is never executed, so we will never have multiple exceptions.
149+
150+
# After opening the connection, we spawn _reader_task in the background and
151+
# yield to user code. If only one of those raise a non-cancelled exception
152+
# we will raise that non-cancelled exception.
153+
# If we get multiple cancelled, we raise the user's cancelled.
154+
# If both raise exceptions, we raise the user code's exception with the entire
155+
# exception group as the __cause__.
156+
# If we somehow get multiple exceptions, but no user exception, then we raise
157+
# TrioWebsocketInternalError.
158+
159+
# If closing the connection fails, then that will be raised as the top
160+
# exception in the last `finally`. If we encountered exceptions in user code
161+
# or in reader task then they will be set as the `__cause__`.
162+
163+
164+
async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection:
129165
try:
130166
with trio.fail_after(connect_timeout):
131-
connection = await connect_websocket(new_nursery, host, port,
167+
return await connect_websocket(nursery, host, port,
132168
resource, use_ssl=use_ssl, subprotocols=subprotocols,
133169
extra_headers=extra_headers,
134170
message_queue_size=message_queue_size,
@@ -137,14 +173,85 @@ async def open_websocket(
137173
raise ConnectionTimeout from None
138174
except OSError as e:
139175
raise HandshakeError from e
176+
177+
async def _close_connection(connection: WebSocketConnection) -> None:
140178
try:
141-
yield connection
142-
finally:
143-
try:
144-
with trio.fail_after(disconnect_timeout):
145-
await connection.aclose()
146-
except trio.TooSlowError:
147-
raise DisconnectionTimeout from None
179+
with trio.fail_after(disconnect_timeout):
180+
await connection.aclose()
181+
except trio.TooSlowError:
182+
raise DisconnectionTimeout from None
183+
184+
connection: WebSocketConnection|None=None
185+
close_result: outcome.Maybe[None] | None = None
186+
user_error = None
187+
188+
try:
189+
async with trio.open_nursery() as new_nursery:
190+
result = await outcome.acapture(_open_connection, new_nursery)
191+
192+
if isinstance(result, outcome.Value):
193+
connection = result.unwrap()
194+
try:
195+
yield connection
196+
except BaseException as e:
197+
user_error = e
198+
raise
199+
finally:
200+
close_result = await outcome.acapture(_close_connection, connection)
201+
# This exception handler should only be entered if either:
202+
# 1. The _reader_task started in connect_websocket raises
203+
# 2. User code raises an exception
204+
# I.e. open/close_connection are not included
205+
except _TRIO_EXC_GROUP_TYPE as e:
206+
# user_error, or exception bubbling up from _reader_task
207+
if len(e.exceptions) == 1:
208+
raise e.exceptions[0]
209+
210+
# contains at most 1 non-cancelled exceptions
211+
exception_to_raise: BaseException|None = None
212+
for sub_exc in e.exceptions:
213+
if not isinstance(sub_exc, trio.Cancelled):
214+
if exception_to_raise is not None:
215+
# multiple non-cancelled
216+
break
217+
exception_to_raise = sub_exc
218+
else:
219+
if exception_to_raise is None:
220+
# all exceptions are cancelled
221+
# prefer raising the one from the user, for traceback reasons
222+
if user_error is not None:
223+
# no reason to raise from e, just to include a bunch of extra
224+
# cancelleds.
225+
raise user_error # pylint: disable=raise-missing-from
226+
# multiple internal Cancelled is not possible afaik
227+
raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from
228+
raise exception_to_raise
229+
230+
# if we have any KeyboardInterrupt in the group, make sure to raise it.
231+
for sub_exc in e.exceptions:
232+
if isinstance(sub_exc, KeyboardInterrupt):
233+
raise sub_exc from e
234+
235+
# Both user code and internal code raised non-cancelled exceptions.
236+
# We "hide" the internal exception(s) in the __cause__ and surface
237+
# the user_error.
238+
if user_error is not None:
239+
raise user_error from e
240+
241+
raise TrioWebsocketInternalError(
242+
"The trio-websocket API is not expected to raise multiple exceptions. "
243+
"Please report this as a bug to "
244+
"https://github.com/python-trio/trio-websocket"
245+
) from e # pragma: no cover
246+
247+
finally:
248+
if close_result is not None:
249+
close_result.unwrap()
250+
251+
252+
# error setting up, unwrap that exception
253+
if connection is None:
254+
result.unwrap()
148255

149256

150257
async def connect_websocket(nursery, host, port, resource, *, use_ssl,

0 commit comments

Comments
 (0)