13
13
import urllib .parse
14
14
from typing import Iterable , List , Optional , Union
15
15
16
+ import outcome
16
17
import trio
17
18
import trio .abc
18
19
from wsproto import ConnectionType , WSConnection
35
36
# pylint doesn't care about the version_info check, so need to ignore the warning
36
37
from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin
37
38
38
- _TRIO_MULTI_ERROR = tuple (map (int , trio .__version__ .split ('.' )[:2 ])) < (0 , 22 )
39
+ _IS_TRIO_MULTI_ERROR = tuple (map (int , trio .__version__ .split ('.' )[:2 ])) < (0 , 22 )
40
+
41
+ if _IS_TRIO_MULTI_ERROR :
42
+ _TRIO_EXC_GROUP_TYPE = trio .MultiError # type: ignore[attr-defined] # pylint: disable=no-member
43
+ else :
44
+ _TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment
39
45
40
46
CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds
41
47
MESSAGE_QUEUE_SIZE = 1
44
50
logger = logging .getLogger ('trio-websocket' )
45
51
46
52
53
+ class TrioWebsocketInternalError (Exception ):
54
+ """Raised as a fallback when open_websocket is unable to unwind an exceptiongroup
55
+ into a single preferred exception. This should never happen, if it does then
56
+ underlying assumptions about the internal code are incorrect.
57
+ """
58
+
59
+
47
60
def _ignore_cancel (exc ):
48
61
return None if isinstance (exc , trio .Cancelled ) else exc
49
62
@@ -70,7 +83,7 @@ def __exit__(self, ty, value, tb):
70
83
if value is None or not self ._armed :
71
84
return False
72
85
73
- if _TRIO_MULTI_ERROR : # pragma: no cover
86
+ if _IS_TRIO_MULTI_ERROR : # pragma: no cover
74
87
filtered_exception = trio .MultiError .filter (_ignore_cancel , value ) # pylint: disable=no-member
75
88
elif isinstance (value , BaseExceptionGroup ): # pylint: disable=possibly-used-before-assignment
76
89
filtered_exception = value .subgroup (lambda exc : not isinstance (exc , trio .Cancelled ))
@@ -125,10 +138,33 @@ async def open_websocket(
125
138
client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`),
126
139
or server rejection (:exc:`ConnectionRejected`) during handshakes.
127
140
'''
128
- async with trio .open_nursery () as new_nursery :
141
+
142
+ # This context manager tries very very hard not to raise an exceptiongroup
143
+ # in order to be as transparent as possible for the end user.
144
+ # In the trivial case, this means that if user code inside the cm raises
145
+ # we make sure that it doesn't get wrapped.
146
+
147
+ # If opening the connection fails, then we will raise that exception. User
148
+ # code is never executed, so we will never have multiple exceptions.
149
+
150
+ # After opening the connection, we spawn _reader_task in the background and
151
+ # yield to user code. If only one of those raise a non-cancelled exception
152
+ # we will raise that non-cancelled exception.
153
+ # If we get multiple cancelled, we raise the user's cancelled.
154
+ # If both raise exceptions, we raise the user code's exception with the entire
155
+ # exception group as the __cause__.
156
+ # If we somehow get multiple exceptions, but no user exception, then we raise
157
+ # TrioWebsocketInternalError.
158
+
159
+ # If closing the connection fails, then that will be raised as the top
160
+ # exception in the last `finally`. If we encountered exceptions in user code
161
+ # or in reader task then they will be set as the `__cause__`.
162
+
163
+
164
+ async def _open_connection (nursery : trio .Nursery ) -> WebSocketConnection :
129
165
try :
130
166
with trio .fail_after (connect_timeout ):
131
- connection = await connect_websocket (new_nursery , host , port ,
167
+ return await connect_websocket (nursery , host , port ,
132
168
resource , use_ssl = use_ssl , subprotocols = subprotocols ,
133
169
extra_headers = extra_headers ,
134
170
message_queue_size = message_queue_size ,
@@ -137,14 +173,85 @@ async def open_websocket(
137
173
raise ConnectionTimeout from None
138
174
except OSError as e :
139
175
raise HandshakeError from e
176
+
177
+ async def _close_connection (connection : WebSocketConnection ) -> None :
140
178
try :
141
- yield connection
142
- finally :
143
- try :
144
- with trio .fail_after (disconnect_timeout ):
145
- await connection .aclose ()
146
- except trio .TooSlowError :
147
- raise DisconnectionTimeout from None
179
+ with trio .fail_after (disconnect_timeout ):
180
+ await connection .aclose ()
181
+ except trio .TooSlowError :
182
+ raise DisconnectionTimeout from None
183
+
184
+ connection : WebSocketConnection | None = None
185
+ close_result : outcome .Maybe [None ] | None = None
186
+ user_error = None
187
+
188
+ try :
189
+ async with trio .open_nursery () as new_nursery :
190
+ result = await outcome .acapture (_open_connection , new_nursery )
191
+
192
+ if isinstance (result , outcome .Value ):
193
+ connection = result .unwrap ()
194
+ try :
195
+ yield connection
196
+ except BaseException as e :
197
+ user_error = e
198
+ raise
199
+ finally :
200
+ close_result = await outcome .acapture (_close_connection , connection )
201
+ # This exception handler should only be entered if either:
202
+ # 1. The _reader_task started in connect_websocket raises
203
+ # 2. User code raises an exception
204
+ # I.e. open/close_connection are not included
205
+ except _TRIO_EXC_GROUP_TYPE as e :
206
+ # user_error, or exception bubbling up from _reader_task
207
+ if len (e .exceptions ) == 1 :
208
+ raise e .exceptions [0 ]
209
+
210
+ # contains at most 1 non-cancelled exceptions
211
+ exception_to_raise : BaseException | None = None
212
+ for sub_exc in e .exceptions :
213
+ if not isinstance (sub_exc , trio .Cancelled ):
214
+ if exception_to_raise is not None :
215
+ # multiple non-cancelled
216
+ break
217
+ exception_to_raise = sub_exc
218
+ else :
219
+ if exception_to_raise is None :
220
+ # all exceptions are cancelled
221
+ # prefer raising the one from the user, for traceback reasons
222
+ if user_error is not None :
223
+ # no reason to raise from e, just to include a bunch of extra
224
+ # cancelleds.
225
+ raise user_error # pylint: disable=raise-missing-from
226
+ # multiple internal Cancelled is not possible afaik
227
+ raise e .exceptions [0 ] # pragma: no cover # pylint: disable=raise-missing-from
228
+ raise exception_to_raise
229
+
230
+ # if we have any KeyboardInterrupt in the group, make sure to raise it.
231
+ for sub_exc in e .exceptions :
232
+ if isinstance (sub_exc , KeyboardInterrupt ):
233
+ raise sub_exc from e
234
+
235
+ # Both user code and internal code raised non-cancelled exceptions.
236
+ # We "hide" the internal exception(s) in the __cause__ and surface
237
+ # the user_error.
238
+ if user_error is not None :
239
+ raise user_error from e
240
+
241
+ raise TrioWebsocketInternalError (
242
+ "The trio-websocket API is not expected to raise multiple exceptions. "
243
+ "Please report this as a bug to "
244
+ "https://github.com/python-trio/trio-websocket"
245
+ ) from e # pragma: no cover
246
+
247
+ finally :
248
+ if close_result is not None :
249
+ close_result .unwrap ()
250
+
251
+
252
+ # error setting up, unwrap that exception
253
+ if connection is None :
254
+ result .unwrap ()
148
255
149
256
150
257
async def connect_websocket (nursery , host , port , resource , * , use_ssl ,
0 commit comments