Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions pkg/gpu/nvidia/health_check/health_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@ import (
"github.com/NVIDIA/gpu-monitoring-tools/bindings/go/nvml"
"github.com/golang/glog"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes/scheme"
clientv1 "k8s.io/client-go/kubernetes/typed/core/v1"
listersv1 "k8s.io/client-go/listers/core/v1"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
client "k8s.io/client-go/kubernetes"
Expand All @@ -44,6 +40,8 @@ import (
const (
XIDConditionType = "XidCriticalError"
eventSource = "nvidia-gpu-device-plugin"

resetXIDConditionTimeout = 2 * time.Minute
)

// GPUHealthChecker checks the health of nvidia GPUs. Note that with the current
Expand All @@ -60,9 +58,7 @@ type GPUHealthChecker struct {
monitorCriticalXid map[uint64]bool
kubeClient client.Interface
nodeName string
nodeLister listersv1.NodeLister
recorder record.EventRecorder
informerFactory informers.SharedInformerFactory
}

// NewGPUHealthChecker returns a GPUHealthChecker object for a given device name
Expand Down Expand Up @@ -102,11 +98,36 @@ func NewGPUHealthChecker(devices map[string]pluginapi.Device, health chan plugin
return hc
}

// resetXIDConditionWithBackoff tries to reset XID condition with exponential backoff.
// It retries with a 1s sleep and backs off to 30s. It times out after 2 minutes.
func (hc *GPUHealthChecker) resetXIDConditionWithBackoff() {
backoff := 1 * time.Second
timeout := time.After(resetXIDConditionTimeout)
for {
select {
case <-timeout:
glog.Errorf("Timeout resetting XID condition after 2 minutes.")
return
default:
err := hc.resetXIDCondition()
if err == nil {
return
}
glog.Errorf("Failed to reset XID condition, will retry in %v. Error: %v", backoff, err)
time.Sleep(backoff)
backoff *= 2
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
}
}
}

// Check whether the XID condition should be removed. If the conditions exists,
// 1. If the bootId changes, consider the node fixed through auto-repair
// 2. If the bootId stay unchanged, consider a pure gpu-device-plugin restart
func (hc *GPUHealthChecker) resetXIDCondition() error {
node, err := hc.nodeLister.Get(hc.nodeName)
node, err := hc.kubeClient.CoreV1().Nodes().Get(context.Background(), hc.nodeName, metav1.GetOptions{})
if err != nil {
return err
}
Expand Down Expand Up @@ -140,28 +161,14 @@ func (hc *GPUHealthChecker) resetXIDCondition() error {

// Start registers NVML events and starts listening to them
func (hc *GPUHealthChecker) Start() error {
ctx := context.Background()
nodeName, err := metadata.InstanceNameWithContext(ctx)
nodeName, err := metadata.InstanceNameWithContext(context.Background())
if err != nil {
glog.Errorf("failed to get nodeName, err: %v", err)
}
hc.nodeName = nodeName
hc.informerFactory = informers.NewSharedInformerFactoryWithOptions(
hc.kubeClient,
0,
informers.WithTweakListOptions(func(options *metav1.ListOptions) {
options.FieldSelector = fields.OneTermEqualSelector("metadata.name", hc.nodeName).String()
}),
)

hc.nodeLister = hc.informerFactory.Core().V1().Nodes().Lister()
hc.informerFactory.Start(ctx.Done())
hc.informerFactory.WaitForCacheSync(wait.NeverStop)

err = hc.resetXIDCondition()
if err != nil {
glog.Errorf("failed to reset XID Condition, err: %v", err)
}

go hc.resetXIDConditionWithBackoff()

go hc.setXIDheartbeat()

glog.Info("Starting GPU Health Checker")
Expand Down Expand Up @@ -282,7 +289,7 @@ func (hc *GPUHealthChecker) monitorXidevent(e nvml.Event) {
if _, ok := hc.monitorCriticalXid[e.Edata]; ok {
glog.Info("Monitoring XID event")
// Set XID condition
node, err := hc.nodeLister.Get(hc.nodeName)
node, err := hc.kubeClient.CoreV1().Nodes().Get(context.Background(), hc.nodeName, metav1.GetOptions{})
if err != nil {
glog.Errorf("Failed to get node %s: %v", hc.nodeName, err)
return
Expand Down Expand Up @@ -351,7 +358,7 @@ func (hc *GPUHealthChecker) setXIDheartbeat() {
}

func (hc *GPUHealthChecker) updateLastHeartbeatTime() {
node, err := hc.nodeLister.Get(hc.nodeName)
node, err := hc.kubeClient.CoreV1().Nodes().Get(context.Background(), hc.nodeName, metav1.GetOptions{})
if err != nil {
glog.Errorf("Failed to get node %s for heartbeat update: %v", hc.nodeName, err)
return
Expand All @@ -377,7 +384,7 @@ func (hc *GPUHealthChecker) updateLastHeartbeatTime() {
}

func (hc *GPUHealthChecker) recordXIDEvent(e nvml.Event) error {
node, err := hc.nodeLister.Get(hc.nodeName)
node, err := hc.kubeClient.CoreV1().Nodes().Get(context.Background(), hc.nodeName, metav1.GetOptions{})
if err != nil {
return err
}
Expand Down Expand Up @@ -462,7 +469,6 @@ func (hc *GPUHealthChecker) listenToEvents() error {

// Stop deletes the NVML events and stops the listening go routine
func (hc *GPUHealthChecker) Stop() {
hc.informerFactory.Shutdown()
hc.recorder.(record.EventBroadcaster).Shutdown()
nvml.DeleteEventSet(hc.eventSet)
hc.stop <- true
Expand Down
85 changes: 44 additions & 41 deletions pkg/gpu/nvidia/health_check/health_checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package healthcheck

import (
"context"
"errors"
"reflect"
"testing"
"time"
Expand All @@ -24,9 +25,9 @@ import (

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/fake"
"k8s.io/client-go/tools/record"
k8sclienttesting "k8s.io/client-go/testing"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)

Expand All @@ -40,17 +41,6 @@ func (gp *mockGPUDevice) parseMigDeviceUUID(UUID string) (string, uint, uint, er
return UUID, 3173334309191009974, 1015241, nil
}

type mockNodeLister struct {
nodes []*v1.Node
}

func (nlister mockNodeLister) List(selector labels.Selector) (ret []*v1.Node, err error) {
return nlister.nodes, nil
}
func (nlister mockNodeLister) Get(name string) (*v1.Node, error) {
return nlister.nodes[0], nil
}

func TestCatchError(t *testing.T) {
gp := mockGPUDevice{}
device1 := pluginapi.Device{
Expand Down Expand Up @@ -248,8 +238,6 @@ func TestCatchError(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
tt.hc.kubeClient = fakeClient
tt.hc.health = make(chan pluginapi.Device, len(tt.hc.devices))
tt.hc.nodeLister = mockNodeLister{nodes: []*v1.Node{&node}}
tt.hc.recorder = record.NewFakeRecorder(200)
tt.hc.catchError(tt.event, &gp)
gotErrorDevices := make(map[string]pluginapi.Device)
for range tt.wantErrorDevices {
Expand Down Expand Up @@ -294,8 +282,6 @@ func TestUpdateLastHeartbeatTime(t *testing.T) {

hc := NewGPUHealthChecker(nil, nil, nil, fakeClient)
hc.nodeName = "test-node"
hc.nodeLister = mockNodeLister{nodes: []*v1.Node{&node}}
hc.recorder = record.NewFakeRecorder(200)

time.Sleep(2 * time.Second)
hc.updateLastHeartbeatTime()
Expand All @@ -308,7 +294,7 @@ func TestUpdateLastHeartbeatTime(t *testing.T) {
}
}

func TestResetXIDCondition(t *testing.T) {
func TestResetXIDConditionWithBackoff(t *testing.T) {
// Initialize the node with condition
node := makeNode(nil, nil, nil)
initialTime := metav1.Now()
Expand All @@ -320,34 +306,53 @@ func TestResetXIDCondition(t *testing.T) {
Reason: "XID",
Message: "0",
})
node.Status.Conditions = append(node.Status.Conditions, v1.NodeCondition{
Type: v1.NodeReady,
Status: "True",
})
node.Status.NodeInfo.BootID = "0"
node.Status.NodeInfo.BootID = "1" // New boot ID to ensure condition is removed on success

fakeClient := fake.NewSimpleClientset(&v1.NodeList{Items: []v1.Node{node}})

callCount := 0
// Mock the first 2 calls to fail, and the 3rd to succeed.
fakeClient.Fake.PrependReactor("get", "nodes", func(action k8sclienttesting.Action) (handled bool, ret runtime.Object, err error) {
callCount++
if callCount < 3 {
return true, nil, errors.New("fake API server error")
}
// On the 3rd call, we let the default reactor handle it, which will return the node.
return false, nil, nil
})

hc := NewGPUHealthChecker(nil, nil, nil, fakeClient)
hc.nodeName = "test-node"
hc.nodeLister = mockNodeLister{nodes: []*v1.Node{&node}}
hc.recorder = record.NewFakeRecorder(200)
// Try reset without rebootId changed, conditions remain the same
hc.resetXIDCondition()
updatedNode, _ := fakeClient.CoreV1().Nodes().Get(context.Background(), "test-node", metav1.GetOptions{})
if len(updatedNode.Status.Conditions) < 2 {
t.Errorf("The XID condition should persist without reboot")

startTime := time.Now()
hc.resetXIDConditionWithBackoff()
duration := time.Since(startTime)

if callCount != 3 {
t.Errorf("Expected 3 calls to get node, but got %d", callCount)
}

// The first sleep is 1s, the second is 2s. Total should be > 3s.
if duration < 3*time.Second {
t.Errorf("Expected backoff to take at least 3 seconds, but it took %v", duration)
}
// Try reset with rebootId changed, conditions get reset
updatedNode.Status.NodeInfo.BootID = "1"
_, err := fakeClient.CoreV1().Nodes().Update(context.Background(), updatedNode, metav1.UpdateOptions{})

// Check that the condition is removed
updatedNode, err := fakeClient.CoreV1().Nodes().Get(context.Background(), "test-node", metav1.GetOptions{})
if err != nil {
t.Errorf("Failed to update node: %v", err)
t.Fatalf("Failed to get node after backoff: %v", err)
}
hc.nodeLister = mockNodeLister{nodes: []*v1.Node{updatedNode}}
hc.resetXIDCondition()
updatedNode, _ = fakeClient.CoreV1().Nodes().Get(context.Background(), "test-node", metav1.GetOptions{})
if len(updatedNode.Status.Conditions) == 2 {
t.Errorf("The XID condition should be reset after reboot")

conditionRemoved := true
for _, c := range updatedNode.Status.Conditions {
if c.Type == XIDConditionType {
conditionRemoved = false
break
}
}

if !conditionRemoved {
t.Errorf("XID condition was not removed after successful retry")
}
}

Expand Down Expand Up @@ -429,8 +434,6 @@ func TestMonitorXidevent(t *testing.T) {

hc := NewGPUHealthChecker(nil, nil, nil, fakeClient)
hc.nodeName = "test-node"
hc.nodeLister = mockNodeLister{nodes: []*v1.Node{&node}}
hc.recorder = record.NewFakeRecorder(200)

for _, event := range test.events {
hc.monitorXidevent(event)
Expand Down