Skip to content

Commit b84f62b

Browse files
borzunovmryab
andauthored
Make log handlers configurable, shorten entries (#378)
1. Fix bugs: make `get_logger()` idempotent and don't trim the actual logger name. 2. Allow a developer to choose where the default hivemind log handler is enabled (in hivemind/in the root logger/nowhere). 3. Enable the `in_root_logger` mode in `examples/albert`, so that all messages (from `__main__`, `transformers`, and `hivemind` itself) consistently follow the hivemind style. 4. Change some log messages to improve their presentation. Co-authored-by: Max Ryabinin <[email protected]>
1 parent fb3f57b commit b84f62b

File tree

5 files changed

+116
-50
lines changed

5 files changed

+116
-50
lines changed

Diff for: benchmarks/benchmark_averaging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def run_averager(index):
8080
with lock_stats:
8181
successful_steps += int(success)
8282
total_steps += 1
83-
logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
83+
logger.info(f"Averager {index}: {'finished' if success else 'failed'} step #{step}")
8484
logger.info(f"Averager {index}: done.")
8585

8686
threads = []

Diff for: examples/albert/run_trainer.py

+11-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python
22

3-
import logging
43
import os
54
import pickle
65
from dataclasses import asdict
@@ -18,32 +17,22 @@
1817
from transformers.trainer_utils import is_main_process
1918

2019
import hivemind
20+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
2121

2222
import utils
2323
from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
2424

25-
logger = logging.getLogger(__name__)
26-
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
25+
use_hivemind_log_handler("in_root_logger")
26+
logger = get_logger()
2727

28+
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
2829

29-
def setup_logging(training_args):
30-
logging.basicConfig(
31-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
32-
datefmt="%m/%d/%Y %H:%M:%S",
33-
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
34-
)
3530

36-
# Log on each process the small summary:
37-
logger.warning(
38-
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
39-
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
40-
)
41-
# Set the verbosity to info of the Transformers logger (on main process only):
42-
if is_main_process(training_args.local_rank):
31+
def setup_transformers_logging(process_rank: int):
32+
if is_main_process(process_rank):
4333
transformers.utils.logging.set_verbosity_info()
44-
transformers.utils.logging.enable_default_handler()
45-
transformers.utils.logging.enable_explicit_format()
46-
logger.info("Training/evaluation parameters %s", training_args)
34+
transformers.utils.logging.disable_default_handler()
35+
transformers.utils.logging.enable_propagation()
4736

4837

4938
def get_model(training_args, config, tokenizer):
@@ -149,7 +138,7 @@ def on_step_end(
149138
loss=self.loss,
150139
mini_steps=self.steps,
151140
)
152-
logger.info(f"Step {self.collaborative_optimizer.local_step}")
141+
logger.info(f"Step #{self.collaborative_optimizer.local_step}")
153142
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
154143
logger.info(f"Performance: {samples_per_second} samples per second.")
155144
if self.steps:
@@ -220,7 +209,8 @@ def main():
220209
if len(collaboration_args.initial_peers) == 0:
221210
raise ValueError("Please specify at least one network endpoint in initial peers.")
222211

223-
setup_logging(training_args)
212+
setup_transformers_logging(training_args.local_rank)
213+
logger.info(f"Training/evaluation parameters:\n{training_args}")
224214

225215
# Set seed before initializing model.
226216
set_seed(training_args.seed)

Diff for: examples/albert/run_training_monitor.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python
22

3-
import logging
43
import time
54
from dataclasses import asdict, dataclass, field
65
from ipaddress import ip_address
@@ -13,11 +12,13 @@
1312
from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
1413

1514
import hivemind
15+
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
1616

1717
import utils
1818
from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
1919

20-
logger = logging.getLogger(__name__)
20+
use_hivemind_log_handler("in_root_logger")
21+
logger = get_logger()
2122

2223

2324
@dataclass
@@ -139,7 +140,7 @@ def upload_checkpoint(self, current_loss):
139140
self.model.push_to_hub(
140141
repo_name=self.repo_path,
141142
repo_url=self.repo_url,
142-
commit_message=f"Step {current_step}, loss {current_loss:.3f}",
143+
commit_message=f"Step #{current_step}, loss {current_loss:.3f}",
143144
)
144145
logger.info("Finished uploading to Model Hub")
145146

Diff for: hivemind/optim/collaborative.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
154154
self.last_step_time = None
155155

156-
self.collaboration_state = self.fetch_collaboration_state()
156+
self.collaboration_state = self._fetch_state()
157157
self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
158158
self.lock_local_progress, self.should_report_progress = Lock(), Event()
159159
self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
@@ -237,8 +237,8 @@ def step(self, batch_size: Optional[int] = None, **kwargs):
237237
if not self.collaboration_state.ready_for_step:
238238
return
239239

240-
logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
241-
self.collaboration_state = self.fetch_collaboration_state()
240+
logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
241+
self.collaboration_state = self._fetch_state()
242242
self.collaboration_state_updated.set()
243243

244244
if not self.is_synchronized:
@@ -288,8 +288,8 @@ def step_aux(self, **kwargs):
288288
if not self.collaboration_state.ready_for_step:
289289
return
290290

291-
logger.log(self.status_loglevel, f"Beginning global optimizer step {self.collaboration_state.optimizer_step}")
292-
self.collaboration_state = self.fetch_collaboration_state()
291+
logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
292+
self.collaboration_state = self._fetch_state()
293293
self.collaboration_state_updated.set()
294294

295295
with self.lock_collaboration_state:
@@ -392,9 +392,9 @@ def check_collaboration_state_periodically(self):
392392
continue # if state was updated externally, reset timer
393393

394394
with self.lock_collaboration_state:
395-
self.collaboration_state = self.fetch_collaboration_state()
395+
self.collaboration_state = self._fetch_state()
396396

397-
def fetch_collaboration_state(self) -> CollaborationState:
397+
def _fetch_state(self) -> CollaborationState:
398398
"""Read performance statistics reported by peers, estimate progress towards next batch"""
399399
response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
400400
current_time = get_dht_time()
@@ -452,9 +452,9 @@ def fetch_collaboration_state(self) -> CollaborationState:
452452
)
453453
logger.log(
454454
self.status_loglevel,
455-
f"Collaboration accumulated {total_samples_accumulated} samples from "
456-
f"{num_peers} peers; ETA {estimated_time_to_next_step:.2f} seconds "
457-
f"(refresh in {time_to_next_fetch:.2f}s.)",
455+
f"{self.prefix} accumulated {total_samples_accumulated} samples from "
456+
f"{num_peers} peers for step #{global_optimizer_step}. "
457+
f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
458458
)
459459
return CollaborationState(
460460
global_optimizer_step,

Diff for: hivemind/utils/logging.py

+90-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import logging
22
import os
33
import sys
4+
import threading
5+
from enum import Enum
6+
from typing import Optional, Union
7+
8+
logging.addLevelName(logging.WARNING, "WARN")
49

510
loglevel = os.getenv("LOGLEVEL", "INFO")
611

@@ -11,6 +16,17 @@
1116
use_colors = sys.stderr.isatty()
1217

1318

19+
class HandlerMode(Enum):
20+
NOWHERE = 0
21+
IN_HIVEMIND = 1
22+
IN_ROOT_LOGGER = 2
23+
24+
25+
_init_lock = threading.RLock()
26+
_current_mode = HandlerMode.IN_HIVEMIND
27+
_default_handler = None
28+
29+
1430
class TextStyle:
1531
"""
1632
ANSI escape codes. Details: https://en.wikipedia.org/wiki/ANSI_escape_code#Colors
@@ -60,23 +76,82 @@ def format(self, record: logging.LogRecord) -> str:
6076
return super().format(record)
6177

6278

63-
def get_logger(module_name: str) -> logging.Logger:
64-
# trim package name
65-
name_without_prefix = ".".join(module_name.split(".")[1:])
79+
def _initialize_if_necessary():
80+
global _current_mode, _default_handler
6681

67-
logging.addLevelName(logging.WARNING, "WARN")
68-
formatter = CustomFormatter(
69-
fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
70-
style="{",
71-
datefmt="%b %d %H:%M:%S",
72-
)
73-
handler = logging.StreamHandler()
74-
handler.setFormatter(formatter)
75-
logger = logging.getLogger(name_without_prefix)
76-
logger.setLevel(loglevel)
77-
logger.addHandler(handler)
82+
with _init_lock:
83+
if _default_handler is not None:
84+
return
85+
86+
formatter = CustomFormatter(
87+
fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
88+
style="{",
89+
datefmt="%b %d %H:%M:%S",
90+
)
91+
_default_handler = logging.StreamHandler()
92+
_default_handler.setFormatter(formatter)
93+
94+
_enable_default_handler("hivemind")
95+
96+
97+
def get_logger(name: Optional[str] = None) -> logging.Logger:
98+
"""
99+
Same as ``logging.getLogger()`` but ensures that the default log handler is initialized.
100+
"""
101+
102+
_initialize_if_necessary()
103+
return logging.getLogger(name)
104+
105+
106+
def _enable_default_handler(name: str) -> None:
107+
logger = get_logger(name)
108+
logger.addHandler(_default_handler)
78109
logger.propagate = False
79-
return logger
110+
logger.setLevel(loglevel)
111+
112+
113+
def _disable_default_handler(name: str) -> None:
114+
logger = get_logger(name)
115+
logger.removeHandler(_default_handler)
116+
logger.propagate = True
117+
logger.setLevel(logging.NOTSET)
118+
119+
120+
def use_hivemind_log_handler(where: Union[HandlerMode, str]) -> None:
121+
"""
122+
Choose loggers where the default hivemind log handler is applied. Options for the ``where`` argument are:
123+
124+
* "in_hivemind" (default): Use the hivemind log handler in the loggers of the ``hivemind`` package.
125+
Don't propagate their messages to the root logger.
126+
* "nowhere": Don't use the hivemind log handler anywhere.
127+
Propagate the ``hivemind`` messages to the root logger.
128+
* "in_root_logger": Use the hivemind log handler in the root logger
129+
(that is, in all application loggers until they disable propagation to the root logger).
130+
Propagate the ``hivemind`` messages to the root logger.
131+
132+
The options may be defined as strings (case-insensitive) or values from the HandlerMode enum.
133+
"""
134+
135+
global _current_mode
136+
137+
if isinstance(where, str):
138+
# We allow `where` to be a string, so a developer does not have to import the enum for one usage
139+
where = HandlerMode[where.upper()]
140+
141+
if where == _current_mode:
142+
return
143+
144+
if _current_mode == HandlerMode.IN_HIVEMIND:
145+
_disable_default_handler("hivemind")
146+
elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
147+
_disable_default_handler(None)
148+
149+
_current_mode = where
150+
151+
if _current_mode == HandlerMode.IN_HIVEMIND:
152+
_enable_default_handler("hivemind")
153+
elif _current_mode == HandlerMode.IN_ROOT_LOGGER:
154+
_enable_default_handler(None)
80155

81156

82157
def golog_level_to_python(level: str) -> int:

0 commit comments

Comments
 (0)