Skip to content

Commit f70d21e

Browse files
committed
Use multiple-of-16 resolutions for all models #2436
* was already the case for Flux/SD3 * other models used 8 which was their latent downscale factor, but profit from 16 to avoid border artifacts
1 parent d69b7e7 commit f70d21e

6 files changed

Lines changed: 43 additions & 46 deletions

File tree

ai_diffusion/model.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from PyQt5.QtCore import QMetaObject, QObject, Qt, QUuid, pyqtSignal
1717
from PyQt5.QtGui import QBrush, QColor, QPainter
1818

19-
from . import eventloop, util, workflow
19+
from . import eventloop, resolution, util, workflow
2020
from .api import (
2121
ConditioningInput,
2222
ControlInput,
@@ -1559,14 +1559,11 @@ def get_selection_modifiers(
15591559
feather = min(feather, 0.01)
15601560
invert = True
15611561

1562-
if isinstance(arch, InpaintContext):
1563-
if arch is InpaintContext.mask_bounds:
1564-
min_size = 0
1565-
multiple = 1
1566-
else:
1567-
multiple = 8
1562+
if arch is InpaintContext.mask_bounds:
1563+
min_size = 0
1564+
multiple = 1
15681565
else:
1569-
multiple = arch.latent_compression_factor
1566+
multiple = resolution.diffusion_multiple
15701567

15711568
return SelectionModifiers(
15721569
feather_rel=feather * strength,

ai_diffusion/resolution.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,11 @@ def target_scaling(self):
133133
return ScaleMode.resize
134134

135135

136+
# Image resolution for diffusion should be divisible by this factor, either because it is the
137+
# required latent compression factor, or to avoid border artifacts with UNET models.
138+
diffusion_multiple = 16
139+
140+
136141
class CheckpointResolution(NamedTuple):
137142
"""Preferred resolution for a SD checkpoint, typically the resolution it was trained on."""
138143

@@ -156,7 +161,7 @@ def compute(extent: Extent, arch: Arch, style: Style | None = None, inpaint=Fals
156161
default = (640, 1280, 512**2, 1024**2)
157162
min_size, max_size, min_pixel_count, max_pixel_count = res.get(arch, default)
158163
else:
159-
range_offset = multiple_of(round(0.2 * style.preferred_resolution), 8)
164+
range_offset = multiple_of(round(0.2 * style.preferred_resolution), diffusion_multiple)
160165
min_size = style.preferred_resolution - range_offset
161166
max_size = style.preferred_resolution + range_offset
162167
min_pixel_count = max_pixel_count = style.preferred_resolution**2
@@ -186,7 +191,6 @@ def prepare_diffusion_input(
186191
desired = apply_resolution_settings(extent, perf)
187192

188193
# The checkpoint may require a different resolution than what is requested.
189-
mult = arch.latent_compression_factor
190194
if arch.is_edit:
191195
downscale = False # Never use 2-pass generation for edit models
192196

@@ -197,8 +201,8 @@ def prepare_diffusion_input(
197201
if downscale and max_scale < 0.9 and any(x > max_size for x in desired):
198202
# Desired resolution is significantly larger than the maximum size. Do 2 passes:
199203
# first pass at checkpoint resolution, then upscale to desired resolution and refine.
200-
input = initial = (desired * max_scale).multiple_of(mult)
201-
desired = desired.multiple_of(mult)
204+
input = initial = (desired * max_scale).multiple_of(diffusion_multiple)
205+
desired = desired.multiple_of(diffusion_multiple)
202206
# Input images are scaled down here for the initial pass directly to avoid encoding
203207
# and processing large images in subsequent steps.
204208
image = Image.scale(image, initial) if image else None
@@ -209,13 +213,13 @@ def prepare_diffusion_input(
209213
scaled = desired * min_scale
210214
# Avoid unnecessary scaling if too small resolution is caused by resolution multiplier
211215
if all(x >= min_size and x <= max_size for x in extent):
212-
initial = desired = extent.multiple_of(mult)
216+
initial = desired = extent.multiple_of(diffusion_multiple)
213217
else:
214-
initial = desired = scaled.multiple_of(mult)
218+
initial = desired = scaled.multiple_of(diffusion_multiple)
215219

216220
else: # Desired resolution is in acceptable range. Do 1 pass at desired resolution.
217221
input = extent
218-
initial = desired = desired.multiple_of(mult)
222+
initial = desired = desired.multiple_of(diffusion_multiple)
219223

220224
# Scale down input images if needed due to resolution_multiplier or max_pixel_count
221225
if extent.pixel_count > desired.pixel_count:

ai_diffusion/resources.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,6 @@ def text_encoders(self):
241241
return ["qwen_3_4b"]
242242
raise ValueError(f"Unsupported architecture: {self}")
243243

244-
@property
245-
def latent_compression_factor(self):
246-
return 16 if self.is_flux2 or self is Arch.sd3 else 8
247-
248244
@staticmethod
249245
def list():
250246
return [

ai_diffusion/workflow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,7 +1327,7 @@ def upscale_tiled(
13271327
models: ModelDict,
13281328
):
13291329
upscale_factor = extent.initial.width / extent.input.width
1330-
multiple = models.arch.latent_compression_factor
1330+
multiple = resolution.diffusion_multiple
13311331
if upscale.tile_overlap >= 0:
13321332
layout = TileLayout(extent.initial, extent.desired.width, upscale.tile_overlap, multiple)
13331333
else:
@@ -1714,9 +1714,9 @@ def prepare(
17141714
else:
17151715
tile_size = 1024
17161716
tile_size = max(tile_size, target_extent.longest_side // 12) # max 12x12 tiles total
1717-
tile_size = multiple_of(tile_size - 128, arch.latent_compression_factor)
1717+
tile_size = multiple_of(tile_size - 128, resolution.diffusion_multiple)
17181718
tile_size = Extent(tile_size, tile_size)
1719-
initial_extent = target_extent.multiple_of(arch.latent_compression_factor)
1719+
initial_extent = target_extent.multiple_of(resolution.diffusion_multiple)
17201720
extent = ExtentInput(canvas.extent, initial_extent, tile_size, target_extent)
17211721
i.images = ImageInput(extent, canvas)
17221722
assert upscale is not None

tests/test_resolution.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ def test_inpaint_context(area, expected_extent, expected_crop: tuple[int, int] |
126126
@pytest.mark.parametrize(
127127
"input,expected_initial,expected_desired",
128128
[
129-
(Extent(1536, 600), Extent(1008, 392), Extent(1536, 600)),
129+
(Extent(1536, 600), Extent(1008, 400), Extent(1536, 608)),
130130
(Extent(400, 1024), Extent(400, 1024), Extent(400, 1024)),
131-
(Extent(777, 999), Extent(560, 712), Extent(784, 1000)),
131+
(Extent(777, 999), Extent(560, 720), Extent(784, 1008)),
132132
],
133133
)
134134
def test_prepare_highres(input, expected_initial, expected_desired):
@@ -144,20 +144,20 @@ def test_prepare_highres(input, expected_initial, expected_desired):
144144
)
145145

146146

147-
def test_prepare_hightres_inpaint():
148-
input = Extent(3000, 2000)
147+
def test_prepare_highres_inpaint():
148+
input = Extent(3008, 2000)
149149
image = Image.create(input)
150150
r, _ = resolution.prepare_image(image, Arch.flux, dummy_style, perf, inpaint=True)
151-
assert r.extent.initial == Extent(1256, 840)
151+
assert r.extent.initial == Extent(1264, 848)
152152
assert r.extent.desired == input
153153

154154

155155
@pytest.mark.parametrize(
156156
"input,expected",
157157
[
158158
(Extent(256, 256), Extent(512, 512)),
159-
(Extent(128, 450), Extent(280, 960)),
160-
(Extent(256, 333), Extent(456, 584)), # multiple of 8
159+
(Extent(128, 450), Extent(288, 960)),
160+
(Extent(256, 333), Extent(464, 592)), # multiple of 16
161161
],
162162
)
163163
def test_prepare_lowres(input: Extent, expected: Extent):
@@ -174,7 +174,7 @@ def test_prepare_lowres(input: Extent, expected: Extent):
174174

175175
@pytest.mark.parametrize(
176176
"input",
177-
[Extent(512, 512), Extent(128, 600), Extent(768, 240)],
177+
[Extent(512, 512), Extent(128, 608), Extent(768, 240)],
178178
)
179179
def test_prepare_passthrough(input: Extent):
180180
image = Image.create(input)
@@ -190,23 +190,23 @@ def test_prepare_passthrough(input: Extent):
190190

191191

192192
@pytest.mark.parametrize(
193-
"input,expected", [(Extent(512, 513), Extent(512, 520)), (Extent(300, 1024), Extent(304, 1024))]
193+
"input,expected", [(Extent(512, 513), Extent(512, 528)), (Extent(300, 1024), Extent(304, 1024))]
194194
)
195-
def test_prepare_multiple8(input: Extent, expected: Extent):
195+
def test_prepare_multiple16(input: Extent, expected: Extent):
196196
r, _ = resolution.prepare_extent(input, Arch.sd15, dummy_style, perf)
197197
assert (
198198
r.extent.input == input
199199
and r.extent.initial == expected
200200
and r.extent.target == input
201-
and r.extent.desired == input.multiple_of(8)
201+
and r.extent.desired == input.multiple_of(16)
202202
)
203203

204204

205205
@pytest.mark.parametrize("sdver", [Arch.sd15, Arch.sdxl])
206206
def test_prepare_extent(sdver: Arch):
207207
input = Extent(1024, 1536)
208208
r, _ = resolution.prepare_extent(input, sdver, dummy_style, perf)
209-
expected = Extent(512, 768) if sdver == Arch.sd15 else Extent(840, 1256)
209+
expected = Extent(512, 768) if sdver == Arch.sd15 else Extent(848, 1264)
210210
assert r.extent.initial == expected and r.extent.desired == input and r.extent.target == input
211211

212212

@@ -228,20 +228,20 @@ def test_prepare_no_downscale(input: Extent):
228228
assert (
229229
r.initial_image
230230
and r.initial_image == image
231-
and r.extent.initial == input.multiple_of(8)
232-
and r.extent.desired == input.multiple_of(8)
231+
and r.extent.initial == input.multiple_of(16)
232+
and r.extent.desired == input.multiple_of(16)
233233
and r.extent.target == input
234234
)
235235

236236

237237
@pytest.mark.parametrize(
238238
"sd_ver,input,expected_initial,expected_desired",
239239
[
240-
(Arch.sd15, Extent(2000, 2000), (632, 632), (1000, 1000)),
241-
(Arch.sd15, Extent(1000, 1000), (632, 632), (1000, 1000)),
240+
(Arch.sd15, Extent(2000, 2000), (640, 640), (1008, 1008)),
241+
(Arch.sd15, Extent(1000, 1000), (640, 640), (1008, 1008)),
242242
(Arch.sdxl, Extent(1024, 1024), (1024, 1024), (1024, 1024)),
243-
(Arch.sdxl, Extent(2000, 2000), (1000, 1000), (1000, 1000)),
244-
(Arch.sd15, Extent(801, 801), (632, 632), (808, 808)),
243+
(Arch.sdxl, Extent(2000, 2000), (1008, 1008), (1008, 1008)),
244+
(Arch.sd15, Extent(801, 801), (640, 640), (816, 816)),
245245
],
246246
ids=["sd15_large", "sd15_small", "sdxl_small", "sdxl_large", "sd15_odd"],
247247
)
@@ -260,11 +260,11 @@ def test_prepare_max_pixel_count(input, sd_ver, expected_initial, expected_desir
260260
[
261261
(Extent(512, 512), 1.0, Extent(512, 512), Extent(512, 512)),
262262
(Extent(1024, 800), 0.5, Extent(512, 400), Extent(512, 400)),
263-
(Extent(2048, 1536), 0.5, Extent(728, 544), Extent(1024, 768)),
263+
(Extent(2048, 1536), 0.5, Extent(736, 544), Extent(1024, 768)),
264264
(Extent(1024, 1024), 0.4, Extent(512, 512), Extent(512, 512)),
265265
(Extent(512, 768), 0.5, Extent(512, 768), Extent(512, 768)),
266-
(Extent(512, 512), 2.0, Extent(632, 632), Extent(1024, 1024)),
267-
(Extent(512, 512), 1.1, Extent(568, 568), Extent(568, 568)),
266+
(Extent(512, 512), 2.0, Extent(640, 640), Extent(1024, 1024)),
267+
(Extent(512, 512), 1.1, Extent(576, 576), Extent(576, 576)),
268268
],
269269
ids=["1.0", "0.5", "0.5_large", "0.4", "0.5_tall", "2.0", "1.1"],
270270
)
@@ -296,13 +296,13 @@ def test_prepare_resolution_multiplier_inputs(multiplier):
296296

297297
@pytest.mark.parametrize(
298298
"multiplier,expected",
299-
[(0.5, Extent(1024, 1024)), (2, Extent(1000, 1000)), (0.25, Extent(512, 512))],
299+
[(0.5, Extent(1024, 1024)), (2, Extent(1008, 1008)), (0.25, Extent(512, 512))],
300300
)
301301
def test_prepare_resolution_multiplier_max(multiplier, expected):
302302
perf_settings = PerformanceSettings(resolution_multiplier=multiplier, max_pixel_count=1)
303303
input = Extent(2048, 2048)
304304
r, _ = resolution.prepare_extent(input, Arch.sd15, dummy_style, perf_settings)
305-
assert r.extent.initial.width <= 632 and r.extent.desired == expected
305+
assert r.extent.initial.width <= 640 and r.extent.desired == expected
306306

307307

308308
tile_layouts = {

tests/test_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def test_refine(qtapp, client, setup):
503503

504504
sdver, extent, strength = {
505505
"sd15": (Arch.sd15, Extent(768, 508), 0.5),
506-
"sdxl": (Arch.sdxl, Extent(1111, 741), 0.65),
506+
"sdxl": (Arch.sdxl, Extent(1111, 741), 0.5),
507507
"flux": (Arch.flux, Extent(1111, 741), 0.65),
508508
"flux_k": (Arch.flux_k, Extent(1111, 741), 1.0),
509509
"flux2": (Arch.flux2_4b, Extent(1111, 741), 1.0),

0 commit comments

Comments
 (0)