Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bbba746
import Flux from https://github.com/black-forest-labs/flux/
wkpark Aug 31, 2024
8e847ee
fix for A1111 webui
wkpark Aug 31, 2024
c0d4ab0
support Flux1
wkpark Aug 31, 2024
7932335
fix for flux
wkpark Aug 31, 2024
4eee381
add cheap approximation for flux
wkpark Aug 31, 2024
67bf106
fix for float8_*
wkpark Aug 31, 2024
47a601c
fix for t5xxl
wkpark Sep 4, 2024
495212d
fix misc
wkpark Sep 5, 2024
4dc5c90
check Unet/VAE and load as is
wkpark Sep 5, 2024
90ff052
patch reset_parameters
wkpark Sep 5, 2024
233e05f
preserve detected dtype_inference
wkpark Sep 5, 2024
0124ec3
add diffusers weight mapping for flux lora
wkpark Sep 5, 2024
a5e057d
fix for Lora flux
wkpark Sep 6, 2024
45a3dae
misc fixes to support float8 dtype_unet
wkpark Sep 6, 2024
673e665
add shared.opts.lora_without_backup_weight option to reduce ram usage
wkpark Sep 7, 2024
0e627de
support copy option to reduce ram usage
wkpark Sep 7, 2024
c64dd9a
optimize
wkpark Sep 8, 2024
c2c44c6
reduce memort usage
wkpark Sep 8, 2024
9a50880
vae fix for flux
wkpark Sep 8, 2024
c5d84c4
check vae/ text_encoders dtype and use as intended
wkpark Sep 8, 2024
f744462
minor update
wkpark Sep 10, 2024
88135af
reduce intermediate steps and optimize
wkpark Sep 10, 2024
aaacdbc
fix to support dtype_inference != dtype case
wkpark Sep 11, 2024
5075552
use empty_like() and partial revert for speed
wkpark Sep 13, 2024
a6c55b2
fix for pytest
wkpark Sep 13, 2024
5cb6200
pytest with --precision full --no-half
wkpark Sep 13, 2024
4499384
minor fixes
wkpark Sep 13, 2024
16e8590
fix lora without backup
wkpark Sep 15, 2024
517c395
revert to use without_autocast()
wkpark Sep 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
--test-server
--do-not-download-clip
--no-half
--precision full
--disable-opt-split-attention
--use-cpu all
--api-server-stop
Expand Down
4 changes: 4 additions & 0 deletions configs/flux1-inference.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model:
target: modules.models.flux.FLUX1Inferencer
params:
state_dict: null
3 changes: 2 additions & 1 deletion extensions-builtin/Lora/network_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import lyco_helpers
import modules.models.sd3.mmdit
import modules.models.flux.modules.layers
import network
from modules import devices

Expand Down Expand Up @@ -37,7 +38,7 @@ def create_module(self, weights, key, none_ok=False):
if weight is None and none_ok:
return None

is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear ]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]

if is_linear:
Expand Down
40 changes: 36 additions & 4 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_x_proj = re.compile(r"(.*)_((?:[qkv]|mlp)_proj)$")
re_compiled = {}

suffix_conversion = {
Expand Down Expand Up @@ -377,6 +377,8 @@ def store_weights_backup(weight):
if weight is None:
return None

if shared.opts.lora_without_backup_weight:
return True
return weight.to(devices.cpu, copy=True)


Expand All @@ -395,6 +397,9 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
if weights_backup is None and bias_backup is None:
return

if shared.opts.lora_without_backup_weight:
return

if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
Expand Down Expand Up @@ -455,7 +460,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn

for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)):
try:
with torch.no_grad():
if getattr(self, 'fp16_weight', None) is None:
Expand Down Expand Up @@ -515,7 +520,9 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn

continue

if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
module_mlp = net.modules.get(network_layer_name + "_mlp_proj", None)

if any(isinstance(self, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)) and module_q and module_k and module_v and module_mlp is None:
try:
with torch.no_grad():
# Send "real" orig_weight into MHA's lora module
Expand All @@ -526,6 +533,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
del qw, kw, vw
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.weight += updown_qkv
del updown_qkv

except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1

continue

if any(isinstance(self, linear) for linear in (modules.models.flux.modules.layers.QkvLinear,)) and module_q and module_k and module_v and module_mlp:
try:
with torch.no_grad():
qw, kw, vw, mlp = torch.tensor_split(self.weight, (3072, 6144, 9216,), 0)
updown_q, _ = module_q.calc_updown(qw)
updown_k, _ = module_k.calc_updown(kw)
updown_v, _ = module_v.calc_updown(vw)
updown_mlp, _ = module_v.calc_updown(mlp)
del qw, kw, vw, mlp
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
self.weight += updown_qkv_mlp
del updown_qkv_mlp

except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
Expand All @@ -539,7 +566,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1

self.network_current_names = wanted_names

if shared.opts.lora_without_backup_weight:
self.network_weights_backup = None
self.network_bias_backup = None
else:
self.network_current_names = wanted_names


def network_forward(org_module, input, original_forward):
Expand Down
102 changes: 80 additions & 22 deletions modules/devices.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import contextlib
from copy import deepcopy
from functools import lru_cache

import torch
Expand Down Expand Up @@ -128,6 +129,26 @@ def enable_tf32():
dtype_inference: torch.dtype = torch.float16
unet_needs_upcast = False

supported_vae_dtypes = [torch.float16, torch.float32]


# prepare available dtypes
if torch.version.cuda:
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes
if has_xpu():
supported_vae_dtypes = [torch.bfloat16] + supported_vae_dtypes


def supports_non_blocking():
if has_mps() or has_xpu():
return False

if npu_specific.has_npu:
return False

return True


def cond_cast_unet(input):
if force_fp16:
Expand All @@ -149,53 +170,79 @@ def cond_cast_float(input):
]


def manual_cast_forward(target_dtype):
def manual_cast_forward(target_dtype, target_device=None, copy=False):
params = dict()
if supports_non_blocking():
params['non_blocking'] = True

def forward_wrapper(self, *args, **kwargs):
if any(
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
for arg in args
):
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
if target_device is not None:
params['device'] = target_device
params['dtype'] = target_dtype

args = list(args)
for j in (i for i, arg in enumerate(args) if isinstance(arg, torch.Tensor) and arg.dtype != target_dtype):
args[j] = args[j].to(**params)
args = tuple(args)

for key in (k for k, v in kwargs.items() if isinstance(v, torch.Tensor) and v.dtype != target_dtype):
kwargs[key] = kwargs[key].to(**params)

org_dtype = target_dtype
for param in self.parameters():
if param.dtype != target_dtype:
org_dtype = param.dtype
break
else:
break

if copy:
copied = deepcopy(self)
if org_dtype != target_dtype:
copied.to(**params)

result = copied.org_forward(*args, **kwargs)
del copied
else:
if org_dtype != target_dtype:
self.to(**params)

if org_dtype != target_dtype:
self.to(target_dtype)
result = self.org_forward(*args, **kwargs)
if org_dtype != target_dtype:
self.to(org_dtype)
result = self.org_forward(*args, **kwargs)

if org_dtype != target_dtype:
params['dtype'] = org_dtype
self.to(**params)

if target_dtype != dtype_inference:
params['dtype'] = dtype_inference
if isinstance(result, tuple):
result = tuple(
i.to(dtype_inference)
i.to(**params)
if isinstance(i, torch.Tensor)
else i
for i in result
)
elif isinstance(result, torch.Tensor):
result = result.to(dtype_inference)
result = result.to(**params)
return result
return forward_wrapper


@contextlib.contextmanager
def manual_cast(target_dtype):
def manual_cast(target_dtype, target_device=None, copy=None):
applied = False

copy = copy if copy is not None else shared.opts.lora_without_backup_weight

for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
continue
applied = True
org_forward = module_type.forward
if module_type == torch.nn.MultiheadAttention:
module_type.forward = manual_cast_forward(torch.float32)
module_type.forward = manual_cast_forward(torch.float32, target_device, copy)
else:
module_type.forward = manual_cast_forward(target_dtype)
module_type.forward = manual_cast_forward(target_dtype, target_device, copy)
module_type.org_forward = org_forward
try:
yield None
Expand All @@ -207,26 +254,37 @@ def manual_cast(target_dtype):
delattr(module_type, "org_forward")


def autocast(disable=False):
def autocast(disable=False, current_dtype=None, target_dtype=None, target_device=None):
if disable:
return contextlib.nullcontext()

if target_dtype is None:
target_dtype = dtype
if target_device is None:
target_device = device

if force_fp16:
# No casting during inference if force_fp16 is enabled.
# All tensor dtype conversion happens before inference.
return contextlib.nullcontext()

if fp8 and device==cpu:
if fp8 and target_device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)

if fp8 and dtype_inference == torch.float32:
return manual_cast(dtype)
return manual_cast(target_dtype, target_device)

if target_dtype != dtype_inference:
return manual_cast(target_dtype, target_device)

if current_dtype is not None and current_dtype != target_dtype:
return manual_cast(target_dtype, target_device)

if dtype == torch.float32 or dtype_inference == torch.float32:
if target_dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext()

if has_xpu() or has_mps() or cuda_no_autocast():
return manual_cast(dtype)
return manual_cast(target_dtype, target_device)

return torch.autocast("cuda")

Expand Down
5 changes: 5 additions & 0 deletions modules/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .flux import FLUX1Inferencer

__all__ = [
"FLUX1Inferencer",
]
Loading