Skip to content

Commit bec8de3

Browse files
committed
expose more options to the user optionally
Signed-off-by: Kevin <[email protected]>
1 parent c132606 commit bec8de3

File tree

2 files changed

+65
-48
lines changed

2 files changed

+65
-48
lines changed

src/codeflare_sdk/jobs/config.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
"""
2020

2121
from dataclasses import dataclass, field
22-
from ..cluster.cluster import Cluster
23-
22+
from typing import Optional, Dict
2423

2524
@dataclass
2625
class JobConfiguration:
@@ -29,7 +28,10 @@ class JobConfiguration:
2928
is passed in as an argument when creating a Job object.
3029
"""
3130

32-
name: str = None
33-
script: str = None
34-
requirements: str = None
35-
scheduler: str = "ray"
31+
name: Optional[str] = None
32+
script: Optional[str] = None
33+
m: Optional[str] = None
34+
h: Optional[str] = None # custom resource types
35+
env: Optional[Dict[str, str]] = None
36+
working_dir: Optional[str] = None
37+
requirements: Optional[str] = None

src/codeflare_sdk/jobs/jobs.py

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020

2121
import abc
2222
from typing import List
23+
from pathlib import Path
2324

2425
from ray.job_submission import JobSubmissionClient
2526
from torchx.components.dist import ddp
2627
from torchx.runner import get_runner, Runner
27-
from torchx.specs import AppHandle, parse_app_handle
28+
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
2829

2930
from .config import JobConfiguration
3031

@@ -35,16 +36,19 @@
3536
all_jobs: List["Job"] = []
3637
torchx_runner: Runner = get_runner()
3738

38-
torchx_runner.run(app, scheduler, cfg, workspace)
39-
40-
torchx_runner.run_component(component, component_args, scheduler, cfg, workspace)
41-
4239
class JobDefinition(metaclass=abc.ABCMeta):
4340
"""
4441
A job definition to be submitted to a generic backend cluster.
4542
"""
4643

47-
def submit(self, cluster: Cluster):
44+
def _dry_run(self, cluster) -> str:
45+
"""
46+
Create job definition, but do not submit.
47+
48+
The primary purpose of this function is to facilitate unit testing.
49+
"""
50+
51+
def submit(self, cluster: "Cluster"):
4852
"""
4953
Method for creating a job on a specific cluster
5054
"""
@@ -76,59 +80,70 @@ def __init__(self, config: JobConfiguration):
7680
"""
7781
self.config = config
7882

79-
def submit(self, cluster: Cluster) -> "TorchXRayJob":
80-
"""
81-
Submit the job definition to a specific cluster, resulting in a Job object.
83+
def _dry_run(self, cluster: "Cluster", *script_args) -> AppDryRunInfo:
8284
"""
83-
return TorchXRayJob(self, cluster)
84-
85+
Create job definition, but do not submit.
8586
86-
class TorchXRayJob(Job):
87-
"""
88-
Active submission of a dist.ddp job to a Ray cluster which can be used to get logs and status.
89-
"""
90-
def __init__(self, job_definition: TorchXJobDefinition, cluster: Cluster, *script_args):
87+
The primary purpose of this function is to facilitate unit testing.
9188
"""
92-
TODO
93-
"""
94-
self.job_definition: TorchXJobDefinition = job_definition
95-
self.cluster: Cluster = cluster
9689
j = f"{cluster.config.max_worker}x{max(cluster.config.gpu, 1)}" # # of proc. = # of gpus
97-
# TODO: allow user to override resource allocation for job
98-
_app_handle: AppHandle = torchx_runner.run(
90+
dashboard_address = f"{cluster.cluster_dashboard_uri(cluster.config.namespace).lstrip('http://')}:8265"
91+
return torchx_runner.dryrun(
9992
app=ddp(
10093
*script_args,
101-
script = job_definition.config.script,
102-
m=None, # python module to run (might be worth exposing)
103-
name = job_definition.config.name,
94+
script = self.config.script,
95+
m=self.config.m,
96+
name=self.config.name,
10497
h=None, # for custom resource types
105-
cpu=cluster.config.max_cpus, # TODO: get from cluster config
98+
cpu=cluster.config.max_cpus,
10699
gpu = cluster.config.gpu,
107100
memMB = 1024 * cluster.config.max_memory, # cluster memory is in GB
108101
j=j,
109-
env=None, # TODO: should definitely expose Dict[str, str]
110-
max_retries = 0, # TODO: maybe expose
111-
mounts=None, # TODO: maybe expose
112-
debug=False # TODO: expose
102+
env=self.config.env,
103+
# max_retries=0, # default
104+
# mounts=None, # default
113105
),
114-
scheduler="ray", cfg="fo",
106+
scheduler="ray", # can be determined by type of cluster if more are introduced
107+
cfg={
108+
"cluster_name": cluster.config.name,
109+
"dashboard_address": "localhost:8265", # dashboard_address,
110+
"working_dir": self.config.working_dir,
111+
"requirements": self.config.requirements,
112+
},
113+
workspace=f"file://{Path.cwd()}"
115114
)
116115

117-
_, _, self.job_id = parse_app_handle(_app_handle)
116+
def submit(self, cluster: "Cluster") -> "TorchXRayJob":
117+
"""
118+
Submit the job definition to a specific cluster, resulting in a Job object.
119+
"""
120+
return TorchXRayJob(self, cluster)
121+
122+
123+
class TorchXRayJob(Job):
124+
"""
125+
Active submission of a dist.ddp job to a Ray cluster which can be used to get logs and status.
126+
"""
127+
def __init__(self, job_definition: TorchXJobDefinition, cluster: "Cluster", *script_args):
128+
"""
129+
Creates job which maximizes resource usage on the passed cluster.
130+
"""
131+
self.job_definition: TorchXJobDefinition = job_definition
132+
self.cluster: "Cluster" = cluster
133+
# dashboard_address = f"{self.cluster.cluster_dashboard_uri(self.cluster.config.namespace).lstrip('http://')}:8265"
134+
self._app_handle = torchx_runner.schedule(job_definition._dry_run(cluster, *script_args))
135+
_, _, self.job_id = parse_app_handle(self._app_handle)
136+
# self.job_id = self.job_id.lstrip(f"{dashboard_address}-")
118137
all_jobs.append(self)
119138

120-
def status(self):
139+
def status(self) -> str:
121140
"""
122-
TODO
141+
Get running job status.
123142
"""
124-
dashboard_route = self.cluster.cluster_dashboard_uri()
125-
client = JobSubmissionClient(dashboard_route)
126-
return client.get_job_status(self.job_id)
143+
return torchx_runner.status(self._app_handle)
127144

128-
def logs(self):
145+
def logs(self) -> str:
129146
"""
130-
TODO
147+
Get job logs.
131148
"""
132-
dashboard_route = self.cluster_dashboard_uri(namespace=self.config.namespace)
133-
client = JobSubmissionClient(dashboard_route)
134-
return client.get_job_logs(job_id)
149+
return "".join(torchx_runner.log_lines(self._app_handle, None))

0 commit comments

Comments
 (0)