Skip to content

Commit 270cb82

Browse files
committed
Refactor gather_dep
1 parent 4488144 commit 270cb82

File tree

4 files changed

+286
-116
lines changed

4 files changed

+286
-116
lines changed

distributed/tests/test_utils_test.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
from distributed import Client, Nanny, Scheduler, Worker, config, default_client
1818
from distributed.compatibility import WINDOWS
19-
from distributed.core import Server, rpc
19+
from distributed.core import Server, Status, rpc
2020
from distributed.metrics import time
2121
from distributed.utils import mp_context
2222
from distributed.utils_test import (
2323
_LockedCommPool,
2424
_UnhashableCallable,
2525
assert_story,
26+
captured_logger,
2627
check_process_leak,
2728
cluster,
2829
dump_cluster_state,
@@ -731,15 +732,27 @@ def test_raises_with_cause():
731732
raise RuntimeError("exception") from ValueError("cause")
732733

733734

734-
def test_worker_fail_hard(capsys):
735-
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
736-
async def test_fail_hard(c, s, a):
737-
with pytest.raises(Exception):
738-
await a.gather_dep(
739-
worker="abcd", to_gather=["x"], total_nbytes=0, stimulus_id="foo"
740-
)
735+
@gen_cluster(nthreads=[("", 1)])
736+
async def test_fail_hard(s, a):
737+
with captured_logger("distributed.worker") as logger:
738+
# Asynchronously kick off handle_acquire_replicas on the worker,
739+
# which will fail
740+
s.stream_comms[a.address].send(
741+
{
742+
"op": "acquire-replicas",
743+
"who_has": {"x": ["abcd"]},
744+
"stimulus_id": "foo",
745+
},
746+
)
747+
while a.status != Status.closed:
748+
await asyncio.sleep(0.01)
749+
750+
assert "missing port number in address 'abcd'" in logger.getvalue()
741751

742-
with pytest.raises(Exception) as info:
743-
test_fail_hard()
744752

745-
assert "abcd" in str(info.value)
753+
@gen_cluster(nthreads=[("", 1)])
754+
async def test_fail_hard_reraises(s, a):
755+
with pytest.raises(AttributeError):
756+
a.handle_stimulus(None)
757+
while a.status != Status.closed:
758+
await asyncio.sleep(0.01)

distributed/tests/test_worker_state_machine.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
2+
13
import asyncio
4+
from collections.abc import Iterator
25
from contextlib import contextmanager
3-
from itertools import chain
46

57
import pytest
68

@@ -16,7 +18,6 @@
1618
ReleaseWorkerDataMsg,
1719
RescheduleEvent,
1820
RescheduleMsg,
19-
SendMessageToScheduler,
2021
StateMachineEvent,
2122
TaskState,
2223
TaskStateState,
@@ -103,14 +104,19 @@ def test_unique_task_heap():
103104
assert repr(heap) == "<UniqueTaskHeap: 0 items>"
104105

105106

107+
def traverse_subclasses(cls: type) -> Iterator[type]:
108+
yield cls
109+
for subcls in cls.__subclasses__():
110+
yield from traverse_subclasses(subcls)
111+
112+
106113
@pytest.mark.parametrize(
107114
"cls",
108-
chain(
109-
[UniqueTaskHeap],
110-
Instruction.__subclasses__(),
111-
SendMessageToScheduler.__subclasses__(),
112-
StateMachineEvent.__subclasses__(),
113-
),
115+
[
116+
UniqueTaskHeap,
117+
*traverse_subclasses(Instruction),
118+
*traverse_subclasses(StateMachineEvent),
119+
],
114120
)
115121
def test_slots(cls):
116122
params = [

0 commit comments

Comments
 (0)