Skip to content

Commit 28caf6c

Browse files
committed
create classes for submitting and watching DDP jobs
Signed-off-by: Kevin <[email protected]>
1 parent 433fd71 commit 28caf6c

File tree

3 files changed

+139
-1
lines changed

3 files changed

+139
-1
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from os import stat
2222
from time import sleep
23-
from typing import List, Optional, Tuple
23+
from typing import List, Optional, Tuple, Dict
2424

2525
import openshift as oc
2626
from ray.job_submission import JobSubmissionClient
@@ -45,6 +45,8 @@ class Cluster:
4545
Note that currently, the underlying implementation is a Ray cluster.
4646
"""
4747

48+
torchx_scheduler = "ray"
49+
4850
def __init__(self, config: ClusterConfiguration):
4951
"""
5052
Create the resource cluster object by passing in a ClusterConfiguration
@@ -272,6 +274,18 @@ def job_logs(self, job_id: str) -> str:
272274
client = JobSubmissionClient(dashboard_route)
273275
return client.get_job_logs(job_id)
274276

277+
def torchx_config(self, working_dir: str = None, requirements: str = None) -> Dict[str, str]:
278+
dashboard_address = f"{self.cluster_dashboard_uri().lstrip('http://')}"
279+
to_return = {
280+
"cluster_name": self.config.name,
281+
"dashboard_address": dashboard_address,
282+
}
283+
if working_dir:
284+
to_return["working_dir"] = working_dir
285+
if requirements:
286+
to_return["requirements"] = requirements
287+
return to_return
288+
275289

276290
def get_current_namespace() -> str:
277291
"""

src/codeflare_sdk/job/__init__.py

Whitespace-only changes.

src/codeflare_sdk/job/jobs.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2023 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import abc
17+
from typing import TYPE_CHECKING, Optional, Dict, List
18+
from pathlib import Path
19+
20+
from torchx.components.dist import ddp
21+
from torchx.runner import get_runner
22+
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
23+
24+
if TYPE_CHECKING:
25+
from ..cluster.cluster import Cluster
26+
27+
all_jobs: List["Job"] = []
28+
torchx_runner = get_runner()
29+
30+
class JobDefinition(metaclass=abc.ABCMeta):
31+
def _dry_run(self, cluster: "Cluster"):
32+
pass
33+
34+
def submit(self, cluster: "Cluster"):
35+
pass
36+
37+
38+
class Job(metaclass=abc.ABCMeta):
39+
def status(self):
40+
pass
41+
42+
def logs(self):
43+
pass
44+
45+
46+
class DDPJobDefinition(JobDefinition):
47+
48+
def __init__(
49+
self,
50+
script: Optional[str] = None,
51+
m: Optional[str]=None,
52+
script_args: Optional[List[str]] = None,
53+
name: Optional[str] = None,
54+
cpu: Optional[int] = None,
55+
gpu: Optional[int] = None,
56+
memMB: Optional[int] = None,
57+
h: Optional[str] = None,
58+
j: Optional[str] = None,
59+
env: Optional[Dict[str, str]] = None,
60+
max_retries: int = 0,
61+
mounts: Optional[List[str]] = None,
62+
rdzv_port: int = 29500,
63+
scheduler_args: Optional[Dict[str, str]] = None,
64+
):
65+
if bool(script) == bool(m): # logical XOR
66+
raise ValueError("Exactly one of the following arguments must be defined: [script, m].")
67+
self.script = script
68+
self.m=m
69+
self.script_args: List[str] = script_args if script_args is not None else []
70+
self.name = name
71+
self.cpu = cpu
72+
self.gpu = gpu
73+
self.memMB = memMB
74+
self.h = h
75+
self.j = j
76+
self.env: Dict[str, str] = env if env is not None else dict()
77+
self.max_retries = max_retries
78+
self.mounts: List[str] = mounts if mounts is not None else []
79+
self.rdzv_port = rdzv_port
80+
self.scheduler_args: Dict[str, str] = scheduler_args if scheduler_args is not None else dict()
81+
82+
def _dry_run(self, cluster: "Cluster"):
83+
j = f"{cluster.config.max_worker}x{max(cluster.config.gpu, 1)}" # # of proc. = # of gpus
84+
return torchx_runner.dryrun(
85+
app=ddp(
86+
*self.script_args,
87+
script=self.script,
88+
m=self.m,
89+
name=self.name,
90+
h=self.h,
91+
cpu=self.cpu if self.cpu is not None else cluster.config.max_cpus,
92+
gpu=self.gpu if self.gpu is not None else cluster.config.gpu,
93+
memMB=self.memMB if self.memMB is not None else cluster.config.max_memory * 1024,
94+
j=self.j if self.j is not None else j,
95+
env=self.env,
96+
max_retries=self.max_retries,
97+
rdzv_port=self.rdzv_port,
98+
mounts=self.mounts,
99+
),
100+
scheduler=cluster.torchx_scheduler,
101+
cfg=cluster.torchx_config(**self.scheduler_args),
102+
workspace=f"file://{Path.cwd()}"
103+
)
104+
105+
def submit(self, cluster: "Cluster") -> "Job":
106+
return DDPJob(self, cluster)
107+
108+
109+
class DDPJob(Job):
110+
def __init__(
111+
self,
112+
job_definition: "DDPJobDefinition",
113+
cluster: "Cluster"
114+
):
115+
self.job_definition = job_definition
116+
self.cluster = cluster
117+
self._app_handle = torchx_runner.schedule(job_definition._dry_run(cluster))
118+
all_jobs.append(self)
119+
120+
def status(self) -> str:
121+
return torchx_runner.status(self._app_handle)
122+
123+
def logs(self) -> str:
124+
return "".join(torchx_runner.log_lines(self._app_handle, None))

0 commit comments

Comments
 (0)