Skip to content

Commit 43fef59

Browse files
committed
Refactor gather_dep (#6388)
1 parent 26c6d44 commit 43fef59

File tree

3 files changed

+260
-105
lines changed

3 files changed

+260
-105
lines changed

distributed/tests/test_worker_state_machine.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
2+
13
import asyncio
4+
from collections.abc import Iterator
25
from contextlib import contextmanager
3-
from itertools import chain
46

57
import pytest
68

@@ -16,7 +18,6 @@
1618
ReleaseWorkerDataMsg,
1719
RescheduleEvent,
1820
RescheduleMsg,
19-
SendMessageToScheduler,
2021
StateMachineEvent,
2122
TaskState,
2223
TaskStateState,
@@ -103,14 +104,19 @@ def test_unique_task_heap():
103104
assert repr(heap) == "<UniqueTaskHeap: 0 items>"
104105

105106

107+
def traverse_subclasses(cls: type) -> Iterator[type]:
108+
yield cls
109+
for subcls in cls.__subclasses__():
110+
yield from traverse_subclasses(subcls)
111+
112+
106113
@pytest.mark.parametrize(
107114
"cls",
108-
chain(
109-
[UniqueTaskHeap],
110-
Instruction.__subclasses__(),
111-
SendMessageToScheduler.__subclasses__(),
112-
StateMachineEvent.__subclasses__(),
113-
),
115+
[
116+
UniqueTaskHeap,
117+
*traverse_subclasses(Instruction),
118+
*traverse_subclasses(StateMachineEvent),
119+
],
114120
)
115121
def test_slots(cls):
116122
params = [

distributed/worker.py

+188-96
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Collection,
2121
Container,
2222
Iterable,
23+
Iterator,
2324
Mapping,
2425
MutableMapping,
2526
)
@@ -117,7 +118,11 @@
117118
ExecuteSuccessEvent,
118119
FindMissingEvent,
119120
GatherDep,
121+
GatherDepBusyEvent,
120122
GatherDepDoneEvent,
123+
GatherDepErrorEvent,
124+
GatherDepNetworkFailureEvent,
125+
GatherDepSuccessEvent,
121126
Instructions,
122127
InvalidTransition,
123128
LongRunningMsg,
@@ -3256,7 +3261,6 @@ def _update_metrics_received_data(
32563261
self.counters["transfer-count"].add(len(data))
32573262
self.incoming_count += 1
32583263

3259-
@fail_hard
32603264
@log_errors
32613265
async def gather_dep(
32623266
self,
@@ -3282,13 +3286,6 @@ async def gather_dep(
32823286
if self.status not in WORKER_ANY_RUNNING:
32833287
return None
32843288

3285-
recommendations: Recs = {}
3286-
instructions: Instructions = []
3287-
response = {}
3288-
3289-
def done_event():
3290-
return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}")
3291-
32923289
try:
32933290
self.log.append(("request-dep", worker, to_gather, stimulus_id, time()))
32943291
logger.debug("Request %d keys from %s", len(to_gather), worker)
@@ -3299,42 +3296,32 @@ def done_event():
32993296
)
33003297
stop = time()
33013298
if response["status"] == "busy":
3302-
return done_event()
3299+
return GatherDepBusyEvent(
3300+
worker=worker, total_nbytes=total_nbytes, stimulus_id=stimulus_id
3301+
)
33033302

3304-
cause = self._get_cause(to_gather)
3305-
self._update_metrics_received_data(
3306-
start=start,
3307-
stop=stop,
3308-
data=response["data"],
3309-
cause=cause,
3303+
assert response["status"] == "OK"
3304+
if response["data"]:
3305+
cause = self._get_cause(response["data"])
3306+
self._update_metrics_received_data(
3307+
start=start,
3308+
stop=stop,
3309+
data=response["data"],
3310+
cause=cause,
3311+
worker=worker,
3312+
)
3313+
3314+
return GatherDepSuccessEvent(
33103315
worker=worker,
3316+
total_nbytes=total_nbytes,
3317+
data=response["data"],
3318+
stimulus_id=stimulus_id,
33113319
)
3312-
self.log.append(
3313-
("receive-dep", worker, set(response["data"]), stimulus_id, time())
3314-
)
3315-
return done_event()
33163320

33173321
except OSError:
3318-
logger.exception("Worker stream died during communication: %s", worker)
3319-
has_what = self.has_what.pop(worker)
3320-
self.data_needed_per_worker.pop(worker)
3321-
self.log.append(
3322-
("receive-dep-failed", worker, has_what, stimulus_id, time())
3322+
return GatherDepNetworkFailureEvent(
3323+
worker=worker, total_nbytes=total_nbytes, stimulus_id=stimulus_id
33233324
)
3324-
for d in has_what:
3325-
ts = self.tasks[d]
3326-
ts.who_has.remove(worker)
3327-
if not ts.who_has and ts.state in (
3328-
"fetch",
3329-
"flight",
3330-
"resumed",
3331-
"cancelled",
3332-
):
3333-
recommendations[ts] = "missing"
3334-
self.log.append(
3335-
("missing-who-has", worker, ts.key, stimulus_id, time())
3336-
)
3337-
return done_event()
33383325

33393326
except Exception as e:
33403327
logger.exception(e)
@@ -3343,61 +3330,15 @@ def done_event():
33433330

33443331
pdb.set_trace()
33453332
msg = error_message(e)
3346-
for k in self.in_flight_workers[worker]:
3347-
ts = self.tasks[k]
3348-
recommendations[ts] = tuple(msg.values())
3349-
return done_event()
3350-
3351-
finally:
3352-
self.comm_nbytes -= total_nbytes
3353-
busy = response.get("status", "") == "busy"
3354-
data = response.get("data", {})
3355-
3356-
if busy:
3357-
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
3358-
# Avoid hammering the worker. If there are multiple replicas
3359-
# available, immediately try fetching from a different worker.
3360-
self.busy_workers.add(worker)
3361-
instructions.append(
3362-
RetryBusyWorkerLater(worker=worker, stimulus_id=stimulus_id)
3363-
)
3364-
3365-
refresh_who_has = set()
3366-
3367-
for d in self.in_flight_workers.pop(worker):
3368-
ts = self.tasks[d]
3369-
ts.done = True
3370-
if d in data:
3371-
recommendations[ts] = ("memory", data[d])
3372-
elif busy:
3373-
recommendations[ts] = "fetch"
3374-
if not ts.who_has - self.busy_workers:
3375-
refresh_who_has.add(ts.key)
3376-
elif ts not in recommendations:
3377-
ts.who_has.discard(worker)
3378-
self.has_what[worker].discard(ts.key)
3379-
self.log.append((d, "missing-dep", stimulus_id, time()))
3380-
instructions.append(
3381-
MissingDataMsg(
3382-
key=d,
3383-
errant_worker=worker,
3384-
stimulus_id=stimulus_id,
3385-
)
3386-
)
3387-
recommendations[ts] = "fetch"
3388-
3389-
if refresh_who_has:
3390-
# All workers that hold known replicas of our tasks are busy.
3391-
# Try querying the scheduler for unknown ones.
3392-
instructions.append(
3393-
RequestRefreshWhoHasMsg(
3394-
keys=list(refresh_who_has),
3395-
stimulus_id=f"gather-dep-busy-{time()}",
3396-
)
3397-
)
3398-
3399-
self.transitions(recommendations, stimulus_id=stimulus_id)
3400-
self._handle_instructions(instructions)
3333+
return GatherDepErrorEvent(
3334+
worker=worker,
3335+
total_nbytes=total_nbytes,
3336+
exception=msg["exception"],
3337+
traceback=msg["traceback"],
3338+
exception_text=msg["exception_text"],
3339+
traceback_text=msg["traceback_text"],
3340+
stimulus_id=stimulus_id,
3341+
)
34013342

34023343
async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None:
34033344
await asyncio.sleep(0.15)
@@ -3940,10 +3881,161 @@ def _(self, ev: UnpauseEvent) -> RecsInstrs:
39403881
self._ensure_communicating(stimulus_id=ev.stimulus_id),
39413882
)
39423883

3884+
def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]:
3885+
"""Common code for all subclasses of GatherDepDoneEvent"""
3886+
self.comm_nbytes -= ev.total_nbytes
3887+
for key in self.in_flight_workers.pop(ev.worker):
3888+
ts = self.tasks[key]
3889+
ts.done = True
3890+
yield ts
3891+
3892+
def _refetch_missing_data(
3893+
self, ev: GatherDepDoneEvent, tasks: Iterable[TaskState]
3894+
) -> RecsInstrs:
3895+
"""Helper of GatherDepDoneEvent subclass handlers"""
3896+
recommendations: Recs = {}
3897+
instructions: Instructions = []
3898+
3899+
for ts in tasks:
3900+
ts.who_has.discard(ev.worker)
3901+
self.has_what[ev.worker].discard(ts.key)
3902+
self.log.append((ts.key, "missing-dep", ev.stimulus_id, time()))
3903+
instructions.append(
3904+
MissingDataMsg(
3905+
key=ts.key,
3906+
errant_worker=ev.worker,
3907+
stimulus_id=ev.stimulus_id,
3908+
)
3909+
)
3910+
recommendations[ts] = "fetch"
3911+
return recommendations, instructions
3912+
39433913
@handle_event.register
3944-
def _(self, ev: GatherDepDoneEvent) -> RecsInstrs:
3945-
"""Temporary hack - to be removed"""
3946-
return self._ensure_communicating(stimulus_id=ev.stimulus_id)
3914+
def _(self, ev: GatherDepSuccessEvent) -> RecsInstrs:
3915+
"""gather_dep terminated successfully.
3916+
The response may contain less keys than the request.
3917+
"""
3918+
self.log.append(
3919+
("receive-dep", ev.worker, set(ev.data), ev.stimulus_id, time())
3920+
)
3921+
3922+
recommendations: Recs = {}
3923+
refetch = set()
3924+
for ts in self._gather_dep_done_common(ev):
3925+
if ts.key in ev.data:
3926+
recommendations[ts] = ("memory", ev.data[ts.key])
3927+
else:
3928+
refetch.add(ts)
3929+
3930+
smsg = EnsureCommunicatingAfterTransitions(stimulus_id=ev.stimulus_id)
3931+
return merge_recs_instructions(
3932+
(recommendations, [smsg]),
3933+
self._refetch_missing_data(ev, refetch),
3934+
)
3935+
3936+
@handle_event.register
3937+
def _(self, ev: GatherDepBusyEvent) -> RecsInstrs:
3938+
"""gather_dep terminated: remote worker is busy"""
3939+
self.log.append(
3940+
(
3941+
"busy-gather",
3942+
ev.worker,
3943+
set(self.in_flight_workers[ev.worker]),
3944+
ev.stimulus_id,
3945+
time(),
3946+
)
3947+
)
3948+
3949+
# Avoid hammering the worker. If there are multiple replicas
3950+
# available, immediately try fetching from a different worker.
3951+
self.busy_workers.add(ev.worker)
3952+
3953+
recommendations: Recs = {}
3954+
refresh_who_has = []
3955+
for ts in self._gather_dep_done_common(ev):
3956+
recommendations[ts] = "fetch"
3957+
if not ts.who_has - self.busy_workers:
3958+
refresh_who_has.append(ts.key)
3959+
3960+
instructions: Instructions = [
3961+
RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id),
3962+
EnsureCommunicatingAfterTransitions(stimulus_id=ev.stimulus_id),
3963+
]
3964+
if refresh_who_has:
3965+
# All workers that hold known replicas of our tasks are busy.
3966+
# Try querying the scheduler for unknown ones.
3967+
instructions.append(
3968+
RequestRefreshWhoHasMsg(
3969+
keys=refresh_who_has,
3970+
stimulus_id=f"gather-dep-busy-{time()}",
3971+
)
3972+
)
3973+
3974+
return recommendations, instructions
3975+
3976+
@handle_event.register
3977+
def _(self, ev: GatherDepNetworkFailureEvent) -> RecsInstrs:
3978+
"""gather_dep terminated: network failure while trying to
3979+
communicate with remote worker
3980+
"""
3981+
logger.exception("Worker stream died during communication: %s", ev.worker)
3982+
3983+
# if state in (fetch, flight, resumed, cancelled):
3984+
# if ts.who_has is now empty:
3985+
# transition to missing; don't send data-missing
3986+
# elif ts in GatherDep.keys:
3987+
# transition to fetch; send data-missing
3988+
# else:
3989+
# don't transition
3990+
# elif ts in GatherDep.keys:
3991+
# transition to fetch; send data-missing
3992+
# else:
3993+
# don't transition
3994+
3995+
has_what = self.has_what.pop(ev.worker)
3996+
self.data_needed_per_worker.pop(ev.worker)
3997+
self.log.append(
3998+
("receive-dep-failed", ev.worker, has_what, ev.stimulus_id, time())
3999+
)
4000+
recommendations: Recs = {}
4001+
for d in has_what:
4002+
ts = self.tasks[d]
4003+
ts.who_has.remove(ev.worker)
4004+
if not ts.who_has and ts.state in (
4005+
"fetch",
4006+
"flight",
4007+
"resumed",
4008+
"cancelled",
4009+
):
4010+
recommendations[ts] = "missing"
4011+
self.log.append(
4012+
("missing-who-has", ev.worker, ts.key, ev.stimulus_id, time())
4013+
)
4014+
4015+
refetch_tasks = set(self._gather_dep_done_common(ev)) - recommendations.keys()
4016+
smsg = EnsureCommunicatingAfterTransitions(stimulus_id=ev.stimulus_id)
4017+
return merge_recs_instructions(
4018+
(recommendations, [smsg]),
4019+
self._refetch_missing_data(ev, refetch_tasks),
4020+
)
4021+
4022+
@handle_event.register
4023+
def _(self, ev: GatherDepErrorEvent) -> RecsInstrs:
4024+
"""gather_dep terminated: generic error raised (not a network failure);
4025+
e.g. data failed to deserialize.
4026+
"""
4027+
recommendations: Recs = {
4028+
ts: (
4029+
"error",
4030+
ev.exception,
4031+
ev.traceback,
4032+
ev.exception_text,
4033+
ev.traceback_text,
4034+
)
4035+
for ts in self._gather_dep_done_common(ev)
4036+
}
4037+
smsg = EnsureCommunicatingAfterTransitions(stimulus_id=ev.stimulus_id)
4038+
return recommendations, [smsg]
39474039

39484040
@handle_event.register
39494041
def _(self, ev: RetryBusyWorkerEvent) -> RecsInstrs:

0 commit comments

Comments
 (0)