Skip to content

Commit 14ac9f1

Browse files
committed
Cleanup in BatchedSend
1 parent 09e62e0 commit 14ac9f1

File tree

6 files changed

+32
-52
lines changed

6 files changed

+32
-52
lines changed

distributed/batched.py

+11-18
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
import logging
33
from collections import deque
44

5-
from tornado import gen, locks
6-
from tornado.ioloop import IOLoop
7-
85
import dask
96
from dask.utils import parse_timedelta
107

@@ -37,14 +34,9 @@ class BatchedSend:
3734
['Hello,', 'world!']
3835
"""
3936

40-
# XXX why doesn't BatchedSend follow either the IOStream or Comm API?
41-
42-
def __init__(self, interval, loop=None, serializers=None):
43-
# XXX is the loop arg useful?
44-
self.loop = loop or IOLoop.current()
37+
def __init__(self, interval, serializers=None):
4538
self.interval = parse_timedelta(interval, default="ms")
46-
self.waker = locks.Event()
47-
self.stopped = locks.Event()
39+
self.waker = asyncio.Event()
4840
self.please_stop = False
4941
self.buffer = []
5042
self.comm = None
@@ -62,7 +54,6 @@ def start(self, comm):
6254
if self._background_task and not self._background_task.done():
6355
raise RuntimeError("Background task still running")
6456
self.please_stop = False
65-
self.stopped.clear()
6657
self.waker.set()
6758
self.next_deadline = None
6859
self.comm = comm
@@ -86,9 +77,12 @@ def __repr__(self):
8677
async def _background_send(self):
8778
while not self.please_stop:
8879
try:
89-
await self.waker.wait(self.next_deadline)
80+
timeout = None
81+
if self.next_deadline:
82+
timeout = self.next_deadline - time()
83+
await asyncio.wait_for(self.waker.wait(), timeout=timeout)
9084
self.waker.clear()
91-
except gen.TimeoutError:
85+
except asyncio.TimeoutError:
9286
pass
9387
if not self.buffer:
9488
# Nothing to send
@@ -100,6 +94,7 @@ async def _background_send(self):
10094
payload, self.buffer = self.buffer, []
10195
self.batch_count += 1
10296
self.next_deadline = time() + self.interval
97+
10398
try:
10499
nbytes = await self.comm.write(
105100
payload, serializers=self.serializers, on_error="raise"
@@ -154,13 +149,11 @@ def send(self, *msgs: dict) -> None:
154149
if self.comm and not self.comm.closed() and self.next_deadline is None:
155150
self.waker.set()
156151

157-
async def close(self, timeout=None):
158-
"""Flush existing messages and then close comm
159-
160-
If set, raises `tornado.util.TimeoutError` after a timeout.
161-
"""
152+
async def close(self):
153+
"""Flush existing messages and then close comm"""
162154
self.please_stop = True
163155
self.waker.set()
156+
164157
if self._background_task:
165158
await self._background_task
166159

distributed/client.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ async def _ensure_connected(self, timeout=None):
12871287
if msg[0].get("warning"):
12881288
warnings.warn(version_module.VersionMismatchWarning(msg[0]["warning"]))
12891289

1290-
bcomm = BatchedSend(interval="10ms", loop=self.loop)
1290+
bcomm = BatchedSend(interval="10ms")
12911291
bcomm.start(comm)
12921292
self.scheduler_comm = bcomm
12931293
if self._set_as_default:
@@ -1523,11 +1523,7 @@ async def _close(self, fast=False):
15231523
with suppress(asyncio.CancelledError, TimeoutError):
15241524
await asyncio.wait_for(asyncio.shield(handle_report_task), 0.1)
15251525

1526-
if (
1527-
self.scheduler_comm
1528-
and self.scheduler_comm.comm
1529-
and not self.scheduler_comm.comm.closed()
1530-
):
1526+
if self.scheduler_comm:
15311527
await self.scheduler_comm.close()
15321528

15331529
for key in list(self.futures):

distributed/scheduler.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -3500,7 +3500,7 @@ async def close(self, fast=False, close_workers=False):
35003500
await future
35013501

35023502
for comm in self.client_comms.values():
3503-
comm.abort()
3503+
await comm.close()
35043504

35053505
await self.rpc.close()
35063506

@@ -3732,7 +3732,7 @@ async def add_worker(
37323732
# for key in keys: # TODO
37333733
# self.mark_key_in_memory(key, [address])
37343734

3735-
self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop)
3735+
self.stream_comms[address] = BatchedSend(interval="5ms")
37363736

37373737
if ws.nthreads > len(ws.processing):
37383738
self.idle[ws.address] = ws
@@ -4316,7 +4316,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True):
43164316

43174317
logger.info("Remove worker %s", ws)
43184318
if close:
4319-
with suppress(AttributeError, CommClosedError):
4319+
with suppress(AttributeError):
43204320
self.stream_comms[address].send({"op": "close", "report": False})
43214321

43224322
self.remove_resources(address)
@@ -4330,6 +4330,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True):
43304330
del self.host_info[host]
43314331

43324332
self.rpc.remove(address)
4333+
await self.stream_comms[address].close()
43334334
del self.stream_comms[address]
43344335
del self.aliases[ws.name]
43354336
self.idle.pop(ws.address, None)
@@ -4684,7 +4685,7 @@ async def add_client(
46844685
logger.exception(e)
46854686

46864687
try:
4687-
bcomm = BatchedSend(interval="2ms", loop=self.loop)
4688+
bcomm = BatchedSend(interval="2ms")
46884689
bcomm.start(comm)
46894690
self.client_comms[client] = bcomm
46904691
msg = {"op": "stream-start"}
@@ -5032,13 +5033,7 @@ def client_send(self, client, msg):
50325033
c = client_comms.get(client)
50335034
if c is None:
50345035
return
5035-
try:
5036-
c.send(msg)
5037-
except CommClosedError:
5038-
if self.status == Status.running:
5039-
logger.critical(
5040-
"Closed comm %r while trying to write %s", c, msg, exc_info=True
5041-
)
5036+
c.send(msg)
50425037

50435038
def send_all(self, client_msgs: dict, worker_msgs: dict):
50445039
"""Send messages to client and workers"""
@@ -5068,7 +5063,7 @@ def send_all(self, client_msgs: dict, worker_msgs: dict):
50685063
except KeyError:
50695064
# worker already gone
50705065
pass
5071-
except (CommClosedError, AttributeError):
5066+
except AttributeError:
50725067
self.loop.add_callback(
50735068
self.remove_worker,
50745069
address=worker,

distributed/tests/test_batched.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,8 @@ async def test_restart():
259259
b = BatchedSend(interval="2ms")
260260
b.start(comm)
261261
b.send(123)
262-
assert (123,) == await comm.read()
263-
b.abort()
262+
assert await comm.read() == (123,)
263+
await b.close()
264264
assert b.closed()
265265

266266
# We can buffer stuff even while it is closed
@@ -269,7 +269,7 @@ async def test_restart():
269269
new_comm = await connect(e.address)
270270
b.start(new_comm)
271271

272-
assert (345,) == await new_comm.read()
272+
assert await new_comm.read() == (345,)
273273
await b.close()
274274
assert new_comm.closed()
275275

@@ -285,5 +285,5 @@ async def test_restart_fails_if_still_running():
285285
b.start(comm)
286286

287287
b.send(123)
288-
assert (123,) == await comm.read()
288+
assert await comm.read() == (123,)
289289
await b.close()

distributed/tests/test_worker.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -2387,10 +2387,6 @@ async def test_hold_on_to_replicas(c, s, *workers):
23872387
await asyncio.sleep(0.01)
23882388

23892389

2390-
@pytest.mark.xfail(
2391-
WINDOWS and sys.version_info[:2] == (3, 8),
2392-
reason="https://github.com/dask/distributed/issues/5621",
2393-
)
23942390
@gen_cluster(client=True, nthreads=[("", 1), ("", 1)])
23952391
async def test_worker_reconnects_mid_compute(c, s, a, b):
23962392
"""Ensure that, if a worker disconnects while computing a result, the scheduler will
@@ -2436,7 +2432,8 @@ def fast_on_a(lock):
24362432

24372433
await s.stream_comms[a.address].close()
24382434

2439-
assert len(s.workers) == 1
2435+
while len(s.workers) == 1:
2436+
await asyncio.sleep(0.1)
24402437
a.heartbeat_active = False
24412438
await a.heartbeat()
24422439
assert len(s.workers) == 2
@@ -2508,7 +2505,8 @@ def fast_on_a(lock):
25082505
# The only way to get f3 to complete is for Worker A to reconnect.
25092506

25102507
f1.release()
2511-
assert len(s.workers) == 1
2508+
while len(s.workers) != 1:
2509+
await asyncio.sleep(0.01)
25122510
story = s.story(f1.key)
25132511
while len(story) == len(story_before):
25142512
story = s.story(f1.key)

distributed/worker.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
)
2626
from concurrent.futures import Executor
2727
from contextlib import suppress
28-
from datetime import timedelta
2928
from inspect import isawaitable
3029
from pickle import PicklingError
3130
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
@@ -775,7 +774,7 @@ def __init__(
775774
self.nthreads, thread_name_prefix="Dask-Default-Threads"
776775
)
777776

778-
self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
777+
self.batched_stream = BatchedSend(interval="2ms")
779778
self.name = name
780779
self.scheduler_delay = 0
781780
self.stream_comms = {}
@@ -1269,7 +1268,7 @@ async def heartbeat(self):
12691268
async def handle_scheduler(self, comm):
12701269
await self.handle_stream(comm, every_cycle=[self.ensure_communicating])
12711270

1272-
self.batched_stream.abort()
1271+
await self.batched_stream.close()
12731272
if self.reconnect and self.status in Status.ANY_RUNNING:
12741273
logger.info("Connection to scheduler broken. Reconnecting...")
12751274
self.loop.add_callback(self.heartbeat)
@@ -1583,8 +1582,7 @@ async def close(
15831582
self.batched_stream.send({"op": "close-stream"})
15841583

15851584
if self.batched_stream:
1586-
with suppress(TimeoutError):
1587-
await self.batched_stream.close(timedelta(seconds=timeout))
1585+
await self.batched_stream.close()
15881586

15891587
for executor in self.executors.values():
15901588
if executor is utils._offload_executor:
@@ -1653,7 +1651,7 @@ async def wait_until_closed(self):
16531651

16541652
def send_to_worker(self, address, msg):
16551653
if address not in self.stream_comms:
1656-
bcomm = BatchedSend(interval="1ms", loop=self.loop)
1654+
bcomm = BatchedSend(interval="1ms")
16571655
self.stream_comms[address] = bcomm
16581656

16591657
async def batched_send_connect():

0 commit comments

Comments
 (0)