|
1 | 1 | import torch |
| 2 | +import math |
| 3 | + |
| 4 | +def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): |
| 5 | + mantissa_scaled = torch.where( |
| 6 | + normal_mask, |
| 7 | + (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), |
| 8 | + (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) |
| 9 | + ) |
| 10 | + |
| 11 | + mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator) |
| 12 | + return mantissa_scaled.floor() / (2**MANTISSA_BITS) |
2 | 13 |
|
3 | 14 | #Not 100% sure about this |
4 | | -def manual_stochastic_round_to_float8(x, dtype): |
| 15 | +def manual_stochastic_round_to_float8(x, dtype, generator=None): |
5 | 16 | if dtype == torch.float8_e4m3fn: |
6 | 17 | EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7 |
7 | 18 | elif dtype == torch.float8_e5m2: |
8 | 19 | EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15 |
9 | 20 | else: |
10 | 21 | raise ValueError("Unsupported dtype") |
11 | 22 |
|
| 23 | + x = x.half() |
12 | 24 | sign = torch.sign(x) |
13 | 25 | abs_x = x.abs() |
| 26 | + sign = torch.where(abs_x == 0, 0, sign) |
14 | 27 |
|
15 | 28 | # Combine exponent calculation and clamping |
16 | 29 | exponent = torch.clamp( |
17 | | - torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS, |
| 30 | + torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS, |
18 | 31 | 0, 2**EXPONENT_BITS - 1 |
19 | 32 | ) |
20 | 33 |
|
21 | 34 | # Combine mantissa calculation and rounding |
22 | | - # min_normal = 2.0 ** (-EXPONENT_BIAS + 1) |
23 | | - # zero_mask = (abs_x == 0) |
24 | | - # subnormal_mask = (exponent == 0) & (abs_x != 0) |
25 | 35 | normal_mask = ~(exponent == 0) |
26 | 36 |
|
27 | | - mantissa_scaled = torch.where( |
28 | | - normal_mask, |
29 | | - (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), |
30 | | - (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) |
31 | | - ) |
32 | | - mantissa_floor = mantissa_scaled.floor() |
33 | | - mantissa = torch.where( |
34 | | - torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor), |
35 | | - (mantissa_floor + 1) / (2**MANTISSA_BITS), |
36 | | - mantissa_floor / (2**MANTISSA_BITS) |
37 | | - ) |
38 | | - result = torch.where( |
| 37 | + abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator) |
| 38 | + |
| 39 | + sign *= torch.where( |
39 | 40 | normal_mask, |
40 | | - sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa), |
41 | | - sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa |
| 41 | + (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x), |
| 42 | + (2.0 ** (-EXPONENT_BIAS + 1)) * abs_x |
42 | 43 | ) |
| 44 | + del abs_x |
43 | 45 |
|
44 | | - result = torch.where(abs_x == 0, 0, result) |
45 | | - return result.to(dtype=dtype) |
| 46 | + return sign.to(dtype=dtype) |
46 | 47 |
|
47 | 48 |
|
48 | 49 |
|
49 | | -def stochastic_rounding(value, dtype): |
| 50 | +def stochastic_rounding(value, dtype, seed=0): |
50 | 51 | if dtype == torch.float32: |
51 | 52 | return value.to(dtype=torch.float32) |
52 | 53 | if dtype == torch.float16: |
53 | 54 | return value.to(dtype=torch.float16) |
54 | 55 | if dtype == torch.bfloat16: |
55 | 56 | return value.to(dtype=torch.bfloat16) |
56 | 57 | if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: |
57 | | - return manual_stochastic_round_to_float8(value, dtype) |
| 58 | + generator = torch.Generator(device=value.device) |
| 59 | + generator.manual_seed(seed) |
| 60 | + return manual_stochastic_round_to_float8(value, dtype, generator=generator) |
58 | 61 |
|
59 | 62 | return value.to(dtype=dtype) |
0 commit comments