diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 5d00cdae8..5fe7914c4 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -60,7 +60,7 @@ def __init__(self, config: ClusterConfiguration): """ self.config = config self.app_wrapper_yaml = self.create_app_wrapper() - self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0] + self.app_wrapper_name = self.config.name def evaluate_dispatch_priority(self): priority_class = self.config.dispatch_priority @@ -147,6 +147,7 @@ def create_app_wrapper(self): image_pull_secrets=image_pull_secrets, dispatch_priority=dispatch_priority, priority_val=priority_val, + write_to_file=self.config.write_to_file, ) # creates a new cluster with the provided or default spec @@ -159,8 +160,8 @@ def up(self): try: config_check() api_instance = client.CustomObjectsApi(api_config_handler()) - with open(self.app_wrapper_yaml) as f: - aw = yaml.load(f, Loader=yaml.FullLoader) + + aw = self.app_wrapper_yaml api_instance.create_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta1", diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index bde3f4ca0..6da62c734 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -51,3 +51,4 @@ class ClusterConfiguration: local_interactive: bool = False image_pull_secrets: list = field(default_factory=list) dispatch_priority: str = None + write_to_file: bool = False diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index 95e1c5ecb..0a8f2ac58 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -390,6 +390,7 @@ def generate_appwrapper( image_pull_secrets: list, dispatch_priority: str, priority_val: int, + write_to_file: bool, ): user_yaml = read_template(template) appwrapper_name, cluster_name = gen_names(name) @@ -433,6 +434,9 @@ def generate_appwrapper( enable_local_interactive(resources, cluster_name, namespace) else: disable_raycluster_tls(resources["resources"]) - outfile = appwrapper_name + ".yaml" - write_user_appwrapper(user_yaml, outfile) - return outfile + + if write_to_file: + outfile = appwrapper_name + ".yaml" + write_user_appwrapper(user_yaml, outfile) + + return user_yaml diff --git a/tests/unit_test.py b/tests/unit_test.py index 78925226a..1be18fe36 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -237,11 +237,15 @@ def test_config_creation(): assert config.machine_types == ["cpu.small", "gpu.large"] assert config.image_pull_secrets == ["unit-test-pull-secret"] assert config.dispatch_priority == None + assert config.write_to_file == True def test_cluster_creation(): cluster = createClusterWithConfig() - assert cluster.app_wrapper_yaml == "unit-test-cluster.yaml" + # load yaml file and compare with in memory yaml + with open("unit-test-cluster.yaml") as f: + aw = yaml.load(f, Loader=yaml.FullLoader) + assert aw == cluster.app_wrapper_yaml assert cluster.app_wrapper_name == "unit-test-cluster" assert filecmp.cmp( "unit-test-cluster.yaml", f"{parent}/tests/test-case.yaml", shallow=True @@ -258,7 +262,10 @@ def test_cluster_creation_priority(mocker): config.name = "prio-test-cluster" config.dispatch_priority = "default" cluster = Cluster(config) - assert cluster.app_wrapper_yaml == "prio-test-cluster.yaml" + # load yaml file and compare with in memory yaml + with open("prio-test-cluster.yaml") as f: + aw = yaml.load(f, Loader=yaml.FullLoader) + assert aw == cluster.app_wrapper_yaml assert cluster.app_wrapper_name == "prio-test-cluster" assert filecmp.cmp( "prio-test-cluster.yaml", f"{parent}/tests/test-case-prio.yaml", shallow=True @@ -272,10 +279,14 @@ def test_default_cluster_creation(mocker): ) default_config = ClusterConfiguration( name="unit-test-default-cluster", + write_to_file=True, ) cluster = Cluster(default_config) - assert cluster.app_wrapper_yaml == "unit-test-default-cluster.yaml" + # open yaml file and compare with in memory yaml + with open("unit-test-default-cluster.yaml") as f: + aw = yaml.load(f, Loader=yaml.FullLoader) + assert aw == cluster.app_wrapper_yaml assert cluster.app_wrapper_name == "unit-test-default-cluster" assert cluster.config.namespace == "opendatahub" diff --git a/tests/unit_test_support.py b/tests/unit_test_support.py index a4ea056a0..725009bb5 100644 --- a/tests/unit_test_support.py +++ b/tests/unit_test_support.py @@ -46,6 +46,7 @@ def createClusterConfig(): instascale=True, machine_types=["cpu.small", "gpu.large"], image_pull_secrets=["unit-test-pull-secret"], + write_to_file=True, ) return config