Skip to content

Commit d572260

Browse files
Load settings from config.toml file during CDI generation
Signed-off-by: Carlos Eduardo Arango Gutierrez <[email protected]>
1 parent 4930f68 commit d572260

File tree

4 files changed

+171
-0
lines changed

4 files changed

+171
-0
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,31 @@ func (m command) build() *cli.Command {
182182
}
183183

184184
func (m command) validateFlags(c *cli.Context, opts *options) error {
185+
// Load config file as base configuration
186+
cfg, err := config.GetConfig()
187+
if err != nil {
188+
return fmt.Errorf("failed to load config: %v", err)
189+
}
190+
191+
// Apply config file values if command line or environment variables are not set.
192+
// order (1) command line, (2) environment variable, (3) config file
193+
if opts.nvidiaCDIHookPath == "" && cfg.NVIDIAContainerRuntimeHookConfig.Path != "" {
194+
opts.nvidiaCDIHookPath = cfg.NVIDIAContainerRuntimeHookConfig.Path
195+
}
196+
197+
if opts.ldconfigPath == "" && string(cfg.NVIDIAContainerCLIConfig.Ldconfig) != "" {
198+
opts.ldconfigPath = string(cfg.NVIDIAContainerCLIConfig.Ldconfig)
199+
}
200+
201+
if opts.mode == "" && cfg.NVIDIAContainerRuntimeConfig.Mode != "" {
202+
opts.mode = cfg.NVIDIAContainerRuntimeConfig.Mode
203+
}
204+
205+
if opts.csv.files.Value() == nil && len(cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath) > 0 {
206+
opts.csv.files = *cli.NewStringSlice(cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath)
207+
}
208+
209+
// Continue with existing validation
185210
opts.format = strings.ToLower(opts.format)
186211
switch opts.format {
187212
case spec.FormatJSON:

cmd/nvidia-ctk/cdi/generate/generate_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package generate
1818

1919
import (
2020
"bytes"
21+
"os"
2122
"path/filepath"
2223
"strings"
2324
"testing"
@@ -26,11 +27,35 @@ import (
2627
"github.com/NVIDIA/go-nvml/pkg/nvml/mock/dgxa100"
2728
testlog "github.com/sirupsen/logrus/hooks/test"
2829
"github.com/stretchr/testify/require"
30+
"github.com/urfave/cli/v2"
2931

3032
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
3133
)
3234

3335
func TestGenerateSpec(t *testing.T) {
36+
// Create a temporary directory for config
37+
tmpDir, err := os.MkdirTemp("", "nvidia-container-toolkit-test-*")
38+
require.NoError(t, err)
39+
defer os.RemoveAll(tmpDir)
40+
41+
// Create a temporary config file
42+
configContent := `
43+
[nvidia-container-runtime]
44+
mode = "nvml"
45+
[[nvidia-container-runtime.modes.cdi]]
46+
spec-dirs = ["/etc/cdi", "/usr/local/cdi"]
47+
[nvidia-container-runtime.modes.csv]
48+
mount-spec-path = "/etc/nvidia-container-runtime/host-files-for-container.d"
49+
`
50+
configPath := filepath.Join(tmpDir, "config.toml")
51+
err = os.WriteFile(configPath, []byte(configContent), 0600)
52+
require.NoError(t, err)
53+
54+
// Set XDG_CONFIG_HOME to point to our temporary directory
55+
oldXDGConfigHome := os.Getenv("XDG_CONFIG_HOME")
56+
os.Setenv("XDG_CONFIG_HOME", tmpDir)
57+
defer os.Setenv("XDG_CONFIG_HOME", oldXDGConfigHome)
58+
3459
t.Setenv("__NVCT_TESTING_DEVICES_ARE_FILES", "true")
3560
moduleRoot, err := test.GetModuleRoot()
3661
require.NoError(t, err)
@@ -62,6 +87,13 @@ func TestGenerateSpec(t *testing.T) {
6287
class: "device",
6388
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
6489
driverRoot: driverRoot,
90+
csv: struct {
91+
files cli.StringSlice
92+
ignorePatterns cli.StringSlice
93+
}{
94+
files: *cli.NewStringSlice("/etc/nvidia-container-runtime/host-files-for-container.d"),
95+
ignorePatterns: cli.StringSlice{},
96+
},
6597
},
6698
expectedSpec: `---
6799
cdiVersion: 0.5.0
@@ -125,6 +157,10 @@ containerEdits:
125157

126158
err := c.validateFlags(nil, &tc.options)
127159
require.ErrorIs(t, err, tc.expectedValidateError)
160+
// Set the ldconfig path to empty.
161+
// This is required during test because config.GetConfig() returns
162+
// the default ldconfig path, even if it is not set in the config file.
163+
tc.options.ldconfigPath = ""
128164
require.EqualValues(t, tc.expectedOptions, tc.options)
129165

130166
// Set up a mock server, reusing the DGX A100 mock.

cmd/nvidia-ctk/cdi/list/list.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/urfave/cli/v2"
2424
"tags.cncf.io/container-device-interface/pkg/cdi"
2525

26+
ctkconfig "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
2627
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2728
)
2829

@@ -64,16 +65,30 @@ func (m command) build() *cli.Command {
6465
Usage: "specify the directories to scan for CDI specifications",
6566
Value: cli.NewStringSlice(cdi.DefaultSpecDirs...),
6667
Destination: &cfg.cdiSpecDirs,
68+
EnvVars: []string{"NVIDIA_CTK_CDI_SPEC_DIRS"},
6769
},
6870
}
6971

7072
return &c
7173
}
7274

7375
func (m command) validateFlags(c *cli.Context, cfg *config) error {
76+
// Load config file as base configuration
77+
config, err := ctkconfig.GetConfig()
78+
if err != nil {
79+
return fmt.Errorf("failed to load config: %v", err)
80+
}
81+
82+
// Apply config file values if command line or environment variables are not set.
83+
// order (1) command line, (2) environment variable, (3) config file
84+
if !c.IsSet("spec-dir") && len(config.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs) > 0 {
85+
cfg.cdiSpecDirs = *cli.NewStringSlice(config.NVIDIAContainerRuntimeConfig.Modes.CDI.SpecDirs...)
86+
}
87+
7488
if len(cfg.cdiSpecDirs.Value()) == 0 {
7589
return errors.New("at least one CDI specification directory must be specified")
7690
}
91+
7792
return nil
7893
}
7994

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package list
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
"github.com/urfave/cli/v2"
10+
11+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
12+
)
13+
14+
func TestValidateFlags(t *testing.T) {
15+
// Create a temporary directory for config
16+
tmpDir, err := os.MkdirTemp("", "nvidia-container-toolkit-test-*")
17+
require.NoError(t, err)
18+
defer os.RemoveAll(tmpDir)
19+
20+
// Create a temporary config file
21+
configContent := `
22+
[nvidia-container-runtime]
23+
mode = "cdi"
24+
[[nvidia-container-runtime.modes.cdi]]
25+
spec-dirs = ["/etc/cdi", "/usr/local/cdi"]
26+
`
27+
configPath := filepath.Join(tmpDir, "config.toml")
28+
err = os.WriteFile(configPath, []byte(configContent), 0600)
29+
require.NoError(t, err)
30+
31+
// Set XDG_CONFIG_HOME to point to our temporary directory
32+
oldXDGConfigHome := os.Getenv("XDG_CONFIG_HOME")
33+
os.Setenv("XDG_CONFIG_HOME", tmpDir)
34+
defer os.Setenv("XDG_CONFIG_HOME", oldXDGConfigHome)
35+
36+
tests := []struct {
37+
name string
38+
cliArgs []string
39+
envVars map[string]string
40+
expectedDirs []string
41+
expectError bool
42+
errorContains string
43+
}{
44+
{
45+
name: "command line takes precedence",
46+
cliArgs: []string{"--spec-dir=/custom/dir1", "--spec-dir=/custom/dir2"},
47+
expectedDirs: []string{"/custom/dir1", "/custom/dir2"},
48+
},
49+
{
50+
name: "environment variable takes precedence over config",
51+
envVars: map[string]string{"NVIDIA_CTK_CDI_SPEC_DIRS": "/env/dir1:/env/dir2"},
52+
expectedDirs: []string{"/env/dir1", "/env/dir2"},
53+
},
54+
{
55+
name: "config file used as fallback",
56+
expectedDirs: []string{"/etc/cdi", "/usr/local/cdi"},
57+
},
58+
}
59+
60+
for _, tt := range tests {
61+
t.Run(tt.name, func(t *testing.T) {
62+
// Set up environment variables
63+
for k, v := range tt.envVars {
64+
old := os.Getenv(k)
65+
os.Setenv(k, v)
66+
defer os.Setenv(k, old)
67+
}
68+
69+
// Create command
70+
cmd := NewCommand(logger.NewMockLogger())
71+
72+
// Create a new context with the command
73+
app := &cli.App{
74+
Commands: []*cli.Command{
75+
{
76+
Name: "cdi",
77+
Subcommands: []*cli.Command{cmd},
78+
},
79+
},
80+
}
81+
82+
// Run command
83+
args := append([]string{"nvidia-ctk", "cdi", "list"}, tt.cliArgs...)
84+
err := app.Run(args)
85+
86+
if tt.expectError {
87+
require.Error(t, err)
88+
require.Contains(t, err.Error(), tt.errorContains)
89+
return
90+
}
91+
92+
require.NoError(t, err)
93+
})
94+
}
95+
}

0 commit comments

Comments
 (0)