Skip to content

Commit 772dfb8

Browse files
davidbrochartpre-commit-ci[bot]blink1073
authored
Replace Tornado with AnyIO (#1079)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Silvester <[email protected]> Co-authored-by: Steven Silvester <[email protected]>
1 parent 830829f commit 772dfb8

30 files changed

+878
-831
lines changed

Diff for: .github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ jobs:
8484
- name: Run Linters
8585
run: |
8686
hatch run typing:test
87+
pipx run interrogate -vv . --fail-under 90
8788
hatch run lint:build
88-
pipx run interrogate -vv .
8989
pipx run doc8 --max-line-length=200
9090
9191
check_release:

Diff for: docs/api/ipykernel.inprocess.rst

+6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ Submodules
4141
:show-inheritance:
4242

4343

44+
.. automodule:: ipykernel.inprocess.session
45+
:members:
46+
:undoc-members:
47+
:show-inheritance:
48+
49+
4450
.. automodule:: ipykernel.inprocess.socket
4551
:members:
4652
:undoc-members:

Diff for: examples/embedding/inprocess_terminal.py

+4-36
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""An in-process terminal example."""
22
import os
3-
import sys
43

5-
import tornado
4+
from anyio import run
65
from jupyter_console.ptshell import ZMQTerminalInteractiveShell
76

87
from ipykernel.inprocess.manager import InProcessKernelManager
@@ -13,46 +12,15 @@ def print_process_id():
1312
print("Process ID is:", os.getpid())
1413

1514

16-
def init_asyncio_patch():
17-
"""set default asyncio policy to be compatible with tornado
18-
Tornado 6 (at least) is not compatible with the default
19-
asyncio implementation on Windows
20-
Pick the older SelectorEventLoopPolicy on Windows
21-
if the known-incompatible default policy is in use.
22-
do this as early as possible to make it a low priority and overridable
23-
ref: https://github.com/tornadoweb/tornado/issues/2608
24-
FIXME: if/when tornado supports the defaults in asyncio,
25-
remove and bump tornado requirement for py38
26-
"""
27-
if (
28-
sys.platform.startswith("win")
29-
and sys.version_info >= (3, 8)
30-
and tornado.version_info < (6, 1)
31-
):
32-
import asyncio
33-
34-
try:
35-
from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy
36-
except ImportError:
37-
pass
38-
# not affected
39-
else:
40-
if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
41-
# WindowsProactorEventLoopPolicy is not compatible with tornado 6
42-
# fallback to the pre-3.8 default of Selector
43-
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
44-
45-
46-
def main():
15+
async def main():
4716
"""The main function."""
4817
print_process_id()
4918

5019
# Create an in-process kernel
5120
# >>> print_process_id()
5221
# will print the same process ID as the main process
53-
init_asyncio_patch()
5422
kernel_manager = InProcessKernelManager()
55-
kernel_manager.start_kernel()
23+
await kernel_manager.start_kernel()
5624
kernel = kernel_manager.kernel
5725
kernel.gui = "qt4"
5826
kernel.shell.push({"foo": 43, "print_process_id": print_process_id})
@@ -64,4 +32,4 @@ def main():
6432

6533

6634
if __name__ == "__main__":
67-
main()
35+
run(main)

Diff for: ipykernel/control.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""A thread for a control channel."""
2-
from threading import Thread
2+
from threading import Event, Thread
33

4-
from tornado.ioloop import IOLoop
4+
from anyio import create_task_group, run, to_thread
55

66
CONTROL_THREAD_NAME = "Control"
77

@@ -12,21 +12,29 @@ class ControlThread(Thread):
1212
def __init__(self, **kwargs):
1313
"""Initialize the thread."""
1414
Thread.__init__(self, name=CONTROL_THREAD_NAME, **kwargs)
15-
self.io_loop = IOLoop(make_current=False)
1615
self.pydev_do_not_trace = True
1716
self.is_pydev_daemon_thread = True
17+
self.__stop = Event()
18+
self._task = None
19+
20+
def set_task(self, task):
21+
self._task = task
1822

1923
def run(self):
2024
"""Run the thread."""
2125
self.name = CONTROL_THREAD_NAME
22-
try:
23-
self.io_loop.start()
24-
finally:
25-
self.io_loop.close()
26+
run(self._main)
27+
28+
async def _main(self):
29+
async with create_task_group() as tg:
30+
if self._task is not None:
31+
tg.start_soon(self._task)
32+
await to_thread.run_sync(self.__stop.wait)
33+
tg.cancel_scope.cancel()
2634

2735
def stop(self):
2836
"""Stop the thread.
2937
3038
This method is threadsafe.
3139
"""
32-
self.io_loop.add_callback(self.io_loop.stop)
40+
self.__stop.set()

Diff for: ipykernel/debugger.py

+40-26
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import re
44
import sys
55
import typing as t
6+
from math import inf
67
from pathlib import Path
78

89
import zmq
10+
from anyio import Event, create_memory_object_stream
911
from IPython.core.getipython import get_ipython
1012
from IPython.core.inputtransformer2 import leading_empty_lines
11-
from tornado.locks import Event
12-
from tornado.queues import Queue
1313
from zmq.utils import jsonapi
1414

1515
try:
@@ -117,7 +117,9 @@ def __init__(self, event_callback, log):
117117
self.tcp_buffer = ""
118118
self._reset_tcp_pos()
119119
self.event_callback = event_callback
120-
self.message_queue: Queue[t.Any] = Queue()
120+
self.message_send_stream, self.message_receive_stream = create_memory_object_stream[dict](
121+
max_buffer_size=inf
122+
)
121123
self.log = log
122124

123125
def _reset_tcp_pos(self):
@@ -136,7 +138,7 @@ def _put_message(self, raw_msg):
136138
else:
137139
self.log.debug("QUEUE - put message:")
138140
self.log.debug(msg)
139-
self.message_queue.put_nowait(msg)
141+
self.message_send_stream.send_nowait(msg)
140142

141143
def put_tcp_frame(self, frame):
142144
"""Put a tcp frame in the queue."""
@@ -187,25 +189,31 @@ def put_tcp_frame(self, frame):
187189

188190
async def get_message(self):
189191
"""Get a message from the queue."""
190-
return await self.message_queue.get()
192+
return await self.message_receive_stream.receive()
191193

192194

193195
class DebugpyClient:
194196
"""A client for debugpy."""
195197

196-
def __init__(self, log, debugpy_stream, event_callback):
198+
def __init__(self, log, debugpy_socket, event_callback):
197199
"""Initialize the client."""
198200
self.log = log
199-
self.debugpy_stream = debugpy_stream
201+
self.debugpy_socket = debugpy_socket
200202
self.event_callback = event_callback
201203
self.message_queue = DebugpyMessageQueue(self._forward_event, self.log)
202204
self.debugpy_host = "127.0.0.1"
203205
self.debugpy_port = -1
204206
self.routing_id = None
205207
self.wait_for_attach = True
206-
self.init_event = Event()
208+
self._init_event = None
207209
self.init_event_seq = -1
208210

211+
@property
212+
def init_event(self):
213+
if self._init_event is None:
214+
self._init_event = Event()
215+
return self._init_event
216+
209217
def _get_endpoint(self):
210218
host, port = self.get_host_port()
211219
return "tcp://" + host + ":" + str(port)
@@ -216,9 +224,9 @@ def _forward_event(self, msg):
216224
self.init_event_seq = msg["seq"]
217225
self.event_callback(msg)
218226

219-
def _send_request(self, msg):
227+
async def _send_request(self, msg):
220228
if self.routing_id is None:
221-
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
229+
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
222230
content = jsonapi.dumps(
223231
msg,
224232
default=json_default,
@@ -233,7 +241,7 @@ def _send_request(self, msg):
233241
self.log.debug("DEBUGPYCLIENT:")
234242
self.log.debug(self.routing_id)
235243
self.log.debug(buf)
236-
self.debugpy_stream.send_multipart((self.routing_id, buf))
244+
await self.debugpy_socket.send_multipart((self.routing_id, buf))
237245

238246
async def _wait_for_response(self):
239247
# Since events are never pushed to the message_queue
@@ -251,7 +259,7 @@ async def _handle_init_sequence(self):
251259
"seq": int(self.init_event_seq) + 1,
252260
"command": "configurationDone",
253261
}
254-
self._send_request(configurationDone)
262+
await self._send_request(configurationDone)
255263

256264
# 3] Waits for configurationDone response
257265
await self._wait_for_response()
@@ -262,7 +270,7 @@ async def _handle_init_sequence(self):
262270
def get_host_port(self):
263271
"""Get the host debugpy port."""
264272
if self.debugpy_port == -1:
265-
socket = self.debugpy_stream.socket
273+
socket = self.debugpy_socket
266274
socket.bind_to_random_port("tcp://" + self.debugpy_host)
267275
self.endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode("utf-8")
268276
socket.unbind(self.endpoint)
@@ -272,14 +280,13 @@ def get_host_port(self):
272280

273281
def connect_tcp_socket(self):
274282
"""Connect to the tcp socket."""
275-
self.debugpy_stream.socket.connect(self._get_endpoint())
276-
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
283+
self.debugpy_socket.connect(self._get_endpoint())
284+
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
277285

278286
def disconnect_tcp_socket(self):
279287
"""Disconnect from the tcp socket."""
280-
self.debugpy_stream.socket.disconnect(self._get_endpoint())
288+
self.debugpy_socket.disconnect(self._get_endpoint())
281289
self.routing_id = None
282-
self.init_event = Event()
283290
self.init_event_seq = -1
284291
self.wait_for_attach = True
285292

@@ -289,7 +296,7 @@ def receive_dap_frame(self, frame):
289296

290297
async def send_dap_request(self, msg):
291298
"""Send a dap request."""
292-
self._send_request(msg)
299+
await self._send_request(msg)
293300
if self.wait_for_attach and msg["command"] == "attach":
294301
rep = await self._handle_init_sequence()
295302
self.wait_for_attach = False
@@ -325,17 +332,19 @@ class Debugger:
325332
]
326333

327334
def __init__(
328-
self, log, debugpy_stream, event_callback, shell_socket, session, just_my_code=True
335+
self, log, debugpy_socket, event_callback, shell_socket, session, just_my_code=True
329336
):
330337
"""Initialize the debugger."""
331338
self.log = log
332-
self.debugpy_client = DebugpyClient(log, debugpy_stream, self._handle_event)
339+
self.debugpy_client = DebugpyClient(log, debugpy_socket, self._handle_event)
333340
self.shell_socket = shell_socket
334341
self.session = session
335342
self.is_started = False
336343
self.event_callback = event_callback
337344
self.just_my_code = just_my_code
338-
self.stopped_queue: Queue[t.Any] = Queue()
345+
self.stopped_send_stream, self.stopped_receive_stream = create_memory_object_stream[dict](
346+
max_buffer_size=inf
347+
)
339348

340349
self.started_debug_handlers = {}
341350
for msg_type in Debugger.started_debug_msg_types:
@@ -360,7 +369,7 @@ def __init__(
360369
def _handle_event(self, msg):
361370
if msg["event"] == "stopped":
362371
if msg["body"]["allThreadsStopped"]:
363-
self.stopped_queue.put_nowait(msg)
372+
self.stopped_send_stream.send_nowait(msg)
364373
# Do not forward the event now, will be done in the handle_stopped_event
365374
return
366375
self.stopped_threads.add(msg["body"]["threadId"])
@@ -398,7 +407,7 @@ async def handle_stopped_event(self):
398407
"""Handle a stopped event."""
399408
# Wait for a stopped event message in the stopped queue
400409
# This message is used for triggering the 'threads' request
401-
event = await self.stopped_queue.get()
410+
event = await self.stopped_receive_stream.receive()
402411
req = {"seq": event["seq"] + 1, "type": "request", "command": "threads"}
403412
rep = await self._forward_message(req)
404413
for thread in rep["body"]["threads"]:
@@ -410,7 +419,7 @@ async def handle_stopped_event(self):
410419
def tcp_client(self):
411420
return self.debugpy_client
412421

413-
def start(self):
422+
async def start(self):
414423
"""Start the debugger."""
415424
if not self.debugpy_initialized:
416425
tmp_dir = get_tmp_directory()
@@ -428,7 +437,12 @@ def start(self):
428437
(self.shell_socket.getsockopt(ROUTING_ID)),
429438
)
430439

431-
ident, msg = self.session.recv(self.shell_socket, mode=0)
440+
msg = await self.shell_socket.recv_multipart()
441+
ident, msg = self.session.feed_identities(msg, copy=True)
442+
try:
443+
msg = self.session.deserialize(msg, content=True, copy=True)
444+
except Exception:
445+
self.log.error("Invalid message", exc_info=True) # noqa: G201
432446
self.debugpy_initialized = msg["content"]["status"] == "ok"
433447

434448
# Don't remove leading empty lines when debugging so the breakpoints are correctly positioned
@@ -714,7 +728,7 @@ async def process_request(self, message):
714728
if self.is_started:
715729
self.log.info("The debugger has already started")
716730
else:
717-
self.is_started = self.start()
731+
self.is_started = await self.start()
718732
if self.is_started:
719733
self.log.info("The debugger has started")
720734
else:

Diff for: ipykernel/eventloops.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -415,13 +415,12 @@ def loop_asyncio(kernel):
415415
loop._should_close = False # type:ignore[attr-defined]
416416

417417
# pause eventloop when there's an event on a zmq socket
418-
def process_stream_events(stream):
418+
def process_stream_events(socket):
419419
"""fall back to main loop when there's a socket event"""
420-
if stream.flush(limit=1):
421-
loop.stop()
420+
loop.stop()
422421

423-
notifier = partial(process_stream_events, kernel.shell_stream)
424-
loop.add_reader(kernel.shell_stream.getsockopt(zmq.FD), notifier)
422+
notifier = partial(process_stream_events, kernel.shell_socket)
423+
loop.add_reader(kernel.shell_socket.getsockopt(zmq.FD), notifier)
425424
loop.call_soon(notifier)
426425

427426
while True:

Diff for: ipykernel/inprocess/blocking.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ class BlockingInProcessKernelClient(InProcessKernelClient):
8080
iopub_channel_class = Type(BlockingInProcessChannel) # type:ignore[arg-type]
8181
stdin_channel_class = Type(BlockingInProcessStdInChannel) # type:ignore[arg-type]
8282

83-
def wait_for_ready(self):
83+
async def wait_for_ready(self):
8484
"""Wait for kernel info reply on shell channel."""
8585
while True:
86-
self.kernel_info()
86+
await self.kernel_info()
8787
try:
8888
msg = self.shell_channel.get_msg(block=True, timeout=1)
8989
except Empty:
@@ -103,6 +103,5 @@ def wait_for_ready(self):
103103
while True:
104104
try:
105105
msg = self.iopub_channel.get_msg(block=True, timeout=0.2)
106-
print(msg["msg_type"])
107106
except Empty:
108107
break

0 commit comments

Comments
 (0)