Skip to content

Commit 52b94c4

Browse files
authored
Add support for image pull secrets for Ray Cluster images (#162)
* add: Support for image pull secrets for Ray Cluster images * test: adjust unit tests to incorporate image_pull_secrets * refactor: refactor update_image_pull_secrets
1 parent df48547 commit 52b94c4

File tree

6 files changed

+30
-0
lines changed

6 files changed

+30
-0
lines changed

src/codeflare_sdk/cluster/cluster.py

+2
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def create_app_wrapper(self):
8585
instance_types = self.config.machine_types
8686
env = self.config.envs
8787
local_interactive = self.config.local_interactive
88+
image_pull_secrets = self.config.image_pull_secrets
8889
return generate_appwrapper(
8990
name=name,
9091
namespace=namespace,
@@ -100,6 +101,7 @@ def create_app_wrapper(self):
100101
instance_types=instance_types,
101102
env=env,
102103
local_interactive=local_interactive,
104+
image_pull_secrets=image_pull_secrets,
103105
)
104106

105107
# creates a new cluster with the provided or default spec

src/codeflare_sdk/cluster/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ class ClusterConfiguration:
4949
envs: dict = field(default_factory=dict)
5050
image: str = "ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103"
5151
local_interactive: bool = False
52+
image_pull_secrets: list = field(default_factory=list)

src/codeflare_sdk/utils/generate_yaml.py

+19
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ def update_image(spec, image):
141141
container["image"] = image
142142

143143

144+
def update_image_pull_secrets(spec, image_pull_secrets):
145+
template_secrets = spec.get("imagePullSecrets", [])
146+
spec["imagePullSecrets"] = template_secrets + [
147+
{"name": x} for x in image_pull_secrets
148+
]
149+
150+
144151
def update_env(spec, env):
145152
containers = spec.get("containers")
146153
for container in containers:
@@ -178,6 +185,7 @@ def update_nodes(
178185
image,
179186
instascale,
180187
env,
188+
image_pull_secrets,
181189
):
182190
if "generictemplate" in item.keys():
183191
head = item.get("generictemplate").get("spec").get("headGroupSpec")
@@ -193,6 +201,7 @@ def update_nodes(
193201
for comp in [head, worker]:
194202
spec = comp.get("template").get("spec")
195203
update_affinity(spec, appwrapper_name, instascale)
204+
update_image_pull_secrets(spec, image_pull_secrets)
196205
update_image(spec, image)
197206
update_env(spec, env)
198207
if comp == head:
@@ -295,6 +304,7 @@ def generate_appwrapper(
295304
instance_types: list,
296305
env,
297306
local_interactive: bool,
307+
image_pull_secrets: list,
298308
):
299309
user_yaml = read_template(template)
300310
appwrapper_name, cluster_name = gen_names(name)
@@ -318,6 +328,7 @@ def generate_appwrapper(
318328
image,
319329
instascale,
320330
env,
331+
image_pull_secrets,
321332
)
322333
update_dashboard_route(route_item, cluster_name, namespace)
323334
if local_interactive:
@@ -409,6 +420,12 @@ def main(): # pragma: no cover
409420
default=False,
410421
help="Enable local interactive mode",
411422
)
423+
parser.add_argument(
424+
"--image-pull-secrets",
425+
required=False,
426+
default=[],
427+
help="Set image pull secrets for private registries",
428+
)
412429

413430
args = parser.parse_args()
414431
name = args.name
@@ -425,6 +442,7 @@ def main(): # pragma: no cover
425442
namespace = args.namespace
426443
local_interactive = args.local_interactive
427444
env = {}
445+
image_pull_secrets = args.image_pull_secrets
428446

429447
outfile = generate_appwrapper(
430448
name,
@@ -441,6 +459,7 @@ def main(): # pragma: no cover
441459
instance_types,
442460
local_interactive,
443461
env,
462+
image_pull_secrets,
444463
)
445464
return outfile
446465

tests/test-case-cmd.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ spec:
9696
cpu: 2
9797
memory: 8G
9898
nvidia.com/gpu: 0
99+
imagePullSecrets: []
99100
rayVersion: 2.1.0
100101
workerGroupSpecs:
101102
- groupName: small-group-unit-cmd-cluster
@@ -144,6 +145,7 @@ spec:
144145
cpu: 1
145146
memory: 2G
146147
nvidia.com/gpu: 1
148+
imagePullSecrets: []
147149
initContainers:
148150
- command:
149151
- sh

tests/test-case.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ spec:
107107
cpu: 2
108108
memory: 8G
109109
nvidia.com/gpu: 0
110+
imagePullSecrets:
111+
- name: unit-test-pull-secret
110112
rayVersion: 2.1.0
111113
workerGroupSpecs:
112114
- groupName: small-group-unit-test-cluster
@@ -164,6 +166,8 @@ spec:
164166
cpu: 3
165167
memory: 5G
166168
nvidia.com/gpu: 7
169+
imagePullSecrets:
170+
- name: unit-test-pull-secret
167171
initContainers:
168172
- command:
169173
- sh

tests/unit_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def test_config_creation():
220220
gpu=7,
221221
instascale=True,
222222
machine_types=["cpu.small", "gpu.large"],
223+
image_pull_secrets=["unit-test-pull-secret"],
223224
)
224225

225226
assert config.name == "unit-test-cluster" and config.namespace == "ns"
@@ -234,6 +235,7 @@ def test_config_creation():
234235
assert config.template == f"{parent}/src/codeflare_sdk/templates/base-template.yaml"
235236
assert config.instascale
236237
assert config.machine_types == ["cpu.small", "gpu.large"]
238+
assert config.image_pull_secrets == ["unit-test-pull-secret"]
237239
return config
238240

239241

0 commit comments

Comments
 (0)