Skip to content

Commit c8f6909

Browse files
committed
create classes for submitting and watching DDP jobs
Signed-off-by: Kevin <[email protected]>
1 parent 433fd71 commit c8f6909

File tree

3 files changed

+137
-1
lines changed

3 files changed

+137
-1
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from os import stat
2222
from time import sleep
23-
from typing import List, Optional, Tuple
23+
from typing import List, Optional, Tuple, Dict
2424

2525
import openshift as oc
2626
from ray.job_submission import JobSubmissionClient
@@ -45,6 +45,8 @@ class Cluster:
4545
Note that currently, the underlying implementation is a Ray cluster.
4646
"""
4747

48+
torchx_scheduler = "ray"
49+
4850
def __init__(self, config: ClusterConfiguration):
4951
"""
5052
Create the resource cluster object by passing in a ClusterConfiguration
@@ -272,6 +274,18 @@ def job_logs(self, job_id: str) -> str:
272274
client = JobSubmissionClient(dashboard_route)
273275
return client.get_job_logs(job_id)
274276

277+
def torchx_config(self, working_dir: str = None, requirements: str = None) -> Dict[str, str]:
278+
dashboard_address = f"{self.cluster_dashboard_uri(self.config.namespace).lstrip('http://')}"
279+
to_return = {
280+
"cluster_name": self.config.name,
281+
"dashboard_address": dashboard_address,
282+
}
283+
if working_dir:
284+
to_return["working_dir"] = working_dir
285+
if requirements:
286+
to_return["requirements"] = requirements
287+
return to_return
288+
275289

276290
def get_current_namespace() -> str:
277291
"""

src/codeflare_sdk/job/__init__.py

Whitespace-only changes.

src/codeflare_sdk/job/jobs.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2023 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import abc
17+
from typing import TYPE_CHECKING, Optional, Dict, List
18+
19+
from torchx.components.dist import ddp
20+
from torchx.runner import get_runner
21+
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
22+
23+
if TYPE_CHECKING:
24+
from ..cluster.cluster import Cluster
25+
26+
all_jobs: List["Job"] = []
27+
torchx_runner = get_runner()
28+
29+
class JobDefinition(metaclass=abc.ABCMeta):
30+
def _dry_run(self, cluster: "Cluster"):
31+
pass
32+
33+
def submit(self, cluster: "Cluster"):
34+
pass
35+
36+
37+
class Job(metaclass=abc.ABCMeta):
38+
def status(self):
39+
pass
40+
41+
def logs(self):
42+
pass
43+
44+
45+
class DDPJobDefinition(JobDefinition):
46+
47+
def __init__(
48+
self,
49+
script: Optional[str] = None,
50+
m: Optional[str]=None,
51+
script_args: Optional[List[str]] = None,
52+
name: Optional[str] = None,
53+
cpu: Optional[int] = None,
54+
gpu: Optional[int] = None,
55+
memMB: Optional[int] = None,
56+
h: Optional[str] = None,
57+
j: Optional[str] = None,
58+
env: Optional[Dict[str, str]] = None,
59+
max_retries: int = 0,
60+
mounts: Optional[List[str]] = None,
61+
rdzv_port: int = 29500,
62+
scheduler_args: Optional[Dict[str, str]] = None,
63+
):
64+
if bool(script) != bool(m): # logical XOR
65+
raise ValueError("Exactly one of the following arguments must be defined: [script, m].")
66+
self.script = script
67+
self.m=m
68+
self.script_args: List[str] = script_args if script_args is not None else []
69+
self.name = name
70+
self.cpu = cpu
71+
self.gpu = gpu
72+
self.memMB = memMB
73+
self.h = h
74+
self.j = j
75+
self.env: Dict[str, str] = env if env is not None else dict()
76+
self.max_retries = max_retries
77+
self.mounts: List[str] = mounts if mounts is not None else []
78+
self.rdzv_port = rdzv_port
79+
self.scheduler_args: Dict[str, str] = scheduler_args if scheduler_args is not None else dict()
80+
81+
def _dry_run(self, cluster: "Cluster"):
82+
j = f"{cluster.config.max_worker}x{max(cluster.config.gpu, 1)}" # # of proc. = # of gpus
83+
torchx_runner.dryrun(
84+
app=ddp(
85+
*self.script_args,
86+
script=self.script,
87+
m=self.m,
88+
name=self.name,
89+
h=self.h,
90+
cpu=self.cpu if self.cpu is not None else cluster.config.max_cpus,
91+
gpu=self.gpu if self.gpu is not None else cluster.config.gpu,
92+
memMB=self.memMB if self.memMB is not None else cluster.config.max_memory * 1024,
93+
j=self.j if self.j is not None else j,
94+
env=self.env,
95+
max_retries=self.max_retries,
96+
rdvz_port=self.rdzv_port,
97+
mounts=self.mounts,
98+
)
99+
scheduler=cluster.torchx_scheduler,
100+
cfg=cluster.torchx_config(**scheduler_args)
101+
)
102+
103+
def submit(self, cluster: "Cluster") -> "Job":
104+
return DDPJob(self, cluster)
105+
106+
107+
class DDPJob(Job):
108+
def __init__(
109+
self,
110+
job_definition: "DDPJobDefinition",
111+
cluster: "Cluster"
112+
):
113+
self.job_definition = job_definition
114+
self.cluster = cluster
115+
self._app_handle = torchx_runner.schedule(job_definition._dry_run(cluster))
116+
all_jobs.append(self)
117+
118+
def status(self) -> str:
119+
return torchx_runner.status(self._app_handle)
120+
121+
def logs(self) -> str:
122+
return "".join(torchx_runner.log_lines(self._app_handle, None))

0 commit comments

Comments
 (0)