Skip to content

Commit 49a4357

Browse files
authored
Merge branch 'comfyanonymous:master' into master
2 parents 656759a + 38c22e6 commit 49a4357

8 files changed

Lines changed: 151 additions & 49 deletions

File tree

app/frontend_management.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dataclasses import dataclass
99
from functools import cached_property
1010
from pathlib import Path
11-
from typing import TypedDict
11+
from typing import TypedDict, Optional
1212

1313
import requests
1414
from typing_extensions import NotRequired
@@ -132,12 +132,13 @@ def parse_version_string(cls, value: str) -> tuple[str, str, str]:
132132
return match_result.group(1), match_result.group(2), match_result.group(3)
133133

134134
@classmethod
135-
def init_frontend_unsafe(cls, version_string: str) -> str:
135+
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
136136
"""
137137
Initializes the frontend for the specified version.
138138
139139
Args:
140140
version_string (str): The version string.
141+
provider (FrontEndProvider, optional): The provider to use. Defaults to None.
141142
142143
Returns:
143144
str: The path to the initialized frontend.
@@ -150,23 +151,29 @@ def init_frontend_unsafe(cls, version_string: str) -> str:
150151
return cls.DEFAULT_FRONTEND_PATH
151152

152153
repo_owner, repo_name, version = cls.parse_version_string(version_string)
153-
provider = FrontEndProvider(repo_owner, repo_name)
154+
provider = provider or FrontEndProvider(repo_owner, repo_name)
154155
release = provider.get_release(version)
155156

156157
semantic_version = release["tag_name"].lstrip("v")
157158
web_root = str(
158159
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
159160
)
160161
if not os.path.exists(web_root):
161-
os.makedirs(web_root, exist_ok=True)
162-
logging.info(
163-
"Downloading frontend(%s) version(%s) to (%s)",
164-
provider.folder_name,
165-
semantic_version,
166-
web_root,
167-
)
168-
logging.debug(release)
169-
download_release_asset_zip(release, destination_path=web_root)
162+
try:
163+
os.makedirs(web_root, exist_ok=True)
164+
logging.info(
165+
"Downloading frontend(%s) version(%s) to (%s)",
166+
provider.folder_name,
167+
semantic_version,
168+
web_root,
169+
)
170+
logging.debug(release)
171+
download_release_asset_zip(release, destination_path=web_root)
172+
finally:
173+
# Clean up the directory if it is empty, i.e. the download failed
174+
if not os.listdir(web_root):
175+
os.rmdir(web_root)
176+
170177
return web_root
171178

172179
@classmethod

comfy/float.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,62 @@
11
import torch
2+
import math
3+
4+
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
5+
mantissa_scaled = torch.where(
6+
normal_mask,
7+
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
8+
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
9+
)
10+
11+
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
12+
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
213

314
#Not 100% sure about this
4-
def manual_stochastic_round_to_float8(x, dtype):
15+
def manual_stochastic_round_to_float8(x, dtype, generator=None):
516
if dtype == torch.float8_e4m3fn:
617
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
718
elif dtype == torch.float8_e5m2:
819
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
920
else:
1021
raise ValueError("Unsupported dtype")
1122

23+
x = x.half()
1224
sign = torch.sign(x)
1325
abs_x = x.abs()
26+
sign = torch.where(abs_x == 0, 0, sign)
1427

1528
# Combine exponent calculation and clamping
1629
exponent = torch.clamp(
17-
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
30+
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
1831
0, 2**EXPONENT_BITS - 1
1932
)
2033

2134
# Combine mantissa calculation and rounding
22-
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
23-
# zero_mask = (abs_x == 0)
24-
# subnormal_mask = (exponent == 0) & (abs_x != 0)
2535
normal_mask = ~(exponent == 0)
2636

27-
mantissa_scaled = torch.where(
28-
normal_mask,
29-
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
30-
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
31-
)
32-
mantissa_floor = mantissa_scaled.floor()
33-
mantissa = torch.where(
34-
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
35-
(mantissa_floor + 1) / (2**MANTISSA_BITS),
36-
mantissa_floor / (2**MANTISSA_BITS)
37-
)
38-
result = torch.where(
37+
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
38+
39+
sign *= torch.where(
3940
normal_mask,
40-
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
41-
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
41+
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
42+
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
4243
)
44+
del abs_x
4345

44-
result = torch.where(abs_x == 0, 0, result)
45-
return result.to(dtype=dtype)
46+
return sign.to(dtype=dtype)
4647

4748

4849

49-
def stochastic_rounding(value, dtype):
50+
def stochastic_rounding(value, dtype, seed=0):
5051
if dtype == torch.float32:
5152
return value.to(dtype=torch.float32)
5253
if dtype == torch.float16:
5354
return value.to(dtype=torch.float16)
5455
if dtype == torch.bfloat16:
5556
return value.to(dtype=torch.bfloat16)
5657
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
57-
return manual_stochastic_round_to_float8(value, dtype)
58+
generator = torch.Generator(device=value.device)
59+
generator.manual_seed(seed)
60+
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
5861

5962
return value.to(dtype=dtype)

comfy/ldm/flux/layers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,8 @@ def __init__(self, dim: int, dtype=None, device=None, operations=None):
6363
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
6464

6565
def forward(self, x: Tensor):
66-
x_dtype = x.dtype
67-
x = x.float()
6866
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
69-
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
67+
return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device)
7068

7169

7270
class QKNorm(torch.nn.Module):

comfy/lora.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
along with this program. If not, see <https://www.gnu.org/licenses/>.
1717
"""
1818

19+
from __future__ import annotations
1920
import comfy.utils
2021
import comfy.model_management
2122
import comfy.model_base
@@ -347,6 +348,39 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
347348
weight[:] = weight_calc
348349
return weight
349350

351+
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
352+
"""
353+
Pad a tensor to a new shape with zeros.
354+
355+
Args:
356+
tensor (torch.Tensor): The original tensor to be padded.
357+
new_shape (List[int]): The desired shape of the padded tensor.
358+
359+
Returns:
360+
torch.Tensor: A new tensor padded with zeros to the specified shape.
361+
362+
Note:
363+
If the new shape is smaller than the original tensor in any dimension,
364+
the original tensor will be truncated in that dimension.
365+
"""
366+
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
367+
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
368+
369+
if len(new_shape) != len(tensor.shape):
370+
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
371+
372+
# Create a new tensor filled with zeros
373+
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
374+
375+
# Create slicing tuples for both tensors
376+
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
377+
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
378+
379+
# Copy the original tensor into the new tensor
380+
padded_tensor[new_slices] = tensor[orig_slices]
381+
382+
return padded_tensor
383+
350384
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
351385
for p in patches:
352386
strength = p[0]
@@ -375,12 +409,18 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
375409
v = v[1]
376410

377411
if patch_type == "diff":
378-
w1 = v[0]
412+
diff: torch.Tensor = v[0]
413+
# An extra flag to pad the weight if the diff's shape is larger than the weight
414+
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
415+
if do_pad_weight and diff.shape != weight.shape:
416+
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
417+
weight = pad_tensor_to_shape(weight, diff.shape)
418+
379419
if strength != 0.0:
380-
if w1.shape != weight.shape:
381-
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
420+
if diff.shape != weight.shape:
421+
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
382422
else:
383-
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
423+
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
384424
elif patch_type == "lora": #lora/locon
385425
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
386426
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)

comfy/model_management.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
405405
if not force_unload:
406406
if unload_weights_only and unload_weight == False:
407407
return None
408+
else:
409+
unload_weight = True
408410

409411
for i in to_unload:
410412
logging.debug("unload clone {} {}".format(i, unload_weight))

comfy/model_patcher.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@
3030
import comfy.lora
3131
from comfy.types import UnetWrapperFunction
3232

33+
def string_to_seed(data):
34+
crc = 0xFFFFFFFF
35+
for byte in data:
36+
if isinstance(byte, str):
37+
byte = ord(byte)
38+
crc ^= byte
39+
for _ in range(8):
40+
if crc & 1:
41+
crc = (crc >> 1) ^ 0xEDB88320
42+
else:
43+
crc >>= 1
44+
return crc ^ 0xFFFFFFFF
3345

3446
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
3547
to = model_options["transformer_options"].copy()
@@ -309,7 +321,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
309321
else:
310322
temp_weight = weight.to(torch.float32, copy=True)
311323
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
312-
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype)
324+
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
313325
if inplace_update:
314326
comfy.utils.copy_to_param(self.model, key, out_weight)
315327
else:
@@ -319,12 +331,21 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
319331
mem_counter = 0
320332
patch_counter = 0
321333
lowvram_counter = 0
322-
load_completely = []
334+
loading = []
323335
for n, m in self.model.named_modules():
336+
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
337+
loading.append((comfy.model_management.module_size(m), n, m))
338+
339+
load_completely = []
340+
loading.sort(reverse=True)
341+
for x in loading:
342+
n = x[1]
343+
m = x[2]
344+
module_mem = x[0]
345+
324346
lowvram_weight = False
325347

326348
if not full_load and hasattr(m, "comfy_cast_weights"):
327-
module_mem = comfy.model_management.module_size(m)
328349
if mem_counter + module_mem >= lowvram_model_memory:
329350
lowvram_weight = True
330351
lowvram_counter += 1
@@ -356,9 +377,8 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
356377
wipe_lowvram_weight(m)
357378

358379
if hasattr(m, "weight"):
359-
mem_used = comfy.model_management.module_size(m)
360-
mem_counter += mem_used
361-
load_completely.append((mem_used, n, m))
380+
mem_counter += module_mem
381+
load_completely.append((module_mem, n, m))
362382

363383
load_completely.sort(reverse=True)
364384
for x in load_completely:

server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,9 @@ async def post_history(request):
586586
@routes.post("/internal/models/download")
587587
async def download_handler(request):
588588
async def report_progress(filename: str, status: DownloadModelStatus):
589-
await self.send_json("download_progress", status.to_dict())
589+
payload = status.to_dict()
590+
payload['download_path'] = filename
591+
await self.send_json("download_progress", payload)
590592

591593
data = await request.json()
592594
url = data.get('url')

tests-unit/app_test/frontend_manager_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import pytest
33
from requests.exceptions import HTTPError
4+
from unittest.mock import patch
45

56
from app.frontend_management import (
67
FrontendManager,
@@ -83,6 +84,35 @@ def test_init_frontend_invalid_provider():
8384
with pytest.raises(HTTPError):
8485
FrontendManager.init_frontend_unsafe(version_string)
8586

87+
@pytest.fixture
88+
def mock_os_functions():
89+
with patch('app.frontend_management.os.makedirs') as mock_makedirs, \
90+
patch('app.frontend_management.os.listdir') as mock_listdir, \
91+
patch('app.frontend_management.os.rmdir') as mock_rmdir:
92+
mock_listdir.return_value = [] # Simulate empty directory
93+
yield mock_makedirs, mock_listdir, mock_rmdir
94+
95+
@pytest.fixture
96+
def mock_download():
97+
with patch('app.frontend_management.download_release_asset_zip') as mock:
98+
mock.side_effect = Exception("Download failed") # Simulate download failure
99+
yield mock
100+
101+
def test_finally_block(mock_os_functions, mock_download, mock_provider):
102+
# Arrange
103+
mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions
104+
version_string = 'test-owner/test-repo@1.0.0'
105+
106+
# Act & Assert
107+
with pytest.raises(Exception):
108+
FrontendManager.init_frontend_unsafe(version_string, mock_provider)
109+
110+
# Assert
111+
mock_makedirs.assert_called_once()
112+
mock_download.assert_called_once()
113+
mock_listdir.assert_called_once()
114+
mock_rmdir.assert_called_once()
115+
86116

87117
def test_parse_version_string():
88118
version_string = "owner/repo@1.0.0"

0 commit comments

Comments
 (0)