Skip to content

Commit c064fec

Browse files
committed
add functions for creating ray with oauth proxy in front of the dashboard
Signed-off-by: Kevin <[email protected]>
1 parent c2013ba commit c064fec

File tree

6 files changed

+441
-82
lines changed

6 files changed

+441
-82
lines changed

src/codeflare_sdk/cluster/cluster.py

+69-13
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,20 @@
2121
from time import sleep
2222
from typing import List, Optional, Tuple, Dict
2323

24+
import openshift as oc
25+
from kubernetes import config
2426
from ray.job_submission import JobSubmissionClient
27+
import urllib3
2528

2629
from .auth import config_check, api_config_handler
2730
from ..utils import pretty_print
2831
from ..utils.generate_yaml import generate_appwrapper
2932
from ..utils.kube_api_helpers import _kube_api_error_handling
33+
from ..utils.openshift_oauth import (
34+
create_openshift_oauth_objects,
35+
delete_openshift_oauth_objects,
36+
download_tls_cert,
37+
)
3038
from .config import ClusterConfiguration
3139
from .model import (
3240
AppWrapper,
@@ -41,6 +49,9 @@
4149
import requests
4250

4351

52+
k8_client = config.new_client_from_config()
53+
54+
4455
class Cluster:
4556
"""
4657
An object for requesting, bringing up, and taking down resources.
@@ -61,6 +72,25 @@ def __init__(self, config: ClusterConfiguration):
6172
self.config = config
6273
self.app_wrapper_yaml = self.create_app_wrapper()
6374
self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0]
75+
self._client = None
76+
77+
@property
78+
def client(self):
79+
if self._client:
80+
return self._client
81+
if self.config.openshift_oauth:
82+
self._client = JobSubmissionClient(
83+
self.cluster_dashboard_uri(),
84+
headers={
85+
"Authorization": k8_client.configuration.auth_settings()[
86+
"BearerToken"
87+
]["value"]
88+
},
89+
verify=False,
90+
)
91+
else:
92+
self._client = JobSubmissionClient(self.cluster_dashboard_uri())
93+
return self._client
6494

6595
def evaluate_dispatch_priority(self):
6696
priority_class = self.config.dispatch_priority
@@ -141,6 +171,7 @@ def create_app_wrapper(self):
141171
image_pull_secrets=image_pull_secrets,
142172
dispatch_priority=dispatch_priority,
143173
priority_val=priority_val,
174+
openshift_oauth=self.config.openshift_oauth,
144175
)
145176

146177
# creates a new cluster with the provided or default spec
@@ -150,6 +181,11 @@ def up(self):
150181
the MCAD queue.
151182
"""
152183
namespace = self.config.namespace
184+
if self.config.openshift_oauth:
185+
create_openshift_oauth_objects(
186+
cluster_name=self.config.name, namespace=namespace
187+
)
188+
153189
try:
154190
config_check()
155191
api_instance = client.CustomObjectsApi(api_config_handler())
@@ -184,6 +220,11 @@ def down(self):
184220
except Exception as e: # pragma: no cover
185221
return _kube_api_error_handling(e)
186222

223+
if self.config.openshift_oauth:
224+
delete_openshift_oauth_objects(
225+
cluster_name=self.config.name, namespace=namespace
226+
)
227+
187228
def status(
188229
self, print_to_console: bool = True
189230
) -> Tuple[CodeFlareClusterStatus, bool]:
@@ -252,7 +293,16 @@ def status(
252293
return status, ready
253294

254295
def is_dashboard_ready(self) -> bool:
255-
response = requests.get(self.cluster_dashboard_uri(), timeout=5)
296+
try:
297+
response = requests.get(
298+
self.cluster_dashboard_uri(),
299+
headers=self.client._headers,
300+
timeout=5,
301+
verify=self.client._verify,
302+
)
303+
except requests.exceptions.SSLError:
304+
# SSL exception occurs when oauth ingress has been created but cluster is not up
305+
return False
256306
if response.status_code == 200:
257307
return True
258308
else:
@@ -311,7 +361,13 @@ def cluster_dashboard_uri(self) -> str:
311361
return _kube_api_error_handling(e)
312362

313363
for route in routes["items"]:
314-
if route["metadata"]["name"] == f"ray-dashboard-{self.config.name}":
364+
if route["metadata"][
365+
"name"
366+
] == f"ray-dashboard-{self.config.name}" or route["metadata"][
367+
"name"
368+
].startswith(
369+
f"{self.config.name}-ingress"
370+
):
315371
protocol = "https" if route["spec"].get("tls") else "http"
316372
return f"{protocol}://{route['spec']['host']}"
317373
return "Dashboard route not available yet, have you run cluster.up()?"
@@ -320,30 +376,24 @@ def list_jobs(self) -> List:
320376
"""
321377
This method accesses the head ray node in your cluster and lists the running jobs.
322378
"""
323-
dashboard_route = self.cluster_dashboard_uri()
324-
client = JobSubmissionClient(dashboard_route)
325-
return client.list_jobs()
379+
return self.client.list_jobs()
326380

327381
def job_status(self, job_id: str) -> str:
328382
"""
329383
This method accesses the head ray node in your cluster and returns the job status for the provided job id.
330384
"""
331-
dashboard_route = self.cluster_dashboard_uri()
332-
client = JobSubmissionClient(dashboard_route)
333-
return client.get_job_status(job_id)
385+
return self.client.get_job_status(job_id)
334386

335387
def job_logs(self, job_id: str) -> str:
336388
"""
337389
This method accesses the head ray node in your cluster and returns the logs for the provided job id.
338390
"""
339-
dashboard_route = self.cluster_dashboard_uri()
340-
client = JobSubmissionClient(dashboard_route)
341-
return client.get_job_logs(job_id)
391+
return self.client.get_job_logs(job_id)
342392

343393
def torchx_config(
344394
self, working_dir: str = None, requirements: str = None
345395
) -> Dict[str, str]:
346-
dashboard_address = f"{self.cluster_dashboard_uri().lstrip('http://')}"
396+
dashboard_address = urllib3.util.parse_url(self.cluster_dashboard_uri()).host
347397
to_return = {
348398
"cluster_name": self.config.name,
349399
"dashboard_address": dashboard_address,
@@ -587,7 +637,13 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
587637
)
588638
ray_route = None
589639
for route in routes["items"]:
590-
if route["metadata"]["name"] == f"ray-dashboard-{rc['metadata']['name']}":
640+
if route["metadata"][
641+
"name"
642+
] == f"ray-dashboard-{rc['metadata']['name']}" or route["metadata"][
643+
"name"
644+
].startswith(
645+
f"{rc['metadata']['name']}-ingress"
646+
):
591647
protocol = "https" if route["spec"].get("tls") else "http"
592648
ray_route = f"{protocol}://{route['spec']['host']}"
593649

src/codeflare_sdk/cluster/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ class ClusterConfiguration:
4848
local_interactive: bool = False
4949
image_pull_secrets: list = field(default_factory=list)
5050
dispatch_priority: str = None
51+
openshift_oauth: bool = False # NOTE: to use the user must have permission to create ClusterRoleBindings

src/codeflare_sdk/job/jobs.py

+86-66
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,20 @@
1818
from pathlib import Path
1919

2020
from torchx.components.dist import ddp
21-
from torchx.runner import get_runner
21+
from torchx.runner import get_runner, Runner
22+
from torchx.schedulers.ray_scheduler import RayScheduler
2223
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
2324

25+
from ray.job_submission import JobSubmissionClient
26+
27+
import openshift as oc
28+
2429
if TYPE_CHECKING:
2530
from ..cluster.cluster import Cluster
2631
from ..cluster.cluster import get_current_namespace
32+
from ..utils.openshift_oauth import download_tls_cert
2733

2834
all_jobs: List["Job"] = []
29-
torchx_runner = get_runner()
3035

3136

3237
class JobDefinition(metaclass=abc.ABCMeta):
@@ -92,30 +97,37 @@ def __init__(
9297

9398
def _dry_run(self, cluster: "Cluster"):
9499
j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus
95-
return torchx_runner.dryrun(
96-
app=ddp(
97-
*self.script_args,
98-
script=self.script,
99-
m=self.m,
100-
name=self.name,
101-
h=self.h,
102-
cpu=self.cpu if self.cpu is not None else cluster.config.max_cpus,
103-
gpu=self.gpu if self.gpu is not None else cluster.config.num_gpus,
104-
memMB=self.memMB
105-
if self.memMB is not None
106-
else cluster.config.max_memory * 1024,
107-
j=self.j if self.j is not None else j,
108-
env=self.env,
109-
max_retries=self.max_retries,
110-
rdzv_port=self.rdzv_port,
111-
rdzv_backend=self.rdzv_backend
112-
if self.rdzv_backend is not None
113-
else "static",
114-
mounts=self.mounts,
100+
runner = get_runner(ray_client=cluster.client)
101+
runner._scheduler_instances["ray"] = RayScheduler(
102+
session_name=runner._name, ray_client=cluster.client
103+
)
104+
return (
105+
runner.dryrun(
106+
app=ddp(
107+
*self.script_args,
108+
script=self.script,
109+
m=self.m,
110+
name=self.name,
111+
h=self.h,
112+
cpu=self.cpu if self.cpu is not None else cluster.config.max_cpus,
113+
gpu=self.gpu if self.gpu is not None else cluster.config.num_gpus,
114+
memMB=self.memMB
115+
if self.memMB is not None
116+
else cluster.config.max_memory * 1024,
117+
j=self.j if self.j is not None else j,
118+
env=self.env,
119+
max_retries=self.max_retries,
120+
rdzv_port=self.rdzv_port,
121+
rdzv_backend=self.rdzv_backend
122+
if self.rdzv_backend is not None
123+
else "static",
124+
mounts=self.mounts,
125+
),
126+
scheduler=cluster.torchx_scheduler,
127+
cfg=cluster.torchx_config(**self.scheduler_args),
128+
workspace=self.workspace,
115129
),
116-
scheduler=cluster.torchx_scheduler,
117-
cfg=cluster.torchx_config(**self.scheduler_args),
118-
workspace=self.workspace,
130+
runner,
119131
)
120132

121133
def _missing_spec(self, spec: str):
@@ -125,41 +137,47 @@ def _dry_run_no_cluster(self):
125137
if self.scheduler_args is not None:
126138
if self.scheduler_args.get("namespace") is None:
127139
self.scheduler_args["namespace"] = get_current_namespace()
128-
return torchx_runner.dryrun(
129-
app=ddp(
130-
*self.script_args,
131-
script=self.script,
132-
m=self.m,
133-
name=self.name if self.name is not None else self._missing_spec("name"),
134-
h=self.h,
135-
cpu=self.cpu
136-
if self.cpu is not None
137-
else self._missing_spec("cpu (# cpus per worker)"),
138-
gpu=self.gpu
139-
if self.gpu is not None
140-
else self._missing_spec("gpu (# gpus per worker)"),
141-
memMB=self.memMB
142-
if self.memMB is not None
143-
else self._missing_spec("memMB (memory in MB)"),
144-
j=self.j
145-
if self.j is not None
146-
else self._missing_spec(
147-
"j (`workers`x`procs`)"
148-
), # # of proc. = # of gpus,
149-
env=self.env, # should this still exist?
150-
max_retries=self.max_retries,
151-
rdzv_port=self.rdzv_port, # should this still exist?
152-
rdzv_backend=self.rdzv_backend
153-
if self.rdzv_backend is not None
154-
else "c10d",
155-
mounts=self.mounts,
156-
image=self.image
157-
if self.image is not None
158-
else self._missing_spec("image"),
140+
runner = get_runner()
141+
return (
142+
runner.dryrun(
143+
app=ddp(
144+
*self.script_args,
145+
script=self.script,
146+
m=self.m,
147+
name=self.name
148+
if self.name is not None
149+
else self._missing_spec("name"),
150+
h=self.h,
151+
cpu=self.cpu
152+
if self.cpu is not None
153+
else self._missing_spec("cpu (# cpus per worker)"),
154+
gpu=self.gpu
155+
if self.gpu is not None
156+
else self._missing_spec("gpu (# gpus per worker)"),
157+
memMB=self.memMB
158+
if self.memMB is not None
159+
else self._missing_spec("memMB (memory in MB)"),
160+
j=self.j
161+
if self.j is not None
162+
else self._missing_spec(
163+
"j (`workers`x`procs`)"
164+
), # # of proc. = # of gpus,
165+
env=self.env, # should this still exist?
166+
max_retries=self.max_retries,
167+
rdzv_port=self.rdzv_port, # should this still exist?
168+
rdzv_backend=self.rdzv_backend
169+
if self.rdzv_backend is not None
170+
else "c10d",
171+
mounts=self.mounts,
172+
image=self.image
173+
if self.image is not None
174+
else self._missing_spec("image"),
175+
),
176+
scheduler="kubernetes_mcad",
177+
cfg=self.scheduler_args,
178+
workspace="",
159179
),
160-
scheduler="kubernetes_mcad",
161-
cfg=self.scheduler_args,
162-
workspace="",
180+
runner,
163181
)
164182

165183
def submit(self, cluster: "Cluster" = None) -> "Job":
@@ -171,18 +189,20 @@ def __init__(self, job_definition: "DDPJobDefinition", cluster: "Cluster" = None
171189
self.job_definition = job_definition
172190
self.cluster = cluster
173191
if self.cluster:
174-
self._app_handle = torchx_runner.schedule(job_definition._dry_run(cluster))
192+
definition, runner = job_definition._dry_run(cluster)
193+
self._app_handle = runner.schedule(definition)
194+
self._runner = runner
175195
else:
176-
self._app_handle = torchx_runner.schedule(
177-
job_definition._dry_run_no_cluster()
178-
)
196+
definition, runner = job_definition._dry_run_no_cluster()
197+
self._app_handle = runner.schedule(definition)
198+
self._runner = runner
179199
all_jobs.append(self)
180200

181201
def status(self) -> str:
182-
return torchx_runner.status(self._app_handle)
202+
return self._runner.status(self._app_handle)
183203

184204
def logs(self) -> str:
185-
return "".join(torchx_runner.log_lines(self._app_handle, None))
205+
return "".join(self._runner.log_lines(self._app_handle, None))
186206

187207
def cancel(self):
188-
torchx_runner.cancel(self._app_handle)
208+
self._runner.cancel(self._app_handle)

0 commit comments

Comments
 (0)