diff --git a/internal/platform-support/dgpu/dgpu.go b/internal/platform-support/dgpu/dgpu.go index 6d3636f52..190154662 100644 --- a/internal/platform-support/dgpu/dgpu.go +++ b/internal/platform-support/dgpu/dgpu.go @@ -72,6 +72,11 @@ func NewForMigDevice(d device.Device, mig device.MigDevice, opts ...Option) (dis return nil, err } o.isMigDevice = true + migProfile, err := mig.GetProfile() + if err != nil { + return nil, fmt.Errorf("error getting MIG Profile attributes: %w", err) + } + o.migAttributes = migProfile.GetInfo().Attributes var discoverers []discover.Discover var errs error diff --git a/internal/platform-support/dgpu/nvml.go b/internal/platform-support/dgpu/nvml.go index bf2fec86f..57c3bcf13 100644 --- a/internal/platform-support/dgpu/nvml.go +++ b/internal/platform-support/dgpu/nvml.go @@ -18,6 +18,7 @@ package dgpu import ( "fmt" + "slices" "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" "github.com/NVIDIA/go-nvml/pkg/nvml" @@ -75,6 +76,7 @@ func (o *options) newNvmlDGPUDiscoverer(d requiredInfo) (discover.Discover, erro type requiredMigInfo interface { getPlacementInfo() (int, int, int, error) getDevNodePath() (string, error) + getPCIBusID() (string, error) } func (o *options) newNvmlMigDiscoverer(d requiredMigInfo) (discover.Discover, error) { @@ -82,6 +84,14 @@ func (o *options) newNvmlMigDiscoverer(d requiredMigInfo) (discover.Discover, er return nil, fmt.Errorf("error getting MIG capability device paths: %v", o.migCapsError) } + var charDevicePaths []string + + parentPath, err := d.getDevNodePath() + if err != nil { + return nil, err + } + charDevicePaths = append(charDevicePaths, parentPath) + gpu, gi, ci, err := d.getPlacementInfo() if err != nil { return nil, fmt.Errorf("error getting placement info: %w", err) @@ -92,16 +102,44 @@ func (o *options) newNvmlMigDiscoverer(d requiredMigInfo) (discover.Discover, er if err != nil { return nil, fmt.Errorf("failed to get GI cap device path: %v", err) } + charDevicePaths = append(charDevicePaths, giCapDevicePath) ciCap := nvcaps.NewComputeInstanceCap(gpu, gi, ci) ciCapDevicePath, err := o.migCaps.GetCapDevicePath(ciCap) if err != nil { return nil, fmt.Errorf("failed to get CI cap device path: %v", err) } - - parentPath, err := d.getDevNodePath() - if err != nil { - return nil, err + charDevicePaths = append(charDevicePaths, ciCapDevicePath) + + supportsDRI := slices.Contains(o.migAttributes, "gfx") + if supportsDRI { + pciBusID, err := d.getPCIBusID() + if err != nil { + return nil, fmt.Errorf("error getting PCI info for device: %w", err) + } + + drmDeviceNodes, err := drm.GetDeviceNodesByBusID(pciBusID) + if err != nil { + return nil, fmt.Errorf("failed to determine DRM devices for %q: %w", pciBusID, err) + } + + charDevicePaths = append(charDevicePaths, drmDeviceNodes...) + deviceNodes := discover.NewCharDeviceDiscoverer( + o.logger, + o.driver.DevRoot, + charDevicePaths, + ) + byPathHooks := &byPathHookDiscoverer{ + logger: o.logger, + devRoot: o.driver.DevRoot, + hookCreator: o.hookCreator, + pciBusID: pciBusID, + deviceNodes: deviceNodes, + } + return discover.Merge( + deviceNodes, + byPathHooks, + ), nil } deviceNodes := discover.NewCharDeviceDiscoverer( @@ -165,3 +203,7 @@ func (d *toRequiredMigInfo) getPlacementInfo() (int, int, int, error) { func (d *toRequiredMigInfo) getDevNodePath() (string, error) { return d.parent.getDevNodePath() } + +func (d *toRequiredMigInfo) getPCIBusID() (string, error) { + return d.parent.GetPCIBusID() +} diff --git a/internal/platform-support/dgpu/nvml_test.go b/internal/platform-support/dgpu/nvml_test.go index c03510f64..4a079c98f 100644 --- a/internal/platform-support/dgpu/nvml_test.go +++ b/internal/platform-support/dgpu/nvml_test.go @@ -97,6 +97,8 @@ func TestNewNvmlMIGDiscoverer(t *testing.T) { nvmllib, ) + parentMock, migMock := newNvmlMigDiscovererTestMocks() + testCases := []struct { description string mig *mock.Device @@ -109,32 +111,8 @@ func TestNewNvmlMIGDiscoverer(t *testing.T) { }{ { description: "", - mig: &mock.Device{ - IsMigDeviceHandleFunc: func() (bool, nvml.Return) { - return true, nvml.SUCCESS - }, - GetGpuInstanceIdFunc: func() (int, nvml.Return) { - return 1, nvml.SUCCESS - }, - GetComputeInstanceIdFunc: func() (int, nvml.Return) { - return 2, nvml.SUCCESS - }, - }, - parent: &mock.Device{ - GetMinorNumberFunc: func() (int, nvml.Return) { - return 3, nvml.SUCCESS - }, - GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) { - var busID [32]uint8 - for i, b := range []byte("00000000:45:00:00") { - busID[i] = b - } - info := nvml.PciInfo{ - BusId: busID, - } - return info, nvml.SUCCESS - }, - }, + mig: migMock, + parent: parentMock, migCaps: nvcaps.MigCaps{ "gpu3/gi1/access": 31, "gpu3/gi1/ci2/access": 312, @@ -172,3 +150,91 @@ func TestNewNvmlMIGDiscoverer(t *testing.T) { }) } } + +// newNvmlMigDiscovererTestMocks returns parent and MIG NVML device mocks wired so +// go-nvlib device.MigDevice.GetProfile succeeds (GPU instance id 1, compute instance id 2). +func newNvmlMigDiscovererTestMocks() (parent *mock.Device, mig *mock.Device) { + const ( + gpuInstanceID = 1 + computeInstanceID = 2 + ) + + mockCI := &mock.ComputeInstance{ + GetInfoFunc: func() (nvml.ComputeInstanceInfo, nvml.Return) { + return nvml.ComputeInstanceInfo{ + ProfileId: nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE, + }, nvml.SUCCESS + }, + } + + mockGI := &mock.GpuInstance{ + GetInfoFunc: func() (nvml.GpuInstanceInfo, nvml.Return) { + return nvml.GpuInstanceInfo{ + ProfileId: nvml.GPU_INSTANCE_PROFILE_1_SLICE, + }, nvml.SUCCESS + }, + GetComputeInstanceByIdFunc: func(n int) (nvml.ComputeInstance, nvml.Return) { + if n == computeInstanceID { + return mockCI, nvml.SUCCESS + } + return nil, nvml.ERROR_INVALID_ARGUMENT + }, + GetComputeInstanceProfileInfoFunc: func(n1, n2 int) (nvml.ComputeInstanceProfileInfo, nvml.Return) { + if n1 == 0 && n2 == 0 { + return nvml.ComputeInstanceProfileInfo{ + Id: nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE, + }, nvml.SUCCESS + } + return nvml.ComputeInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED + }, + } + + parent = &mock.Device{ + GetMinorNumberFunc: func() (int, nvml.Return) { + return 3, nvml.SUCCESS + }, + GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) { + var busID [32]uint8 + for i, b := range []byte("00000000:45:00:00") { + busID[i] = b + } + return nvml.PciInfo{BusId: busID}, nvml.SUCCESS + }, + GetMemoryInfoFunc: func() (nvml.Memory, nvml.Return) { + // Non-zero total memory is required for MIG profile memory math in go-nvlib. + return nvml.Memory{Total: 40 * 1024 * 1024 * 1024}, nvml.SUCCESS + }, + GetGpuInstanceByIdFunc: func(n int) (nvml.GpuInstance, nvml.Return) { + if n == gpuInstanceID { + return mockGI, nvml.SUCCESS + } + return nil, nvml.ERROR_INVALID_ARGUMENT + }, + GetGpuInstanceProfileInfoFunc: func(n int) (nvml.GpuInstanceProfileInfo, nvml.Return) { + if n == 0 { + return nvml.GpuInstanceProfileInfo{ + Id: nvml.GPU_INSTANCE_PROFILE_1_SLICE, + MemorySizeMB: 5120, + }, nvml.SUCCESS + } + return nvml.GpuInstanceProfileInfo{}, nvml.ERROR_NOT_SUPPORTED + }, + } + + mig = &mock.Device{ + IsMigDeviceHandleFunc: func() (bool, nvml.Return) { + return true, nvml.SUCCESS + }, + GetGpuInstanceIdFunc: func() (int, nvml.Return) { + return gpuInstanceID, nvml.SUCCESS + }, + GetComputeInstanceIdFunc: func() (int, nvml.Return) { + return computeInstanceID, nvml.SUCCESS + }, + GetAttributesFunc: func() (nvml.DeviceAttributes, nvml.Return) { + return nvml.DeviceAttributes{MemorySizeMB: 5120}, nvml.SUCCESS + }, + } + + return parent, mig +} diff --git a/internal/platform-support/dgpu/nvsandboxutils.go b/internal/platform-support/dgpu/nvsandboxutils.go index 4e8eb2c80..620cc1b5b 100644 --- a/internal/platform-support/dgpu/nvsandboxutils.go +++ b/internal/platform-support/dgpu/nvsandboxutils.go @@ -19,6 +19,7 @@ package dgpu import ( "fmt" "path/filepath" + "slices" "strings" "github.com/NVIDIA/go-nvml/pkg/nvml" @@ -31,7 +32,7 @@ type nvsandboxutilsDGPU struct { lib nvsandboxutils.Interface uuid string devRoot string - isMig bool + supportsDRI bool hookCreator discover.HookCreator deviceLinks []string } @@ -52,11 +53,13 @@ func (o *options) newNvsandboxutilsDGPUDiscoverer(d UUIDer) (discover.Discover, return nil, fmt.Errorf("failed to get device UUID: %w", nvmlRet) } + supportsDRI := !o.isMigDevice || slices.Contains(o.migAttributes, "gfx") + nvd := nvsandboxutilsDGPU{ lib: o.nvsandboxutilslib, uuid: uuid, devRoot: strings.TrimSuffix(filepath.Clean(o.driver.DevRoot), "/dev"), - isMig: o.isMigDevice, + supportsDRI: supportsDRI, hookCreator: o.hookCreator, } @@ -73,7 +76,7 @@ func (d *nvsandboxutilsDGPU) Devices() ([]discover.Device, error) { for _, info := range gpuFileInfos { switch info.SubType { case nvsandboxutils.NV_DEV_DRI_CARD, nvsandboxutils.NV_DEV_DRI_RENDERD: - if d.isMig { + if !d.supportsDRI { continue } fallthrough @@ -90,7 +93,7 @@ func (d *nvsandboxutilsDGPU) Devices() ([]discover.Device, error) { } devices = append(devices, device) case nvsandboxutils.NV_DEV_DRI_CARD_SYMLINK, nvsandboxutils.NV_DEV_DRI_RENDERD_SYMLINK: - if d.isMig { + if !d.supportsDRI { continue } if info.Flags == nvsandboxutils.NV_FILE_FLAG_CONTENT { diff --git a/internal/platform-support/dgpu/options.go b/internal/platform-support/dgpu/options.go index 7f8bf2403..6301e4ccb 100644 --- a/internal/platform-support/dgpu/options.go +++ b/internal/platform-support/dgpu/options.go @@ -29,7 +29,9 @@ type options struct { driver *root.Driver hookCreator discover.HookCreator - isMigDevice bool + isMigDevice bool + migAttributes []string + // migCaps stores the MIG capabilities for the system. // If MIG is not available, this is nil. migCaps nvcaps.MigCaps