Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/platform-support/dgpu/dgpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ func NewForMigDevice(d device.Device, mig device.MigDevice, opts ...Option) (dis
return nil, err
}
o.isMigDevice = true
migProfile, err := mig.GetProfile()
if err != nil {
return nil, fmt.Errorf("error getting MIG Profile attributes: %w", err)
}
o.migAttributes = migProfile.GetInfo().Attributes

var discoverers []discover.Discover
var errs error
Expand Down
50 changes: 46 additions & 4 deletions internal/platform-support/dgpu/nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package dgpu

import (
"fmt"
"slices"

"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"github.com/NVIDIA/go-nvml/pkg/nvml"
Expand Down Expand Up @@ -75,13 +76,22 @@ func (o *options) newNvmlDGPUDiscoverer(d requiredInfo) (discover.Discover, erro
type requiredMigInfo interface {
getPlacementInfo() (int, int, int, error)
getDevNodePath() (string, error)
getPCIBusID() (string, error)
}

func (o *options) newNvmlMigDiscoverer(d requiredMigInfo) (discover.Discover, error) {
if o.migCaps == nil || o.migCapsError != nil {
return nil, fmt.Errorf("error getting MIG capability device paths: %v", o.migCapsError)
}

var charDevicePaths []string

parentPath, err := d.getDevNodePath()
if err != nil {
return nil, err
}
charDevicePaths = append(charDevicePaths, parentPath)

gpu, gi, ci, err := d.getPlacementInfo()
if err != nil {
return nil, fmt.Errorf("error getting placement info: %w", err)
Expand All @@ -92,16 +102,44 @@ func (o *options) newNvmlMigDiscoverer(d requiredMigInfo) (discover.Discover, er
if err != nil {
return nil, fmt.Errorf("failed to get GI cap device path: %v", err)
}
charDevicePaths = append(charDevicePaths, giCapDevicePath)

ciCap := nvcaps.NewComputeInstanceCap(gpu, gi, ci)
ciCapDevicePath, err := o.migCaps.GetCapDevicePath(ciCap)
if err != nil {
return nil, fmt.Errorf("failed to get CI cap device path: %v", err)
}

parentPath, err := d.getDevNodePath()
if err != nil {
return nil, err
charDevicePaths = append(charDevicePaths, ciCapDevicePath)

supportsDRI := slices.Contains(o.migAttributes, "gfx")
if supportsDRI {
pciBusID, err := d.getPCIBusID()
if err != nil {
return nil, fmt.Errorf("error getting PCI info for device: %w", err)
}

drmDeviceNodes, err := drm.GetDeviceNodesByBusID(pciBusID)
if err != nil {
return nil, fmt.Errorf("failed to determine DRM devices for %q: %w", pciBusID, err)
}

charDevicePaths = append(charDevicePaths, drmDeviceNodes...)
deviceNodes := discover.NewCharDeviceDiscoverer(
o.logger,
o.driver.DevRoot,
charDevicePaths,
)
byPathHooks := &byPathHookDiscoverer{
logger: o.logger,
devRoot: o.driver.DevRoot,
hookCreator: o.hookCreator,
pciBusID: pciBusID,
deviceNodes: deviceNodes,
}
return discover.Merge(
deviceNodes,
byPathHooks,
), nil
}

deviceNodes := discover.NewCharDeviceDiscoverer(
Expand Down Expand Up @@ -165,3 +203,7 @@ func (d *toRequiredMigInfo) getPlacementInfo() (int, int, int, error) {
func (d *toRequiredMigInfo) getDevNodePath() (string, error) {
return d.parent.getDevNodePath()
}

func (d *toRequiredMigInfo) getPCIBusID() (string, error) {
return d.parent.GetPCIBusID()
}
118 changes: 92 additions & 26 deletions internal/platform-support/dgpu/nvml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ func TestNewNvmlMIGDiscoverer(t *testing.T) {
nvmllib,
)

parentMock, migMock := newNvmlMigDiscovererTestMocks()

testCases := []struct {
description string
mig *mock.Device
Expand All @@ -109,32 +111,8 @@ func TestNewNvmlMIGDiscoverer(t *testing.T) {
}{
{
description: "",
mig: &mock.Device{
IsMigDeviceHandleFunc: func() (bool, nvml.Return) {
return true, nvml.SUCCESS
},
GetGpuInstanceIdFunc: func() (int, nvml.Return) {
return 1, nvml.SUCCESS
},
GetComputeInstanceIdFunc: func() (int, nvml.Return) {
return 2, nvml.SUCCESS
},
},
parent: &mock.Device{
GetMinorNumberFunc: func() (int, nvml.Return) {
return 3, nvml.SUCCESS
},
GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) {
var busID [32]uint8
for i, b := range []byte("00000000:45:00:00") {
busID[i] = b
}
info := nvml.PciInfo{
BusId: busID,
}
return info, nvml.SUCCESS
},
},
mig: migMock,
parent: parentMock,
migCaps: nvcaps.MigCaps{
"gpu3/gi1/access": 31,
"gpu3/gi1/ci2/access": 312,
Expand Down Expand Up @@ -172,3 +150,91 @@ func TestNewNvmlMIGDiscoverer(t *testing.T) {
})
}
}

// newNvmlMigDiscovererTestMocks returns parent and MIG NVML device mocks wired so
// go-nvlib device.MigDevice.GetProfile succeeds (GPU instance id 1, compute instance id 2).
func newNvmlMigDiscovererTestMocks() (parent *mock.Device, mig *mock.Device) {
const (
gpuInstanceID = 1
computeInstanceID = 2
)

mockCI := &mock.ComputeInstance{
GetInfoFunc: func() (nvml.ComputeInstanceInfo, nvml.Return) {
return nvml.ComputeInstanceInfo{
ProfileId: nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE,
}, nvml.SUCCESS
},
}

mockGI := &mock.GpuInstance{
GetInfoFunc: func() (nvml.GpuInstanceInfo, nvml.Return) {
return nvml.GpuInstanceInfo{
ProfileId: nvml.GPU_INSTANCE_PROFILE_1_SLICE,
}, nvml.SUCCESS
},
GetComputeInstanceByIdFunc: func(n int) (nvml.ComputeInstance, nvml.Return) {
if n == computeInstanceID {
return mockCI, nvml.SUCCESS
}
return nil, nvml.ERROR_INVALID_ARGUMENT
},
GetComputeInstanceProfileInfoFunc: func(n1, n2 int) (nvml.ComputeInstanceProfileInfo, nvml.Return) {
if n1 == 0 && n2 == 0 {
return nvml.ComputeInstanceProfileInfo{
Id: nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE,
}, nvml.SUCCESS
}
return nvml.ComputeInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED
},
}

parent = &mock.Device{
GetMinorNumberFunc: func() (int, nvml.Return) {
return 3, nvml.SUCCESS
},
GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) {
var busID [32]uint8
for i, b := range []byte("00000000:45:00:00") {
busID[i] = b
}
return nvml.PciInfo{BusId: busID}, nvml.SUCCESS
},
GetMemoryInfoFunc: func() (nvml.Memory, nvml.Return) {
// Non-zero total memory is required for MIG profile memory math in go-nvlib.
return nvml.Memory{Total: 40 * 1024 * 1024 * 1024}, nvml.SUCCESS
},
GetGpuInstanceByIdFunc: func(n int) (nvml.GpuInstance, nvml.Return) {
if n == gpuInstanceID {
return mockGI, nvml.SUCCESS
}
return nil, nvml.ERROR_INVALID_ARGUMENT
},
GetGpuInstanceProfileInfoFunc: func(n int) (nvml.GpuInstanceProfileInfo, nvml.Return) {
if n == 0 {
return nvml.GpuInstanceProfileInfo{
Id: nvml.GPU_INSTANCE_PROFILE_1_SLICE,
MemorySizeMB: 5120,
}, nvml.SUCCESS
}
return nvml.GpuInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED
},
}

mig = &mock.Device{
IsMigDeviceHandleFunc: func() (bool, nvml.Return) {
return true, nvml.SUCCESS
},
GetGpuInstanceIdFunc: func() (int, nvml.Return) {
return gpuInstanceID, nvml.SUCCESS
},
GetComputeInstanceIdFunc: func() (int, nvml.Return) {
return computeInstanceID, nvml.SUCCESS
},
GetAttributesFunc: func() (nvml.DeviceAttributes, nvml.Return) {
return nvml.DeviceAttributes{MemorySizeMB: 5120}, nvml.SUCCESS
},
}

return parent, mig
}
11 changes: 7 additions & 4 deletions internal/platform-support/dgpu/nvsandboxutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package dgpu
import (
"fmt"
"path/filepath"
"slices"
"strings"

"github.com/NVIDIA/go-nvml/pkg/nvml"
Expand All @@ -31,7 +32,7 @@ type nvsandboxutilsDGPU struct {
lib nvsandboxutils.Interface
uuid string
devRoot string
isMig bool
supportsDRI bool
hookCreator discover.HookCreator
deviceLinks []string
}
Expand All @@ -52,11 +53,13 @@ func (o *options) newNvsandboxutilsDGPUDiscoverer(d UUIDer) (discover.Discover,
return nil, fmt.Errorf("failed to get device UUID: %w", nvmlRet)
}

supportsDRI := !o.isMigDevice || slices.Contains(o.migAttributes, "gfx")

nvd := nvsandboxutilsDGPU{
lib: o.nvsandboxutilslib,
uuid: uuid,
devRoot: strings.TrimSuffix(filepath.Clean(o.driver.DevRoot), "/dev"),
isMig: o.isMigDevice,
supportsDRI: supportsDRI,
hookCreator: o.hookCreator,
}

Expand All @@ -73,7 +76,7 @@ func (d *nvsandboxutilsDGPU) Devices() ([]discover.Device, error) {
for _, info := range gpuFileInfos {
switch info.SubType {
case nvsandboxutils.NV_DEV_DRI_CARD, nvsandboxutils.NV_DEV_DRI_RENDERD:
if d.isMig {
if !d.supportsDRI {
continue
}
fallthrough
Expand All @@ -90,7 +93,7 @@ func (d *nvsandboxutilsDGPU) Devices() ([]discover.Device, error) {
}
devices = append(devices, device)
case nvsandboxutils.NV_DEV_DRI_CARD_SYMLINK, nvsandboxutils.NV_DEV_DRI_RENDERD_SYMLINK:
if d.isMig {
if !d.supportsDRI {
continue
}
if info.Flags == nvsandboxutils.NV_FILE_FLAG_CONTENT {
Expand Down
4 changes: 3 additions & 1 deletion internal/platform-support/dgpu/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ type options struct {
driver *root.Driver
hookCreator discover.HookCreator

isMigDevice bool
isMigDevice bool
migAttributes []string

// migCaps stores the MIG capabilities for the system.
// If MIG is not available, this is nil.
migCaps nvcaps.MigCaps
Expand Down
Loading