18
18
from pathlib import Path
19
19
20
20
from torchx .components .dist import ddp
21
- from torchx .runner import get_runner
21
+ from torchx .runner import get_runner , Runner
22
+ from torchx .schedulers .ray_scheduler import RayScheduler
22
23
from torchx .specs import AppHandle , parse_app_handle , AppDryRunInfo
23
24
25
+ from ray .job_submission import JobSubmissionClient
26
+
27
+ import openshift as oc
28
+
24
29
if TYPE_CHECKING :
25
30
from ..cluster .cluster import Cluster
26
31
from ..cluster .cluster import get_current_namespace
32
+ from ..utils .openshift_oauth import download_tls_cert
27
33
28
34
all_jobs : List ["Job" ] = []
29
- torchx_runner = get_runner ()
30
35
31
36
32
37
class JobDefinition (metaclass = abc .ABCMeta ):
@@ -92,30 +97,37 @@ def __init__(
92
97
93
98
def _dry_run (self , cluster : "Cluster" ):
94
99
j = f"{ cluster .config .num_workers } x{ max (cluster .config .num_gpus , 1 )} " # # of proc. = # of gpus
95
- return torchx_runner .dryrun (
96
- app = ddp (
97
- * self .script_args ,
98
- script = self .script ,
99
- m = self .m ,
100
- name = self .name ,
101
- h = self .h ,
102
- cpu = self .cpu if self .cpu is not None else cluster .config .max_cpus ,
103
- gpu = self .gpu if self .gpu is not None else cluster .config .num_gpus ,
104
- memMB = self .memMB
105
- if self .memMB is not None
106
- else cluster .config .max_memory * 1024 ,
107
- j = self .j if self .j is not None else j ,
108
- env = self .env ,
109
- max_retries = self .max_retries ,
110
- rdzv_port = self .rdzv_port ,
111
- rdzv_backend = self .rdzv_backend
112
- if self .rdzv_backend is not None
113
- else "static" ,
114
- mounts = self .mounts ,
100
+ runner = get_runner (ray_client = cluster .client )
101
+ runner ._scheduler_instances ["ray" ] = RayScheduler (
102
+ session_name = runner ._name , ray_client = cluster .client
103
+ )
104
+ return (
105
+ runner .dryrun (
106
+ app = ddp (
107
+ * self .script_args ,
108
+ script = self .script ,
109
+ m = self .m ,
110
+ name = self .name ,
111
+ h = self .h ,
112
+ cpu = self .cpu if self .cpu is not None else cluster .config .max_cpus ,
113
+ gpu = self .gpu if self .gpu is not None else cluster .config .num_gpus ,
114
+ memMB = self .memMB
115
+ if self .memMB is not None
116
+ else cluster .config .max_memory * 1024 ,
117
+ j = self .j if self .j is not None else j ,
118
+ env = self .env ,
119
+ max_retries = self .max_retries ,
120
+ rdzv_port = self .rdzv_port ,
121
+ rdzv_backend = self .rdzv_backend
122
+ if self .rdzv_backend is not None
123
+ else "static" ,
124
+ mounts = self .mounts ,
125
+ ),
126
+ scheduler = cluster .torchx_scheduler ,
127
+ cfg = cluster .torchx_config (** self .scheduler_args ),
128
+ workspace = self .workspace ,
115
129
),
116
- scheduler = cluster .torchx_scheduler ,
117
- cfg = cluster .torchx_config (** self .scheduler_args ),
118
- workspace = self .workspace ,
130
+ runner ,
119
131
)
120
132
121
133
def _missing_spec (self , spec : str ):
@@ -125,41 +137,47 @@ def _dry_run_no_cluster(self):
125
137
if self .scheduler_args is not None :
126
138
if self .scheduler_args .get ("namespace" ) is None :
127
139
self .scheduler_args ["namespace" ] = get_current_namespace ()
128
- return torchx_runner .dryrun (
129
- app = ddp (
130
- * self .script_args ,
131
- script = self .script ,
132
- m = self .m ,
133
- name = self .name if self .name is not None else self ._missing_spec ("name" ),
134
- h = self .h ,
135
- cpu = self .cpu
136
- if self .cpu is not None
137
- else self ._missing_spec ("cpu (# cpus per worker)" ),
138
- gpu = self .gpu
139
- if self .gpu is not None
140
- else self ._missing_spec ("gpu (# gpus per worker)" ),
141
- memMB = self .memMB
142
- if self .memMB is not None
143
- else self ._missing_spec ("memMB (memory in MB)" ),
144
- j = self .j
145
- if self .j is not None
146
- else self ._missing_spec (
147
- "j (`workers`x`procs`)"
148
- ), # # of proc. = # of gpus,
149
- env = self .env , # should this still exist?
150
- max_retries = self .max_retries ,
151
- rdzv_port = self .rdzv_port , # should this still exist?
152
- rdzv_backend = self .rdzv_backend
153
- if self .rdzv_backend is not None
154
- else "c10d" ,
155
- mounts = self .mounts ,
156
- image = self .image
157
- if self .image is not None
158
- else self ._missing_spec ("image" ),
140
+ runner = get_runner ()
141
+ return (
142
+ runner .dryrun (
143
+ app = ddp (
144
+ * self .script_args ,
145
+ script = self .script ,
146
+ m = self .m ,
147
+ name = self .name
148
+ if self .name is not None
149
+ else self ._missing_spec ("name" ),
150
+ h = self .h ,
151
+ cpu = self .cpu
152
+ if self .cpu is not None
153
+ else self ._missing_spec ("cpu (# cpus per worker)" ),
154
+ gpu = self .gpu
155
+ if self .gpu is not None
156
+ else self ._missing_spec ("gpu (# gpus per worker)" ),
157
+ memMB = self .memMB
158
+ if self .memMB is not None
159
+ else self ._missing_spec ("memMB (memory in MB)" ),
160
+ j = self .j
161
+ if self .j is not None
162
+ else self ._missing_spec (
163
+ "j (`workers`x`procs`)"
164
+ ), # # of proc. = # of gpus,
165
+ env = self .env , # should this still exist?
166
+ max_retries = self .max_retries ,
167
+ rdzv_port = self .rdzv_port , # should this still exist?
168
+ rdzv_backend = self .rdzv_backend
169
+ if self .rdzv_backend is not None
170
+ else "c10d" ,
171
+ mounts = self .mounts ,
172
+ image = self .image
173
+ if self .image is not None
174
+ else self ._missing_spec ("image" ),
175
+ ),
176
+ scheduler = "kubernetes_mcad" ,
177
+ cfg = self .scheduler_args ,
178
+ workspace = "" ,
159
179
),
160
- scheduler = "kubernetes_mcad" ,
161
- cfg = self .scheduler_args ,
162
- workspace = "" ,
180
+ runner ,
163
181
)
164
182
165
183
def submit (self , cluster : "Cluster" = None ) -> "Job" :
@@ -171,18 +189,20 @@ def __init__(self, job_definition: "DDPJobDefinition", cluster: "Cluster" = None
171
189
self .job_definition = job_definition
172
190
self .cluster = cluster
173
191
if self .cluster :
174
- self ._app_handle = torchx_runner .schedule (job_definition ._dry_run (cluster ))
192
+ definition , runner = job_definition ._dry_run (cluster )
193
+ self ._app_handle = runner .schedule (definition )
194
+ self ._runner = runner
175
195
else :
176
- self . _app_handle = torchx_runner . schedule (
177
- job_definition . _dry_run_no_cluster ( )
178
- )
196
+ definition , runner = job_definition . _dry_run_no_cluster ()
197
+ self . _app_handle = runner . schedule ( definition )
198
+ self . _runner = runner
179
199
all_jobs .append (self )
180
200
181
201
def status (self ) -> str :
182
- return torchx_runner .status (self ._app_handle )
202
+ return self . _runner .status (self ._app_handle )
183
203
184
204
def logs (self ) -> str :
185
- return "" .join (torchx_runner .log_lines (self ._app_handle , None ))
205
+ return "" .join (self . _runner .log_lines (self ._app_handle , None ))
186
206
187
207
def cancel (self ):
188
- torchx_runner .cancel (self ._app_handle )
208
+ self . _runner .cancel (self ._app_handle )
0 commit comments