diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 005c42f48..260005925 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -51,6 +51,14 @@ func (f *Factory) newCDIModifier(isJitCDI bool) (oci.SpecModifier, error) { defaultKind, ) devices := deviceRequestor.DeviceRequests() + + // Run before the empty-device return so NVIDIA_REQUIRE_* is still enforced when + // len(devices)==0 (e.g. CRI CDI injection without matching spec signals). When + // there are no requirements, checkRequirements returns immediately. + if err := checkRequirements(f.logger, f.image, f.driver); err != nil { + return nil, fmt.Errorf("requirements not met: %w", err) + } + if len(devices) == 0 { f.logger.Debugf("No devices requested; no modification required.") return nil, nil diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index b20fdb134..2b93d6041 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -20,10 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" - "github.com/NVIDIA/nvidia-container-toolkit/internal/cuda" - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" - "github.com/NVIDIA/nvidia-container-toolkit/internal/requirements" ) // newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper. @@ -36,45 +33,13 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) { } f.logger.Infof("Constructing modifier from config: %+v", *f.cfg) - if err := checkRequirements(f.logger, f.image); err != nil { + if err := checkRequirements(f.logger, f.image, f.driver); err != nil { return nil, fmt.Errorf("requirements not met: %v", err) } return f.newAutomaticCDISpecModifier(devices) } -func checkRequirements(logger logger.Interface, image *image.CUDA) error { - if image == nil || image.HasDisableRequire() { - // TODO: We could print the real value here instead - logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true) - return nil - } - - imageRequirements, err := image.GetRequirements() - if err != nil { - // TODO: Should we treat this as a failure, or just issue a warning? - return fmt.Errorf("failed to get image requirements: %v", err) - } - - r := requirements.New(logger, imageRequirements) - - cudaVersion, err := cuda.Version() - if err != nil { - logger.Warningf("Failed to get CUDA version: %v", err) - } else { - r.AddVersionProperty(requirements.CUDA, cudaVersion) - } - - compteCapability, err := cuda.ComputeCapability(0) - if err != nil { - logger.Warningf("Failed to get CUDA Compute Capability: %v", err) - } else { - r.AddVersionProperty(requirements.ARCH, compteCapability) - } - - return r.Assert() -} - type csvDevices image.CUDA func (d csvDevices) DeviceRequests() []string { diff --git a/internal/modifier/image_requirements.go b/internal/modifier/image_requirements.go new file mode 100644 index 000000000..b36ff4480 --- /dev/null +++ b/internal/modifier/image_requirements.go @@ -0,0 +1,200 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package modifier + +import ( + "fmt" + "strconv" + "strings" + + "github.com/NVIDIA/go-nvml/pkg/nvml" + "golang.org/x/mod/semver" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/NVIDIA/nvidia-container-toolkit/internal/cuda" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" + "github.com/NVIDIA/nvidia-container-toolkit/internal/requirements" +) + +// checkRequirements evaluates NVIDIA_REQUIRE_* constraints using the host +// CUDA driver API version from libcuda, the NVIDIA display driver version from +// the driver root (libcuda / libnvidia-ml soname), the compute capability of +// CUDA device 0, and (when requirements reference brand) the GPU product brand +// from NVML. It is used for CSV and CDI / JIT-CDI modes. +func checkRequirements(logger logger.Interface, image *image.CUDA, driver *root.Driver) error { + if image == nil || image.HasDisableRequire() { + logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true) + return nil + } + + imageRequirements, err := image.GetRequirements() + if err != nil { + return fmt.Errorf("failed to get image requirements: %v", err) + } + if len(imageRequirements) == 0 { + return nil + } + + r := requirements.New(logger, imageRequirements) + + cudaVersion, err := cuda.Version() + if err != nil { + logger.Warningf("Failed to get CUDA version: %v", err) + } else { + r.AddVersionProperty(requirements.CUDA, cudaVersion) + } + + compteCapability, err := cuda.ComputeCapability(0) + if err != nil { + logger.Warningf("Failed to get CUDA Compute Capability: %v", err) + } else { + r.AddVersionProperty(requirements.ARCH, compteCapability) + } + + driverVersion, err := driver.Version() + if err != nil { + logger.Warningf("Failed to get NVIDIA driver version: %v", err) + } else { + normalized, normErr := normalizeDriverVersionForSemver(driverVersion) + if normErr != nil { + logger.Warningf("NVIDIA driver version %q is not semver-normalizable: %v", driverVersion, normErr) + } else { + r.AddVersionProperty(requirements.DRIVER, normalized) + } + } + + brand, err := getBrandFromNVML(driver) + if err != nil { + logger.Warningf("Failed to get GPU brand from NVML: %v", err) + } else { + r.AddStringProperty(requirements.BRAND, brand) + } + + return r.Assert() +} + +// normalizeDriverVersionForSemver converts a driver version taken from a +// libcuda / libnvidia-ml soname suffix into a form accepted by +// golang.org/x/mod/semver (no leading zeros in numeric segments) +func normalizeDriverVersionForSemver(raw string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", fmt.Errorf("empty driver version") + } + parts := strings.Split(raw, ".") + out := make([]string, 0, len(parts)) + for _, p := range parts { + if p == "" { + return "", fmt.Errorf("empty version segment in %q", raw) + } + if strings.TrimLeft(p, "0123456789") != "" { + return "", fmt.Errorf("non-numeric version segment %q in %q", p, raw) + } + n, err := strconv.ParseUint(p, 10, 64) + if err != nil { + return "", fmt.Errorf("invalid version segment %q in %q: %w", p, raw, err) + } + out = append(out, strconv.FormatUint(n, 10)) + } + normalized := strings.Join(out, ".") + if !semver.IsValid("v" + normalized) { + return "", fmt.Errorf("normalized driver version %q is not valid semver", normalized) + } + return normalized, nil +} + +// getBrandFromNVML returns a lowercase brand token for the first visible GPU +// (index 0), using NVML. When driver is non-nil, NVML is loaded from the +// versioned libnvidia-ml under the driver root when possible. +func getBrandFromNVML(driver *root.Driver) (string, error) { + var lib nvml.Interface + var opts []nvml.LibraryOption + v, err := driver.Version() + if err == nil && v != "" && v != "*.*" { + paths, err := driver.Libraries().Locate("libnvidia-ml.so." + v) + if err == nil && len(paths) > 0 { + opts = append(opts, nvml.WithLibraryPath(paths[0])) + } + } + + lib = nvml.New(opts...) + if ret := lib.Init(); ret != nvml.SUCCESS { + return "", fmt.Errorf("nvml.Init: %s", lib.ErrorString(ret)) + } + defer func() { + _ = lib.Shutdown() + }() + + device, ret := lib.DeviceGetHandleByIndex(0) + if ret != nvml.SUCCESS { + return "", fmt.Errorf("nvml.DeviceGetHandleByIndex(0): %s", lib.ErrorString(ret)) + } + + brandType, ret := lib.DeviceGetBrand(device) + if ret != nvml.SUCCESS { + return "", fmt.Errorf("nvml.DeviceGetBrand: %s", lib.ErrorString(ret)) + } + brand, ok := brandTypeToRequirementString(brandType) + if !ok { + return "", fmt.Errorf("unknown NVML brand type %v", brandType) + } + return brand, nil +} + +// brandTypeToRequirementString maps NVML brand enums to lowercase tokens +// consistent with typical NVIDIA_REQUIRE_* image constraints. +func brandTypeToRequirementString(b nvml.BrandType) (string, bool) { + switch b { + case nvml.BRAND_UNKNOWN: + return "", false + case nvml.BRAND_QUADRO: + return "quadro", true + case nvml.BRAND_TESLA: + return "tesla", true + case nvml.BRAND_NVS: + return "nvs", true + case nvml.BRAND_GRID: + return "grid", true + case nvml.BRAND_GEFORCE: + return "geforce", true + case nvml.BRAND_TITAN: + return "titan", true + case nvml.BRAND_NVIDIA_VAPPS: + return "nvidiavapps", true + case nvml.BRAND_NVIDIA_VPC: + return "nvidiavpc", true + case nvml.BRAND_NVIDIA_VCS: + return "nvidiavcs", true + case nvml.BRAND_NVIDIA_VWS: + return "nvidiavws", true + case nvml.BRAND_NVIDIA_CLOUD_GAMING: + return "nvidiacloudgaming", true + case nvml.BRAND_QUADRO_RTX: + return "quadrortx", true + case nvml.BRAND_NVIDIA_RTX: + return "nvidiartx", true + case nvml.BRAND_NVIDIA: + return "nvidia", true + case nvml.BRAND_GEFORCE_RTX: + return "geforcertx", true + case nvml.BRAND_TITAN_RTX: + return "titanrtx", true + default: + return "", false + } +}