Skip to content

Commit 00a5d08

Browse files
Lower fp8 lora memory usage.
1 parent d043997 commit 00a5d08

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

comfy/float.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
4141
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
4242
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
4343
)
44-
del abs_x
4544

46-
return sign.to(dtype=dtype)
45+
return sign
4746

4847

4948

@@ -57,6 +56,11 @@ def stochastic_rounding(value, dtype, seed=0):
5756
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
5857
generator = torch.Generator(device=value.device)
5958
generator.manual_seed(seed)
60-
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
59+
output = torch.empty_like(value, dtype=dtype)
60+
num_slices = max(1, (value.numel() / (4096 * 4096)))
61+
slice_size = max(1, round(value.shape[0] / num_slices))
62+
for i in range(0, value.shape[0], slice_size):
63+
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
64+
return output
6165

6266
return value.to(dtype=dtype)

0 commit comments

Comments
 (0)