Skip to content

Commit 30e37f5

Browse files
committed
Refactor instance_norm operator to reduce in PyTorch instead of using tl.atomic_add in Triton
1 parent d4ee060 commit 30e37f5

3 files changed

Lines changed: 82 additions & 151 deletions

File tree

src/ntops/kernels/instance_norm.py

Lines changed: 49 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -9,114 +9,77 @@
99

1010
def arrangement(
1111
input,
12+
mean,
13+
var,
1214
running_mean,
1315
running_var,
14-
tmp_mean,
15-
tmp_var,
1616
weight,
1717
bias,
18-
momentum,
1918
eps,
2019
output,
2120
num_normalized_elements,
2221
use_input_stats,
23-
tracking_running_stats,
2422
dims,
2523
block_size=None,
2624
):
27-
def _arrange_per_channel_tensor(tensor):
25+
if block_size is None:
26+
block_size = ninetoothed.block_size()
27+
28+
def _arrange_channel_tensor(tensor):
2829
arranged = tensor.tile((1,))
2930
arranged.dtype = arranged.dtype.squeeze(0)
3031
arranged = arranged.unsqueeze(0)
3132
arranged = arranged.expand((input.shape[0], -1))
3233

3334
return arranged
3435

36+
def _arrange_mean_or_var(tensor):
37+
arranged = tensor.tile((1, 1))
38+
arranged.dtype = arranged.dtype.squeeze((0, 1))
39+
40+
return arranged
41+
3542
input_arranged, output_arranged = reduction_arrangement(
3643
input, output, dim=dims, block_size=block_size
3744
)
38-
running_mean_arranged = _arrange_per_channel_tensor(running_mean)
39-
running_var_arranged = _arrange_per_channel_tensor(running_var)
40-
tmp_mean_arranged = _arrange_per_channel_tensor(tmp_mean)
41-
tmp_var_arranged = _arrange_per_channel_tensor(tmp_var)
42-
weight_arranged = _arrange_per_channel_tensor(weight)
43-
bias_arranged = _arrange_per_channel_tensor(bias)
44-
momentum_arranged = momentum
45+
mean_arranged = _arrange_mean_or_var(mean)
46+
var_arranged = _arrange_mean_or_var(var)
47+
running_mean_arranged = _arrange_channel_tensor(running_mean)
48+
running_var_arranged = _arrange_channel_tensor(running_var)
49+
weight_arranged = _arrange_channel_tensor(weight)
50+
bias_arranged = _arrange_channel_tensor(bias)
4551
eps_arranged = eps
4652
num_normalized_elements_arranged = num_normalized_elements
4753

4854
if use_input_stats:
49-
if tracking_running_stats:
50-
return (
51-
input_arranged,
52-
running_mean_arranged,
53-
running_var_arranged,
54-
tmp_mean_arranged,
55-
tmp_var_arranged,
56-
weight_arranged,
57-
bias_arranged,
58-
momentum_arranged,
59-
eps_arranged,
60-
output_arranged,
61-
num_normalized_elements_arranged,
62-
)
63-
else:
64-
return (
65-
input_arranged,
66-
weight_arranged,
67-
bias_arranged,
68-
eps_arranged,
69-
output_arranged,
70-
num_normalized_elements_arranged,
71-
)
72-
73-
return (
74-
input_arranged,
75-
running_mean_arranged,
76-
running_var_arranged,
77-
weight_arranged,
78-
bias_arranged,
79-
eps_arranged,
80-
output_arranged,
81-
)
82-
83-
84-
def application_without_tracking(
85-
input,
86-
weight,
87-
bias,
88-
eps,
89-
output,
90-
num_normalized_elements,
91-
):
92-
_mean = ntl.zeros(input.dtype.shape, dtype=ntl.float32)
93-
94-
for i in range(input.shape[0]):
95-
_mean += ntl.cast(input[i], ntl.float32)
96-
97-
mean = ntl.sum(_mean, 0) / num_normalized_elements
98-
99-
_var = ntl.zeros(input.dtype.shape, dtype=ntl.float32)
100-
101-
for i in range(input.shape[0]):
102-
diff = ntl.cast(input[i], ntl.float32) - mean
103-
diff = ntl.where(input[i].offsets(-1) < input.source.shape[-1], diff, 0)
104-
_var += diff * diff
105-
106-
var = ntl.sum(_var, 0) / num_normalized_elements
107-
108-
application_with_mean_var(input, mean, var, weight, bias, eps, output)
109-
110-
111-
def application_with_tracking(
55+
return (
56+
input_arranged,
57+
mean_arranged,
58+
var_arranged,
59+
weight_arranged,
60+
bias_arranged,
61+
eps_arranged,
62+
output_arranged,
63+
num_normalized_elements_arranged,
64+
)
65+
else:
66+
return (
67+
input_arranged,
68+
running_mean_arranged,
69+
running_var_arranged,
70+
weight_arranged,
71+
bias_arranged,
72+
eps_arranged,
73+
output_arranged,
74+
)
75+
76+
77+
def application_using_input_stats(
11278
input,
113-
running_mean,
114-
running_var,
115-
tmp_mean,
116-
tmp_var,
79+
mean,
80+
var,
11781
weight,
11882
bias,
119-
momentum,
12083
eps,
12184
output,
12285
num_normalized_elements,
@@ -137,22 +100,6 @@ def application_with_tracking(
137100

138101
var = ntl.sum(_var, 0) / num_normalized_elements
139102

140-
ntl.atomic_add(
141-
tmp_mean.source.data_ptr() + tmp_mean.offsets(0), ntl.cast(mean, ntl.float32)
142-
)
143-
ntl.atomic_add(
144-
tmp_var.source.data_ptr() + tmp_mean.offsets(0), ntl.cast(var, ntl.float32)
145-
)
146-
147-
ntl.debug_barrier()
148-
149-
if input[0].offsets(0) == 0:
150-
tmp_mean = tmp_mean / input.source.shape[0]
151-
tmp_var = tmp_var / input.source.shape[0]
152-
153-
running_mean = running_mean * (1 - momentum) + tmp_mean * momentum
154-
running_var = running_var * (1 - momentum) + tmp_var * momentum
155-
156103
application_with_mean_var(input, mean, var, weight, bias, eps, output)
157104

158105

@@ -174,7 +121,6 @@ def application_with_mean_var(
174121
def premake(
175122
ndim,
176123
use_input_stats,
177-
tracking_running_stats,
178124
num_normalized_elements,
179125
dtype=None,
180126
block_size=None,
@@ -184,36 +130,30 @@ def premake(
184130
arrangement_ = functools.partial(
185131
arrangement,
186132
use_input_stats=use_input_stats,
187-
tracking_running_stats=tracking_running_stats,
188133
dims=dims,
189134
block_size=block_size,
190135
)
191136

192137
input = Tensor(ndim, other=0, dtype=dtype)
193-
running_mean, running_var, tmp_mean, tmp_var, weight, bias = (
194-
Tensor(1, dtype=dtype) for _ in range(6)
195-
)
196-
momentum, eps = (Tensor(0, dtype=ninetoothed.float64) for _ in range(2))
138+
mean, var = (Tensor(2, dtype=dtype) for _ in range(2))
139+
running_mean, running_var, weight, bias = (Tensor(1, dtype=dtype) for _ in range(4))
140+
eps = Tensor(0, dtype=ninetoothed.float64)
197141
output = Tensor(ndim, dtype=dtype)
198142
num_normalized_elements = Tensor(0, constexpr=True, value=num_normalized_elements)
199143

200144
if use_input_stats:
201-
if tracking_running_stats:
202-
application = application_with_tracking
203-
else:
204-
application = application_without_tracking
145+
application = application_using_input_stats
205146
else:
206147
application = application_with_mean_var
207148

208149
tensors = (
209150
input,
151+
mean,
152+
var,
210153
running_mean,
211154
running_var,
212-
tmp_mean,
213-
tmp_var,
214155
weight,
215156
bias,
216-
momentum,
217157
eps,
218158
output,
219159
num_normalized_elements,

src/ntops/torch/instance_norm.py

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,11 @@ def instance_norm(
2222
if bias is None:
2323
bias = torch.zeros(input.shape[1], device=input.device, dtype=input.dtype)
2424

25-
tracking_running_stats = False
25+
has_running_stats = running_mean is not None and running_var is not None
2626

27-
if not use_input_stats:
28-
assert running_mean is not None and running_var is not None, (
29-
"`running_mean` and `running_var` must be provided when `use_input_stats=False`."
30-
)
31-
assert running_mean.shape == (input.shape[1],) and running_var.shape == (
32-
input.shape[1],
33-
), "`running_mean` and `running_var` must have shape (C,)"
34-
else:
35-
if running_mean is not None and running_var is not None:
36-
assert running_mean.shape == (input.shape[1],) and running_var.shape == (
37-
input.shape[1],
38-
), "`running_mean` and `running_var` must have shape (C,)"
39-
tracking_running_stats = True
40-
tmp_mean = torch.zeros_like(running_mean)
41-
tmp_var = torch.zeros_like(running_var)
27+
if use_input_stats:
28+
mean = torch.empty(input.shape[:2], device=input.device, dtype=input.dtype)
29+
var = torch.empty(input.shape[:2], device=input.device, dtype=input.dtype)
4230

4331
output = torch.empty_like(input)
4432

@@ -47,35 +35,37 @@ def instance_norm(
4735
ntops.kernels.instance_norm.premake,
4836
input.ndim,
4937
use_input_stats,
50-
tracking_running_stats,
5138
num_normalized_elements,
52-
block_size=32,
39+
dtype=input.dtype,
5340
)
5441

5542
if use_input_stats:
56-
if tracking_running_stats:
57-
kernel(
58-
input,
59-
running_mean,
60-
running_var,
61-
tmp_mean,
62-
tmp_var,
63-
weight,
64-
bias,
65-
momentum,
66-
eps,
67-
output,
68-
num_normalized_elements,
69-
)
70-
else:
71-
kernel(
72-
input,
73-
weight,
74-
bias,
75-
eps,
76-
output,
77-
num_normalized_elements,
43+
kernel(
44+
input,
45+
mean,
46+
var,
47+
weight,
48+
bias,
49+
eps,
50+
output,
51+
num_normalized_elements,
52+
)
53+
54+
# We reduce in PyTorch instead of using tl.atomic_add in Triton because:
55+
# 1. Triton blocks cannot synchronize to safely apply the momentum update after all additions finish.
56+
# 2. N blocks atomically adding to the same C addresses creates severe memory contention.
57+
if use_input_stats and has_running_stats:
58+
batch_mean = mean.mean(0)
59+
avg_vars = var.mean(0)
60+
61+
unbiased_var = (
62+
(avg_vars) * num_normalized_elements / (num_normalized_elements - 1)
63+
if num_normalized_elements > 1
64+
else avg_vars
7865
)
66+
67+
running_mean.mul_(1 - momentum).add_(momentum * batch_mean)
68+
running_var.mul_(1 - momentum).add_(momentum * unbiased_var)
7969
else:
8070
kernel(input, running_mean, running_var, weight, bias, eps, output)
8171

tests/test_instance_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,6 @@ def test_instance_norm(
8181
assert torch.allclose(
8282
ninetoothed_running_mean, reference_running_mean, rtol=rtol, atol=atol
8383
)
84-
# TODO: The running var is not close.
85-
# assert torch.allclose(ninetoothed_running_var, reference_running_var, rtol=rtol, atol=atol)
84+
assert torch.allclose(
85+
ninetoothed_running_var, reference_running_var, rtol=rtol, atol=atol
86+
)

0 commit comments

Comments
 (0)