Skip to content

Commit 25162f1

Browse files
committed
Added controlnet seg support,added schedulers - KDPM2DiscreteScheduler,HeunDiscreteScheduler,KDPM2AncestralDiscreteScheduler,DPMSolverSinglestepScheduler
1 parent 070851a commit 25162f1

10 files changed

Lines changed: 252 additions & 31 deletions

File tree

Readme.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ DiffusionMagic focused on the following areas:
2222
- Pose
2323
- Depth
2424
- Scribble
25+
- Segmentation
2526
- Pytorch 2.0 support
2627
- Supports all stable diffusion Hugging Face models
2728
- Supports Stable diffusion v1 and v2 models, derived models
@@ -109,6 +110,10 @@ E.g `https://huggingface.co/dreamlike-art/dreamlike-diffusion-1.0`
109110
Here model id is `dreamlike-art/dreamlike-diffusion-1.0`
110111
Or we can clone the model use the local folder path as model id.
111112
- Adding locally copied model path to configs/stable_diffusion_models.txt file
113+
## Linting (Development)
114+
Run the following commands from src folder
115+
`mypy --ignore-missing-imports --explicit-package-bases .`
116+
`flake8 --max-line-length=100 .`
112117
## Contribute
113118
Contributions are welcomed.
114119

configs/stable_diffusion_models.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ lllyasviel/sd-controlnet-normal
1515
lllyasviel/sd-controlnet-hed
1616
lllyasviel/sd-controlnet-openpose
1717
lllyasviel/sd-controlnet-depth
18-
lllyasviel/sd-controlnet-scribble
18+
lllyasviel/sd-controlnet-scribble
19+
lllyasviel/sd-controlnet-seg

src/backend/controlnet/ControlContext.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
from time import time
2-
from typing import Any
3-
import numpy as np
4-
from cv2 import Canny, bitwise_not
2+
53
import torch
6-
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
7-
from PIL import Image, ImageOps
4+
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
5+
from PIL import ImageOps
86

97
from backend.computing import Computing
8+
from backend.controlnet.controls.image_control_factory import ImageControlFactory
9+
from backend.image_ops import resize_pil_image
1010
from backend.stablediffusion.models.scheduler_types import SchedulerType
11-
from backend.stablediffusion.models.setting import (
12-
StableDiffusionControlnetSetting,
13-
)
11+
from backend.stablediffusion.models.setting import StableDiffusionControlnetSetting
1412
from backend.stablediffusion.scheduler_mixin import SamplerMixin
15-
from backend.image_ops import resize_pil_image
1613
from backend.stablediffusion.stable_diffusion_types import (
17-
get_diffusion_type,
1814
StableDiffusionType,
15+
get_diffusion_type,
1916
)
20-
from backend.controlnet.controls.image_control_factory import ImageControlFactory
2117

2218

2319
class ControlnetContext(SamplerMixin):
@@ -134,16 +130,6 @@ def _load_model(self):
134130
else:
135131
self._load_full_precision_model()
136132

137-
def get_canny_image(self, image: Image) -> Any:
138-
low_threshold = 100
139-
high_threshold = 200
140-
image = np.array(image)
141-
image = Canny(image, low_threshold, high_threshold)
142-
image_inv = bitwise_not(image)
143-
image = image[:, :, None]
144-
image = np.concatenate([image, image, image], axis=2)
145-
return Image.fromarray(image), Image.fromarray(image_inv)
146-
147133
def _enable_slicing(self, setting: StableDiffusionControlnetSetting):
148134
if setting.attention_slicing:
149135
self.controlnet_pipeline.enable_attention_slicing()

src/backend/controlnet/controls/image_control_factory.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from backend.stablediffusion.stable_diffusion_types import (
2-
StableDiffusionType,
3-
)
41
from backend.controlnet.controls.canny_control import CannyControl
2+
from backend.controlnet.controls.depth_control import DepthControl
3+
from backend.controlnet.controls.hed_control import HedControl
54
from backend.controlnet.controls.line_control import LineControl
65
from backend.controlnet.controls.normal_control import NormalControl
7-
from backend.controlnet.controls.hed_control import HedControl
86
from backend.controlnet.controls.pose_control import PoseControl
9-
from backend.controlnet.controls.depth_control import DepthControl
107
from backend.controlnet.controls.scribble_control import ScribbleControl
8+
from backend.controlnet.controls.seg_control import SegControl
9+
from backend.stablediffusion.stable_diffusion_types import StableDiffusionType
1110

1211

1312
class ImageControlFactory:
@@ -26,6 +25,8 @@ def create_control(self, controlnet_type: StableDiffusionType):
2625
return DepthControl()
2726
elif controlnet_type == StableDiffusionType.controlnet_scribble:
2827
return ScribbleControl()
28+
elif controlnet_type == StableDiffusionType.controlnet_seg:
29+
return SegControl()
2930
else:
3031
print("Error: Control type not implemented!")
3132
raise Exception("Error: Control type not implemented!")
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import numpy as np
2+
import torch
3+
from PIL import Image
4+
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
5+
6+
from backend.controlnet.controls.control_interface import ControlInterface
7+
8+
palette = np.asarray(
9+
[
10+
[0, 0, 0],
11+
[120, 120, 120],
12+
[180, 120, 120],
13+
[6, 230, 230],
14+
[80, 50, 50],
15+
[4, 200, 3],
16+
[120, 120, 80],
17+
[140, 140, 140],
18+
[204, 5, 255],
19+
[230, 230, 230],
20+
[4, 250, 7],
21+
[224, 5, 255],
22+
[235, 255, 7],
23+
[150, 5, 61],
24+
[120, 120, 70],
25+
[8, 255, 51],
26+
[255, 6, 82],
27+
[143, 255, 140],
28+
[204, 255, 4],
29+
[255, 51, 7],
30+
[204, 70, 3],
31+
[0, 102, 200],
32+
[61, 230, 250],
33+
[255, 6, 51],
34+
[11, 102, 255],
35+
[255, 7, 71],
36+
[255, 9, 224],
37+
[9, 7, 230],
38+
[220, 220, 220],
39+
[255, 9, 92],
40+
[112, 9, 255],
41+
[8, 255, 214],
42+
[7, 255, 224],
43+
[255, 184, 6],
44+
[10, 255, 71],
45+
[255, 41, 10],
46+
[7, 255, 255],
47+
[224, 255, 8],
48+
[102, 8, 255],
49+
[255, 61, 6],
50+
[255, 194, 7],
51+
[255, 122, 8],
52+
[0, 255, 20],
53+
[255, 8, 41],
54+
[255, 5, 153],
55+
[6, 51, 255],
56+
[235, 12, 255],
57+
[160, 150, 20],
58+
[0, 163, 255],
59+
[140, 140, 140],
60+
[250, 10, 15],
61+
[20, 255, 0],
62+
[31, 255, 0],
63+
[255, 31, 0],
64+
[255, 224, 0],
65+
[153, 255, 0],
66+
[0, 0, 255],
67+
[255, 71, 0],
68+
[0, 235, 255],
69+
[0, 173, 255],
70+
[31, 0, 255],
71+
[11, 200, 200],
72+
[255, 82, 0],
73+
[0, 255, 245],
74+
[0, 61, 255],
75+
[0, 255, 112],
76+
[0, 255, 133],
77+
[255, 0, 0],
78+
[255, 163, 0],
79+
[255, 102, 0],
80+
[194, 255, 0],
81+
[0, 143, 255],
82+
[51, 255, 0],
83+
[0, 82, 255],
84+
[0, 255, 41],
85+
[0, 255, 173],
86+
[10, 0, 255],
87+
[173, 255, 0],
88+
[0, 255, 153],
89+
[255, 92, 0],
90+
[255, 0, 255],
91+
[255, 0, 245],
92+
[255, 0, 102],
93+
[255, 173, 0],
94+
[255, 0, 20],
95+
[255, 184, 184],
96+
[0, 31, 255],
97+
[0, 255, 61],
98+
[0, 71, 255],
99+
[255, 0, 204],
100+
[0, 255, 194],
101+
[0, 255, 82],
102+
[0, 10, 255],
103+
[0, 112, 255],
104+
[51, 0, 255],
105+
[0, 194, 255],
106+
[0, 122, 255],
107+
[0, 255, 163],
108+
[255, 153, 0],
109+
[0, 255, 10],
110+
[255, 112, 0],
111+
[143, 255, 0],
112+
[82, 0, 255],
113+
[163, 255, 0],
114+
[255, 235, 0],
115+
[8, 184, 170],
116+
[133, 0, 255],
117+
[0, 255, 92],
118+
[184, 0, 255],
119+
[255, 0, 31],
120+
[0, 184, 255],
121+
[0, 214, 255],
122+
[255, 0, 112],
123+
[92, 255, 0],
124+
[0, 224, 255],
125+
[112, 224, 255],
126+
[70, 184, 160],
127+
[163, 0, 255],
128+
[153, 0, 255],
129+
[71, 255, 0],
130+
[255, 0, 163],
131+
[255, 204, 0],
132+
[255, 0, 143],
133+
[0, 255, 235],
134+
[133, 255, 0],
135+
[255, 0, 235],
136+
[245, 0, 255],
137+
[255, 0, 122],
138+
[255, 245, 0],
139+
[10, 190, 212],
140+
[214, 255, 0],
141+
[0, 204, 255],
142+
[20, 0, 255],
143+
[255, 255, 0],
144+
[0, 153, 255],
145+
[0, 41, 255],
146+
[0, 255, 204],
147+
[41, 0, 255],
148+
[41, 255, 0],
149+
[173, 0, 255],
150+
[0, 245, 255],
151+
[71, 0, 255],
152+
[122, 0, 255],
153+
[0, 255, 184],
154+
[0, 92, 255],
155+
[184, 255, 0],
156+
[0, 133, 255],
157+
[255, 214, 0],
158+
[25, 194, 194],
159+
[102, 255, 0],
160+
[92, 0, 255],
161+
]
162+
)
163+
164+
165+
class SegControl(ControlInterface):
166+
def get_control_image(self, image: Image) -> Image:
167+
image_processor = AutoImageProcessor.from_pretrained(
168+
"openmmlab/upernet-convnext-small"
169+
)
170+
image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
171+
"openmmlab/upernet-convnext-small"
172+
)
173+
pixel_values = image_processor(image, return_tensors="pt").pixel_values
174+
with torch.no_grad():
175+
outputs = image_segmentor(pixel_values)
176+
seg = image_processor.post_process_semantic_segmentation(
177+
outputs, target_sizes=[image.size[::-1]]
178+
)[0]
179+
180+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
181+
182+
for label, color in enumerate(palette):
183+
color_seg[seg == label, :] = color
184+
color_seg = color_seg.astype(np.uint8)
185+
image = Image.fromarray(color_seg)
186+
return image

src/backend/stablediffusion/models/scheduler_types.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
EulerDiscreteScheduler,
1010
LMSDiscreteScheduler,
1111
PNDMScheduler,
12-
UniPCMultistepScheduler
12+
UniPCMultistepScheduler,
13+
KDPM2DiscreteScheduler,
14+
HeunDiscreteScheduler,
15+
KDPM2AncestralDiscreteScheduler,
16+
DPMSolverSinglestepScheduler,
1317
)
1418

1519

@@ -25,6 +29,10 @@ class SchedulerType(Enum):
2529
PNDMScheduler = "PNDM"
2630
DEISScheduler = "DEISMultistep"
2731
UniPCMultistepScheduler = "UniPCMultistep"
32+
KDPM2DiscreteScheduler = "KDPM2DiscreteScheduler"
33+
HeunDiscreteScheduler = "HeunDiscreteScheduler"
34+
KDPM2AncestralDiscreteScheduler = "KDPM2AncestralDiscreteScheduler"
35+
DPMSolverSinglestepScheduler = "DPMSolverSinglestepScheduler"
2836

2937

3038
Scheduler = Union[
@@ -36,6 +44,10 @@ class SchedulerType(Enum):
3644
LMSDiscreteScheduler,
3745
PNDMScheduler,
3846
UniPCMultistepScheduler,
47+
KDPM2DiscreteScheduler,
48+
HeunDiscreteScheduler,
49+
KDPM2AncestralDiscreteScheduler,
50+
DPMSolverSinglestepScheduler,
3951
]
4052

4153

src/backend/stablediffusion/scheduler_factory.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
LMSDiscreteScheduler,
88
PNDMScheduler,
99
UniPCMultistepScheduler,
10+
KDPM2DiscreteScheduler,
11+
HeunDiscreteScheduler,
12+
KDPM2AncestralDiscreteScheduler,
13+
DPMSolverSinglestepScheduler,
1014
)
1115

1216
from backend.stablediffusion.models.scheduler_types import SchedulerType, Scheduler
@@ -58,6 +62,26 @@ def get_scheduler(
5862
repo_id,
5963
subfolder="scheduler",
6064
)
65+
elif scheduler_type == SchedulerType.KDPM2DiscreteScheduler.value:
66+
return KDPM2DiscreteScheduler.from_pretrained(
67+
repo_id,
68+
subfolder="scheduler",
69+
)
70+
elif scheduler_type == SchedulerType.HeunDiscreteScheduler.value:
71+
return HeunDiscreteScheduler.from_pretrained(
72+
repo_id,
73+
subfolder="scheduler",
74+
)
75+
elif scheduler_type == SchedulerType.KDPM2AncestralDiscreteScheduler.value:
76+
return KDPM2AncestralDiscreteScheduler.from_pretrained(
77+
repo_id,
78+
subfolder="scheduler",
79+
)
80+
elif scheduler_type == SchedulerType.DPMSolverSinglestepScheduler.value:
81+
return DPMSolverSinglestepScheduler.from_pretrained(
82+
repo_id,
83+
subfolder="scheduler",
84+
)
6185
else:
6286
print(f"Scheduler {scheduler_type} not found")
6387
return None

src/backend/stablediffusion/stable_diffusion_types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class StableDiffusionType(str, Enum):
1515
controlnet_pose = "controlnet_pose"
1616
controlnet_depth = "controlnet_depth"
1717
controlnet_scribble = "controlnet_scribble"
18+
controlnet_seg = "controlnet_seg"
1819

1920

2021
def get_diffusion_type(
@@ -39,6 +40,8 @@ def get_diffusion_type(
3940
stable_diffusion_type = StableDiffusionType.controlnet_depth
4041
elif "depth" in model_id:
4142
stable_diffusion_type = StableDiffusionType.depth2img
42-
elif "scribble" in model_id:
43+
elif "controlnet-scribble" in model_id:
4344
stable_diffusion_type = StableDiffusionType.controlnet_scribble
45+
elif "controlnet-seg" in model_id:
46+
stable_diffusion_type = StableDiffusionType.controlnet_seg
4447
return stable_diffusion_type

src/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
VERSION = "2.0.0-alpha.0"
1+
VERSION = "2.0.0-beta.0"
22
STABLE_DIFFUSION_MODELS_FILE = "stable_diffusion_models.txt"
33
APP_SETTINGS_FILE = "settings.yaml"
44
CONFIG_DIRECTORY = "configs"

0 commit comments

Comments
 (0)