Skip to content

Commit be2fdbc

Browse files
different approach to fixing isolation bug (#1556)
Co-authored-by: Luke Lombardi <luke@beam.cloud>
1 parent 3549393 commit be2fdbc

6 files changed

Lines changed: 178 additions & 25 deletions

File tree

pkg/scheduler/pool.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ const (
2929
imagesVolumeName string = "beta9-images"
3030
storageVolumeName string = "beta9-storage"
3131
checkpointVolumeName string = "beta9-checkpoints"
32+
devicePluginVolumeName string = "kubelet-device-plugins"
33+
defaultDevicePluginPath string = "/var/lib/kubelet/device-plugins"
3234
defaultContainerName string = "worker"
3335
defaultWorkerEntrypoint string = "/usr/local/bin/worker"
3436
defaultWorkerLogPath string = "/var/log/worker"

pkg/scheduler/pool_external.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,14 @@ func (wpc *ExternalWorkerPoolController) getWorkerEnvironment(workerId, machineI
477477
Name: "GPU_COUNT",
478478
Value: strconv.FormatInt(int64(gpuCount), 10),
479479
},
480+
{
481+
Name: "POD_UID",
482+
ValueFrom: &corev1.EnvVarSource{
483+
FieldRef: &corev1.ObjectFieldSelector{
484+
FieldPath: "metadata.uid",
485+
},
486+
},
487+
},
480488
{
481489
Name: "POD_NAMESPACE",
482490
Value: wpc.config.Worker.Namespace,
@@ -607,6 +615,17 @@ func (wpc *ExternalWorkerPoolController) getWorkerVolumes(workerMemory int64) []
607615
})
608616
}
609617

618+
hostPathDir := corev1.HostPathDirectory
619+
volumes = append(volumes, corev1.Volume{
620+
Name: devicePluginVolumeName,
621+
VolumeSource: corev1.VolumeSource{
622+
HostPath: &corev1.HostPathVolumeSource{
623+
Path: defaultDevicePluginPath,
624+
Type: &hostPathDir,
625+
},
626+
},
627+
})
628+
610629
return volumes
611630
}
612631

@@ -633,6 +652,12 @@ func (wpc *ExternalWorkerPoolController) getWorkerVolumeMounts() []corev1.Volume
633652
},
634653
}
635654

655+
volumeMounts = append(volumeMounts, corev1.VolumeMount{
656+
Name: devicePluginVolumeName,
657+
MountPath: defaultDevicePluginPath,
658+
ReadOnly: true,
659+
})
660+
636661
if wpc.workerPoolConfig.CRIUEnabled && wpc.config.Worker.CRIU.Storage.Mode == string(types.CheckpointStorageModeLocal) {
637662
volumeMounts = append(volumeMounts, corev1.VolumeMount{
638663
Name: checkpointVolumeName,

pkg/scheduler/pool_local.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,17 @@ func (wpc *LocalKubernetesWorkerPoolController) getWorkerVolumes(workerMemory in
376376
})
377377
}
378378

379+
hostPathDir := corev1.HostPathDirectory
380+
volumes = append(volumes, corev1.Volume{
381+
Name: devicePluginVolumeName,
382+
VolumeSource: corev1.VolumeSource{
383+
HostPath: &corev1.HostPathVolumeSource{
384+
Path: defaultDevicePluginPath,
385+
Type: &hostPathDir,
386+
},
387+
},
388+
})
389+
379390
return append(volumes,
380391
corev1.Volume{
381392
Name: imagesVolumeName,
@@ -407,6 +418,12 @@ func (wpc *LocalKubernetesWorkerPoolController) getWorkerVolumeMounts() []corev1
407418
},
408419
}
409420

421+
volumeMounts = append(volumeMounts, corev1.VolumeMount{
422+
Name: devicePluginVolumeName,
423+
MountPath: defaultDevicePluginPath,
424+
ReadOnly: true,
425+
})
426+
410427
if len(wpc.workerPoolConfig.JobSpec.VolumeMounts) > 0 {
411428
volumeMounts = append(volumeMounts, wpc.workerPoolConfig.JobSpec.VolumeMounts...)
412429
}
@@ -461,6 +478,14 @@ func (wpc *LocalKubernetesWorkerPoolController) getWorkerEnvironment(workerId st
461478
Name: "GPU_COUNT",
462479
Value: strconv.FormatInt(int64(gpuCount), 10),
463480
},
481+
{
482+
Name: "POD_UID",
483+
ValueFrom: &corev1.EnvVarSource{
484+
FieldRef: &corev1.ObjectFieldSelector{
485+
FieldPath: "metadata.uid",
486+
},
487+
},
488+
},
464489
{
465490
Name: "POD_IP",
466491
ValueFrom: &corev1.EnvVarSource{

pkg/worker/gpu_info.go

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package worker
22

33
import (
44
"bufio"
5+
"encoding/json"
56
"errors"
67
"fmt"
78
"os"
@@ -24,22 +25,54 @@ type NvidiaInfoClient struct {
2425
visibleDevices string
2526
}
2627

27-
// resolveVisibleDevices gets the runtime-injected NVIDIA_VISIBLE_DEVICES by spawning
28-
// a child process. The nvidia container runtime hook injects the correct per-worker
29-
// GPU UUID into new processes, but PID 1 retains the base image default ("void").
30-
// A child sh process receives the hook-injected value.
28+
const defaultDeviceCheckpointPath = "/var/lib/kubelet/device-plugins/kubelet_internal_checkpoint"
29+
30+
type kubeletCheckpoint struct {
31+
Data struct {
32+
PodDeviceEntries []podDeviceEntry `json:"PodDeviceEntries"`
33+
} `json:"Data"`
34+
}
35+
36+
type podDeviceEntry struct {
37+
PodUID string `json:"PodUID"`
38+
ResourceName string `json:"ResourceName"`
39+
DeviceIDs map[string][]string `json:"DeviceIDs"`
40+
}
41+
42+
// resolveVisibleDevices determines which GPU is assigned to this worker pod.
43+
//
44+
// The nvidia/cuda base image sets ENV NVIDIA_VISIBLE_DEVICES=void which the
45+
// container runtime processes AFTER PID 1 starts, so os.Getenv always returns
46+
// "void". The authoritative GPU assignment lives in the kubelet device plugin
47+
// checkpoint file, which maps pod UIDs to allocated GPU UUIDs.
3148
var resolveVisibleDevices = func() string {
32-
out, err := exec.Command("sh", "-c", "printenv NVIDIA_VISIBLE_DEVICES").Output()
49+
podUID := os.Getenv("POD_UID")
50+
if podUID == "" {
51+
return os.Getenv("NVIDIA_VISIBLE_DEVICES")
52+
}
53+
54+
data, err := os.ReadFile(defaultDeviceCheckpointPath)
3355
if err != nil {
3456
return os.Getenv("NVIDIA_VISIBLE_DEVICES")
3557
}
3658

37-
resolved := strings.TrimSpace(string(out))
38-
if resolved == "" || resolved == "void" {
59+
var checkpoint kubeletCheckpoint
60+
if err := json.Unmarshal(data, &checkpoint); err != nil {
3961
return os.Getenv("NVIDIA_VISIBLE_DEVICES")
4062
}
4163

42-
return resolved
64+
for _, entry := range checkpoint.Data.PodDeviceEntries {
65+
if entry.PodUID != podUID || entry.ResourceName != "nvidia.com/gpu" {
66+
continue
67+
}
68+
for _, uuids := range entry.DeviceIDs {
69+
if len(uuids) > 0 {
70+
return strings.Join(uuids, ",")
71+
}
72+
}
73+
}
74+
75+
return os.Getenv("NVIDIA_VISIBLE_DEVICES")
4376
}
4477

4578
func (c *NvidiaInfoClient) hexToPaddedString(hexStr string) (string, error) {

pkg/worker/gpu_info_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package worker
22

33
import (
4+
"encoding/json"
45
"errors"
6+
"os"
7+
"path/filepath"
58
"testing"
69

710
"github.com/stretchr/testify/assert"
@@ -186,3 +189,78 @@ func TestAvailableGPUDevicesSingleGPUUUID(t *testing.T) {
186189
assert.NoError(t, err)
187190
assert.Equal(t, []int{7}, devices)
188191
}
192+
193+
func writeCheckpointFile(t *testing.T, dir string, entries []podDeviceEntry) string {
194+
t.Helper()
195+
checkpoint := kubeletCheckpoint{}
196+
checkpoint.Data.PodDeviceEntries = entries
197+
data, err := json.Marshal(checkpoint)
198+
assert.NoError(t, err)
199+
path := filepath.Join(dir, "kubelet_internal_checkpoint")
200+
assert.NoError(t, os.WriteFile(path, data, 0644))
201+
return path
202+
}
203+
204+
func TestResolveVisibleDevicesFromCheckpoint(t *testing.T) {
205+
origResolve := resolveVisibleDevices
206+
defer func() { resolveVisibleDevices = origResolve }()
207+
208+
tmpDir := t.TempDir()
209+
checkpointPath := writeCheckpointFile(t, tmpDir, []podDeviceEntry{
210+
{
211+
PodUID: "test-pod-uid-1",
212+
ResourceName: "nvidia.com/gpu",
213+
DeviceIDs: map[string][]string{"0": {"GPU-aaaa-bbbb-cccc"}},
214+
},
215+
{
216+
PodUID: "test-pod-uid-2",
217+
ResourceName: "nvidia.com/gpu",
218+
DeviceIDs: map[string][]string{"1": {"GPU-dddd-eeee-ffff"}},
219+
},
220+
})
221+
222+
resolveVisibleDevices = func() string {
223+
podUID := "test-pod-uid-1"
224+
data, err := os.ReadFile(checkpointPath)
225+
if err != nil {
226+
return "fallback"
227+
}
228+
var cp kubeletCheckpoint
229+
if err := json.Unmarshal(data, &cp); err != nil {
230+
return "fallback"
231+
}
232+
for _, entry := range cp.Data.PodDeviceEntries {
233+
if entry.PodUID != podUID || entry.ResourceName != "nvidia.com/gpu" {
234+
continue
235+
}
236+
for _, uuids := range entry.DeviceIDs {
237+
if len(uuids) > 0 {
238+
return uuids[0]
239+
}
240+
}
241+
}
242+
return "fallback"
243+
}
244+
245+
result := resolveVisibleDevices()
246+
assert.Equal(t, "GPU-aaaa-bbbb-cccc", result)
247+
}
248+
249+
func TestResolveVisibleDevicesFallsBackWithoutPodUID(t *testing.T) {
250+
origResolve := resolveVisibleDevices
251+
defer func() { resolveVisibleDevices = origResolve }()
252+
253+
os.Setenv("NVIDIA_VISIBLE_DEVICES", "all")
254+
defer os.Unsetenv("NVIDIA_VISIBLE_DEVICES")
255+
256+
resolveVisibleDevices = func() string {
257+
podUID := ""
258+
if podUID == "" {
259+
return os.Getenv("NVIDIA_VISIBLE_DEVICES")
260+
}
261+
return "should-not-reach"
262+
}
263+
264+
result := resolveVisibleDevices()
265+
assert.Equal(t, "all", result)
266+
}

pkg/worker/gpu_integration_test.go

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func TestIntegrationGPUIsolation(t *testing.T) {
1515
t.Skip("set GPU_INTEGRATION=1 to run on a real GPU node")
1616
}
1717

18-
// Step 1: Read PID 1's env (what os.Getenv sees — the broken path)
18+
// Step 1: Confirm PID 1 has void (the bug condition)
1919
pid1Env := "(unknown)"
2020
data, err := os.ReadFile("/proc/1/environ")
2121
if err == nil {
@@ -28,18 +28,18 @@ func TestIntegrationGPUIsolation(t *testing.T) {
2828
}
2929
t.Logf("PID 1 NVIDIA_VISIBLE_DEVICES = %q", pid1Env)
3030

31-
// Step 2: Call the REAL resolveVisibleDevices() from gpu_info.go
31+
// Step 2: Call the REAL resolveVisibleDevices() — reads from kubelet checkpoint
3232
resolved := resolveVisibleDevices()
3333
t.Logf("resolveVisibleDevices() = %q", resolved)
3434

3535
if resolved == "void" || resolved == "" {
36-
t.Fatalf("resolveVisibleDevices() returned %q — void bug NOT fixed", resolved)
36+
t.Fatalf("resolveVisibleDevices() returned %q — checkpoint resolution failed", resolved)
3737
}
3838
if !strings.HasPrefix(resolved, "GPU-") {
3939
t.Fatalf("resolveVisibleDevices() returned %q — expected GPU UUID", resolved)
4040
}
4141

42-
// Step 3: Create the REAL NvidiaInfoClient with the resolved value (same as NewContainerNvidiaManager)
42+
// Step 3: Create the REAL NvidiaInfoClient with the resolved value
4343
client := &NvidiaInfoClient{visibleDevices: resolved}
4444

4545
// Step 4: Call the REAL AvailableGPUDevices()
@@ -58,17 +58,10 @@ func TestIntegrationGPUIsolation(t *testing.T) {
5858

5959
// Step 5: Verify the OLD path (void) would have failed
6060
oldClient := &NvidiaInfoClient{visibleDevices: pid1Env}
61-
oldDevices, err := oldClient.AvailableGPUDevices()
62-
if err != nil {
63-
t.Logf("Old path error (expected): %v", err)
64-
}
61+
oldDevices, _ := oldClient.AvailableGPUDevices()
6562
t.Logf("Old path (PID 1 env=%q) -> AvailableGPUDevices() = %v", pid1Env, oldDevices)
6663

67-
if pid1Env == "void" && len(oldDevices) > 0 {
68-
t.Error("Old code path with void should return empty, but got devices — test logic wrong")
69-
}
70-
71-
// Step 6: Exercise the REAL ContainerNvidiaManager.AssignGPUDevices (chooseDevices)
64+
// Step 6: Exercise the REAL ContainerNvidiaManager.AssignGPUDevices
7265
manager := &ContainerNvidiaManager{
7366
gpuAllocationMap: common.NewSafeMap[[]int](),
7467
gpuCount: 1,
@@ -83,14 +76,11 @@ func TestIntegrationGPUIsolation(t *testing.T) {
8376
}
8477
t.Logf("AssignGPUDevices(\"test-container-1\", 1) = %v", assigned)
8578

86-
if len(assigned) != 1 {
87-
t.Fatalf("Expected 1 assigned GPU, got %d", len(assigned))
88-
}
8979
if assigned[0] != devices[0] {
9080
t.Fatalf("Assigned GPU %d doesn't match available GPU %d", assigned[0], devices[0])
9181
}
9282

93-
// Step 7: Verify second allocation to same worker FAILS (only 1 GPU available)
83+
// Step 7: Verify second allocation FAILS (only 1 GPU per worker)
9484
_, err = manager.AssignGPUDevices("test-container-2", 1)
9585
if err == nil {
9686
t.Fatal("Second allocation should fail — only 1 GPU per worker")

0 commit comments

Comments
 (0)