Skip to content

Commit cf1fe4a

Browse files
committed
add functions for creating ray with oauth proxy in front of the dashboard
Signed-off-by: Kevin <[email protected]>
1 parent 06a3a59 commit cf1fe4a

File tree

7 files changed

+540
-126
lines changed

7 files changed

+540
-126
lines changed

src/codeflare_sdk/cluster/auth.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import urllib3
2626
from ..utils.kube_api_helpers import _kube_api_error_handling
2727

28+
from typing import Optional
29+
2830
global api_client
2931
api_client = None
3032
global config_path
@@ -188,7 +190,7 @@ def config_check() -> str:
188190
return config_path
189191

190192

191-
def api_config_handler() -> str:
193+
def api_config_handler() -> Optional[client.ApiClient]:
192194
"""
193195
This function is used to load the api client if the user has logged in
194196
"""

src/codeflare_sdk/cluster/cluster.py

+83-14
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,20 @@
2121
from time import sleep
2222
from typing import List, Optional, Tuple, Dict
2323

24+
import openshift as oc
25+
from kubernetes import config
2426
from ray.job_submission import JobSubmissionClient
27+
import urllib3
2528

2629
from .auth import config_check, api_config_handler
2730
from ..utils import pretty_print
2831
from ..utils.generate_yaml import generate_appwrapper
2932
from ..utils.kube_api_helpers import _kube_api_error_handling
33+
from ..utils.openshift_oauth import (
34+
create_openshift_oauth_objects,
35+
delete_openshift_oauth_objects,
36+
download_tls_cert,
37+
)
3038
from .config import ClusterConfiguration
3139
from .model import (
3240
AppWrapper,
@@ -40,6 +48,8 @@
4048
import os
4149
import requests
4250

51+
from kubernetes import config
52+
4353

4454
class Cluster:
4555
"""
@@ -61,6 +71,39 @@ def __init__(self, config: ClusterConfiguration):
6171
self.config = config
6272
self.app_wrapper_yaml = self.create_app_wrapper()
6373
self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0]
74+
self._client = None
75+
76+
@property
77+
def _client_headers(self):
78+
k8_client = api_config_handler() or client.ApiClient()
79+
return {
80+
"Authorization": k8_client.configuration.get_api_key_with_prefix(
81+
"authorization"
82+
)
83+
}
84+
85+
@property
86+
def _client_verify_tls(self):
87+
return not self.config.openshift_oauth
88+
89+
@property
90+
def client(self):
91+
if self._client:
92+
return self._client
93+
if self.config.openshift_oauth:
94+
print(
95+
api_config_handler().configuration.get_api_key_with_prefix(
96+
"authorization"
97+
)
98+
)
99+
self._client = JobSubmissionClient(
100+
self.cluster_dashboard_uri(),
101+
headers=self._client_headers,
102+
verify=self._client_verify_tls,
103+
)
104+
else:
105+
self._client = JobSubmissionClient(self.cluster_dashboard_uri())
106+
return self._client
64107

65108
def evaluate_dispatch_priority(self):
66109
priority_class = self.config.dispatch_priority
@@ -147,6 +190,7 @@ def create_app_wrapper(self):
147190
image_pull_secrets=image_pull_secrets,
148191
dispatch_priority=dispatch_priority,
149192
priority_val=priority_val,
193+
openshift_oauth=self.config.openshift_oauth,
150194
)
151195

152196
# creates a new cluster with the provided or default spec
@@ -156,6 +200,11 @@ def up(self):
156200
the MCAD queue.
157201
"""
158202
namespace = self.config.namespace
203+
if self.config.openshift_oauth:
204+
create_openshift_oauth_objects(
205+
cluster_name=self.config.name, namespace=namespace
206+
)
207+
159208
try:
160209
config_check()
161210
api_instance = client.CustomObjectsApi(api_config_handler())
@@ -190,6 +239,11 @@ def down(self):
190239
except Exception as e: # pragma: no cover
191240
return _kube_api_error_handling(e)
192241

242+
if self.config.openshift_oauth:
243+
delete_openshift_oauth_objects(
244+
cluster_name=self.config.name, namespace=namespace
245+
)
246+
193247
def status(
194248
self, print_to_console: bool = True
195249
) -> Tuple[CodeFlareClusterStatus, bool]:
@@ -258,7 +312,16 @@ def status(
258312
return status, ready
259313

260314
def is_dashboard_ready(self) -> bool:
261-
response = requests.get(self.cluster_dashboard_uri(), timeout=5)
315+
try:
316+
response = requests.get(
317+
self.cluster_dashboard_uri(),
318+
headers=self._client_headers,
319+
timeout=5,
320+
verify=self._client_verify_tls,
321+
)
322+
except requests.exceptions.SSLError:
323+
# SSL exception occurs when oauth ingress has been created but cluster is not up
324+
return False
262325
if response.status_code == 200:
263326
return True
264327
else:
@@ -330,7 +393,13 @@ def cluster_dashboard_uri(self) -> str:
330393
return _kube_api_error_handling(e)
331394

332395
for route in routes["items"]:
333-
if route["metadata"]["name"] == f"ray-dashboard-{self.config.name}":
396+
if route["metadata"][
397+
"name"
398+
] == f"ray-dashboard-{self.config.name}" or route["metadata"][
399+
"name"
400+
].startswith(
401+
f"{self.config.name}-ingress"
402+
):
334403
protocol = "https" if route["spec"].get("tls") else "http"
335404
return f"{protocol}://{route['spec']['host']}"
336405
return "Dashboard route not available yet, have you run cluster.up()?"
@@ -339,30 +408,24 @@ def list_jobs(self) -> List:
339408
"""
340409
This method accesses the head ray node in your cluster and lists the running jobs.
341410
"""
342-
dashboard_route = self.cluster_dashboard_uri()
343-
client = JobSubmissionClient(dashboard_route)
344-
return client.list_jobs()
411+
return self.client.list_jobs()
345412

346413
def job_status(self, job_id: str) -> str:
347414
"""
348415
This method accesses the head ray node in your cluster and returns the job status for the provided job id.
349416
"""
350-
dashboard_route = self.cluster_dashboard_uri()
351-
client = JobSubmissionClient(dashboard_route)
352-
return client.get_job_status(job_id)
417+
return self.client.get_job_status(job_id)
353418

354419
def job_logs(self, job_id: str) -> str:
355420
"""
356421
This method accesses the head ray node in your cluster and returns the logs for the provided job id.
357422
"""
358-
dashboard_route = self.cluster_dashboard_uri()
359-
client = JobSubmissionClient(dashboard_route)
360-
return client.get_job_logs(job_id)
423+
return self.client.get_job_logs(job_id)
361424

362425
def torchx_config(
363426
self, working_dir: str = None, requirements: str = None
364427
) -> Dict[str, str]:
365-
dashboard_address = f"{self.cluster_dashboard_uri().lstrip('http://')}"
428+
dashboard_address = urllib3.util.parse_url(self.cluster_dashboard_uri()).host
366429
to_return = {
367430
"cluster_name": self.config.name,
368431
"dashboard_address": dashboard_address,
@@ -591,7 +654,7 @@ def _get_app_wrappers(
591654

592655

593656
def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
594-
if "status" in rc and "state" in rc["status"]:
657+
if "state" in rc["status"]:
595658
status = RayClusterStatus(rc["status"]["state"].lower())
596659
else:
597660
status = RayClusterStatus.UNKNOWN
@@ -606,7 +669,13 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
606669
)
607670
ray_route = None
608671
for route in routes["items"]:
609-
if route["metadata"]["name"] == f"ray-dashboard-{rc['metadata']['name']}":
672+
if route["metadata"][
673+
"name"
674+
] == f"ray-dashboard-{rc['metadata']['name']}" or route["metadata"][
675+
"name"
676+
].startswith(
677+
f"{rc['metadata']['name']}-ingress"
678+
):
610679
protocol = "https" if route["spec"].get("tls") else "http"
611680
ray_route = f"{protocol}://{route['spec']['host']}"
612681

src/codeflare_sdk/cluster/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,4 @@ class ClusterConfiguration:
5151
local_interactive: bool = False
5252
image_pull_secrets: list = field(default_factory=list)
5353
dispatch_priority: str = None
54+
openshift_oauth: bool = False # NOTE: to use the user must have permission to create a RoleBinding for system:auth-delegator

0 commit comments

Comments
 (0)