forked from learning-at-home/hivemind
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
28 lines (19 loc) · 956 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import Dict, List, Tuple
from pydantic.v1 import BaseModel, StrictFloat, confloat, conint
from hivemind.dht.crypto import RSASignatureValidator
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
from hivemind.dht.validation import RecordValidatorBase
from hivemind.utils.logging import get_logger
logger = get_logger(__name__)
class LocalMetrics(BaseModel):
step: conint(ge=0, strict=True)
samples_per_second: confloat(ge=0.0, strict=True)
samples_accumulated: conint(ge=0, strict=True)
loss: StrictFloat
mini_steps: conint(ge=0, strict=True)
class MetricSchema(BaseModel):
metrics: Dict[BytesWithPublicKey, LocalMetrics]
def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
signature_validator = RSASignatureValidator()
validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
return validators, signature_validator.local_public_key