diff --git a/CHANGES/10544.bugfix.rst b/CHANGES/10544.bugfix.rst new file mode 100644 index 00000000000..2ccf54fc0ab --- /dev/null +++ b/CHANGES/10544.bugfix.rst @@ -0,0 +1 @@ +Check for PONG before closing connection for missing PONG -- by :user:`mstegmaier`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 3004ee5cd18..c759e39eb1b 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -256,6 +256,7 @@ Matvey Tingaev Meet Mangukiya Meshya Michael Ihnatenko +Michael Stegmaier Michał Górny Mikhail Burshteyn Mikhail Kashkin diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 78c130179f5..d3d4401f1e1 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -81,6 +81,7 @@ class WebSocketResponse(StreamResponse): _heartbeat_cb: Optional[asyncio.TimerHandle] = None _pong_response_cb: Optional[asyncio.TimerHandle] = None _ping_task: Optional[asyncio.Task[None]] = None + _pong_not_received_task: Optional[asyncio.Task[None]] = None def __init__( self, @@ -185,12 +186,60 @@ def _ping_task_done(self, task: "asyncio.Task[None]") -> None: self._ping_task = None def _pong_not_received(self) -> None: + """Callback for when no PONG was received after self._pong_heartbeat seconds""" if self._req is not None and self._req.transport is not None: - self._handle_ping_pong_exception( - asyncio.TimeoutError( - f"No PONG received after {self._pong_heartbeat} seconds" + loop = self._loop + if ( + loop is not None and not self._waiting + ): # If self._waiting is set we already are in the receive loop and would have read the PONG if one was there + pong_not_received_task = loop.create_task( + self._pong_not_received_coro() ) + if not pong_not_received_task.done(): + self._pong_not_received_task = pong_not_received_task + pong_not_received_task.add_done_callback( + self._pong_not_received_done + ) + else: + self._pong_not_received_done(pong_not_received_task) + else: + self._handle_ping_pong_exception( + asyncio.TimeoutError( + f"No PONG received after {self._pong_heartbeat} seconds" + ) + ) + + async def _pong_not_received_coro(self) -> None: + """Coroutine to check for pending PONG when no PONG was received after self._pong_heartbeat seconds""" + reader = self._reader + assert reader is not None + try: + async with async_timeout.timeout(self._pong_heartbeat / 10.0): + msg = await reader.read() + self._reset_heartbeat() + if msg.type is not WSMsgType.PONG: + ws_logger.warning( + f"Received {msg} while waiting for PONG. It seems like you haven't called `receive` within {self._pong_heartbeat} seconds." + ) + return + except asyncio.TimeoutError: # We still did not receive a PONG + pass + except Exception as exc: + self._exception = exc + self._set_closing(WSCloseCode.ABNORMAL_CLOSURE) + await self.close() + return + self._handle_ping_pong_exception( + asyncio.TimeoutError( + f"No PONG received after {self._pong_heartbeat} seconds" ) + ) + + def _pong_not_received_done(self, task: "asyncio.Task[None]") -> None: + """Callback for when the pong not received task completes.""" + if not task.cancelled() and (exc := task.exception()): + self._handle_ping_pong_exception(exc) + self._pong_not_received_task = None def _handle_ping_pong_exception(self, exc: BaseException) -> None: """Handle exceptions raised during ping/pong processing."""