Skip to content

Commit d3f56b7

Browse files
committed
T1-1-19 Add alpha_dropout operator
1 parent fa33ecc commit d3f56b7

5 files changed

Lines changed: 132 additions & 0 deletions

File tree

src/ntops/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
abs,
33
add,
44
addmm,
5+
alpha_dropout,
56
avg_pool2d,
67
bitwise_and,
78
bitwise_not,
@@ -47,6 +48,7 @@
4748
"abs",
4849
"add",
4950
"addmm",
51+
"alpha_dropout",
5052
"avg_pool2d",
5153
"bitwise_and",
5254
"bitwise_not",

src/ntops/kernels/alpha_dropout.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import functools
2+
3+
import ninetoothed
4+
import ninetoothed.language as ntl
5+
from ninetoothed import Tensor
6+
7+
from ntops.kernels.element_wise import arrangement
8+
9+
10+
def application(input, a, b, sat, p, seed, output):
11+
keep = ntl.rand(seed, input.offsets()) > p
12+
output = ntl.where(keep, a * input + b, sat) # noqa: F841
13+
14+
15+
def premake(ndim, dtype=None, block_size=None):
16+
arrangement_ = functools.partial(arrangement, block_size=block_size)
17+
18+
tensors = (
19+
Tensor(ndim, dtype=dtype),
20+
Tensor(0, dtype=ninetoothed.float64),
21+
Tensor(0, dtype=ninetoothed.float64),
22+
Tensor(0, dtype=ninetoothed.float64),
23+
Tensor(0, dtype=ninetoothed.float64),
24+
Tensor(0, dtype=ninetoothed.int64),
25+
Tensor(ndim, dtype=dtype),
26+
)
27+
28+
return arrangement_, application, tensors

src/ntops/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ntops.torch.abs import abs
22
from ntops.torch.add import add
33
from ntops.torch.addmm import addmm
4+
from ntops.torch.alpha_dropout import alpha_dropout
45
from ntops.torch.avg_pool2d import avg_pool2d
56
from ntops.torch.bitwise_and import bitwise_and
67
from ntops.torch.bitwise_not import bitwise_not
@@ -46,6 +47,7 @@
4647
"abs",
4748
"add",
4849
"addmm",
50+
"alpha_dropout",
4951
"avg_pool2d",
5052
"bitwise_and",
5153
"bitwise_not",

src/ntops/torch/alpha_dropout.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import math
2+
import random
3+
4+
import torch
5+
6+
import ntops
7+
from ntops.torch.utils import _cached_make
8+
9+
# SELU saturation value: -lambda * alpha
10+
_ALPHA_P = -1.7580993408473766
11+
12+
13+
def alpha_dropout(input, p=0.5, training=False, inplace=False):
14+
if not training or p == 0:
15+
if inplace:
16+
return input
17+
else:
18+
return input.clone()
19+
20+
q = 1.0 - p
21+
a = 1.0 / math.sqrt(q * (1.0 + p * _ALPHA_P * _ALPHA_P))
22+
b = -a * p * _ALPHA_P
23+
sat = a * _ALPHA_P + b
24+
25+
seed = random.randrange(0, 2**31)
26+
27+
if inplace:
28+
output = input
29+
else:
30+
output = torch.empty_like(input)
31+
32+
kernel = _cached_make(ntops.kernels.alpha_dropout.premake, input.ndim)
33+
34+
kernel(input, a, b, sat, p, seed, output)
35+
36+
return output

tests/test_alpha_dropout.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import math
2+
import random
3+
4+
import pytest
5+
import torch
6+
import torch.nn.functional as F
7+
8+
import ntops
9+
from tests.skippers import skip_if_cuda_not_available
10+
from tests.utils import generate_arguments
11+
12+
_ALPHA_P = -1.7580993408473766
13+
14+
15+
@skip_if_cuda_not_available
16+
@pytest.mark.parametrize(*generate_arguments())
17+
def test_alpha_dropout(shape, dtype, device, rtol, atol):
18+
input = torch.randn(shape, dtype=dtype, device=device)
19+
p = random.uniform(0.1, 0.5)
20+
21+
ninetoothed_output = ntops.torch.alpha_dropout(input, p=p, training=True)
22+
reference_output = F.alpha_dropout(input, p=p, training=True)
23+
24+
# 1. Shape must match.
25+
assert ninetoothed_output.shape == reference_output.shape
26+
27+
# 2. Compute expected affine parameters.
28+
q = 1.0 - p
29+
a = 1.0 / math.sqrt(q * (1.0 + p * _ALPHA_P * _ALPHA_P))
30+
b = -a * p * _ALPHA_P
31+
sat = a * _ALPHA_P + b
32+
33+
# 3. Drop ratios should be close to each other.
34+
ninetoothed_drop_ratio = (
35+
torch.isclose(
36+
ninetoothed_output, torch.full_like(ninetoothed_output, sat), atol=atol
37+
)
38+
.float()
39+
.mean()
40+
.item()
41+
)
42+
reference_drop_ratio = (
43+
torch.isclose(
44+
reference_output, torch.full_like(reference_output, sat), atol=atol
45+
)
46+
.float()
47+
.mean()
48+
.item()
49+
)
50+
51+
assert abs(ninetoothed_drop_ratio - reference_drop_ratio) < 0.1
52+
53+
# 4. Kept elements should satisfy the same affine transform.
54+
kept_mask = ~torch.isclose(
55+
ninetoothed_output, torch.full_like(ninetoothed_output, sat), atol=atol
56+
)
57+
expected_kept = a * input[kept_mask].float() + b
58+
actual_kept = ninetoothed_output[kept_mask].float()
59+
60+
assert torch.allclose(actual_kept, expected_kept, rtol=rtol, atol=atol)
61+
62+
# 5. training=False should return input unchanged.
63+
output_eval = ntops.torch.alpha_dropout(input, p=p, training=False)
64+
assert torch.equal(output_eval, input)

0 commit comments

Comments
 (0)