diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index b761c5316..a109cb3f9 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -2,28 +2,36 @@ from ..utils.pretty_print import RayCluster from ..utils import pretty_print import openshift as oc -from typing import List +from typing import List, Optional + class Cluster: def __init__(self, config: ClusterConfiguration): - pass + self.config = config + # creates a new cluser with the provided or default spec def up(self): pass def down(self, name): pass - def status(self, name): - pass - + def status(self, print_to_console=True): + cluster = _ray_cluster_status(self.config.name) + if cluster: + if print_to_console: + pretty_print.print_clusters([cluster]) + return cluster.status + else: + return None + def list_all_clusters(print_to_console=True): clusters = _get_ray_clusters() if print_to_console: pretty_print.print_clusters(clusters) return clusters - + # private methods @@ -33,18 +41,39 @@ def _get_appwrappers(namespace='default'): return app_wrappers +def _ray_cluster_status(name, namespace='default') -> Optional[RayCluster]: + + with oc.project(namespace), oc.timeout(10*60): + cluster = oc.selector(f'rayclusters/{name}').object() + + if cluster: + return _map_to_ray_cluster(cluster) + else: + return None + + + + def _get_ray_clusters(namespace='default') -> List[RayCluster]: list_of_clusters = [] + with oc.project(namespace), oc.timeout(10*60): ray_clusters = oc.selector('rayclusters').objects() - for cluster in ray_clusters: - cluster_model = cluster.model - list_of_clusters.append(RayCluster( - name=cluster.name(), status=cluster_model.status.state, - min_workers=cluster_model.spec.workerGroupSpecs[0].replicas, - max_workers=cluster_model.spec.workerGroupSpecs[0].replicas, - worker_mem_max=cluster_model.spec.workerGroupSpecs[0].template.spec.containers[0].resources.limits.memory, - worker_mem_min=cluster_model.spec.workerGroupSpecs[0].template.spec.containers[0].resources.requests.memory, - worker_cpu=cluster_model.spec.workerGroupSpecs[0].template.spec.containers[0].resources.limits.cpu, - worker_gpu=0)) + + for cluster in ray_clusters: + list_of_clusters.append(_map_to_ray_cluster(cluster)) return list_of_clusters + + +def _map_to_ray_cluster(cluster): + cluster_model = cluster.model + return RayCluster( + name=cluster.name(), status=cluster_model.status.state, + min_workers=cluster_model.spec.workerGroupSpecs[0].replicas, + max_workers=cluster_model.spec.workerGroupSpecs[0].replicas, + worker_mem_max=cluster_model.spec.workerGroupSpecs[ + 0].template.spec.containers[0].resources.limits.memory, + worker_mem_min=cluster_model.spec.workerGroupSpecs[ + 0].template.spec.containers[0].resources.requests.memory, + worker_cpu=cluster_model.spec.workerGroupSpecs[0].template.spec.containers[0].resources.limits.cpu, + worker_gpu=0) diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index f1bce9fa9..657c5ec87 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -2,5 +2,8 @@ @dataclass class ClusterConfiguration: - min_cpus: int - max_cpus: int \ No newline at end of file + name: str + min_cpus: int = 1 + max_cpus: int = 1 + min_worker: int = 0 + max_worker: int = 1 diff --git a/src/codeflare_sdk/utils/pretty_print.py b/src/codeflare_sdk/utils/pretty_print.py index 2a927110a..6d1afae8f 100644 --- a/src/codeflare_sdk/utils/pretty_print.py +++ b/src/codeflare_sdk/utils/pretty_print.py @@ -20,10 +20,16 @@ class RayCluster: worker_cpu: int worker_gpu: int +def _print_no_cluster_found(): + pass + def print_clusters(clusters:List[RayCluster], verbose=True): console = Console() - title_printed = False + #FIXME handle case where no clusters are found + if len(clusters) == 0: + _print_no_cluster_found() + return #exit early for cluster in clusters: status = "Active :white_heavy_check_mark:" if cluster.status.lower() == 'ready' else "InActive :x:" diff --git a/tests/test_clusters.py b/tests/test_clusters.py index 7f65410e8..3a7f4b88e 100644 --- a/tests/test_clusters.py +++ b/tests/test_clusters.py @@ -1,6 +1,13 @@ -from codeflare_sdk.cluster.cluster import _get_ray_clusters -from codeflare_sdk.utils.pretty_print import print_clusters +from codeflare_sdk.cluster.cluster import list_all_clusters +from codeflare_sdk.cluster.cluster import Cluster, ClusterConfiguration + +#for now these tests assume that the cluster was already created def test_list_clusters(): - clusters = _get_ray_clusters() - print_clusters(clusters) + clusters = list_all_clusters() + +def test_cluster_status(): + cluster = Cluster(ClusterConfiguration(name='raycluster-autoscaler')) + cluster.status() + +