66using System . Reflection . Emit ;
77using System . Runtime . Intrinsics ;
88using System . Runtime . Intrinsics . X86 ;
9- using System . Threading . Tasks ;
109
1110// =============================================================================
1211// ILKernelGenerator.Reduction.Axis.Simd.cs - SIMD Axis Reduction Kernels
1716// - AxisReductionSimdHelper<T> - main SIMD helper
1817// - ReduceContiguousAxis variants (SIMD256, SIMD128, scalar)
1918// - ReduceStridedAxis with AVX2 gather for float/double
20- // - Parallel outer loop for large output sizes
2119// - Vector identity/combine/horizontal helpers
2220// - IKernelProvider interface implementation
2321//
@@ -42,16 +40,9 @@ private static unsafe AxisReductionKernel CreateAxisReductionKernelTyped<T>(Axis
4240 } ;
4341 }
4442
45- /// <summary>
46- /// Threshold for parallelizing the outer loop in axis reductions.
47- /// Only parallelize when output size exceeds this threshold.
48- /// </summary>
49- private const int AxisReductionParallelThreshold = 1000 ;
50-
5143 /// <summary>
5244 /// SIMD helper for axis reduction operations.
5345 /// Reduces along a specific axis, writing results to output array.
54- /// Uses parallel outer loop for large output sizes.
5546 /// </summary>
5647 /// <typeparam name="T">Element type</typeparam>
5748 /// <param name="input">Input data pointer</param>
@@ -97,125 +88,47 @@ internal static unsafe void AxisReductionSimdHelper<T>(
9788 ReductionOp actualOp = op == ReductionOp . Mean ? ReductionOp . Sum : op ;
9889 bool isMean = op == ReductionOp . Mean ;
9990
100- // Use parallel loop for large output sizes
101- if ( outputSize > AxisReductionParallelThreshold )
91+ // Sequential loop over output elements
92+ for ( int outIdx = 0 ; outIdx < outputSize ; outIdx ++ )
10293 {
103- // Copy strides to managed arrays for safe parallel access
104- int [ ] inputStridesArray = new int [ ndim ] ;
105- for ( int i = 0 ; i < ndim ; i ++ )
106- inputStridesArray [ i ] = inputStrides [ i ] ;
107-
108- int [ ] outputStridesArray = new int [ outputNdim > 0 ? outputNdim : 1 ] ;
109- for ( int i = 0 ; i < outputStridesArray . Length && i < outputNdim ; i ++ )
110- outputStridesArray [ i ] = outputStrides [ i ] ;
94+ // Convert linear output index to coordinates and compute input base offset
95+ int remaining = outIdx ;
96+ int inputBaseOffset = 0 ;
97+ int outputOffset = 0 ;
11198
112- // Capture pointers for lambda
113- T * inputPtr = input ;
114- T * outputPtr = output ;
115-
116- Parallel . For ( 0 , outputSize , outIdx =>
99+ for ( int d = 0 ; d < outputNdim ; d ++ )
117100 {
118- ReduceAxisElement (
119- inputPtr , outputPtr ,
120- inputStridesArray , outputStridesArray , outputDimStridesArray ,
121- axis , axisSize , axisStride , outputNdim ,
122- axisContiguous , actualOp , isMean , outIdx ) ;
123- } ) ;
124- }
125- else
126- {
127- // Sequential loop for small output sizes
128- for ( int outIdx = 0 ; outIdx < outputSize ; outIdx ++ )
129- {
130- // Convert linear output index to coordinates and compute input base offset
131- int remaining = outIdx ;
132- int inputBaseOffset = 0 ;
133- int outputOffset = 0 ;
134-
135- for ( int d = 0 ; d < outputNdim ; d ++ )
136- {
137- // Map output dimension d to input dimension
138- int inputDim = d >= axis ? d + 1 : d ;
139-
140- int coord = remaining / outputDimStridesArray [ d ] ;
141- remaining = remaining % outputDimStridesArray [ d ] ;
142-
143- inputBaseOffset += coord * inputStrides [ inputDim ] ;
144- outputOffset += coord * outputStrides [ d ] ;
145- }
146-
147- // Now reduce along the axis
148- T * axisStart = input + inputBaseOffset ;
149-
150- T result ;
151- if ( axisContiguous )
152- {
153- // Fast path: axis is contiguous, use SIMD
154- result = ReduceContiguousAxis ( axisStart , axisSize , actualOp ) ;
155- }
156- else
157- {
158- // Strided path: axis is not contiguous, use SIMD gather if beneficial
159- result = ReduceStridedAxis ( axisStart , axisSize , axisStride , actualOp ) ;
160- }
161-
162- // For Mean, divide by count
163- if ( isMean )
164- result = DivideByCountTyped ( result , axisSize ) ;
165-
166- output [ outputOffset ] = result ;
167- }
168- }
169- }
101+ // Map output dimension d to input dimension
102+ int inputDim = d >= axis ? d + 1 : d ;
170103
171- /// <summary>
172- /// Process a single output element for axis reduction.
173- /// Used by parallel loop to process each output position independently.
174- /// </summary>
175- private static unsafe void ReduceAxisElement < T > (
176- T * input , T * output ,
177- int [ ] inputStrides , int [ ] outputStrides , int [ ] outputDimStrides ,
178- int axis , int axisSize , int axisStride , int outputNdim ,
179- bool axisContiguous , ReductionOp op , bool isMean , int outIdx )
180- where T : unmanaged
181- {
182- // Convert linear output index to coordinates and compute input base offset
183- int remaining = outIdx ;
184- int inputBaseOffset = 0 ;
185- int outputOffset = 0 ;
104+ int coord = remaining / outputDimStridesArray [ d ] ;
105+ remaining = remaining % outputDimStridesArray [ d ] ;
186106
187- for ( int d = 0 ; d < outputNdim ; d ++ )
188- {
189- // Map output dimension d to input dimension
190- int inputDim = d >= axis ? d + 1 : d ;
107+ inputBaseOffset += coord * inputStrides [ inputDim ] ;
108+ outputOffset += coord * outputStrides [ d ] ;
109+ }
191110
192- int coord = remaining / outputDimStrides [ d ] ;
193- remaining = remaining % outputDimStrides [ d ] ;
111+ // Now reduce along the axis
112+ T * axisStart = input + inputBaseOffset ;
194113
195- inputBaseOffset += coord * inputStrides [ inputDim ] ;
196- outputOffset += coord * outputStrides [ d ] ;
197- }
114+ T result ;
115+ if ( axisContiguous )
116+ {
117+ // Fast path: axis is contiguous, use SIMD
118+ result = ReduceContiguousAxis ( axisStart , axisSize , actualOp ) ;
119+ }
120+ else
121+ {
122+ // Strided path: axis is not contiguous, use SIMD gather if beneficial
123+ result = ReduceStridedAxis ( axisStart , axisSize , axisStride , actualOp ) ;
124+ }
198125
199- // Now reduce along the axis
200- T * axisStart = input + inputBaseOffset ;
126+ // For Mean, divide by count
127+ if ( isMean )
128+ result = DivideByCountTyped ( result , axisSize ) ;
201129
202- T result ;
203- if ( axisContiguous )
204- {
205- // Fast path: axis is contiguous, use SIMD
206- result = ReduceContiguousAxis ( axisStart , axisSize , op ) ;
130+ output [ outputOffset ] = result ;
207131 }
208- else
209- {
210- // Strided path: axis is not contiguous, use SIMD gather if beneficial
211- result = ReduceStridedAxis ( axisStart , axisSize , axisStride , op ) ;
212- }
213-
214- // For Mean, divide by count
215- if ( isMean )
216- result = DivideByCountTyped ( result , axisSize ) ;
217-
218- output [ outputOffset ] = result ;
219132 }
220133
221134 /// <summary>
0 commit comments