diff --git a/api/v1alpha1/nodeset_types.go b/api/v1alpha1/nodeset_types.go index 33090d0c9..585640845 100644 --- a/api/v1alpha1/nodeset_types.go +++ b/api/v1alpha1/nodeset_types.go @@ -132,6 +132,18 @@ type NodeSetSpec struct { // +kubebuilder:default="20%" MaxUnavailable *intstr.IntOrString `json:"maxUnavailable,omitempty"` + // MaxConcurrentStartup caps the number of worker pods created in parallel + // during initial NodeSet scale-out (i.e. cluster creation or NodeSet growth). + // Value can be an absolute number (ex: 500) or a percentage of desired pods (ex: 10%). + // Maps to the underlying kruise AdvancedStatefulSet's scaleStrategy.maxUnavailable. + // Prevents overloading the Slurm controller with simultaneous slurmd registrations + // on large clusters. + // Defaults to 500. + // + // +kubebuilder:validation:Optional + // +kubebuilder:default=500 + MaxConcurrentStartup *intstr.IntOrString `json:"maxConcurrentStartup,omitempty"` + // EphemeralNodes enables ephemeral node behavior for this NodeSet. // When true, nodes will use dynamic topology injection instead of legacy topology.conf. // Topology data is read from the topology-node-labels ConfigMap at runtime diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index 179cad9e0..75a58d7ff 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -1034,6 +1034,11 @@ func (in *NodeSetSpec) DeepCopyInto(out *NodeSetSpec) { *out = new(intstr.IntOrString) **out = **in } + if in.MaxConcurrentStartup != nil { + in, out := &in.MaxConcurrentStartup, &out.MaxConcurrentStartup + *out = new(intstr.IntOrString) + **out = **in + } if in.EphemeralNodes != nil { in, out := &in.EphemeralNodes, &out.EphemeralNodes *out = new(bool) diff --git a/config/crd/bases/slurm.nebius.ai_nodesets.yaml b/config/crd/bases/slurm.nebius.ai_nodesets.yaml index 3cad84721..bfaa33734 100644 --- a/config/crd/bases/slurm.nebius.ai_nodesets.yaml +++ b/config/crd/bases/slurm.nebius.ai_nodesets.yaml @@ -2562,6 +2562,20 @@ spec: format: int32 minimum: 0 type: integer + maxConcurrentStartup: + anyOf: + - type: integer + - type: string + default: 500 + description: |- + MaxConcurrentStartup caps the number of worker pods created in parallel + during initial NodeSet scale-out (i.e. cluster creation or NodeSet growth). + Value can be an absolute number (ex: 500) or a percentage of desired pods (ex: 10%). + Maps to the underlying kruise AdvancedStatefulSet's scaleStrategy.maxUnavailable. + Prevents overloading the Slurm controller with simultaneous slurmd registrations + on large clusters. + Defaults to 500. + x-kubernetes-int-or-string: true maxUnavailable: anyOf: - type: integer diff --git a/helm/nodesets/templates/nodeset.yaml b/helm/nodesets/templates/nodeset.yaml index abed378ba..08e2905d7 100644 --- a/helm/nodesets/templates/nodeset.yaml +++ b/helm/nodesets/templates/nodeset.yaml @@ -27,6 +27,10 @@ spec: maxUnavailable: {{ . }} {{- end }} + {{- with .maxConcurrentStartup }} + maxConcurrentStartup: {{ . }} + {{- end }} + {{- if .ephemeralNodes }} ephemeralNodes: {{ .ephemeralNodes }} {{- end }} diff --git a/helm/nodesets/tests/node_config_test.yaml b/helm/nodesets/tests/node_config_test.yaml index 5ae69c883..f77b32cfe 100644 --- a/helm/nodesets/tests/node_config_test.yaml +++ b/helm/nodesets/tests/node_config_test.yaml @@ -249,6 +249,59 @@ tests: value: 2 documentIndex: 1 + - it: should configure maxConcurrentStartup when set and omit it when unset + set: + nodesets: + - name: gpu-workers + replicas: 3 + maxConcurrentStartup: 250 + slurmd: + image: + repository: "test/slurm" + resources: + cpu: "4" + memory: "8Gi" + volumes: + spool: + emptyDir: {} + jail: + emptyDir: {} + jailSubMounts: [] + munge: + image: + repository: "test/munge" + resources: + cpu: "100m" + memory: "128Mi" + - name: cpu-workers + replicas: 5 + slurmd: + image: + repository: "test/slurm" + resources: + cpu: "2" + memory: "4Gi" + volumes: + spool: + emptyDir: {} + jail: + emptyDir: {} + jailSubMounts: [] + munge: + image: + repository: "test/munge" + resources: + cpu: "50m" + memory: "64Mi" + asserts: + - equal: + path: spec.maxConcurrentStartup + value: 250 + documentIndex: 0 + - notExists: + path: spec.maxConcurrentStartup + documentIndex: 1 + - it: should configure worker annotations correctly set: nodesets: diff --git a/helm/nodesets/values.yaml b/helm/nodesets/values.yaml index 5c9534fea..d37ef66bc 100644 --- a/helm/nodesets/values.yaml +++ b/helm/nodesets/values.yaml @@ -49,6 +49,12 @@ nodesets: # Could be a count (number) or percent (string) # Optional, defaults to 20% maxUnavailable: 1 + # Maximum number of worker pods that can be created in parallel during + # initial scale-out, to avoid overloading the Slurm controller with + # simultaneous slurmd registrations. + # Could be a count (number) or percent (string). + # Optional, defaults to 500 + maxConcurrentStartup: 500 # Enable ephemeral node behavior for this NodeSet. # When true, nodes will use dynamic topology injection instead of legacy topology.conf. # Topology data is read from the topology-node-labels ConfigMap at runtime. diff --git a/helm/soperator-crds/templates/slurmcluster-crd.yaml b/helm/soperator-crds/templates/slurmcluster-crd.yaml index 15adfa732..791f6ee67 100644 --- a/helm/soperator-crds/templates/slurmcluster-crd.yaml +++ b/helm/soperator-crds/templates/slurmcluster-crd.yaml @@ -17656,6 +17656,20 @@ spec: format: int32 minimum: 0 type: integer + maxConcurrentStartup: + anyOf: + - type: integer + - type: string + default: 500 + description: |- + MaxConcurrentStartup caps the number of worker pods created in parallel + during initial NodeSet scale-out (i.e. cluster creation or NodeSet growth). + Value can be an absolute number (ex: 500) or a percentage of desired pods (ex: 10%). + Maps to the underlying kruise AdvancedStatefulSet's scaleStrategy.maxUnavailable. + Prevents overloading the Slurm controller with simultaneous slurmd registrations + on large clusters. + Defaults to 500. + x-kubernetes-int-or-string: true maxUnavailable: anyOf: - type: integer diff --git a/helm/soperator/crds/slurmcluster-crd.yaml b/helm/soperator/crds/slurmcluster-crd.yaml index 15adfa732..791f6ee67 100644 --- a/helm/soperator/crds/slurmcluster-crd.yaml +++ b/helm/soperator/crds/slurmcluster-crd.yaml @@ -17656,6 +17656,20 @@ spec: format: int32 minimum: 0 type: integer + maxConcurrentStartup: + anyOf: + - type: integer + - type: string + default: 500 + description: |- + MaxConcurrentStartup caps the number of worker pods created in parallel + during initial NodeSet scale-out (i.e. cluster creation or NodeSet growth). + Value can be an absolute number (ex: 500) or a percentage of desired pods (ex: 10%). + Maps to the underlying kruise AdvancedStatefulSet's scaleStrategy.maxUnavailable. + Prevents overloading the Slurm controller with simultaneous slurmd registrations + on large clusters. + Defaults to 500. + x-kubernetes-int-or-string: true maxUnavailable: anyOf: - type: integer diff --git a/internal/controller/reconciler/k8s_statefulset_advanced.go b/internal/controller/reconciler/k8s_statefulset_advanced.go index 33a988158..59cad0cec 100644 --- a/internal/controller/reconciler/k8s_statefulset_advanced.go +++ b/internal/controller/reconciler/k8s_statefulset_advanced.go @@ -57,6 +57,7 @@ func (r *AdvancedStatefulSetReconciler) patch(existing, desired client.Object) ( } dst.Spec.Replicas = src.Spec.Replicas dst.Spec.UpdateStrategy = src.Spec.UpdateStrategy + dst.Spec.ScaleStrategy = src.Spec.ScaleStrategy dst.Spec.Template.Spec = src.Spec.Template.Spec dst.Spec.ReserveOrdinals = src.Spec.ReserveOrdinals dst.Spec.PersistentVolumeClaimRetentionPolicy = src.Spec.PersistentVolumeClaimRetentionPolicy diff --git a/internal/controller/reconciler/k8s_statefulset_test.go b/internal/controller/reconciler/k8s_statefulset_test.go index 76f66f8cc..abf7b5bfd 100644 --- a/internal/controller/reconciler/k8s_statefulset_test.go +++ b/internal/controller/reconciler/k8s_statefulset_test.go @@ -10,6 +10,7 @@ import ( "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/intstr" "sigs.k8s.io/controller-runtime/pkg/client/fake" slurmv1 "nebius.ai/slurm-operator/api/v1" @@ -184,6 +185,48 @@ func TestAdvancedStatefulSetPatchCopiesPVCDeletionPolicy(t *testing.T) { } } +func TestAdvancedStatefulSetPatchCopiesScaleStrategy(t *testing.T) { + tests := []struct { + name string + existing *kruisev1b1.StatefulSetScaleStrategy + desired *kruisev1b1.StatefulSetScaleStrategy + }{ + { + name: "absolute MaxUnavailable is propagated", + existing: &kruisev1b1.StatefulSetScaleStrategy{MaxUnavailable: ptrIntOrString(intstr.FromInt32(100))}, + desired: &kruisev1b1.StatefulSetScaleStrategy{MaxUnavailable: ptrIntOrString(intstr.FromInt32(500))}, + }, + { + name: "percentage MaxUnavailable is propagated", + existing: &kruisev1b1.StatefulSetScaleStrategy{MaxUnavailable: ptrIntOrString(intstr.FromString("5%"))}, + desired: &kruisev1b1.StatefulSetScaleStrategy{MaxUnavailable: ptrIntOrString(intstr.FromString("25%"))}, + }, + { + name: "newly set ScaleStrategy is populated", + existing: nil, + desired: &kruisev1b1.StatefulSetScaleStrategy{MaxUnavailable: ptrIntOrString(intstr.FromInt32(500))}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + existing := &kruisev1b1.StatefulSet{Spec: kruisev1b1.StatefulSetSpec{ScaleStrategy: tt.existing}} + desired := &kruisev1b1.StatefulSet{Spec: kruisev1b1.StatefulSetSpec{ScaleStrategy: tt.desired}} + + r := &AdvancedStatefulSetReconciler{} + if _, err := r.patch(existing, desired); err != nil { + t.Fatalf("patch returned error: %v", err) + } + + if !equality.Semantic.DeepEqual(existing.Spec.ScaleStrategy, tt.desired) { + t.Fatalf("expected ScaleStrategy=%+v, got %+v", tt.desired, existing.Spec.ScaleStrategy) + } + }) + } +} + +func ptrIntOrString(v intstr.IntOrString) *intstr.IntOrString { return &v } + func TestStatefulSetPatchCopiesPVCDeletionPolicy(t *testing.T) { existing := &appsv1.StatefulSet{ Spec: appsv1.StatefulSetSpec{ diff --git a/internal/render/common/configmap.go b/internal/render/common/configmap.go index df698585c..3419696bf 100644 --- a/internal/render/common/configmap.go +++ b/internal/render/common/configmap.go @@ -103,11 +103,15 @@ func AddNodeSetsToSlurmConfig(res *renderutils.PropertiesConfig, cluster *values } } -// AddNodesToSlurmConfig adds all node names to the slurm config +// AddNodesToSlurmConfig adds all node names to the slurm config. +// +// All replicas in a NodeSet share an identical body, so we emit one line per +// NodeSet using Slurm hostlist range syntax. This keeps the ConfigMap well +// under the Kubernetes 1 MiB object size limit on large clusters. // // Example output: -// NodeName=gb200-0-0 State=CLOUD NodeHostname=gb200-0-0 NodeAddr=gb200-0-0.gb200-0.soperator.svc RealMemory=1612639 Features=platform-gb200,gb200-rack-0 Gres=gpu:nvidia-b200:4 NodeCPUs=128 Boards=1 SocketsPerBoard=2 CoresPerSocket=32 ThreadsPerCode=2 -// NodeName=gb200-0-1 State=CLOUD NodeHostname=gb200-0-1 NodeAddr=gb200-0-1.gb200-0.soperator.svc RealMemory=1612639 Features=platform-gb200,gb200-rack-0 Gres=gpu:nvidia-b200:4 NodeCPUs=128 Boards=1 SocketsPerBoard=2 CoresPerSocket=32 ThreadsPerCode=2 +// NodeName=gb200-0-0 State=CLOUD NodeAddr=gb200-0-0.gb200-0.soperator.svc RealMemory=1612639 Feature=platform-gb200,gb200-rack-0 Gres=gpu:nvidia-b200:4 NodeCPUs=128 Boards=1 SocketsPerBoard=2 CoresPerSocket=32 ThreadsPerCode=2 +// NodeName=gb200-1-[0-17] State=CLOUD NodeAddr=gb200-1-[0-17].gb200-1.soperator.svc RealMemory=1612639 Feature=platform-gb200,gb200-rack-1 Gres=gpu:nvidia-b200:4 NodeCPUs=128 Boards=1 SocketsPerBoard=2 CoresPerSocket=32 ThreadsPerCode=2 func AddNodesToSlurmConfig(res *renderutils.PropertiesConfig, cluster *values.SlurmCluster) { res.AddComment("Nodes section") @@ -125,59 +129,60 @@ func AddNodesToSlurmConfig(res *renderutils.PropertiesConfig, cluster *values.Sl continue } - for i := int32(0); i < nodeSet.Spec.Replicas; i++ { - nodeName := fmt.Sprintf("%s-%d", nodeSet.Name, i) - - nodeAddr := fmt.Sprintf( - "%s.%s", - nodeName, - naming.BuildNodeSetUmbrellaServiceFQDN(nodeSet.Namespace, cluster.Name), - ) - realMemory := strconv.FormatInt( - RenderRealMemorySlurmd(corev1.ResourceRequirements{Requests: nodeSet.Spec.Slurmd.Resources}), - 10, - ) - - var nodeConfigParts []string - nodeConfigParts = append(nodeConfigParts, - fmt.Sprintf("NodeHostname=%s", nodeName), - fmt.Sprintf("NodeAddr=%s", nodeAddr), - fmt.Sprintf("RealMemory=%s", realMemory), - ) - if nodeSet.Spec.Slurmd.Port != 0 { - nodeConfigParts = append(nodeConfigParts, fmt.Sprintf("Port=%d", nodeSet.Spec.Slurmd.Port)) - } - nodeConfig := strings.Join(nodeConfigParts, " ") + var nodeRange string + if nodeSet.Spec.Replicas == 1 { + nodeRange = fmt.Sprintf("%s-0", nodeSet.Name) + } else { + nodeRange = fmt.Sprintf("%s-[0-%d]", nodeSet.Name, nodeSet.Spec.Replicas-1) + } - if len(nodeSet.Spec.NodeConfig.Features) > 0 { - features := strings.Join(nodeSet.Spec.NodeConfig.Features, ",") - nodeConfig = fmt.Sprintf("%s %s%s", nodeConfig, nodeFeatureKey, features) - } + nodeAddr := fmt.Sprintf( + "%s.%s", + nodeRange, + naming.BuildNodeSetUmbrellaServiceFQDN(nodeSet.Namespace, cluster.Name), + ) + realMemory := strconv.FormatInt( + RenderRealMemorySlurmd(corev1.ResourceRequirements{Requests: nodeSet.Spec.Slurmd.Resources}), + 10, + ) + + var nodeConfigParts []string + nodeConfigParts = append(nodeConfigParts, + fmt.Sprintf("NodeAddr=%s", nodeAddr), + fmt.Sprintf("RealMemory=%s", realMemory), + ) + if nodeSet.Spec.Slurmd.Port != 0 { + nodeConfigParts = append(nodeConfigParts, fmt.Sprintf("Port=%d", nodeSet.Spec.Slurmd.Port)) + } + nodeConfig := strings.Join(nodeConfigParts, " ") - // Remove feature and state overrides - staticConfig := strings.Join( - slices.Filter( - nil, - strings.Split(nodeSet.Spec.NodeConfig.Static, " "), - func(s string) bool { - return !strings.HasPrefix(s, nodeFeatureKey) && - !strings.HasPrefix(s, stateKey) - }, - ), - " ", - ) - - if len(nodeConfig) > 0 { - nodeConfig = fmt.Sprintf("%s %s", nodeConfig, staticConfig) - } + if len(nodeSet.Spec.NodeConfig.Features) > 0 { + features := strings.Join(nodeSet.Spec.NodeConfig.Features, ",") + nodeConfig = fmt.Sprintf("%s %s%s", nodeConfig, nodeFeatureKey, features) + } - // Create static nodes with state CLOUD. - // Otherwise, nodes will disappear from the Slurm state every time the corresponding K8s pods don't run. - res.AddProperty( - "NodeName", - fmt.Sprintf("%s State=CLOUD %s", nodeName, nodeConfig), - ) + staticConfig := strings.Join( + slices.Filter( + nil, + strings.Split(nodeSet.Spec.NodeConfig.Static, " "), + func(s string) bool { + return !strings.HasPrefix(s, nodeFeatureKey) && + !strings.HasPrefix(s, stateKey) + }, + ), + " ", + ) + + if len(nodeConfig) > 0 { + nodeConfig = fmt.Sprintf("%s %s", nodeConfig, staticConfig) } + + // State=CLOUD keeps nodes registered in Slurm even when the + // corresponding K8s pods are not running. + res.AddProperty( + "NodeName", + fmt.Sprintf("%s State=CLOUD %s", nodeRange, nodeConfig), + ) } } diff --git a/internal/render/common/configmap_test.go b/internal/render/common/configmap_test.go index 72ddfd647..2190dde2b 100644 --- a/internal/render/common/configmap_test.go +++ b/internal/render/common/configmap_test.go @@ -585,7 +585,7 @@ func TestAddNodesToSlurmConfig(t *testing.T) { }, }, }, - expected: "NodeName=nodeA-0 State=CLOUD NodeHostname=nodeA-0 NodeAddr=nodeA-0.slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=2048 Feature=a,b Gres=gpu:nvidia-a100:4 NodeCPUs=64 Boards=1 SocketsPerBoard=2 CoresPerSocket=32 ThreadsPerCode=1", + expected: "NodeName=nodeA-0 State=CLOUD NodeAddr=nodeA-0.slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=2048 Feature=a,b Gres=gpu:nvidia-a100:4 NodeCPUs=64 Boards=1 SocketsPerBoard=2 CoresPerSocket=32 ThreadsPerCode=1", }, { name: "Single nodeset with multiple replicas", @@ -614,9 +614,7 @@ func TestAddNodesToSlurmConfig(t *testing.T) { }, }, }, - expected: "NodeName=nodeB-0 State=CLOUD NodeHostname=nodeB-0 NodeAddr=nodeB-0.slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=4096 Gres=gpu:nvidia-a100:8 NodeCPUs=128 Boards=1 SocketsPerBoard=4 CoresPerSocket=32 ThreadsPerCode=1\n" + - "NodeName=nodeB-1 State=CLOUD NodeHostname=nodeB-1 NodeAddr=nodeB-1.slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=4096 Gres=gpu:nvidia-a100:8 NodeCPUs=128 Boards=1 SocketsPerBoard=4 CoresPerSocket=32 ThreadsPerCode=1\n" + - "NodeName=nodeB-2 State=CLOUD NodeHostname=nodeB-2 NodeAddr=nodeB-2.slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=4096 Gres=gpu:nvidia-a100:8 NodeCPUs=128 Boards=1 SocketsPerBoard=4 CoresPerSocket=32 ThreadsPerCode=1", + expected: "NodeName=nodeB-[0-2] State=CLOUD NodeAddr=nodeB-[0-2].slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=4096 Gres=gpu:nvidia-a100:8 NodeCPUs=128 Boards=1 SocketsPerBoard=4 CoresPerSocket=32 ThreadsPerCode=1", }, { name: "Multiple nodesets with varying replicas", @@ -662,9 +660,8 @@ func TestAddNodesToSlurmConfig(t *testing.T) { }, }, }, - expected: "NodeName=nodeC-0 State=CLOUD NodeHostname=nodeC-0 NodeAddr=nodeC-0.slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=8192 Gres=gpu:nvidia-a100:16 NodeCPUs=256 Boards=2 SocketsPerBoard=4 CoresPerSocket=32 ThreadsPerCode=1\n" + - "NodeName=nodeC-1 State=CLOUD NodeHostname=nodeC-1 NodeAddr=nodeC-1.slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=8192 Gres=gpu:nvidia-a100:16 NodeCPUs=256 Boards=2 SocketsPerBoard=4 CoresPerSocket=32 ThreadsPerCode=1\n" + - "NodeName=nodeD-0 State=CLOUD NodeHostname=nodeD-0 NodeAddr=nodeD-0.slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=16384 Gres=gpu:nvidia-a100:32 NodeCPUs=512 Boards=4 SocketsPerBoard=4 CoresPerSocket=32 ThreadsPerCode=1", + expected: "NodeName=nodeC-[0-1] State=CLOUD NodeAddr=nodeC-[0-1].slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=8192 Gres=gpu:nvidia-a100:16 NodeCPUs=256 Boards=2 SocketsPerBoard=4 CoresPerSocket=32 ThreadsPerCode=1\n" + + "NodeName=nodeD-0 State=CLOUD NodeAddr=nodeD-0.slurm-test-nodeset-svc.soperator.svc.cluster.local RealMemory=16384 Gres=gpu:nvidia-a100:32 NodeCPUs=512 Boards=4 SocketsPerBoard=4 CoresPerSocket=32 ThreadsPerCode=1", }, { name: "Nodeset with zero replicas", diff --git a/internal/render/worker/statefulset.go b/internal/render/worker/statefulset.go index a844cfef3..5ba867a20 100644 --- a/internal/render/worker/statefulset.go +++ b/internal/render/worker/statefulset.go @@ -159,6 +159,9 @@ func RenderNodeSetStatefulSet( ServiceName: nodeSet.ServiceUmbrella.Name, Replicas: replicas, ReserveOrdinals: reserveOrdinals, + ScaleStrategy: &kruisev1b1.StatefulSetScaleStrategy{ + MaxUnavailable: &nodeSet.StatefulSet.MaxConcurrentStartup, + }, UpdateStrategy: kruisev1b1.StatefulSetUpdateStrategy{ Type: appsv1.RollingUpdateStatefulSetStrategyType, RollingUpdate: &kruisev1b1.RollingUpdateStatefulSetStrategy{ diff --git a/internal/render/worker/statefulset_test.go b/internal/render/worker/statefulset_test.go index 3300754be..dc407651b 100644 --- a/internal/render/worker/statefulset_test.go +++ b/internal/render/worker/statefulset_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" @@ -418,6 +419,90 @@ func TestRenderNodeSetStatefulSet_PersistentVolumeClaimRetentionPolicy(t *testin } } +func TestRenderNodeSetStatefulSet_ScaleStrategy(t *testing.T) { + makeNodeSet := func(maxConcurrentStartup, maxUnavailable intstr.IntOrString) *values.SlurmNodeSet { + return &values.SlurmNodeSet{ + Name: "test-nodeset", + ParentalCluster: client.ObjectKey{ + Namespace: "test-namespace", + Name: "test-cluster", + }, + ContainerSlurmd: values.Container{ + NodeContainer: slurmv1.NodeContainer{ + Image: "test-image", + ImagePullPolicy: corev1.PullIfNotPresent, + Resources: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("1Gi"), + corev1.ResourceCPU: resource.MustParse("100m"), + corev1.ResourceEphemeralStorage: resource.MustParse("1Gi"), + }, + }, + }, + ContainerMunge: values.Container{ + NodeContainer: slurmv1.NodeContainer{Image: "munge-image"}, + }, + VolumeSpool: corev1.VolumeSource{HostPath: &corev1.HostPathVolumeSource{Path: "/tmp/spool"}}, + VolumeJail: corev1.VolumeSource{HostPath: &corev1.HostPathVolumeSource{Path: "/tmp/jail"}}, + StatefulSet: values.StatefulSet{ + Replicas: 1, + MaxUnavailable: maxUnavailable, + MaxConcurrentStartup: maxConcurrentStartup, + }, + SupervisorDConfigMapName: "supervisord-config", + SSHDConfigMapName: "sshd-config", + GPU: &slurmv1alpha1.GPUSpec{Enabled: false}, + } + } + + tests := []struct { + name string + maxConcurrentStartup intstr.IntOrString + maxUnavailable intstr.IntOrString + expectedMaxConcurrentStart intstr.IntOrString + expectedMaxUnavailable intstr.IntOrString + }{ + { + name: "absolute MaxConcurrentStartup is propagated to ScaleStrategy", + maxConcurrentStartup: intstr.FromInt32(500), + maxUnavailable: intstr.FromString("20%"), + expectedMaxConcurrentStart: intstr.FromInt32(500), + expectedMaxUnavailable: intstr.FromString("20%"), + }, + { + name: "percentage MaxConcurrentStartup is propagated to ScaleStrategy", + maxConcurrentStartup: intstr.FromString("10%"), + maxUnavailable: intstr.FromInt32(1), + expectedMaxConcurrentStart: intstr.FromString("10%"), + expectedMaxUnavailable: intstr.FromInt32(1), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := worker.RenderNodeSetStatefulSet( + "test-cluster", + makeNodeSet(tt.maxConcurrentStartup, tt.maxUnavailable), + &slurmv1.Secrets{}, + consts.CGroupV2, + false, + ) + assert.NoError(t, err) + + if assert.NotNil(t, result.Spec.ScaleStrategy) && + assert.NotNil(t, result.Spec.ScaleStrategy.MaxUnavailable) { + assert.Equal(t, tt.expectedMaxConcurrentStart, *result.Spec.ScaleStrategy.MaxUnavailable, + "ScaleStrategy.MaxUnavailable should mirror StatefulSet.MaxConcurrentStartup") + } + + if assert.NotNil(t, result.Spec.UpdateStrategy.RollingUpdate) && + assert.NotNil(t, result.Spec.UpdateStrategy.RollingUpdate.MaxUnavailable) { + assert.Equal(t, tt.expectedMaxUnavailable, *result.Spec.UpdateStrategy.RollingUpdate.MaxUnavailable, + "UpdateStrategy.RollingUpdate.MaxUnavailable governs the update path and must not be affected") + } + }) + } +} + func TestRenderNodeSetStatefulSet_EphemeralNodesReserveOrdinals(t *testing.T) { createNodeSetWithActiveNodes := func(ephemeralNodes bool, activeNodes []int32) *values.SlurmNodeSet { return &values.SlurmNodeSet{ diff --git a/internal/values/slurm_controller.go b/internal/values/slurm_controller.go index 8c9bcfcfb..84f92695c 100644 --- a/internal/values/slurm_controller.go +++ b/internal/values/slurm_controller.go @@ -42,6 +42,7 @@ func buildSlurmControllerFrom(clusterName string, maintenance *consts.Maintenanc naming.BuildStatefulSetName(consts.ComponentTypeController), consts.SingleReplicas, nil, + nil, ) daemonSet := buildDaemonSetFrom( diff --git a/internal/values/slurm_nodeset.go b/internal/values/slurm_nodeset.go index 266cd5ee1..608950540 100644 --- a/internal/values/slurm_nodeset.go +++ b/internal/values/slurm_nodeset.go @@ -119,6 +119,7 @@ func BuildSlurmNodeSetFrom( naming.BuildNodeSetStatefulSetName(nodeSet.Name), nsSpec.Replicas, nsSpec.MaxUnavailable, + nsSpec.MaxConcurrentStartup, ), Service: buildServiceFrom(naming.BuildNodeSetServiceName(clusterName, nodeSet.Name)), ServiceUmbrella: buildServiceFrom(naming.BuildServiceName(consts.ComponentTypeNodeSet, clusterName)), diff --git a/internal/values/types.go b/internal/values/types.go index 8eeafb044..861a5ff0a 100644 --- a/internal/values/types.go +++ b/internal/values/types.go @@ -64,9 +64,10 @@ func buildServiceFrom( // region StatefulSet type StatefulSet struct { - Name string - Replicas int32 - MaxUnavailable intstr.IntOrString + Name string + Replicas int32 + MaxUnavailable intstr.IntOrString + MaxConcurrentStartup intstr.IntOrString } func buildStatefulSetFrom( @@ -84,6 +85,7 @@ func buildStatefulSetWithMaxUnavailableFrom( name string, size int32, maxUnavailable *intstr.IntOrString, + maxConcurrentStartup *intstr.IntOrString, ) StatefulSet { result := StatefulSet{ Name: name, @@ -94,6 +96,10 @@ func buildStatefulSetWithMaxUnavailableFrom( result.MaxUnavailable = *maxUnavailable } + if maxConcurrentStartup != nil { + result.MaxConcurrentStartup = *maxConcurrentStartup + } + return result }