2
2
import multiprocessing as mp
3
3
import time
4
4
from functools import partial
5
+ from typing import List
5
6
6
7
import numpy as np
7
8
import pytest
8
9
import torch
9
10
import torch .nn as nn
10
11
import torch .nn .functional as F
12
+ from multiaddr import Multiaddr
11
13
12
14
import hivemind
13
15
from hivemind .averaging .control import AveragingStage
@@ -227,8 +229,10 @@ def test_progress_tracker():
227
229
finished_evt = mp .Event ()
228
230
emas = mp .Array (ctypes .c_double , 5 )
229
231
230
- def run_worker (index : int , batch_size : int , period : float , ** kwargs ):
231
- dht = hivemind .DHT (initial_peers = dht_root .get_visible_maddrs (), start = True )
232
+ root_maddrs = dht_root .get_visible_maddrs ()
233
+
234
+ def run_worker (index : int , batch_size : int , step_time : float , initial_peers : List [Multiaddr ]):
235
+ dht = hivemind .DHT (initial_peers = initial_peers , start = True )
232
236
tracker = ProgressTracker (
233
237
dht ,
234
238
prefix ,
@@ -238,18 +242,17 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
238
242
default_refresh_period = 0.2 ,
239
243
max_refresh_period = 0.5 ,
240
244
private_key = RSAPrivateKey (),
241
- ** kwargs ,
242
245
)
246
+ with tracker .pause_updates ():
247
+ barrier .wait ()
248
+ if index == 4 :
249
+ delayed_start_evt .wait ()
243
250
244
- barrier .wait ()
245
- if index == 4 :
246
- delayed_start_evt .wait ()
247
-
248
- local_epoch = 2 if index == 4 else 0
249
- samples_accumulated = 0
251
+ local_epoch = 2 if index == 4 else 0
252
+ samples_accumulated = 0
250
253
251
254
while True :
252
- time .sleep (period )
255
+ time .sleep (step_time )
253
256
if finished_evt .is_set ():
254
257
break
255
258
@@ -270,10 +273,10 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
270
273
dht .shutdown ()
271
274
272
275
workers = [
273
- mp .Process (target = run_worker , kwargs = dict (index = 1 , batch_size = 12 , period = 0.6 )),
274
- mp .Process (target = run_worker , kwargs = dict (index = 2 , batch_size = 16 , period = 0.5 )),
275
- mp .Process (target = run_worker , kwargs = dict (index = 3 , batch_size = 24 , period = 0.4 )),
276
- mp .Process (target = run_worker , kwargs = dict (index = 4 , batch_size = 64 , period = 0.4 )),
276
+ mp .Process (target = run_worker , kwargs = dict (index = 1 , batch_size = 12 , step_time = 0.6 , initial_peers = root_maddrs )),
277
+ mp .Process (target = run_worker , kwargs = dict (index = 2 , batch_size = 16 , step_time = 0.5 , initial_peers = root_maddrs )),
278
+ mp .Process (target = run_worker , kwargs = dict (index = 3 , batch_size = 24 , step_time = 0.2 , initial_peers = root_maddrs )),
279
+ mp .Process (target = run_worker , kwargs = dict (index = 4 , batch_size = 64 , step_time = 0.2 , initial_peers = root_maddrs )),
277
280
]
278
281
for worker in workers :
279
282
worker .start ()
@@ -336,7 +339,7 @@ def run_worker(index: int, batch_size: int, period: float, **kwargs):
336
339
(False , True , True , True , True ),
337
340
(False , True , True , False , True ),
338
341
(True , False , False , False , False ),
339
- (True , True , False , False , False , ),
342
+ (True , True , False , False , False ),
340
343
],
341
344
# fmt: on
342
345
)
@@ -359,6 +362,8 @@ def test_optimizer(
359
362
def _test_optimizer (
360
363
num_peers : int = 1 ,
361
364
num_clients : int = 0 ,
365
+ default_batch_size : int = 4 ,
366
+ default_batch_time : int = 0.1 ,
362
367
target_batch_size : int = 32 ,
363
368
total_epochs : int = 3 ,
364
369
use_local_updates : bool = False ,
@@ -422,20 +427,21 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
422
427
423
428
prev_time = time .perf_counter ()
424
429
425
- time .sleep (1.0 )
426
430
optimizer .shutdown ()
427
431
return optimizer
428
432
429
433
peers = []
430
434
431
435
for index in range (num_peers ):
436
+ peer_batch_size = default_batch_size + index
437
+ peer_batch_time = default_batch_time + 0.01 * index
432
438
peers .append (
433
439
mp .Process (
434
440
target = run_trainer ,
435
441
name = f"trainer-{ index } " ,
436
442
kwargs = dict (
437
- batch_size = 4 + index ,
438
- batch_time = 0.3 + 0.2 * index ,
443
+ batch_size = peer_batch_size ,
444
+ batch_time = peer_batch_time ,
439
445
client_mode = (index >= num_peers - num_clients ),
440
446
),
441
447
)
@@ -451,7 +457,12 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
451
457
assert optimizer .local_epoch == optimizer .tracker .global_epoch == total_epochs
452
458
expected_samples_accumulated = target_batch_size * total_epochs
453
459
assert expected_samples_accumulated <= total_samples_accumulated .value <= expected_samples_accumulated * 1.2
454
- assert 4 / 0.3 * 0.8 <= optimizer .tracker .performance_ema .samples_per_second <= 4 / 0.3 * 1.2
460
+ expected_performance = default_batch_size / default_batch_time
461
+ assert (
462
+ expected_performance * 0.8
463
+ <= optimizer .tracker .performance_ema .samples_per_second
464
+ <= expected_performance * 1.2
465
+ )
455
466
456
467
assert not optimizer .state_averager .is_alive ()
457
468
assert not optimizer .tracker .is_alive ()
0 commit comments