Skip to content

Commit a605836

Browse files
committed
Optimize and harden additional tests
1 parent ba85383 commit a605836

File tree

4 files changed

+52
-29
lines changed

4 files changed

+52
-29
lines changed

tests/test_cli_scripts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def test_dht_connection_successful():
10-
dht_refresh_period = 1
10+
dht_refresh_period = 3
1111

1212
cloned_env = os.environ.copy()
1313
# overriding the loglevel to prevent debug print statements

tests/test_moe.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -316,9 +316,9 @@ def test_client_anomaly_detection():
316316
server.shutdown()
317317

318318

319-
def _measure_coro_running_time(n_coros, elapsed_fut, counter):
319+
def _measure_coro_running_time(n_coros, elapsed_fut, counter, coroutine_time):
320320
async def coro():
321-
await asyncio.sleep(0.1)
321+
await asyncio.sleep(coroutine_time)
322322
counter.value += 1
323323

324324
try:
@@ -337,20 +337,21 @@ async def coro():
337337

338338

339339
@pytest.mark.forked
340-
def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10):
340+
def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10, coroutine_time=0.1):
341341
processes = []
342342
counter = mp.Value(ctypes.c_int64)
343343
for i in range(n_processes):
344344
elapsed_fut = MPFuture()
345345
factory = threading.Thread if i % 2 == 0 else mp.Process # Test both threads and processes
346346

347-
proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter))
347+
proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter, coroutine_time))
348348
proc.start()
349349
processes.append((proc, elapsed_fut))
350350

351351
for proc, elapsed_fut in processes:
352352
# Ensure that the coroutines were run concurrently, not sequentially
353-
assert elapsed_fut.result() < 0.2
353+
expected_time = coroutine_time * 3 # from non-blocking calls + blocking call + some overhead
354+
assert elapsed_fut.result() < expected_time
354355
proc.join()
355356

356357
assert counter.value == n_processes * n_coros # Ensure all couroutines have finished

tests/test_optimizer.py

+30-19
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import multiprocessing as mp
33
import time
44
from functools import partial
5+
from typing import List
56

67
import numpy as np
78
import pytest
89
import torch
910
import torch.nn as nn
1011
import torch.nn.functional as F
12+
from multiaddr import Multiaddr
1113

1214
import hivemind
1315
from hivemind.averaging.control import AveragingStage
@@ -227,8 +229,10 @@ def test_progress_tracker():
227229
finished_evt = mp.Event()
228230
emas = mp.Array(ctypes.c_double, 5)
229231

230-
def run_worker(index: int, batch_size: int, period: float, **kwargs):
231-
dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
232+
root_maddrs = dht_root.get_visible_maddrs()
233+
234+
def run_worker(index: int, batch_size: int, step_time: float, initial_peers: List[Multiaddr]):
235+
dht = hivemind.DHT(initial_peers=initial_peers, start=True)
232236
tracker = ProgressTracker(
233237
dht,
234238
prefix,
@@ -238,18 +242,17 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
238242
default_refresh_period=0.2,
239243
max_refresh_period=0.5,
240244
private_key=RSAPrivateKey(),
241-
**kwargs,
242245
)
246+
with tracker.pause_updates():
247+
barrier.wait()
248+
if index == 4:
249+
delayed_start_evt.wait()
243250

244-
barrier.wait()
245-
if index == 4:
246-
delayed_start_evt.wait()
247-
248-
local_epoch = 2 if index == 4 else 0
249-
samples_accumulated = 0
251+
local_epoch = 2 if index == 4 else 0
252+
samples_accumulated = 0
250253

251254
while True:
252-
time.sleep(period)
255+
time.sleep(step_time)
253256
if finished_evt.is_set():
254257
break
255258

@@ -270,10 +273,10 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
270273
dht.shutdown()
271274

272275
workers = [
273-
mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
274-
mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
275-
mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
276-
mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
276+
mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, step_time=0.6, initial_peers=root_maddrs)),
277+
mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, step_time=0.5, initial_peers=root_maddrs)),
278+
mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, step_time=0.2, initial_peers=root_maddrs)),
279+
mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, step_time=0.2, initial_peers=root_maddrs)),
277280
]
278281
for worker in workers:
279282
worker.start()
@@ -336,7 +339,7 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
336339
(False, True, True, True, True),
337340
(False, True, True, False, True),
338341
(True, False, False, False, False),
339-
(True, True, False, False, False,),
342+
(True, True, False, False, False),
340343
],
341344
# fmt: on
342345
)
@@ -359,6 +362,8 @@ def test_optimizer(
359362
def _test_optimizer(
360363
num_peers: int = 1,
361364
num_clients: int = 0,
365+
default_batch_size: int = 4,
366+
default_batch_time: int = 0.1,
362367
target_batch_size: int = 32,
363368
total_epochs: int = 3,
364369
use_local_updates: bool = False,
@@ -422,20 +427,21 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
422427

423428
prev_time = time.perf_counter()
424429

425-
time.sleep(1.0)
426430
optimizer.shutdown()
427431
return optimizer
428432

429433
peers = []
430434

431435
for index in range(num_peers):
436+
peer_batch_size = default_batch_size + index
437+
peer_batch_time = default_batch_time + 0.01 * index
432438
peers.append(
433439
mp.Process(
434440
target=run_trainer,
435441
name=f"trainer-{index}",
436442
kwargs=dict(
437-
batch_size=4 + index,
438-
batch_time=0.3 + 0.2 * index,
443+
batch_size=peer_batch_size,
444+
batch_time=peer_batch_time,
439445
client_mode=(index >= num_peers - num_clients),
440446
),
441447
)
@@ -451,7 +457,12 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
451457
assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
452458
expected_samples_accumulated = target_batch_size * total_epochs
453459
assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
454-
assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
460+
expected_performance = default_batch_size / default_batch_time
461+
assert (
462+
expected_performance * 0.8
463+
<= optimizer.tracker.performance_ema.samples_per_second
464+
<= expected_performance * 1.2
465+
)
455466

456467
assert not optimizer.state_averager.is_alive()
457468
assert not optimizer.tracker.is_alive()

tests/test_util_modules.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import random
55
import time
66
from concurrent.futures import ThreadPoolExecutor, as_completed
7+
from threading import Event
78

89
import numpy as np
910
import pytest
@@ -266,9 +267,10 @@ def _check_result_and_set(future):
266267
with pytest.raises(RuntimeError):
267268
future1.add_done_callback(lambda future: (1, 2, 3))
268269

270+
events[0].wait()
269271
assert future1.done() and not future1.cancelled()
270272
assert future2.done() and future2.cancelled()
271-
for i in 0, 1, 4:
273+
for i in 1, 4:
272274
events[i].wait(1)
273275
assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
274276
assert not events[3].is_set()
@@ -557,16 +559,25 @@ def test_performance_ema_threadsafe(
557559
bias_power: float = 0.7,
558560
tolerance: float = 0.05,
559561
):
560-
def run_task(ema):
561-
task_size = random.randint(1, 4)
562+
def run_task(ema, start_event, task_size):
563+
start_event.wait()
562564
with ema.update_threadsafe(task_size):
563565
time.sleep(task_size * interval * (0.9 + 0.2 * random.random()))
564566
return task_size
565567

566568
with ThreadPoolExecutor(max_workers) as pool:
567569
ema = PerformanceEMA(alpha=alpha)
570+
start_event = Event()
568571
start_time = time.perf_counter()
569-
futures = [pool.submit(run_task, ema) for _ in range(num_updates)]
572+
573+
futures = []
574+
for _ in range(num_updates):
575+
task_size = random.randint(1, 4)
576+
future = pool.submit(run_task, ema, start_event, task_size)
577+
futures.append(future)
578+
579+
ema.reset_timer()
580+
start_event.set()
570581
total_size = sum(future.result() for future in as_completed(futures))
571582
end_time = time.perf_counter()
572583
target = total_size / (end_time - start_time)

0 commit comments

Comments
 (0)