from ..utils import pretty_print
from ..utils.generate_yaml import (
generate_appwrapper,
+ head_worker_gpu_count_from_cluster,
)
from ..utils.kube_api_helpers import _kube_api_error_handling
from ..utils.generate_yaml import is_openshift_cluster
@@ -135,16 +136,6 @@
Module codeflare_sdk.cluster.cluster
)
return self._job_submission_client
- def validate_image_config(self):
- """
- Validates that the image configuration is not empty.
-
- :param image: The image string to validate
- :raises ValueError: If the image is not specified
- """
- if self.config.image == "" or self.config.image == None:
- raise ValueError("Image must be specified in the ClusterConfiguration")
-
def create_app_wrapper(self):
"""
Called upon cluster object creation, creates an AppWrapper yaml based on
@@ -160,51 +151,7 @@
if print_to_console:
# overriding the number of gpus with requested
- cluster.worker_gpu = self.config.num_gpus
+ _, cluster.worker_gpu = head_worker_gpu_count_from_cluster(self)
pretty_print.print_cluster_status(cluster)
elif print_to_console:
if status == CodeFlareClusterStatus.UNKNOWN:
@@ -488,6 +435,29 @@
worker_cpu=rc["spec"]["workerGroupSpecs"][0]["template"]["spec"]["containers"][
0
]["resources"]["limits"]["cpu"],
- worker_gpu=0, # hard to detect currently how many gpus, can override it with what the user asked for
+ worker_extended_resources=worker_extended_resources,
namespace=rc["metadata"]["namespace"],
head_cpus=rc["spec"]["headGroupSpec"]["template"]["spec"]["containers"][0][
"resources"
@@ -925,9 +902,7 @@
)
return self._job_submission_client
- def validate_image_config(self):
- """
- Validates that the image configuration is not empty.
-
- :param image: The image string to validate
- :raises ValueError: If the image is not specified
- """
- if self.config.image == "" or self.config.image == None:
- raise ValueError("Image must be specified in the ClusterConfiguration")
-
def create_app_wrapper(self):
"""
Called upon cluster object creation, creates an AppWrapper yaml based on
@@ -1206,51 +1171,7 @@
if print_to_console:
# overriding the number of gpus with requested
- cluster.worker_gpu = self.config.num_gpus
+ _, cluster.worker_gpu = head_worker_gpu_count_from_cluster(self)
pretty_print.print_cluster_status(cluster)
elif print_to_console:
if status == CodeFlareClusterStatus.UNKNOWN:
@@ -1534,6 +1455,29 @@
if print_to_console:
# overriding the number of gpus with requested
- cluster.worker_gpu = self.config.num_gpus
+ _, cluster.worker_gpu = head_worker_gpu_count_from_cluster(self)
pretty_print.print_cluster_status(cluster)
elif print_to_console:
if status == CodeFlareClusterStatus.UNKNOWN:
@@ -2124,28 +2028,6 @@
Methods
return _kube_api_error_handling(e)
-
-def validate_image_config(self)
-
-
-
Validates that the image configuration is not empty.
-
:param image: The image string to validate
-:raises ValueError: If the image is not specified
-
-
-Expand source code
-
-
def validate_image_config(self):
- """
- Validates that the image configuration is not empty.
-
- :param image: The image string to validate
- :raises ValueError: If the image is not specified
- """
- if self.config.image == "" or self.config.image == None:
- raise ValueError("Image must be specified in the ClusterConfiguration")
Cluster object.
"""
-from dataclasses import dataclass, field
import pathlib
-import typing
+import warnings
+from dataclasses import dataclass, field, fields
+from typing import Dict, List, Optional, Union, get_args, get_origin
dir = pathlib.Path(__file__).parent.parent.resolve()
+# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
+DEFAULT_RESOURCE_MAPPING = {
+ "nvidia.com/gpu": "GPU",
+ "intel.com/gpu": "GPU",
+ "amd.com/gpu": "GPU",
+ "aws.amazon.com/neuroncore": "neuron_cores",
+ "google.com/tpu": "TPU",
+ "habana.ai/gaudi": "HPU",
+ "huawei.com/Ascend910": "NPU",
+ "huawei.com/Ascend310": "NPU",
+}
+
@dataclass
class ClusterConfiguration:
"""
This dataclass is used to specify resource requirements and other details, and
is passed in as an argument when creating a Cluster object.
+
+ Attributes:
+ - name: The name of the cluster.
+ - namespace: The namespace in which the cluster should be created.
+ - head_info: A list of strings containing information about the head node.
+ - head_cpus: The number of CPUs to allocate to the head node.
+ - head_memory: The amount of memory to allocate to the head node.
+ - head_gpus: The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
+ - head_extended_resource_requests: A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
+ - machine_types: A list of machine types to use for the cluster.
+ - min_cpus: The minimum number of CPUs to allocate to each worker.
+ - max_cpus: The maximum number of CPUs to allocate to each worker.
+ - num_workers: The number of workers to create.
+ - min_memory: The minimum amount of memory to allocate to each worker.
+ - max_memory: The maximum amount of memory to allocate to each worker.
+ - num_gpus: The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
+ - template: The path to the template file to use for the cluster.
+ - appwrapper: A boolean indicating whether to use an AppWrapper.
+ - envs: A dictionary of environment variables to set for the cluster.
+ - image: The image to use for the cluster.
+ - image_pull_secrets: A list of image pull secrets to use for the cluster.
+ - write_to_file: A boolean indicating whether to write the cluster configuration to a file.
+ - verify_tls: A boolean indicating whether to verify TLS when connecting to the cluster.
+ - labels: A dictionary of labels to apply to the cluster.
+ - worker_extended_resource_requests: A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
+ - extended_resource_mapping: A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
+ - overwrite_default_resource_mapping: A boolean indicating whether to overwrite the default resource mapping.
"""
name: str
- namespace: str = None
- head_info: list = field(default_factory=list)
- head_cpus: typing.Union[int, str] = 2
- head_memory: typing.Union[int, str] = 8
- head_gpus: int = 0
- machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
- min_cpus: typing.Union[int, str] = 1
- max_cpus: typing.Union[int, str] = 1
+ namespace: Optional[str] = None
+ head_info: List[str] = field(default_factory=list)
+ head_cpus: Union[int, str] = 2
+ head_memory: Union[int, str] = 8
+ head_gpus: Optional[int] = None # Deprecating
+ head_extended_resource_requests: Dict[str, int] = field(default_factory=dict)
+ machine_types: List[str] = field(
+ default_factory=list
+ ) # ["m4.xlarge", "g4dn.xlarge"]
+ worker_cpu_requests: Union[int, str] = 1
+ worker_cpu_limits: Union[int, str] = 1
+ min_cpus: Optional[Union[int, str]] = None # Deprecating
+ max_cpus: Optional[Union[int, str]] = None # Deprecating
num_workers: int = 1
- min_memory: typing.Union[int, str] = 2
- max_memory: typing.Union[int, str] = 2
- num_gpus: int = 0
+ worker_memory_requests: Union[int, str] = 2
+ worker_memory_limits: Union[int, str] = 2
+ min_memory: Optional[Union[int, str]] = None # Deprecating
+ max_memory: Optional[Union[int, str]] = None # Deprecating
+ num_gpus: Optional[int] = None # Deprecating
template: str = f"{dir}/templates/base-template.yaml"
appwrapper: bool = False
- envs: dict = field(default_factory=dict)
+ envs: Dict[str, str] = field(default_factory=dict)
image: str = ""
- image_pull_secrets: list = field(default_factory=list)
+ image_pull_secrets: List[str] = field(default_factory=list)
write_to_file: bool = False
verify_tls: bool = True
- labels: dict = field(default_factory=dict)
+ labels: Dict[str, str] = field(default_factory=dict)
+ worker_extended_resource_requests: Dict[str, int] = field(default_factory=dict)
+ extended_resource_mapping: Dict[str, str] = field(default_factory=dict)
+ overwrite_default_resource_mapping: bool = False
+ local_queue: Optional[str] = None
def __post_init__(self):
if not self.verify_tls:
print(
"Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
)
+
+ self._validate_types()
self._memory_to_string()
self._str_mem_no_unit_add_GB()
+ self._memory_to_resource()
+ self._cpu_to_resource()
+ self._gpu_to_resource()
+ self._combine_extended_resource_mapping()
+ self._validate_extended_resource_requests(self.head_extended_resource_requests)
+ self._validate_extended_resource_requests(
+ self.worker_extended_resource_requests
+ )
+
+ def _combine_extended_resource_mapping(self):
+ if overwritten := set(self.extended_resource_mapping.keys()).intersection(
+ DEFAULT_RESOURCE_MAPPING.keys()
+ ):
+ if self.overwrite_default_resource_mapping:
+ warnings.warn(
+ f"Overwriting default resource mapping for {overwritten}",
+ UserWarning,
+ )
+ else:
+ raise ValueError(
+ f"Resource mapping already exists for {overwritten}, set overwrite_default_resource_mapping to True to overwrite"
+ )
+ self.extended_resource_mapping = {
+ **DEFAULT_RESOURCE_MAPPING,
+ **self.extended_resource_mapping,
+ }
+
+ def _validate_extended_resource_requests(self, extended_resources: Dict[str, int]):
+ for k in extended_resources.keys():
+ if k not in self.extended_resource_mapping.keys():
+ raise ValueError(
+ f"extended resource '{k}' not found in extended_resource_mapping, available resources are {list(self.extended_resource_mapping.keys())}, to add more supported resources use extended_resource_mapping. i.e. extended_resource_mapping = {{'{k}': 'FOO_BAR'}}"
+ )
+
+ def _gpu_to_resource(self):
+ if self.head_gpus:
+ warnings.warn(
+ f"head_gpus is being deprecated, replacing with head_extended_resource_requests['nvidia.com/gpu'] = {self.head_gpus}"
+ )
+ if "nvidia.com/gpu" in self.head_extended_resource_requests:
+ raise ValueError(
+ "nvidia.com/gpu already exists in head_extended_resource_requests"
+ )
+ self.head_extended_resource_requests["nvidia.com/gpu"] = self.head_gpus
+ if self.num_gpus:
+ warnings.warn(
+ f"num_gpus is being deprecated, replacing with worker_extended_resource_requests['nvidia.com/gpu'] = {self.num_gpus}"
+ )
+ if "nvidia.com/gpu" in self.worker_extended_resource_requests:
+ raise ValueError(
+ "nvidia.com/gpu already exists in worker_extended_resource_requests"
+ )
+ self.worker_extended_resource_requests["nvidia.com/gpu"] = self.num_gpus
def _str_mem_no_unit_add_GB(self):
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():
self.head_memory = f"{self.head_memory}G"
- if isinstance(self.min_memory, str) and self.min_memory.isdecimal():
- self.min_memory = f"{self.min_memory}G"
- if isinstance(self.max_memory, str) and self.max_memory.isdecimal():
- self.max_memory = f"{self.max_memory}G"
+ if (
+ isinstance(self.worker_memory_requests, str)
+ and self.worker_memory_requests.isdecimal()
+ ):
+ self.worker_memory_requests = f"{self.worker_memory_requests}G"
+ if (
+ isinstance(self.worker_memory_limits, str)
+ and self.worker_memory_limits.isdecimal()
+ ):
+ self.worker_memory_limits = f"{self.worker_memory_limits}G"
def _memory_to_string(self):
if isinstance(self.head_memory, int):
self.head_memory = f"{self.head_memory}G"
- if isinstance(self.min_memory, int):
- self.min_memory = f"{self.min_memory}G"
- if isinstance(self.max_memory, int):
- self.max_memory = f"{self.max_memory}G"
+ if isinstance(self.worker_memory_requests, int):
+ self.worker_memory_requests = f"{self.worker_memory_requests}G"
+ if isinstance(self.worker_memory_limits, int):
+ self.worker_memory_limits = f"{self.worker_memory_limits}G"
+
+ def _cpu_to_resource(self):
+ if self.min_cpus:
+ warnings.warn("min_cpus is being deprecated, use worker_cpu_requests")
+ self.worker_cpu_requests = self.min_cpus
+ if self.max_cpus:
+ warnings.warn("max_cpus is being deprecated, use worker_cpu_limits")
+ self.worker_cpu_limits = self.max_cpus
+
+ def _memory_to_resource(self):
+ if self.min_memory:
+ warnings.warn("min_memory is being deprecated, use worker_memory_requests")
+ self.worker_memory_requests = f"{self.min_memory}G"
+ if self.max_memory:
+ warnings.warn("max_memory is being deprecated, use worker_memory_limits")
+ self.worker_memory_limits = f"{self.max_memory}G"
+
+ def _validate_types(self):
+ """Validate the types of all fields in the ClusterConfiguration dataclass."""
+ for field_info in fields(self):
+ value = getattr(self, field_info.name)
+ expected_type = field_info.type
+ if not self._is_type(value, expected_type):
+ raise TypeError(
+ f"'{field_info.name}' should be of type {expected_type}"
+ )
+
+ @staticmethod
+ def _is_type(value, expected_type):
+ """Check if the value matches the expected type."""
+
+ def check_type(value, expected_type):
+ origin_type = get_origin(expected_type)
+ args = get_args(expected_type)
+ if origin_type is Union:
+ return any(check_type(value, union_type) for union_type in args)
+ if origin_type is list:
+ return all(check_type(elem, args[0]) for elem in value)
+ if origin_type is dict:
+ return all(
+ check_type(k, args[0]) and check_type(v, args[1])
+ for k, v in value.items()
+ )
+ if origin_type is tuple:
+ return all(check_type(elem, etype) for elem, etype in zip(value, args))
+ return isinstance(value, expected_type)
- local_queue: str = None
+ return check_type(value, expected_type)
@@ -124,11 +282,37 @@
This dataclass is used to specify resource requirements and other details, and
-is passed in as an argument when creating a Cluster object.
+is passed in as an argument when creating a Cluster object.
+
Attributes:
+- name: The name of the cluster.
+- namespace: The namespace in which the cluster should be created.
+- head_info: A list of strings containing information about the head node.
+- head_cpus: The number of CPUs to allocate to the head node.
+- head_memory: The amount of memory to allocate to the head node.
+- head_gpus: The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
+- head_extended_resource_requests: A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
+- machine_types: A list of machine types to use for the cluster.
+- min_cpus: The minimum number of CPUs to allocate to each worker.
+- max_cpus: The maximum number of CPUs to allocate to each worker.
+- num_workers: The number of workers to create.
+- min_memory: The minimum amount of memory to allocate to each worker.
+- max_memory: The maximum amount of memory to allocate to each worker.
+- num_gpus: The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
+- template: The path to the template file to use for the cluster.
+- appwrapper: A boolean indicating whether to use an AppWrapper.
+- envs: A dictionary of environment variables to set for the cluster.
+- image: The image to use for the cluster.
+- image_pull_secrets: A list of image pull secrets to use for the cluster.
+- write_to_file: A boolean indicating whether to write the cluster configuration to a file.
+- verify_tls: A boolean indicating whether to verify TLS when connecting to the cluster.
+- labels: A dictionary of labels to apply to the cluster.
+- worker_extended_resource_requests: A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
+- extended_resource_mapping: A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
+- overwrite_default_resource_mapping: A boolean indicating whether to overwrite the default resource mapping.
Expand source code
@@ -138,55 +322,200 @@
Classes
"""
This dataclass is used to specify resource requirements and other details, and
is passed in as an argument when creating a Cluster object.
+
+ Attributes:
+ - name: The name of the cluster.
+ - namespace: The namespace in which the cluster should be created.
+ - head_info: A list of strings containing information about the head node.
+ - head_cpus: The number of CPUs to allocate to the head node.
+ - head_memory: The amount of memory to allocate to the head node.
+ - head_gpus: The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
+ - head_extended_resource_requests: A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
+ - machine_types: A list of machine types to use for the cluster.
+ - min_cpus: The minimum number of CPUs to allocate to each worker.
+ - max_cpus: The maximum number of CPUs to allocate to each worker.
+ - num_workers: The number of workers to create.
+ - min_memory: The minimum amount of memory to allocate to each worker.
+ - max_memory: The maximum amount of memory to allocate to each worker.
+ - num_gpus: The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
+ - template: The path to the template file to use for the cluster.
+ - appwrapper: A boolean indicating whether to use an AppWrapper.
+ - envs: A dictionary of environment variables to set for the cluster.
+ - image: The image to use for the cluster.
+ - image_pull_secrets: A list of image pull secrets to use for the cluster.
+ - write_to_file: A boolean indicating whether to write the cluster configuration to a file.
+ - verify_tls: A boolean indicating whether to verify TLS when connecting to the cluster.
+ - labels: A dictionary of labels to apply to the cluster.
+ - worker_extended_resource_requests: A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
+ - extended_resource_mapping: A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
+ - overwrite_default_resource_mapping: A boolean indicating whether to overwrite the default resource mapping.
"""
name: str
- namespace: str = None
- head_info: list = field(default_factory=list)
- head_cpus: typing.Union[int, str] = 2
- head_memory: typing.Union[int, str] = 8
- head_gpus: int = 0
- machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
- min_cpus: typing.Union[int, str] = 1
- max_cpus: typing.Union[int, str] = 1
+ namespace: Optional[str] = None
+ head_info: List[str] = field(default_factory=list)
+ head_cpus: Union[int, str] = 2
+ head_memory: Union[int, str] = 8
+ head_gpus: Optional[int] = None # Deprecating
+ head_extended_resource_requests: Dict[str, int] = field(default_factory=dict)
+ machine_types: List[str] = field(
+ default_factory=list
+ ) # ["m4.xlarge", "g4dn.xlarge"]
+ worker_cpu_requests: Union[int, str] = 1
+ worker_cpu_limits: Union[int, str] = 1
+ min_cpus: Optional[Union[int, str]] = None # Deprecating
+ max_cpus: Optional[Union[int, str]] = None # Deprecating
num_workers: int = 1
- min_memory: typing.Union[int, str] = 2
- max_memory: typing.Union[int, str] = 2
- num_gpus: int = 0
+ worker_memory_requests: Union[int, str] = 2
+ worker_memory_limits: Union[int, str] = 2
+ min_memory: Optional[Union[int, str]] = None # Deprecating
+ max_memory: Optional[Union[int, str]] = None # Deprecating
+ num_gpus: Optional[int] = None # Deprecating
template: str = f"{dir}/templates/base-template.yaml"
appwrapper: bool = False
- envs: dict = field(default_factory=dict)
+ envs: Dict[str, str] = field(default_factory=dict)
image: str = ""
- image_pull_secrets: list = field(default_factory=list)
+ image_pull_secrets: List[str] = field(default_factory=list)
write_to_file: bool = False
verify_tls: bool = True
- labels: dict = field(default_factory=dict)
+ labels: Dict[str, str] = field(default_factory=dict)
+ worker_extended_resource_requests: Dict[str, int] = field(default_factory=dict)
+ extended_resource_mapping: Dict[str, str] = field(default_factory=dict)
+ overwrite_default_resource_mapping: bool = False
+ local_queue: Optional[str] = None
def __post_init__(self):
if not self.verify_tls:
print(
"Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
)
+
+ self._validate_types()
self._memory_to_string()
self._str_mem_no_unit_add_GB()
+ self._memory_to_resource()
+ self._cpu_to_resource()
+ self._gpu_to_resource()
+ self._combine_extended_resource_mapping()
+ self._validate_extended_resource_requests(self.head_extended_resource_requests)
+ self._validate_extended_resource_requests(
+ self.worker_extended_resource_requests
+ )
+
+ def _combine_extended_resource_mapping(self):
+ if overwritten := set(self.extended_resource_mapping.keys()).intersection(
+ DEFAULT_RESOURCE_MAPPING.keys()
+ ):
+ if self.overwrite_default_resource_mapping:
+ warnings.warn(
+ f"Overwriting default resource mapping for {overwritten}",
+ UserWarning,
+ )
+ else:
+ raise ValueError(
+ f"Resource mapping already exists for {overwritten}, set overwrite_default_resource_mapping to True to overwrite"
+ )
+ self.extended_resource_mapping = {
+ **DEFAULT_RESOURCE_MAPPING,
+ **self.extended_resource_mapping,
+ }
+
+ def _validate_extended_resource_requests(self, extended_resources: Dict[str, int]):
+ for k in extended_resources.keys():
+ if k not in self.extended_resource_mapping.keys():
+ raise ValueError(
+ f"extended resource '{k}' not found in extended_resource_mapping, available resources are {list(self.extended_resource_mapping.keys())}, to add more supported resources use extended_resource_mapping. i.e. extended_resource_mapping = {{'{k}': 'FOO_BAR'}}"
+ )
+
+ def _gpu_to_resource(self):
+ if self.head_gpus:
+ warnings.warn(
+ f"head_gpus is being deprecated, replacing with head_extended_resource_requests['nvidia.com/gpu'] = {self.head_gpus}"
+ )
+ if "nvidia.com/gpu" in self.head_extended_resource_requests:
+ raise ValueError(
+ "nvidia.com/gpu already exists in head_extended_resource_requests"
+ )
+ self.head_extended_resource_requests["nvidia.com/gpu"] = self.head_gpus
+ if self.num_gpus:
+ warnings.warn(
+ f"num_gpus is being deprecated, replacing with worker_extended_resource_requests['nvidia.com/gpu'] = {self.num_gpus}"
+ )
+ if "nvidia.com/gpu" in self.worker_extended_resource_requests:
+ raise ValueError(
+ "nvidia.com/gpu already exists in worker_extended_resource_requests"
+ )
+ self.worker_extended_resource_requests["nvidia.com/gpu"] = self.num_gpus
def _str_mem_no_unit_add_GB(self):
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():
self.head_memory = f"{self.head_memory}G"
- if isinstance(self.min_memory, str) and self.min_memory.isdecimal():
- self.min_memory = f"{self.min_memory}G"
- if isinstance(self.max_memory, str) and self.max_memory.isdecimal():
- self.max_memory = f"{self.max_memory}G"
+ if (
+ isinstance(self.worker_memory_requests, str)
+ and self.worker_memory_requests.isdecimal()
+ ):
+ self.worker_memory_requests = f"{self.worker_memory_requests}G"
+ if (
+ isinstance(self.worker_memory_limits, str)
+ and self.worker_memory_limits.isdecimal()
+ ):
+ self.worker_memory_limits = f"{self.worker_memory_limits}G"
def _memory_to_string(self):
if isinstance(self.head_memory, int):
self.head_memory = f"{self.head_memory}G"
- if isinstance(self.min_memory, int):
- self.min_memory = f"{self.min_memory}G"
- if isinstance(self.max_memory, int):
- self.max_memory = f"{self.max_memory}G"
+ if isinstance(self.worker_memory_requests, int):
+ self.worker_memory_requests = f"{self.worker_memory_requests}G"
+ if isinstance(self.worker_memory_limits, int):
+ self.worker_memory_limits = f"{self.worker_memory_limits}G"
+
+ def _cpu_to_resource(self):
+ if self.min_cpus:
+ warnings.warn("min_cpus is being deprecated, use worker_cpu_requests")
+ self.worker_cpu_requests = self.min_cpus
+ if self.max_cpus:
+ warnings.warn("max_cpus is being deprecated, use worker_cpu_limits")
+ self.worker_cpu_limits = self.max_cpus
+
+ def _memory_to_resource(self):
+ if self.min_memory:
+ warnings.warn("min_memory is being deprecated, use worker_memory_requests")
+ self.worker_memory_requests = f"{self.min_memory}G"
+ if self.max_memory:
+ warnings.warn("max_memory is being deprecated, use worker_memory_limits")
+ self.worker_memory_limits = f"{self.max_memory}G"
+
+ def _validate_types(self):
+ """Validate the types of all fields in the ClusterConfiguration dataclass."""
+ for field_info in fields(self):
+ value = getattr(self, field_info.name)
+ expected_type = field_info.type
+ if not self._is_type(value, expected_type):
+ raise TypeError(
+ f"'{field_info.name}' should be of type {expected_type}"
+ )
+
+ @staticmethod
+ def _is_type(value, expected_type):
+ """Check if the value matches the expected type."""
+
+ def check_type(value, expected_type):
+ origin_type = get_origin(expected_type)
+ args = get_args(expected_type)
+ if origin_type is Union:
+ return any(check_type(value, union_type) for union_type in args)
+ if origin_type is list:
+ return all(check_type(elem, args[0]) for elem in value)
+ if origin_type is dict:
+ return all(
+ check_type(k, args[0]) and check_type(v, args[1])
+ for k, v in value.items()
+ )
+ if origin_type is tuple:
+ return all(check_type(elem, etype) for elem, etype in zip(value, args))
+ return isinstance(value, expected_type)
- local_queue: str = None
+ return check_type(value, expected_type)
Class variables
@@ -194,7 +523,11 @@
Class variables
-
var envs : dict
+
var envs : Dict[str, str]
+
+
+
+
var extended_resource_mapping : Dict[str, str]
@@ -202,11 +535,15 @@
Class variables
-
var head_gpus : int
+
var head_extended_resource_requests : Dict[str, int]
+
+
+
+
var head_gpus : Optional[int]
-
var head_info : list
+
var head_info : List[str]
@@ -218,35 +555,35 @@
Class variables
-
var image_pull_secrets : list
+
var image_pull_secrets : List[str]
-
var labels : dict
+
var labels : Dict[str, str]
-
var local_queue : str
+
var local_queue : Optional[str]
-
var machine_types : list
+
var machine_types : List[str]
-
var max_cpus : Union[int, str]
+
var max_cpus : Union[int, str, ForwardRef(None)]
-
var max_memory : Union[int, str]
+
var max_memory : Union[int, str, ForwardRef(None)]
-
var min_cpus : Union[int, str]
+
var min_cpus : Union[int, str, ForwardRef(None)]
-
var min_memory : Union[int, str]
+
var min_memory : Union[int, str, ForwardRef(None)]
@@ -254,11 +591,11 @@
Class variables
-
var namespace : str
+
var namespace : Optional[str]
-
var num_gpus : int
+
var num_gpus : Optional[int]
@@ -266,6 +603,10 @@
Class variables
+
var overwrite_default_resource_mapping : bool
+
+
+
var template : str
@@ -274,6 +615,26 @@
Class variables
+
var worker_cpu_limits : Union[int, str]
+
+
+
+
var worker_cpu_requests : Union[int, str]
+
+
+
+
var worker_extended_resource_requests : Dict[str, int]
dataclasses to store information for Ray clusters and AppWrappers.
"""
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from enum import Enum
+import typing
class RayClusterStatus(Enum):
@@ -107,14 +108,14 @@
Module codeflare_sdk.cluster.model
status: RayClusterStatus
head_cpus: int
head_mem: str
- head_gpu: int
workers: int
worker_mem_min: str
worker_mem_max: str
worker_cpu: int
- worker_gpu: int
namespace: str
dashboard: str
+ worker_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
+ head_extended_resources: typing.Dict[str, int] = field(default_factory=dict)
@dataclass
@@ -292,7 +293,7 @@
import pathlib
+import shutil
+
+package_dir = pathlib.Path(__file__).parent.parent.resolve()
+demo_dir = f"{package_dir}/demo-notebooks"
+
+
+def copy_demo_nbs(dir: str = "./demo-notebooks", overwrite: bool = False):
+ """
+ Copy the demo notebooks from the package to the current working directory
+
+ overwrite=True will overwrite any files that exactly match files written by copy_demo_nbs in the target directory.
+ Any files that exist in the directory that don't match these values will remain untouched.
+
+ Args:
+ dir (str): The directory to copy the demo notebooks to. Defaults to "./demo-notebooks". overwrite (bool):
+ overwrite (bool): Whether to overwrite files in the directory if it already exists. Defaults to False.
+ Raises:
+ FileExistsError: If the directory already exists.
+ """
+ # does dir exist already?
+ if overwrite is False and pathlib.Path(dir).exists():
+ raise FileExistsError(
+ f"Directory {dir} already exists. Please remove it or provide a different location."
+ )
+
+ shutil.copytree(demo_dir, dir, dirs_exist_ok=True)
Copy the demo notebooks from the package to the current working directory
+
overwrite=True will overwrite any files that exactly match files written by copy_demo_nbs in the target directory.
+Any files that exist in the directory that don't match these values will remain untouched.
+
Args
+
+
dir : str
+
The directory to copy the demo notebooks to. Defaults to "./demo-notebooks". overwrite (bool):
+
overwrite : bool
+
Whether to overwrite files in the directory if it already exists. Defaults to False.
+
+
Raises
+
+
FileExistsError
+
If the directory already exists.
+
+
+
+Expand source code
+
+
def copy_demo_nbs(dir: str = "./demo-notebooks", overwrite: bool = False):
+ """
+ Copy the demo notebooks from the package to the current working directory
+
+ overwrite=True will overwrite any files that exactly match files written by copy_demo_nbs in the target directory.
+ Any files that exist in the directory that don't match these values will remain untouched.
+
+ Args:
+ dir (str): The directory to copy the demo notebooks to. Defaults to "./demo-notebooks". overwrite (bool):
+ overwrite (bool): Whether to overwrite files in the directory if it already exists. Defaults to False.
+ Raises:
+ FileExistsError: If the directory already exists.
+ """
+ # does dir exist already?
+ if overwrite is False and pathlib.Path(dir).exists():
+ raise FileExistsError(
+ f"Directory {dir} already exists. Please remove it or provide a different location."
+ )
+
+ shutil.copytree(demo_dir, dir, dirs_exist_ok=True)
namespace=namespace,
plural="localqueues",
)
- except Exception as e: # pragma: no cover
- return _kube_api_error_handling(e)
+ except ApiException as e: # pragma: no cover
+ if e.status == 404 or e.status == 403:
+ return
+ else:
+ return _kube_api_error_handling(e)
for lq in local_queues["items"]:
if (
"annotations" in lq["metadata"]
@@ -218,9 +286,6 @@
Module codeflare_sdk.utils.generate_yaml
== "true"
):
return lq["metadata"]["name"]
- raise ValueError(
- "Default Local Queue with kueue.x-k8s.io/default-queue: true annotation not found please create a default Local Queue or provide the local_queue name in Cluster Configuration"
- )
def local_queue_exists(namespace: str, local_queue_name: str):
@@ -245,7 +310,9 @@
Module codeflare_sdk.utils.generate_yaml
def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]):
lq_name = local_queue or get_default_kueue_name(namespace)
- if not local_queue_exists(namespace, lq_name):
+ if lq_name == None:
+ return
+ elif not local_queue_exists(namespace, lq_name):
raise ValueError(
"local_queue provided does not exist or is not in this namespace. Please provide the correct local_queue name in Cluster Configuration"
)
@@ -291,65 +358,32 @@
def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]):
lq_name = local_queue or get_default_kueue_name(namespace)
- if not local_queue_exists(namespace, lq_name):
+ if lq_name == None:
+ return
+ elif not local_queue_exists(namespace, lq_name):
raise ValueError(
"local_queue provided does not exist or is not in this namespace. Please provide the correct local_queue name in Cluster Configuration"
)
@@ -428,7 +464,7 @@
namespace=namespace,
plural="localqueues",
)
- except Exception as e: # pragma: no cover
- return _kube_api_error_handling(e)
+ except ApiException as e: # pragma: no cover
+ if e.status == 404 or e.status == 403:
+ return
+ else:
+ return _kube_api_error_handling(e)
for lq in local_queues["items"]:
if (
"annotations" in lq["metadata"]
@@ -527,10 +533,64 @@
Functions
and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower()
== "true"
):
- return lq["metadata"]["name"]
- raise ValueError(
- "Default Local Queue with kueue.x-k8s.io/default-queue: true annotation not found please create a default Local Queue or provide the local_queue name in Cluster Configuration"
- )