Skip to content

Commit fdd1bb9

Browse files
committed
fix nng perf hostname error leading deadlock
1 parent a7a57a6 commit fdd1bb9

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

.coveragerc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[run]
2-
concurrency = multiprocessing
2+
concurrency = multiprocessing,thread
33
omit =
44
ding/utils/slurm_helper.py
55
ding/utils/file_helper.py
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1-
from ding.framework.message_queue.perfs.perf_nng import nng_perf_main
21
import multiprocessing as mp
32
import pytest
3+
import socket
4+
import torch
5+
from ding.framework.message_queue.perfs.perf_nng import nng_perf_main
46

57

68
@pytest.mark.benchmark
79
# @pytest.mark.multiprocesstest
810
def test_nng():
9-
params = [
10-
("12960", None, "127.0.0.1", "learner", "0"), ("12961", "tcp://127.0.0.1:12960", "127.0.0.1", "collector", "1")
11-
]
12-
ctx = mp.get_context("spawn")
13-
with ctx.Pool(processes=2) as pool:
14-
pool.starmap(nng_perf_main, params)
15-
pool.close()
16-
pool.join()
11+
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
12+
address = socket.gethostbyname(socket.gethostname())
13+
params = [
14+
("12960", None, address, "learner", "0"),
15+
("12961", "tcp://{}:12960".format(address), "127.0.0.1", "collector", "1")
16+
]
17+
ctx = mp.get_context("spawn")
18+
with ctx.Pool(processes=2) as pool:
19+
pool.starmap(nng_perf_main, params)
20+
pool.close()
21+
pool.join()

ding/framework/message_queue/perfs/tests/test_perf_torchrpc_nccl.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1-
from ding.framework.message_queue.perfs.perf_torchrpc_nccl import rpc_model_exchanger
2-
from ding.compatibility import torch_ge_1121
31
import multiprocessing as mp
42
import pytest
53
import torch
64
import platform
5+
import socket
6+
from ding.utils.system_helper import find_free_port
7+
from ding.framework.message_queue.perfs.perf_torchrpc_nccl import rpc_model_exchanger
8+
from ding.compatibility import torch_ge_1121
79

810

911
@pytest.mark.benchmark
1012
@pytest.mark.cudatest
1113
# @pytest.mark.multiprocesstest
1214
def test_perf_torchrpc_nccl():
15+
address = socket.gethostbyname(socket.gethostname())
16+
init_method = "tcp://{}:{}".format(address, find_free_port(address))
1317
if platform.system().lower() != 'windows' and torch.cuda.is_available():
1418
if torch_ge_1121() and torch.cuda.device_count() >= 2:
15-
params = [(0, "tcp://127.0.0.1:12387", False, True), (1, "tcp://127.0.0.1:12387", False, True)]
19+
params = [(0, init_method, False, True), (1, init_method, False, True)]
1620
ctx = mp.get_context("spawn")
1721
with ctx.Pool(processes=2) as pool:
1822
pool.starmap(rpc_model_exchanger, params)

ding/framework/message_queue/tests/test_torch_rpc.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
import pytest
2+
import torch
3+
import platform
4+
import time
5+
import socket
6+
17
from ding.framework.message_queue.torch_rpc import DeviceMap, TORCHRPCMQ, DEFAULT_DEVICE_MAP_NUMS
28
from torch.distributed import rpc
39
from multiprocessing import Pool, get_context
410
from ding.compatibility import torch_ge_1121
511
from ditk import logging
6-
7-
import pytest
8-
import torch
9-
import platform
10-
import time
12+
from ding.utils.system_helper import find_free_port
1113

1214
mq = None
1315
recv_tensor_list = [None, None, None, None]
@@ -22,6 +24,7 @@ def torchrpc(rank):
2224
global mq
2325
global recv_tensor_list
2426
mq = None
27+
address = socket.gethostbyname(socket.gethostname())
2528
recv_tensor_list = [None, None, None, None]
2629
logging.getLogger().setLevel(logging.DEBUG)
2730
name_list = ["A", "B", "C", "D"]
@@ -34,7 +37,7 @@ def torchrpc(rank):
3437
mq = TORCHRPCMQ(
3538
rpc_name=name_list[rank],
3639
global_rank=rank,
37-
init_method="tcp://127.0.0.1:12398",
40+
init_method="tcp://{}:12398".format(address),
3841
remote_parallel_entrance=remote_mq_entrance,
3942
attach_to=attach_to,
4043
async_rpc=False,
@@ -81,6 +84,7 @@ def torchrpc_cuda(rank):
8184
mq = None
8285
recv_tensor_list = [None, None, None, None]
8386
name_list = ["A", "B"]
87+
address = socket.gethostbyname(socket.gethostname())
8488
logging.getLogger().setLevel(logging.DEBUG)
8589

8690
if rank == 0:
@@ -96,7 +100,7 @@ def torchrpc_cuda(rank):
96100
mq = TORCHRPCMQ(
97101
rpc_name=name_list[rank],
98102
global_rank=rank,
99-
init_method="tcp://127.0.0.1:12390",
103+
init_method="tcp://{}:12390".format(address),
100104
remote_parallel_entrance=remote_mq_entrance,
101105
attach_to=attach_to,
102106
device_maps=device_map,

0 commit comments

Comments
 (0)