Skip to content

Commit 99afd96

Browse files
committed
expose more options to the user optionally
Signed-off-by: Kevin <[email protected]>
1 parent c132606 commit 99afd96

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

src/codeflare_sdk/jobs/config.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@
1919
"""
2020

2121
from dataclasses import dataclass, field
22-
from ..cluster.cluster import Cluster
22+
from typing import Optional, Dict
2323

24-
25-
@dataclass
24+
@dataclass(slots=True)
2625
class JobConfiguration:
2726
"""
2827
This dataclass is used to specify resource requirements and other details, and
2928
is passed in as an argument when creating a Job object.
3029
"""
3130

32-
name: str = None
33-
script: str = None
34-
requirements: str = None
35-
scheduler: str = "ray"
31+
name: Optional[str] = None
32+
script: Optional[str] = None
33+
m: Optional[str] = None
34+
h: Optional[str] = None # custom resource types
35+
env: Optional[Dict[str, str]] = None
36+
working_dir: Optional[str] = None
37+
requirements: Optional[str] = None
38+
debug: bool = False

src/codeflare_sdk/jobs/jobs.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,12 @@
3535
all_jobs: List["Job"] = []
3636
torchx_runner: Runner = get_runner()
3737

38-
torchx_runner.run(app, scheduler, cfg, workspace)
39-
40-
torchx_runner.run_component(component, component_args, scheduler, cfg, workspace)
41-
4238
class JobDefinition(metaclass=abc.ABCMeta):
4339
"""
4440
A job definition to be submitted to a generic backend cluster.
4541
"""
4642

47-
def submit(self, cluster: Cluster):
43+
def submit(self, cluster: "Cluster"):
4844
"""
4945
Method for creating a job on a specific cluster
5046
"""
@@ -76,7 +72,7 @@ def __init__(self, config: JobConfiguration):
7672
"""
7773
self.config = config
7874

79-
def submit(self, cluster: Cluster) -> "TorchXRayJob":
75+
def submit(self, cluster: "Cluster") -> "TorchXRayJob":
8076
"""
8177
Submit the job definition to a specific cluster, resulting in a Job object.
8278
"""
@@ -87,48 +83,53 @@ class TorchXRayJob(Job):
8783
"""
8884
Active submission of a dist.ddp job to a Ray cluster which can be used to get logs and status.
8985
"""
90-
def __init__(self, job_definition: TorchXJobDefinition, cluster: Cluster, *script_args):
86+
def __init__(self, job_definition: TorchXJobDefinition, cluster: "Cluster", *script_args):
9187
"""
92-
TODO
88+
Creates job which maximizes resource usage on the passed cluster.
9389
"""
9490
self.job_definition: TorchXJobDefinition = job_definition
95-
self.cluster: Cluster = cluster
91+
self.cluster: "Cluster" = cluster
9692
j = f"{cluster.config.max_worker}x{max(cluster.config.gpu, 1)}" # # of proc. = # of gpus
97-
# TODO: allow user to override resource allocation for job
9893
_app_handle: AppHandle = torchx_runner.run(
9994
app=ddp(
10095
*script_args,
10196
script = job_definition.config.script,
102-
m=None, # python module to run (might be worth exposing)
103-
name = job_definition.config.name,
97+
m=job_definition.config.m,
98+
name=job_definition.config.name,
10499
h=None, # for custom resource types
105-
cpu=cluster.config.max_cpus, # TODO: get from cluster config
100+
cpu=cluster.config.max_cpus,
106101
gpu = cluster.config.gpu,
107102
memMB = 1024 * cluster.config.max_memory, # cluster memory is in GB
108103
j=j,
109-
env=None, # TODO: should definitely expose Dict[str, str]
110-
max_retries = 0, # TODO: maybe expose
111-
mounts=None, # TODO: maybe expose
112-
debug=False # TODO: expose
104+
env=job_definition.config.env,
105+
debug=job_definition.config.debug # TODO: expose
106+
# max_retries=0, # default
107+
# mounts=None, # default
113108
),
114-
scheduler="ray", cfg="fo",
109+
scheduler="ray", # can be determined by type of cluster if more are introduced
110+
cfg={
111+
"cluster_name": self.cluster.config.name,
112+
"dashboard_address": self.cluster.cluster_dashboard_uri(self.cluster.config.namespace).split(":")[0],
113+
"working_dir": job_definition.config.working_dir,
114+
"requirements": job_definition.config.requirements,
115+
},
115116
)
116117

117118
_, _, self.job_id = parse_app_handle(_app_handle)
118119
all_jobs.append(self)
119120

120-
def status(self):
121+
def status(self) -> str:
121122
"""
122-
TODO
123+
Get running job status.
123124
"""
124-
dashboard_route = self.cluster.cluster_dashboard_uri()
125+
dashboard_route = self.cluster.cluster_dashboard_uri(namespace=self.cluster.config.namespace)
125126
client = JobSubmissionClient(dashboard_route)
126127
return client.get_job_status(self.job_id)
127128

128-
def logs(self):
129+
def logs(self) -> str:
129130
"""
130-
TODO
131+
Get job logs.
131132
"""
132-
dashboard_route = self.cluster_dashboard_uri(namespace=self.config.namespace)
133+
dashboard_route = self.cluster_dashboard_uri(namespace=self.cluster.config.namespace)
133134
client = JobSubmissionClient(dashboard_route)
134-
return client.get_job_logs(job_id)
135+
return client.get_job_logs(self.job_id)

0 commit comments

Comments
 (0)