20
20
Collection ,
21
21
Container ,
22
22
Iterable ,
23
+ Iterator ,
23
24
Mapping ,
24
25
MutableMapping ,
25
26
)
117
118
ExecuteSuccessEvent ,
118
119
FindMissingEvent ,
119
120
GatherDep ,
121
+ GatherDepBusyEvent ,
120
122
GatherDepDoneEvent ,
123
+ GatherDepErrorEvent ,
124
+ GatherDepNetworkFailureEvent ,
125
+ GatherDepSuccessEvent ,
121
126
Instructions ,
122
127
InvalidTransition ,
123
128
LongRunningMsg ,
@@ -3256,7 +3261,6 @@ def _update_metrics_received_data(
3256
3261
self .counters ["transfer-count" ].add (len (data ))
3257
3262
self .incoming_count += 1
3258
3263
3259
- @fail_hard
3260
3264
@log_errors
3261
3265
async def gather_dep (
3262
3266
self ,
@@ -3282,13 +3286,6 @@ async def gather_dep(
3282
3286
if self .status not in WORKER_ANY_RUNNING :
3283
3287
return None
3284
3288
3285
- recommendations : Recs = {}
3286
- instructions : Instructions = []
3287
- response = {}
3288
-
3289
- def done_event ():
3290
- return GatherDepDoneEvent (stimulus_id = f"gather-dep-done-{ time ()} " )
3291
-
3292
3289
try :
3293
3290
self .log .append (("request-dep" , worker , to_gather , stimulus_id , time ()))
3294
3291
logger .debug ("Request %d keys from %s" , len (to_gather ), worker )
@@ -3299,42 +3296,32 @@ def done_event():
3299
3296
)
3300
3297
stop = time ()
3301
3298
if response ["status" ] == "busy" :
3302
- return done_event ()
3299
+ return GatherDepBusyEvent (
3300
+ worker = worker , total_nbytes = total_nbytes , stimulus_id = stimulus_id
3301
+ )
3303
3302
3304
- cause = self ._get_cause (to_gather )
3305
- self ._update_metrics_received_data (
3306
- start = start ,
3307
- stop = stop ,
3308
- data = response ["data" ],
3309
- cause = cause ,
3303
+ assert response ["status" ] == "OK"
3304
+ if response ["data" ]:
3305
+ cause = self ._get_cause (response ["data" ])
3306
+ self ._update_metrics_received_data (
3307
+ start = start ,
3308
+ stop = stop ,
3309
+ data = response ["data" ],
3310
+ cause = cause ,
3311
+ worker = worker ,
3312
+ )
3313
+
3314
+ return GatherDepSuccessEvent (
3310
3315
worker = worker ,
3316
+ total_nbytes = total_nbytes ,
3317
+ data = response ["data" ],
3318
+ stimulus_id = stimulus_id ,
3311
3319
)
3312
- self .log .append (
3313
- ("receive-dep" , worker , set (response ["data" ]), stimulus_id , time ())
3314
- )
3315
- return done_event ()
3316
3320
3317
3321
except OSError :
3318
- logger .exception ("Worker stream died during communication: %s" , worker )
3319
- has_what = self .has_what .pop (worker )
3320
- self .data_needed_per_worker .pop (worker )
3321
- self .log .append (
3322
- ("receive-dep-failed" , worker , has_what , stimulus_id , time ())
3322
+ return GatherDepNetworkFailureEvent (
3323
+ worker = worker , total_nbytes = total_nbytes , stimulus_id = stimulus_id
3323
3324
)
3324
- for d in has_what :
3325
- ts = self .tasks [d ]
3326
- ts .who_has .remove (worker )
3327
- if not ts .who_has and ts .state in (
3328
- "fetch" ,
3329
- "flight" ,
3330
- "resumed" ,
3331
- "cancelled" ,
3332
- ):
3333
- recommendations [ts ] = "missing"
3334
- self .log .append (
3335
- ("missing-who-has" , worker , ts .key , stimulus_id , time ())
3336
- )
3337
- return done_event ()
3338
3325
3339
3326
except Exception as e :
3340
3327
logger .exception (e )
@@ -3343,61 +3330,15 @@ def done_event():
3343
3330
3344
3331
pdb .set_trace ()
3345
3332
msg = error_message (e )
3346
- for k in self .in_flight_workers [worker ]:
3347
- ts = self .tasks [k ]
3348
- recommendations [ts ] = tuple (msg .values ())
3349
- return done_event ()
3350
-
3351
- finally :
3352
- self .comm_nbytes -= total_nbytes
3353
- busy = response .get ("status" , "" ) == "busy"
3354
- data = response .get ("data" , {})
3355
-
3356
- if busy :
3357
- self .log .append (("busy-gather" , worker , to_gather , stimulus_id , time ()))
3358
- # Avoid hammering the worker. If there are multiple replicas
3359
- # available, immediately try fetching from a different worker.
3360
- self .busy_workers .add (worker )
3361
- instructions .append (
3362
- RetryBusyWorkerLater (worker = worker , stimulus_id = stimulus_id )
3363
- )
3364
-
3365
- refresh_who_has = set ()
3366
-
3367
- for d in self .in_flight_workers .pop (worker ):
3368
- ts = self .tasks [d ]
3369
- ts .done = True
3370
- if d in data :
3371
- recommendations [ts ] = ("memory" , data [d ])
3372
- elif busy :
3373
- recommendations [ts ] = "fetch"
3374
- if not ts .who_has - self .busy_workers :
3375
- refresh_who_has .add (ts .key )
3376
- elif ts not in recommendations :
3377
- ts .who_has .discard (worker )
3378
- self .has_what [worker ].discard (ts .key )
3379
- self .log .append ((d , "missing-dep" , stimulus_id , time ()))
3380
- instructions .append (
3381
- MissingDataMsg (
3382
- key = d ,
3383
- errant_worker = worker ,
3384
- stimulus_id = stimulus_id ,
3385
- )
3386
- )
3387
- recommendations [ts ] = "fetch"
3388
-
3389
- if refresh_who_has :
3390
- # All workers that hold known replicas of our tasks are busy.
3391
- # Try querying the scheduler for unknown ones.
3392
- instructions .append (
3393
- RequestRefreshWhoHasMsg (
3394
- keys = list (refresh_who_has ),
3395
- stimulus_id = f"gather-dep-busy-{ time ()} " ,
3396
- )
3397
- )
3398
-
3399
- self .transitions (recommendations , stimulus_id = stimulus_id )
3400
- self ._handle_instructions (instructions )
3333
+ return GatherDepErrorEvent (
3334
+ worker = worker ,
3335
+ total_nbytes = total_nbytes ,
3336
+ exception = msg ["exception" ],
3337
+ traceback = msg ["traceback" ],
3338
+ exception_text = msg ["exception_text" ],
3339
+ traceback_text = msg ["traceback_text" ],
3340
+ stimulus_id = stimulus_id ,
3341
+ )
3401
3342
3402
3343
async def retry_busy_worker_later (self , worker : str ) -> StateMachineEvent | None :
3403
3344
await asyncio .sleep (0.15 )
@@ -3940,10 +3881,161 @@ def _(self, ev: UnpauseEvent) -> RecsInstrs:
3940
3881
self ._ensure_communicating (stimulus_id = ev .stimulus_id ),
3941
3882
)
3942
3883
3884
+ def _gather_dep_done_common (self , ev : GatherDepDoneEvent ) -> Iterator [TaskState ]:
3885
+ """Common code for all subclasses of GatherDepDoneEvent"""
3886
+ self .comm_nbytes -= ev .total_nbytes
3887
+ for key in self .in_flight_workers .pop (ev .worker ):
3888
+ ts = self .tasks [key ]
3889
+ ts .done = True
3890
+ yield ts
3891
+
3892
+ def _refetch_missing_data (
3893
+ self , ev : GatherDepDoneEvent , tasks : Iterable [TaskState ]
3894
+ ) -> RecsInstrs :
3895
+ """Helper of GatherDepDoneEvent subclass handlers"""
3896
+ recommendations : Recs = {}
3897
+ instructions : Instructions = []
3898
+
3899
+ for ts in tasks :
3900
+ ts .who_has .discard (ev .worker )
3901
+ self .has_what [ev .worker ].discard (ts .key )
3902
+ self .log .append ((ts .key , "missing-dep" , ev .stimulus_id , time ()))
3903
+ instructions .append (
3904
+ MissingDataMsg (
3905
+ key = ts .key ,
3906
+ errant_worker = ev .worker ,
3907
+ stimulus_id = ev .stimulus_id ,
3908
+ )
3909
+ )
3910
+ recommendations [ts ] = "fetch"
3911
+ return recommendations , instructions
3912
+
3943
3913
@handle_event .register
3944
- def _ (self , ev : GatherDepDoneEvent ) -> RecsInstrs :
3945
- """Temporary hack - to be removed"""
3946
- return self ._ensure_communicating (stimulus_id = ev .stimulus_id )
3914
+ def _ (self , ev : GatherDepSuccessEvent ) -> RecsInstrs :
3915
+ """gather_dep terminated successfully.
3916
+ The response may contain less keys than the request.
3917
+ """
3918
+ self .log .append (
3919
+ ("receive-dep" , ev .worker , set (ev .data ), ev .stimulus_id , time ())
3920
+ )
3921
+
3922
+ recommendations : Recs = {}
3923
+ refetch = set ()
3924
+ for ts in self ._gather_dep_done_common (ev ):
3925
+ if ts .key in ev .data :
3926
+ recommendations [ts ] = ("memory" , ev .data [ts .key ])
3927
+ else :
3928
+ refetch .add (ts )
3929
+
3930
+ smsg = EnsureCommunicatingAfterTransitions (stimulus_id = ev .stimulus_id )
3931
+ return merge_recs_instructions (
3932
+ (recommendations , [smsg ]),
3933
+ self ._refetch_missing_data (ev , refetch ),
3934
+ )
3935
+
3936
+ @handle_event .register
3937
+ def _ (self , ev : GatherDepBusyEvent ) -> RecsInstrs :
3938
+ """gather_dep terminated: remote worker is busy"""
3939
+ self .log .append (
3940
+ (
3941
+ "busy-gather" ,
3942
+ ev .worker ,
3943
+ set (self .in_flight_workers [ev .worker ]),
3944
+ ev .stimulus_id ,
3945
+ time (),
3946
+ )
3947
+ )
3948
+
3949
+ # Avoid hammering the worker. If there are multiple replicas
3950
+ # available, immediately try fetching from a different worker.
3951
+ self .busy_workers .add (ev .worker )
3952
+
3953
+ recommendations : Recs = {}
3954
+ refresh_who_has = []
3955
+ for ts in self ._gather_dep_done_common (ev ):
3956
+ recommendations [ts ] = "fetch"
3957
+ if not ts .who_has - self .busy_workers :
3958
+ refresh_who_has .append (ts .key )
3959
+
3960
+ instructions : Instructions = [
3961
+ RetryBusyWorkerLater (worker = ev .worker , stimulus_id = ev .stimulus_id ),
3962
+ EnsureCommunicatingAfterTransitions (stimulus_id = ev .stimulus_id ),
3963
+ ]
3964
+ if refresh_who_has :
3965
+ # All workers that hold known replicas of our tasks are busy.
3966
+ # Try querying the scheduler for unknown ones.
3967
+ instructions .append (
3968
+ RequestRefreshWhoHasMsg (
3969
+ keys = refresh_who_has ,
3970
+ stimulus_id = f"gather-dep-busy-{ time ()} " ,
3971
+ )
3972
+ )
3973
+
3974
+ return recommendations , instructions
3975
+
3976
+ @handle_event .register
3977
+ def _ (self , ev : GatherDepNetworkFailureEvent ) -> RecsInstrs :
3978
+ """gather_dep terminated: network failure while trying to
3979
+ communicate with remote worker
3980
+ """
3981
+ logger .exception ("Worker stream died during communication: %s" , ev .worker )
3982
+
3983
+ # if state in (fetch, flight, resumed, cancelled):
3984
+ # if ts.who_has is now empty:
3985
+ # transition to missing; don't send data-missing
3986
+ # elif ts in GatherDep.keys:
3987
+ # transition to fetch; send data-missing
3988
+ # else:
3989
+ # don't transition
3990
+ # elif ts in GatherDep.keys:
3991
+ # transition to fetch; send data-missing
3992
+ # else:
3993
+ # don't transition
3994
+
3995
+ has_what = self .has_what .pop (ev .worker )
3996
+ self .data_needed_per_worker .pop (ev .worker )
3997
+ self .log .append (
3998
+ ("receive-dep-failed" , ev .worker , has_what , ev .stimulus_id , time ())
3999
+ )
4000
+ recommendations : Recs = {}
4001
+ for d in has_what :
4002
+ ts = self .tasks [d ]
4003
+ ts .who_has .remove (ev .worker )
4004
+ if not ts .who_has and ts .state in (
4005
+ "fetch" ,
4006
+ "flight" ,
4007
+ "resumed" ,
4008
+ "cancelled" ,
4009
+ ):
4010
+ recommendations [ts ] = "missing"
4011
+ self .log .append (
4012
+ ("missing-who-has" , ev .worker , ts .key , ev .stimulus_id , time ())
4013
+ )
4014
+
4015
+ refetch_tasks = set (self ._gather_dep_done_common (ev )) - recommendations .keys ()
4016
+ smsg = EnsureCommunicatingAfterTransitions (stimulus_id = ev .stimulus_id )
4017
+ return merge_recs_instructions (
4018
+ (recommendations , [smsg ]),
4019
+ self ._refetch_missing_data (ev , refetch_tasks ),
4020
+ )
4021
+
4022
+ @handle_event .register
4023
+ def _ (self , ev : GatherDepErrorEvent ) -> RecsInstrs :
4024
+ """gather_dep terminated: generic error raised (not a network failure);
4025
+ e.g. data failed to deserialize.
4026
+ """
4027
+ recommendations : Recs = {
4028
+ ts : (
4029
+ "error" ,
4030
+ ev .exception ,
4031
+ ev .traceback ,
4032
+ ev .exception_text ,
4033
+ ev .traceback_text ,
4034
+ )
4035
+ for ts in self ._gather_dep_done_common (ev )
4036
+ }
4037
+ smsg = EnsureCommunicatingAfterTransitions (stimulus_id = ev .stimulus_id )
4038
+ return recommendations , [smsg ]
3947
4039
3948
4040
@handle_event .register
3949
4041
def _ (self , ev : RetryBusyWorkerEvent ) -> RecsInstrs :
0 commit comments