99
1010def arrangement (
1111 input ,
12+ mean ,
13+ var ,
1214 running_mean ,
1315 running_var ,
14- tmp_mean ,
15- tmp_var ,
1616 weight ,
1717 bias ,
18- momentum ,
1918 eps ,
2019 output ,
2120 num_normalized_elements ,
2221 use_input_stats ,
23- tracking_running_stats ,
2422 dims ,
2523 block_size = None ,
2624):
27- def _arrange_per_channel_tensor (tensor ):
25+ if block_size is None :
26+ block_size = ninetoothed .block_size ()
27+
28+ def _arrange_channel_tensor (tensor ):
2829 arranged = tensor .tile ((1 ,))
2930 arranged .dtype = arranged .dtype .squeeze (0 )
3031 arranged = arranged .unsqueeze (0 )
3132 arranged = arranged .expand ((input .shape [0 ], - 1 ))
3233
3334 return arranged
3435
36+ def _arrange_mean_or_var (tensor ):
37+ arranged = tensor .tile ((1 , 1 ))
38+ arranged .dtype = arranged .dtype .squeeze ((0 , 1 ))
39+
40+ return arranged
41+
3542 input_arranged , output_arranged = reduction_arrangement (
3643 input , output , dim = dims , block_size = block_size
3744 )
38- running_mean_arranged = _arrange_per_channel_tensor (running_mean )
39- running_var_arranged = _arrange_per_channel_tensor (running_var )
40- tmp_mean_arranged = _arrange_per_channel_tensor (tmp_mean )
41- tmp_var_arranged = _arrange_per_channel_tensor (tmp_var )
42- weight_arranged = _arrange_per_channel_tensor (weight )
43- bias_arranged = _arrange_per_channel_tensor (bias )
44- momentum_arranged = momentum
45+ mean_arranged = _arrange_mean_or_var (mean )
46+ var_arranged = _arrange_mean_or_var (var )
47+ running_mean_arranged = _arrange_channel_tensor (running_mean )
48+ running_var_arranged = _arrange_channel_tensor (running_var )
49+ weight_arranged = _arrange_channel_tensor (weight )
50+ bias_arranged = _arrange_channel_tensor (bias )
4551 eps_arranged = eps
4652 num_normalized_elements_arranged = num_normalized_elements
4753
4854 if use_input_stats :
49- if tracking_running_stats :
50- return (
51- input_arranged ,
52- running_mean_arranged ,
53- running_var_arranged ,
54- tmp_mean_arranged ,
55- tmp_var_arranged ,
56- weight_arranged ,
57- bias_arranged ,
58- momentum_arranged ,
59- eps_arranged ,
60- output_arranged ,
61- num_normalized_elements_arranged ,
62- )
63- else :
64- return (
65- input_arranged ,
66- weight_arranged ,
67- bias_arranged ,
68- eps_arranged ,
69- output_arranged ,
70- num_normalized_elements_arranged ,
71- )
72-
73- return (
74- input_arranged ,
75- running_mean_arranged ,
76- running_var_arranged ,
77- weight_arranged ,
78- bias_arranged ,
79- eps_arranged ,
80- output_arranged ,
81- )
82-
83-
84- def application_without_tracking (
85- input ,
86- weight ,
87- bias ,
88- eps ,
89- output ,
90- num_normalized_elements ,
91- ):
92- _mean = ntl .zeros (input .dtype .shape , dtype = ntl .float32 )
93-
94- for i in range (input .shape [0 ]):
95- _mean += ntl .cast (input [i ], ntl .float32 )
96-
97- mean = ntl .sum (_mean , 0 ) / num_normalized_elements
98-
99- _var = ntl .zeros (input .dtype .shape , dtype = ntl .float32 )
100-
101- for i in range (input .shape [0 ]):
102- diff = ntl .cast (input [i ], ntl .float32 ) - mean
103- diff = ntl .where (input [i ].offsets (- 1 ) < input .source .shape [- 1 ], diff , 0 )
104- _var += diff * diff
105-
106- var = ntl .sum (_var , 0 ) / num_normalized_elements
107-
108- application_with_mean_var (input , mean , var , weight , bias , eps , output )
109-
110-
111- def application_with_tracking (
55+ return (
56+ input_arranged ,
57+ mean_arranged ,
58+ var_arranged ,
59+ weight_arranged ,
60+ bias_arranged ,
61+ eps_arranged ,
62+ output_arranged ,
63+ num_normalized_elements_arranged ,
64+ )
65+ else :
66+ return (
67+ input_arranged ,
68+ running_mean_arranged ,
69+ running_var_arranged ,
70+ weight_arranged ,
71+ bias_arranged ,
72+ eps_arranged ,
73+ output_arranged ,
74+ )
75+
76+
77+ def application_using_input_stats (
11278 input ,
113- running_mean ,
114- running_var ,
115- tmp_mean ,
116- tmp_var ,
79+ mean ,
80+ var ,
11781 weight ,
11882 bias ,
119- momentum ,
12083 eps ,
12184 output ,
12285 num_normalized_elements ,
@@ -137,22 +100,6 @@ def application_with_tracking(
137100
138101 var = ntl .sum (_var , 0 ) / num_normalized_elements
139102
140- ntl .atomic_add (
141- tmp_mean .source .data_ptr () + tmp_mean .offsets (0 ), ntl .cast (mean , ntl .float32 )
142- )
143- ntl .atomic_add (
144- tmp_var .source .data_ptr () + tmp_mean .offsets (0 ), ntl .cast (var , ntl .float32 )
145- )
146-
147- ntl .debug_barrier ()
148-
149- if input [0 ].offsets (0 ) == 0 :
150- tmp_mean = tmp_mean / input .source .shape [0 ]
151- tmp_var = tmp_var / input .source .shape [0 ]
152-
153- running_mean = running_mean * (1 - momentum ) + tmp_mean * momentum
154- running_var = running_var * (1 - momentum ) + tmp_var * momentum
155-
156103 application_with_mean_var (input , mean , var , weight , bias , eps , output )
157104
158105
@@ -174,7 +121,6 @@ def application_with_mean_var(
174121def premake (
175122 ndim ,
176123 use_input_stats ,
177- tracking_running_stats ,
178124 num_normalized_elements ,
179125 dtype = None ,
180126 block_size = None ,
@@ -184,36 +130,30 @@ def premake(
184130 arrangement_ = functools .partial (
185131 arrangement ,
186132 use_input_stats = use_input_stats ,
187- tracking_running_stats = tracking_running_stats ,
188133 dims = dims ,
189134 block_size = block_size ,
190135 )
191136
192137 input = Tensor (ndim , other = 0 , dtype = dtype )
193- running_mean , running_var , tmp_mean , tmp_var , weight , bias = (
194- Tensor (1 , dtype = dtype ) for _ in range (6 )
195- )
196- momentum , eps = (Tensor (0 , dtype = ninetoothed .float64 ) for _ in range (2 ))
138+ mean , var = (Tensor (2 , dtype = dtype ) for _ in range (2 ))
139+ running_mean , running_var , weight , bias = (Tensor (1 , dtype = dtype ) for _ in range (4 ))
140+ eps = Tensor (0 , dtype = ninetoothed .float64 )
197141 output = Tensor (ndim , dtype = dtype )
198142 num_normalized_elements = Tensor (0 , constexpr = True , value = num_normalized_elements )
199143
200144 if use_input_stats :
201- if tracking_running_stats :
202- application = application_with_tracking
203- else :
204- application = application_without_tracking
145+ application = application_using_input_stats
205146 else :
206147 application = application_with_mean_var
207148
208149 tensors = (
209150 input ,
151+ mean ,
152+ var ,
210153 running_mean ,
211154 running_var ,
212- tmp_mean ,
213- tmp_var ,
214155 weight ,
215156 bias ,
216- momentum ,
217157 eps ,
218158 output ,
219159 num_normalized_elements ,
0 commit comments