Skip to content

Commit 8e39b57

Browse files
committed
Unify ArgMin/ArgMax to IL kernels (NaN/bool)
Route all elementwise ArgMax/ArgMin cases (including Boolean, Single, Double) through the IL kernel path and remove the old scalar fallbacks. Added specialized IL helpers for float/double NaN-aware semantics and Boolean semantics (ArgMax/ArgMin helpers and EmitArgReductionStep variants), updated kernel generator to emit correct initial min/max for Boolean, and dispatch to type-specific helpers. Deleted legacy SimdReductionOptimized and a Boolean elementwise template, and adjusted engine calls to use ExecuteElementReduction for the unified path. These changes consolidate logic, ensure NumPy-like NaN handling (first NaN wins), and reduce duplication.
1 parent f631363 commit 8e39b57

9 files changed

Lines changed: 343 additions & 1167 deletions

File tree

src/NumSharp.Core/Backends/Default/Math/DefaultEngine.ReductionOp.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ protected object min_elementwise_il(NDArray arr, NPTypeCode? typeCode)
219219
/// <summary>
220220
/// Execute element-wise argmax reduction using IL kernels.
221221
/// Returns the index of the maximum value.
222+
/// All types including Boolean, Single, Double now use unified IL kernel path.
222223
/// </summary>
223224
[MethodImpl(MethodImplOptions.AggressiveInlining)]
224225
protected long argmax_elementwise_il(NDArray arr)
@@ -230,20 +231,19 @@ protected long argmax_elementwise_il(NDArray arr)
230231

231232
// ArgMax returns long (int64) to match NumPy 2.x behavior
232233
// Internally uses int kernels (arrays rarely exceed 2^31 elements), widens to long for API
233-
// For floating point types, use scalar implementation which handles NaN correctly (NumPy: first NaN wins)
234-
// For Boolean, use scalar implementation (IL doesn't support bool comparison directly)
234+
// All types use IL kernels - NaN-aware helpers for float/double, bool-aware for boolean
235235
return inputType switch
236236
{
237-
NPTypeCode.Boolean => (long)(int)argmax_elementwise(arr), // Boolean scalar path
237+
NPTypeCode.Boolean => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.Boolean),
238238
NPTypeCode.Byte => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.Byte),
239239
NPTypeCode.Int16 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.Int16),
240240
NPTypeCode.UInt16 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.UInt16),
241241
NPTypeCode.Int32 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.Int32),
242242
NPTypeCode.UInt32 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.UInt32),
243243
NPTypeCode.Int64 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.Int64),
244244
NPTypeCode.UInt64 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.UInt64),
245-
NPTypeCode.Single => (long)(int)argmax_elementwise(arr), // NaN-aware scalar path
246-
NPTypeCode.Double => (long)(int)argmax_elementwise(arr), // NaN-aware scalar path
245+
NPTypeCode.Single => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.Single),
246+
NPTypeCode.Double => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.Double),
247247
NPTypeCode.Decimal => ExecuteElementReduction<int>(arr, ReductionOp.ArgMax, NPTypeCode.Decimal),
248248
_ => throw new NotSupportedException($"ArgMax not supported for type {inputType}")
249249
};
@@ -252,6 +252,7 @@ protected long argmax_elementwise_il(NDArray arr)
252252
/// <summary>
253253
/// Execute element-wise argmin reduction using IL kernels.
254254
/// Returns the index of the minimum value.
255+
/// All types including Boolean, Single, Double now use unified IL kernel path.
255256
/// </summary>
256257
[MethodImpl(MethodImplOptions.AggressiveInlining)]
257258
protected long argmin_elementwise_il(NDArray arr)
@@ -263,20 +264,19 @@ protected long argmin_elementwise_il(NDArray arr)
263264

264265
// ArgMin returns long (int64) to match NumPy 2.x behavior
265266
// Internally uses int kernels (arrays rarely exceed 2^31 elements), widens to long for API
266-
// For floating point types, use scalar implementation which handles NaN correctly (NumPy: first NaN wins)
267-
// For Boolean, use scalar implementation (IL doesn't support bool comparison directly)
267+
// All types use IL kernels - NaN-aware helpers for float/double, bool-aware for boolean
268268
return inputType switch
269269
{
270-
NPTypeCode.Boolean => (long)(int)argmin_elementwise(arr), // Boolean scalar path
270+
NPTypeCode.Boolean => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.Boolean),
271271
NPTypeCode.Byte => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.Byte),
272272
NPTypeCode.Int16 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.Int16),
273273
NPTypeCode.UInt16 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.UInt16),
274274
NPTypeCode.Int32 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.Int32),
275275
NPTypeCode.UInt32 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.UInt32),
276276
NPTypeCode.Int64 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.Int64),
277277
NPTypeCode.UInt64 => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.UInt64),
278-
NPTypeCode.Single => (long)(int)argmin_elementwise(arr), // NaN-aware scalar path
279-
NPTypeCode.Double => (long)(int)argmin_elementwise(arr), // NaN-aware scalar path
278+
NPTypeCode.Single => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.Single),
279+
NPTypeCode.Double => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.Double),
280280
NPTypeCode.Decimal => ExecuteElementReduction<int>(arr, ReductionOp.ArgMin, NPTypeCode.Decimal),
281281
_ => throw new NotSupportedException($"ArgMin not supported for type {inputType}")
282282
};

src/NumSharp.Core/Backends/Default/Math/Reduction/Default.Reduction.ArgMax.cs

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -163,83 +163,5 @@ private unsafe NDArray ExecuteAxisArgReduction(NDArray arr, int axis, bool keepd
163163
return ret;
164164
}
165165

166-
/// <summary>
167-
/// Element-wise argmax for types requiring special handling (Boolean, Single, Double with NaN).
168-
/// Called by argmax_elementwise_il for these types only.
169-
/// </summary>
170-
protected object argmax_elementwise(NDArray arr)
171-
{
172-
if (arr.Shape.IsScalar || (arr.Shape.size == 1 && arr.Shape.NDim == 1))
173-
return 0;
174-
175-
// This method is only called for Boolean, Single, Double types
176-
// All other types use IL-generated kernels via argmax_elementwise_il
177-
switch (arr.GetTypeCode)
178-
{
179-
case NPTypeCode.Boolean:
180-
{
181-
// Boolean: True=1, False=0, so argmax finds first True
182-
var iter = arr.AsIterator<bool>();
183-
var moveNext = iter.MoveNext;
184-
var hasNext = iter.HasNext;
185-
int idx = 1, maxAt = 0;
186-
bool max = moveNext();
187-
while (hasNext())
188-
{
189-
var val = moveNext();
190-
// For argmax: True > False
191-
if (val && !max)
192-
{
193-
max = val;
194-
maxAt = idx;
195-
}
196-
idx++;
197-
}
198-
return maxAt;
199-
}
200-
case NPTypeCode.Single:
201-
{
202-
var iter = arr.AsIterator<float>();
203-
var moveNext = iter.MoveNext;
204-
var hasNext = iter.HasNext;
205-
int idx = 1, maxAt = 0;
206-
float max = moveNext();
207-
while (hasNext())
208-
{
209-
var val = moveNext();
210-
// NumPy: first NaN always wins
211-
if (val > max || (float.IsNaN(val) && !float.IsNaN(max)))
212-
{
213-
max = val;
214-
maxAt = idx;
215-
}
216-
idx++;
217-
}
218-
return maxAt;
219-
}
220-
case NPTypeCode.Double:
221-
{
222-
var iter = arr.AsIterator<double>();
223-
var moveNext = iter.MoveNext;
224-
var hasNext = iter.HasNext;
225-
int idx = 1, maxAt = 0;
226-
double max = moveNext();
227-
while (hasNext())
228-
{
229-
var val = moveNext();
230-
// NumPy: first NaN always wins
231-
if (val > max || (double.IsNaN(val) && !double.IsNaN(max)))
232-
{
233-
max = val;
234-
maxAt = idx;
235-
}
236-
idx++;
237-
}
238-
return maxAt;
239-
}
240-
default:
241-
throw new NotSupportedException($"argmax_elementwise should only be called for Boolean, Single, Double. Got: {arr.GetTypeCode}");
242-
}
243-
}
244166
}
245167
}

src/NumSharp.Core/Backends/Default/Math/Reduction/Default.Reduction.ArgMin.cs

Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -127,84 +127,5 @@ public override NDArray ReduceArgMin(NDArray arr, int? axis_, bool keepdims = fa
127127
// Use IL kernel for axis reduction (reuse the ArgMax method which handles both ArgMax and ArgMin)
128128
return ExecuteAxisArgReduction(arr, axis, keepdims, outputShape, axisedShape, ReductionOp.ArgMin);
129129
}
130-
131-
/// <summary>
132-
/// Element-wise argmin for types requiring special handling (Boolean, Single, Double with NaN).
133-
/// Called by argmin_elementwise_il for these types only.
134-
/// </summary>
135-
protected object argmin_elementwise(NDArray arr)
136-
{
137-
if (arr.Shape.IsScalar || (arr.Shape.size == 1 && arr.Shape.NDim == 1))
138-
return 0;
139-
140-
// This method is only called for Boolean, Single, Double types
141-
// All other types use IL-generated kernels via argmin_elementwise_il
142-
switch (arr.GetTypeCode)
143-
{
144-
case NPTypeCode.Boolean:
145-
{
146-
// Boolean: True=1, False=0, so argmin finds first False
147-
var iter = arr.AsIterator<bool>();
148-
var moveNext = iter.MoveNext;
149-
var hasNext = iter.HasNext;
150-
int idx = 1, minAt = 0;
151-
bool min = moveNext();
152-
while (hasNext())
153-
{
154-
var val = moveNext();
155-
// For argmin: False < True
156-
if (!val && min)
157-
{
158-
min = val;
159-
minAt = idx;
160-
}
161-
idx++;
162-
}
163-
return minAt;
164-
}
165-
case NPTypeCode.Single:
166-
{
167-
var iter = arr.AsIterator<float>();
168-
var moveNext = iter.MoveNext;
169-
var hasNext = iter.HasNext;
170-
int idx = 1, minAt = 0;
171-
float min = moveNext();
172-
while (hasNext())
173-
{
174-
var val = moveNext();
175-
// NumPy: first NaN always wins
176-
if (val < min || (float.IsNaN(val) && !float.IsNaN(min)))
177-
{
178-
min = val;
179-
minAt = idx;
180-
}
181-
idx++;
182-
}
183-
return minAt;
184-
}
185-
case NPTypeCode.Double:
186-
{
187-
var iter = arr.AsIterator<double>();
188-
var moveNext = iter.MoveNext;
189-
var hasNext = iter.HasNext;
190-
int idx = 1, minAt = 0;
191-
double min = moveNext();
192-
while (hasNext())
193-
{
194-
var val = moveNext();
195-
// NumPy: first NaN always wins
196-
if (val < min || (double.IsNaN(val) && !double.IsNaN(min)))
197-
{
198-
min = val;
199-
minAt = idx;
200-
}
201-
idx++;
202-
}
203-
return minAt;
204-
}
205-
default:
206-
throw new NotSupportedException($"argmin_elementwise should only be called for Boolean, Single, Double. Got: {arr.GetTypeCode}");
207-
}
208-
}
209130
}
210131
}

0 commit comments

Comments
 (0)