Skip to content

Commit 6621bb5

Browse files
committed
Implement CenteredClip in averager
1 parent b84f62b commit 6621bb5

File tree

5 files changed

+126
-11
lines changed

5 files changed

+126
-11
lines changed

Diff for: hivemind/averaging/accumulators.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import dataclasses
2+
from abc import ABC
3+
from typing import Callable, Optional
4+
5+
import torch
6+
7+
8+
class AccumulatorBase(ABC):
9+
def accumulate_part(self, tensor: torch.Tensor, weight: float) -> None:
10+
...
11+
12+
def reduce(self) -> torch.Tensor:
13+
...
14+
15+
16+
AccumulatorFactory = Callable[[torch.Size, int], AccumulatorBase]
17+
18+
19+
class MeanAccumulator(AccumulatorBase):
20+
def __init__(self, part_shape: torch.Size, _n_peers: int):
21+
self._accumulator = torch.zeros(part_shape)
22+
self._denominator = 0.0
23+
24+
def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
25+
self._accumulator.add_(tensor_part, alpha=weight)
26+
self._denominator += weight
27+
28+
def reduce(self) -> torch.Tensor:
29+
return self._accumulator.div_(self._denominator)
30+
31+
32+
class CenteredClipAccumulator(AccumulatorBase):
33+
def __init__(self, part_shape: torch.Size, n_peers: int, **kwargs):
34+
self._kwargs = kwargs
35+
36+
self._tensors = torch.empty([n_peers] + part_shape)
37+
self._weights = torch.empty(n_peers)
38+
self._index = 0
39+
40+
def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
41+
self._tensors[self._index] = tensor_part
42+
self._weights[self._index] = weight
43+
self._index += 1
44+
45+
def reduce(self) -> torch.Tensor:
46+
clipped = centered_clip(self._tensors, self._weights, **self._kwargs)
47+
return clipped.result
48+
49+
50+
@dataclasses.dataclass(frozen=True)
51+
class CenteredClipResult:
52+
result: torch.Tensor
53+
n_clipped: torch.Tensor
54+
last_step_delta: torch.Tensor
55+
56+
57+
def centered_clip(input_tensors: torch.Tensor, weights: torch.Tensor,
58+
tau: float = 1.0, n_iters: int = 20, stop_delta: Optional[float] = None) -> CenteredClipResult:
59+
"""
60+
Optimized implementation of CenteredClip from [Karimireddy, 2021].
61+
Intended to be used in a decentralized fashion as in [Gorbunov, 2021].
62+
63+
:stop_delta: Stop iterations early if the ``L_inf`` norm of the last step is less than ``stop_delta``.
64+
Note: if this option is used, the step norm calculations may increase the time per iteration by ~25%.
65+
66+
References:
67+
68+
[Karimireddy, 2021] Karimireddy, Sai Praneeth, Lie He, and Martin Jaggi. "Learning from history for byzantine
69+
robust optimization." International Conference on Machine Learning. PMLR, 2021.
70+
71+
[Gorbunov, 2021] Gorbunov, Eduard, Alexander Borzunov, Michael Diskin, and Max Ryabinin.
72+
"Secure Distributed Training at Scale." arXiv preprint arXiv:2106.11257 (2021).
73+
"""
74+
75+
with torch.no_grad():
76+
n_peers = input_tensors.shape[0]
77+
result_shape = input_tensors.shape[1:]
78+
79+
input_tensors = input_tensors.flatten(start_dim=1)
80+
weights /= weights.sum()
81+
82+
# This finds medians faster than torch.median() and torch.quantile(q=0.5),
83+
# see https://github.com/pytorch/pytorch/issues/51450
84+
sorted_tensors = input_tensors.sort(dim=0).values
85+
result = sorted_tensors[n_peers // 2].clone()
86+
delta = None
87+
88+
diff = torch.sub(input_tensors, result, out=sorted_tensors) # Reuse memory from `sorted_tensors`
89+
for _ in range(n_iters):
90+
norms = diff.norm(dim=1)
91+
coeffs = weights * torch.minimum(torch.tensor(1.0), tau / norms)
92+
93+
if stop_delta is not None:
94+
prev_diff = result[...] = diff[0] # Reuse memory from `result`
95+
96+
# We only need to update `diff` (not `result`) between iterations
97+
diff.addmm_(-coeffs.repeat(n_peers, 1), diff)
98+
99+
if stop_delta is not None:
100+
delta = prev_diff.sub_(diff[0]).max()
101+
if delta < stop_delta:
102+
break
103+
torch.sub(input_tensors[0], diff[0], out=result)
104+
105+
return CenteredClipResult(result=result.reshape(result_shape),
106+
n_clipped=(tau < norms).sum(),
107+
last_step_delta=delta)

Diff for: hivemind/averaging/allreduce.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from hivemind.averaging.accumulators import AccumulatorFactory
78
from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
89
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
910
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
@@ -58,6 +59,7 @@ def __init__(
5859
tensors: Sequence[torch.Tensor],
5960
ordered_peer_ids: Sequence[PeerID],
6061
peer_fractions: Tuple[float, ...],
62+
accumulator_factory: AccumulatorFactory,
6163
weights: Optional[Sequence[float]] = None,
6264
modes: Optional[Sequence[AveragingMode]] = None,
6365
gathered: Optional[Dict[PeerID, Any]] = None,
@@ -97,7 +99,8 @@ def __init__(
9799
self.tensor_part_reducer = TensorPartReducer(
98100
tuple(part.shape for part in self.parts_for_local_averaging),
99101
len(self.sender_peer_ids),
100-
self.sender_weights,
102+
weights=self.sender_weights,
103+
accumulator_factory=accumulator_factory,
101104
)
102105

103106
def __repr__(self):

Diff for: hivemind/averaging/averager.py

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import torch
1717

18+
from hivemind.averaging.accumulators import AccumulatorFactory, MeanAccumulator
1819
from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
1920
from hivemind.averaging.group_info import GroupInfo
2021
from hivemind.averaging.load_balancing import load_balance_peers
@@ -112,6 +113,7 @@ def __init__(
112113
compression: CompressionBase = NoCompression(),
113114
state_compression: CompressionBase = NoCompression(),
114115
tensor_infos: Optional[Sequence[CompressionInfo]] = None,
116+
accumulator_factory: AccumulatorFactory = MeanAccumulator,
115117
bandwidth: Optional[float] = None,
116118
min_vector_size: int = 0,
117119
auxiliary: bool = False,
@@ -170,6 +172,7 @@ def __init__(
170172
compression=compression,
171173
part_size_bytes=part_size_bytes,
172174
min_vector_size=min_vector_size,
175+
accumulator_factory=accumulator_factory,
173176
)
174177
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
175178
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce

Diff for: hivemind/averaging/partition.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010

11+
from hivemind.averaging.accumulators import AccumulatorFactory
1112
from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
1213
from hivemind.proto import runtime_pb2
1314
from hivemind.utils.asyncio import amap_in_executor
@@ -171,16 +172,17 @@ class TensorPartReducer:
171172
:note: even if local peer is not sending data, local parts will be used for shape information
172173
"""
173174

174-
def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
175+
def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int,
176+
*, weights: Optional[Sequence[float]], accumulator_factory: AccumulatorFactory):
175177
self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
176178
self.weights = tuple(weights or (1 for _ in range(num_senders)))
177179
assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
178180
assert all(isinstance(weight, (int, float)) for weight in self.weights)
179181
self.current_part_index = -1 # index in local_parts of the part that should be loaded next
180182
self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
181-
self.accumulator = None # this will contain the sum of current tensor part from group peers
182-
self.denominator = 0.0 # total weight accumulated from all peers for current part
183183
self.current_part_future = asyncio.Future()
184+
self.accumulator_factory = accumulator_factory
185+
self.accumulator = None
184186
self.finished = asyncio.Event()
185187
self.reset_accumulators()
186188

@@ -194,8 +196,7 @@ def reset_accumulators(self):
194196
self.current_part_index += 1
195197
self.current_part_accumulated_from = 0
196198
self.current_part_future = asyncio.Future()
197-
self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
198-
self.denominator = 0.0
199+
self.accumulator = self.accumulator_factory(self.part_shapes[self.current_part_index], self.num_senders)
199200

200201
async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
201202
"""Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
@@ -211,21 +212,20 @@ async def accumulate_part(self, sender_index: int, part_index: int, tensor_part:
211212

212213
current_part_future = self.current_part_future
213214

214-
self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
215-
self.denominator += self.weights[sender_index]
215+
self.accumulator.accumulate_part(tensor_part, self.weights[sender_index])
216216
self.current_part_accumulated_from += 1
217217

218218
assert self.current_part_accumulated_from <= self.num_senders
219219
if self.current_part_accumulated_from == self.num_senders:
220-
current_part_future.set_result(self.accumulator.div_(self.denominator))
220+
current_part_future.set_result(self.accumulator.reduce())
221221
self.reset_accumulators()
222222
return await current_part_future
223223

224224
def finalize(self):
225225
if not self.finished.is_set():
226226
if hasattr(self, "current_part_future"):
227227
self.current_part_future.cancel()
228-
del self.accumulator
228+
self.accumulator = None
229229
self.finished.set()
230230

231231
def __del__(self):

Diff for: tests/test_allreduce.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
from hivemind import Quantile8BitQuantization, aenumerate
10+
from hivemind.averaging.accumulators import MeanAccumulator
1011
from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
1112
from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
1213
from hivemind.compression import deserialize_torch_tensor
@@ -119,7 +120,7 @@ async def wait_synchronously():
119120
@pytest.mark.asyncio
120121
async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
121122
tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
122-
reducer = TensorPartReducer(tensor_part_shapes, num_senders)
123+
reducer = TensorPartReducer(tensor_part_shapes, num_senders, weights=None, accumulator_factory=MeanAccumulator)
123124

124125
local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)] for j in range(num_senders)]
125126

@@ -196,6 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
196197
tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
197198
ordered_peer_ids=peers,
198199
peer_fractions=peer_fractions,
200+
accumulator_factory=MeanAccumulator,
199201
modes=peer_modes,
200202
weights=averaging_weights,
201203
part_size_bytes=part_size_bytes,

0 commit comments

Comments
 (0)