-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdl_job.py
275 lines (236 loc) · 9.51 KB
/
dl_job.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import os
import json
from abc import ABC, abstractmethod
import logging
import settings.settings as settings
from kubernetes import client
from replica import Replica
logger = logging.getLogger(os.path.basename(__file__))
class DLJob(ABC):
"""
Abstract class which defines a distributed training job
"""
# TODO: How to force redefinition of these two variables?
job_type = None
replica_types = list()
container_properties = {}
def __init__(self, name, spec):
# check the class was properly subclassed
if self.job_type is None or \
type(self.replica_types) != list or \
len(self.replica_types) == 0:
raise Exception("Need to subclass DLJob with custom type and define job_type class variable")
self.job_name = name
self.spec = spec
self.replicas = list()
# TODO: Initialize this in one single place
self.core_v1_client = client.CoreV1Api()
self.clean_up_pods()
self.clean_up_services()
self.create_replicas()
self.reconcile()
def create_replicas(self):
logger.info("create replicas")
# create the replica objects
for replica_spec in self.spec:
# create n replicas for this type of replica
for i in range(0, replica_spec['replicas']):
# create new replica
self.container_properties['env'] = self.get_environment_variables(replica_spec, i)
self.replicas.append(Replica(
uid=i,
replica_name=self.generate_replica_name(replica_spec["replicaType"].casefold(), self.job_name, i),
replica_type=replica_spec["replicaType"],
job_name=self.job_name,
template=replica_spec["template"],
# define arbitrary container specific parameters
**self.container_properties
))
def validate_spec(self):
# validate replica type
for replica_spec in self.spec['replicas']:
if replica_spec['replica_type'].casefold() not in self.replica_types:
logger.error(
f"Job {self.job_name} has invalid spec. {replica_spec['replica_type']} "
f"replica type not valid.")
# standardize string
replica_spec['replica_type'] = replica_spec['replica_type'].casefold()
# @property
# @abstractmethod
# def job_type(cls):
# """
# Defines the name of the Kubernetes job
# :return: A string of the name of the job
# """
# pass
#
# @property
# @abstractmethod
# def replica_types(cls):
# """
# A list with the names of the possible replicas admitted by the job
# :return: List with the names of the replicas
# """
# pass
@abstractmethod
def get_environment_variables(self, replica_spec, replica_index):
"""
Return env variables to be set for a specific replica
:param replica_spec: The spec of the replica to be crated
:param replica_index: The uid of the replica to be created,
based on how many replicas need to be created from a template
(see the `replicas` field of the spec)
:return: A dictionary of env variables
"""
pass
def reconcile(self):
logger.info(f"Reconcile Job {self.job_name}")
for r in self.replicas:
r.reconcile()
def number_of_replicas(self, replica_type):
for r in self.spec:
if r["replicaType"] == replica_type:
return r["replicas"]
logger.error(f"Replica {replica_type} not found")
def get_replica_spec(self, replica_type):
for s in self.spec['replicas']:
if s['replicaType'] == replica_type:
return s
def clean_up(self):
for r in self.replicas:
r.clean_up()
def clean_up_pods(self):
"""
Delete pods that match the selection job_name=self.job_name.
This function is called every time a new job is created to delete any previous pod
associated to the same job name. This avoid collisions and unpleasant behaviors.
"""
api_instance = client.CoreV1Api()
logger.info(f"Deleting pods matching {self.job_name} job name")
selector = f"job_name={self.job_name}"
pod_list = api_instance.list_namespaced_pod(settings.NAMESPACE, label_selector=selector)
if len(pod_list.items) == 0:
return
logger.info(f"Deleting {len(pod_list.items)} pods.")
pod_names = list()
for pod in pod_list.items:
logger.info(f"Deleting pod {pod.metadata.name}.")
pod_names.append(pod.metadata.name)
api_instance.delete_namespaced_pod(
pod.metadata.name,
settings.NAMESPACE,
body=client.V1DeleteOptions())
# wait for the resources to be deleted
while len(api_instance.list_namespaced_pod(settings.NAMESPACE,
label_selector=f"pod_name in ({', '.join(pod_names)})").items):
pass
def clean_up_services(self):
"""
Delete services that match the selection job_name=self.job_name.
This function is called every time a new job is created to delete any previous pod
associated to the same job name. This avoid collisions and unpleasant behaviors.
"""
api_instance = client.CoreV1Api()
logger.info(f"Deleting services matching {self.job_name} job name")
selector = f"job_name={self.job_name}"
service_list = api_instance.list_namespaced_service(settings.NAMESPACE, label_selector=selector)
if len(service_list.items) == 0:
return
logger.info(f"Deleting {len(service_list.items)} services.")
service_names = list()
for service in service_list.items:
logger.info(f"Deleting service {service.metadata.name}.")
service_names.append(service.metadata.name)
api_instance.delete_namespaced_service(
service.metadata.name,
settings.NAMESPACE,
body=client.V1DeleteOptions())
# wait for the resources to be deleted
while len(api_instance.list_namespaced_service(settings.NAMESPACE,
label_selector=f"service_name in ({', '.join(service_names)})").items):
pass
@staticmethod
def generate_replica_name(*args):
return "-".join(map(str, list(args))).casefold()
class MXJob(DLJob):
job_type = "MXJob"
replica_types = ['scheduler', 'server', 'worker']
container_properties = {
"ports": [9000],
"volumes": []
}
mx_port = 9000
def __init__(self, name, spec):
super(MXJob, self).__init__(name, spec)
# @property
# def job_type(cls):
# return "MXJob"
#
# @property
# def replica_types(cls):
# return ['scheduler', 'server', 'worker']
def validate_spec(self):
super(MXJob, self).validate_spec()
def get_environment_variables(self, replica_spec, replica_index, scheduler_ip=""):
env = {
# define later
"DMLC_ROLE": replica_spec['replicaType'].casefold(),
# define later
"DMLC_PS_ROOT_URI": scheduler_ip,
"DMLC_PS_ROOT_PORT": self.mx_port,
# auto conversion to str not supported by client for safety reasons
"DMLC_NUM_SERVER": str(self.number_of_replicas("SERVER")),
"DMLC_NUM_WORKER": str(self.number_of_replicas("WORKER")),
"PS_VERBOSE": "2"
}
return env
def reconcile(self):
logger.info(f"MXJob: Reconcile Job {self.job_name}")
scheduler_ip = ""
for r in self.replicas:
# TODO: Need to remove this once we solve the hostname resolve issue
if r.replica_type != "SCHEDULER":
r.container_params['env']['DMLC_PS_ROOT_URI'] = scheduler_ip
r.reconcile()
if r.replica_type == "SCHEDULER":
scheduler_ip = r.scheduler_ip
class TFJob(DLJob):
job_type = "TFJob"
replica_types = ['ps', 'worker']
container_properties = {
"ports": [2222]
}
tf_port = 2222
def __init__(self, name, spec):
super(TFJob, self).__init__(name, spec)
# def job_type(self):
# return "TFJob"
#
# def replica_types(self):
# return ['ps', 'worker']
def validate_spec(self):
super(TFJob, self).validate_spec()
# TODO: Validate the number of identical replicaTypes = 1 (e.g. Just one PS specification)
def get_cluster_config(self):
cluster_config = dict()
# iterate over the number of replicas to give unique name
cluster_config['ps'] = [
"{}{}.{}".format("ps", i, self.tf_port)
for i in range(self.number_of_replicas("ps"))
]
cluster_config['worker'] = [
"{}{}.{}".format("worker", i, self.tf_port)
for i in range(self.number_of_replicas("worker"))
]
return cluster_config
def get_environment_variables(self, replica_spec, replica_index):
tf_config = dict()
tf_config['cluster'] = self.get_cluster_config()
tf_config['task'] = {
"type": replica_spec['replicaType'],
"index": replica_index
}
env = {
"TF_CONFIG": json.dumps(tf_config)
}
return env