Skip to content

Add functions to construct devices by identifier #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/nvlib/device/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ type Interface interface {
GetMigDevices() ([]MigDevice, error)
GetMigProfiles() ([]MigProfile, error)
NewDevice(d nvml.Device) (Device, error)
NewDeviceByIdentifier(Identifier) (Device, error)
NewDeviceByUUID(uuid string) (Device, error)
NewMigDevice(d nvml.Device) (MigDevice, error)
NewMigDeviceByIdentifier(Identifier) (MigDevice, error)
NewMigDeviceByUUID(uuid string) (MigDevice, error)
NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error)
ParseMigProfile(profile string) (MigProfile, error)
Expand Down
26 changes: 26 additions & 0 deletions pkg/nvlib/device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package device

import (
"fmt"
"strconv"

"github.com/NVIDIA/go-nvml/pkg/nvml"
)
Expand Down Expand Up @@ -49,6 +50,31 @@ func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
return d.newDevice(dev)
}

// NewDeviceByIdentifier builds a new device from a device identifier.
func (d *devicelib) NewDeviceByIdentifier(id Identifier) (Device, error) {
switch {
case id.IsGpuUUID():
return d.NewDeviceByUUID(string(id))
case id.IsGpuIndex():
idx, err := strconv.Atoi(string(id))
if err != nil {
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
}
return d.NewDeviceByIndex(idx)
default:
return nil, fmt.Errorf("invalid device identifier: %v", id)
}
}

// NewDeviceByIndex builds a new Device for the specified index.
func (d *devicelib) NewDeviceByIndex(index int) (Device, error) {
dev, ret := d.nvmllib.DeviceGetHandleByIndex(index)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for index '%v': %v", index, ret)
}
return d.newDevice(dev)
}

// NewDeviceByUUID builds a new Device from a UUID.
func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) {
dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid)
Expand Down
38 changes: 38 additions & 0 deletions pkg/nvlib/device/mig_device.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package device

import (
"fmt"
"strconv"
"strings"

"github.com/NVIDIA/go-nvml/pkg/nvml"
)
Expand Down Expand Up @@ -45,6 +47,42 @@ func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {
if !isMig {
return nil, fmt.Errorf("not a MIG device")
}
return d.newMigDevice(handle)
}

// NewMigDeviceByIdentifier builds a new MigDevice for the specified identifier.
// If the identifier is not a valid MIG identifier, an error is raised.
func (d *devicelib) NewMigDeviceByIdentifier(id Identifier) (MigDevice, error) {
switch {
case id.IsMigUUID():
return d.NewMigDeviceByUUID(string(id))
case id.IsMigIndex():
split := strings.SplitN(string(id), ":", 2)
gpuIdx, err := strconv.Atoi(split[0])
if err != nil {
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
}
migIdx, err := strconv.Atoi(split[1])
if err != nil {
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
}
parent, err := d.NewDeviceByIndex(gpuIdx)
if err != nil {
return nil, fmt.Errorf("failed to get parent device handle: %w", err)
}
migDevice, ret := parent.GetMigDeviceHandleByIndex(migIdx)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("failed to get mig device by index: %w", ret)
}
return d.newMigDevice(migDevice)
default:
return nil, fmt.Errorf("invalid MIG device identifier: %v", id)
}
}

// newMigDevice constructs a new MigDevice for the supplied handle.
// The handle is not checked for validity.
func (d *devicelib) newMigDevice(handle nvml.Device) (MigDevice, error) {
return &migdevice{handle, d, nil}, nil
}

Expand Down