Skip to content

Commit 5bf4610

Browse files
authored
Merge pull request #3 from rupeshs/sdxl-feature-dev
Sdxl feature dev
2 parents 25162f1 + 198e4e4 commit 5bf4610

18 files changed

Lines changed: 656 additions & 17 deletions

Readme.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ DiffusionMagic focused on the following areas:
55
- Cross-platform (Windows/Linux/Mac)
66
- Modular design, latest best optimizations for speed and memory
77

8+
## Stable diffusion XL Colab
9+
We can run StableDiffusion XL 0.9 on Google Colab
10+
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1KrmcU2gONIQ2WihI1s6uITgDDzkbKaJK?usp=sharing)
11+
812
![ DiffusionMagic](https://raw.githubusercontent.com/rupeshs/diffusionmagic/main/docs/images/diffusion_magic.PNG)
913
## Features
1014
- Supports various Stable Diffusion workflows
@@ -113,6 +117,7 @@ Or we can clone the model use the local folder path as model id.
113117
## Linting (Development)
114118
Run the following commands from src folder
115119
`mypy --ignore-missing-imports --explicit-package-bases .`
120+
116121
`flake8 --max-line-length=100 .`
117122
## Contribute
118123
Contributions are welcomed.

configs/stable_diffusion_models.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ lllyasviel/sd-controlnet-hed
1616
lllyasviel/sd-controlnet-openpose
1717
lllyasviel/sd-controlnet-depth
1818
lllyasviel/sd-controlnet-scribble
19-
lllyasviel/sd-controlnet-seg
19+
lllyasviel/sd-controlnet-seg
20+
stabilityai/stable-diffusion-xl-base-1.0

environment.yml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@ dependencies:
1212
- torchvision=0.15.0
1313
- numpy=1.19.2
1414
- pip:
15-
- accelerate==0.17.1
16-
- diffusers==0.14.0
17-
- gradio==3.17.1
18-
- safetensors==0.2.8
15+
- accelerate==0.21.0
16+
- diffusers==0.19.3
17+
- gradio==3.32.0
18+
- safetensors==0.3.1
1919
- scipy==1.10.0
20-
- transformers==4.26.0
20+
- transformers==4.31.0
2121
- pydantic==1.10.4
2222
- mypy==1.0.0
2323
- black==23.1.0
2424
- flake8==6.0.0
2525
- markupsafe==2.0.1
2626
- opencv-contrib-python==4.7.0.72
27-
- controlnet-aux==0.0.1
27+
- controlnet-aux==0.0.1
28+
- invisible-watermark==0.2.0

src/backend/generate.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from backend.controlnet.ControlContext import ControlnetContext
1717
from backend.stablediffusion.stablediffusion import StableDiffusion
18+
from backend.stablediffusion.stablediffusionxl import StableDiffusionXl
1819
from settings import AppSettings
1920

2021

@@ -30,6 +31,7 @@ def __init__(self, compute: Computing):
3031
self.stable_diffusion_depth = StableDiffusionDepthToImage(compute)
3132
self.stable_diffusion_pix_to_pix = StableDiffusionInstructPixToPix(compute)
3233
self.controlnet = ControlnetContext(compute)
34+
self.stable_diffusion_xl = StableDiffusionXl(compute)
3335
self.app_settings = AppSettings().get_settings()
3436
self.model_id = self.app_settings.model_settings.model_id
3537
self.low_vram_mode = self.app_settings.low_memory_mode
@@ -78,6 +80,15 @@ def _init_stable_diffusion(self):
7880
)
7981
self.pipe_initialized = True
8082

83+
def _init_stable_diffusion_xl(self):
84+
if not self.pipe_initialized:
85+
print("Initializing stable diffusion xl pipeline")
86+
self.stable_diffusion_xl.get_text_to_image_xl_pipleline(
87+
self.model_id,
88+
self.low_vram_mode,
89+
)
90+
self.pipe_initialized = True
91+
8192
def diffusion_image_to_image(
8293
self,
8394
image,
@@ -355,3 +366,78 @@ def diffusion_control_to_image(
355366
"CannyToImage",
356367
)
357368
return images
369+
370+
def diffusion_text_to_image_xl(
371+
self,
372+
prompt,
373+
neg_prompt,
374+
image_height,
375+
image_width,
376+
inference_steps,
377+
scheduler,
378+
guidance_scale,
379+
num_images,
380+
attention_slicing,
381+
vae_slicing,
382+
seed,
383+
) -> Any:
384+
stable_diffusion_settings = StableDiffusionSetting(
385+
prompt=prompt,
386+
negative_prompt=neg_prompt,
387+
image_height=image_height,
388+
image_width=image_width,
389+
inference_steps=inference_steps,
390+
guidance_scale=guidance_scale,
391+
number_of_images=num_images,
392+
scheduler=scheduler,
393+
seed=seed,
394+
attention_slicing=attention_slicing,
395+
vae_slicing=vae_slicing,
396+
)
397+
self._init_stable_diffusion_xl()
398+
images = self.stable_diffusion_xl.text_to_image_xl(stable_diffusion_settings)
399+
self._save_images(
400+
images,
401+
"TextToImage",
402+
)
403+
return images
404+
405+
def diffusion_image_to_image_xl(
406+
self,
407+
image,
408+
strength,
409+
prompt,
410+
neg_prompt,
411+
image_height,
412+
image_width,
413+
inference_steps,
414+
scheduler,
415+
guidance_scale,
416+
num_images,
417+
attention_slicing,
418+
seed,
419+
) -> Any:
420+
stable_diffusion_image_settings = StableDiffusionImageToImageSetting(
421+
image=image,
422+
strength=strength,
423+
prompt=prompt,
424+
negative_prompt=neg_prompt,
425+
image_height=image_height,
426+
image_width=image_width,
427+
inference_steps=inference_steps,
428+
guidance_scale=guidance_scale,
429+
number_of_images=num_images,
430+
scheduler=scheduler,
431+
seed=seed,
432+
attention_slicing=attention_slicing,
433+
)
434+
self._init_stable_diffusion_xl()
435+
images = self.stable_diffusion_xl.image_to_image(
436+
stable_diffusion_image_settings
437+
)
438+
439+
self._save_images(
440+
images,
441+
"ImageToImage",
442+
)
443+
return images

src/backend/stablediffusion/stable_diffusion_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class StableDiffusionType(str, Enum):
1616
controlnet_depth = "controlnet_depth"
1717
controlnet_scribble = "controlnet_scribble"
1818
controlnet_seg = "controlnet_seg"
19+
stable_diffusion_xl = "StableDiffusionXl"
1920

2021

2122
def get_diffusion_type(
@@ -44,4 +45,6 @@ def get_diffusion_type(
4445
stable_diffusion_type = StableDiffusionType.controlnet_scribble
4546
elif "controlnet-seg" in model_id:
4647
stable_diffusion_type = StableDiffusionType.controlnet_seg
48+
elif "stable-diffusion-xl" in model_id:
49+
stable_diffusion_type = StableDiffusionType.stable_diffusion_xl
4750
return stable_diffusion_type
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from time import time
2+
3+
import torch
4+
from diffusers import (
5+
DiffusionPipeline,
6+
StableDiffusionXLImg2ImgPipeline,
7+
)
8+
from PIL import Image
9+
10+
from backend.computing import Computing
11+
from backend.stablediffusion.modelmeta import ModelMeta
12+
from backend.stablediffusion.models.scheduler_types import SchedulerType
13+
from backend.stablediffusion.models.setting import (
14+
StableDiffusionImageToImageSetting,
15+
StableDiffusionSetting,
16+
)
17+
from backend.stablediffusion.scheduler_mixin import SamplerMixin
18+
19+
20+
class StableDiffusionXl(SamplerMixin):
21+
def __init__(self, compute: Computing):
22+
self.compute = compute
23+
self.pipeline = None
24+
self.device = self.compute.name
25+
26+
super().__init__()
27+
28+
def get_text_to_image_xl_pipleline(
29+
self,
30+
model_id: str = "stabilityai/stable-diffusion-xl-base-1.0",
31+
low_vram_mode: bool = False,
32+
sampler: str = SchedulerType.DPMSolverMultistepScheduler.value,
33+
):
34+
repo_id = model_id
35+
model_meta = ModelMeta(repo_id)
36+
is_lora_model = model_meta.is_loramodel()
37+
if is_lora_model:
38+
print("LoRA model detected")
39+
self.model_id = model_meta.get_lora_base_model()
40+
print(f"LoRA base model - {self.model_id}")
41+
else:
42+
self.model_id = model_id
43+
44+
self.low_vram_mode = low_vram_mode
45+
print(f"StableDiffusion - {self.compute.name},{self.compute.datatype}")
46+
print(f"using model {model_id}")
47+
self.default_sampler = self.find_sampler(
48+
sampler,
49+
self.model_id,
50+
)
51+
tic = time()
52+
self._load_model()
53+
delta = time() - tic
54+
print(f"Model loaded in {delta:.2f}s ")
55+
56+
if self.pipeline is None:
57+
raise Exception("Text to image pipeline not initialized")
58+
if is_lora_model:
59+
self.pipeline.unet.load_attn_procs(repo_id)
60+
self._pipeline_to_device()
61+
components = self.pipeline.components
62+
self.img_to_img_pipeline = StableDiffusionXLImg2ImgPipeline(**components)
63+
64+
def text_to_image_xl(self, setting: StableDiffusionSetting):
65+
if self.pipeline is None:
66+
raise Exception("Text to image pipeline not initialized")
67+
68+
self.pipeline.scheduler = self.find_sampler(
69+
setting.scheduler,
70+
self.model_id,
71+
)
72+
generator = None
73+
if setting.seed != -1:
74+
print(f"Using seed {setting.seed}")
75+
generator = torch.Generator(self.device).manual_seed(setting.seed)
76+
77+
# if setting.attention_slicing:
78+
# self.pipeline.enable_attention_slicing()
79+
# else:
80+
# self.pipeline.disable_attention_slicing()
81+
82+
if setting.vae_slicing:
83+
self.pipeline.enable_vae_slicing()
84+
else:
85+
self.pipeline.disable_vae_slicing()
86+
87+
images = self.pipeline(
88+
setting.prompt,
89+
guidance_scale=setting.guidance_scale,
90+
num_inference_steps=setting.inference_steps,
91+
height=setting.image_height,
92+
width=setting.image_width,
93+
negative_prompt=setting.negative_prompt,
94+
num_images_per_prompt=setting.number_of_images,
95+
generator=generator,
96+
).images
97+
98+
# self.pipeline.unet = torch.compile(
99+
# self.pipeline.unet,
100+
# mode="reduce-overhead",
101+
# fullgraph=True,
102+
# )
103+
return images
104+
105+
def _pipeline_to_device(self):
106+
if self.low_vram_mode:
107+
print("Running in low VRAM mode,slower to generate images")
108+
self.pipeline.enable_sequential_cpu_offload()
109+
else:
110+
if self.compute.name == "cuda":
111+
self.pipeline = self.pipeline.to("cuda")
112+
elif self.compute.name == "mps":
113+
self.pipeline = self.pipeline.to("mps")
114+
115+
def _load_full_precision_model(self):
116+
self.pipeline = DiffusionPipeline.from_pretrained(
117+
self.model_id,
118+
torch_dtype=self.compute.datatype,
119+
scheduler=self.default_sampler,
120+
)
121+
122+
def _load_model(self):
123+
if self.compute.name == "cuda":
124+
try:
125+
self.pipeline = DiffusionPipeline.from_pretrained(
126+
self.model_id,
127+
torch_dtype=self.compute.datatype,
128+
scheduler=self.default_sampler,
129+
use_safetensors=True,
130+
variant="fp16",
131+
)
132+
except Exception as ex:
133+
print(
134+
f" The fp16 of the model not found using full precision model, {ex}"
135+
)
136+
self._load_full_precision_model()
137+
else:
138+
self._load_full_precision_model()
139+
140+
def image_to_image(self, setting: StableDiffusionImageToImageSetting):
141+
if setting.scheduler is None:
142+
raise Exception("Scheduler cannot be empty")
143+
144+
print("Running image to image pipeline")
145+
self.img_to_img_pipeline.scheduler = self.find_sampler( # type: ignore
146+
setting.scheduler,
147+
self.model_id,
148+
)
149+
generator = None
150+
if setting.seed != -1 and setting.seed:
151+
print(f"Using seed {setting.seed}")
152+
generator = torch.Generator(self.device).manual_seed(setting.seed)
153+
154+
if setting.attention_slicing:
155+
self.img_to_img_pipeline.enable_attention_slicing() # type: ignore
156+
else:
157+
self.img_to_img_pipeline.disable_attention_slicing() # type: ignore
158+
159+
if setting.vae_slicing:
160+
self.pipeline.enable_vae_slicing() # type: ignore
161+
else:
162+
self.pipeline.disable_vae_slicing() # type: ignore
163+
164+
init_image = setting.image.resize(
165+
(
166+
setting.image_width,
167+
setting.image_height,
168+
),
169+
Image.Resampling.LANCZOS,
170+
)
171+
images = self.img_to_img_pipeline( # type: ignore
172+
image=init_image,
173+
strength=setting.strength,
174+
prompt=setting.prompt,
175+
guidance_scale=setting.guidance_scale,
176+
num_inference_steps=setting.inference_steps,
177+
negative_prompt=setting.negative_prompt,
178+
num_images_per_prompt=setting.number_of_images,
179+
generator=generator,
180+
).images
181+
return images

src/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
VERSION = "2.0.0-beta.0"
1+
VERSION = "3.0.0"
22
STABLE_DIFFUSION_MODELS_FILE = "stable_diffusion_models.txt"
33
APP_SETTINGS_FILE = "settings.yaml"
44
CONFIG_DIRECTORY = "configs"

src/frontend/web/depth_to_image_ui.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def random_seed():
114114
show_label=True,
115115
elem_id="gallery",
116116
).style(
117-
grid=2,
117+
columns=2,
118118
)
119119
generate_btn.click(
120120
fn=generate_callback_fn,

src/frontend/web/image_inpainting_ui.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def random_seed():
6969
label="Number of images to generate",
7070
)
7171
attn_slicing = gr.Checkbox(
72-
label="Attention slicing (Enable if low VRAM)",
72+
label="Attention slicing (Not used)",
7373
value=True,
7474
)
7575
seed = gr.Number(
@@ -105,7 +105,7 @@ def random_seed():
105105
show_label=True,
106106
elem_id="gallery",
107107
).style(
108-
grid=2,
108+
columns=2,
109109
)
110110
generate_btn.click(
111111
fn=generate_callback_fn,

0 commit comments

Comments
 (0)