Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions olive/olive_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,15 @@
"supported_algorithms": [ ],
"supported_quantization_encodings": [ ],
"run_on_target": true
},
"VitisGenerateModelSD": {
"module_path": "olive.passes.onnx.vitis_ai.vitis_generate_model_sd.VitisGenerateModelSD",
"supported_providers": [ "CPUExecutionProvider" ],
"supported_accelerators": [ "cpu" ],
"supported_precisions": [ "int8" ],
"supported_algorithms": [ ],
"supported_quantization_encodings": [ ],
"run_on_target": true
}
},
"extra_dependencies": {
Expand Down
148 changes: 148 additions & 0 deletions olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# -------------------------------------------------------------------------
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
# -------------------------------------------------------------------------

"""Olive Pass for Vitis NPU Stable Diffusion submodel generation.

Accepts ONNX input only; run OnnxConversion to produce ONNX input model first,
then this pass runs generate_sd_model to generate NPU-ready models.
"""
Comment thread
liujij marked this conversation as resolved.

from __future__ import annotations

import logging
import shutil
from pathlib import Path

from olive.model import ONNXModelHandler
from olive.passes import Pass
from olive.passes.pass_config import BasePassConfig, PassConfigParam

logger = logging.getLogger(__name__)


class VitisGenerateModelSD(Pass):
"""Generate Vitis NPU-ready SD submodel from ONNX input.

Use OnnxConversion to produce ONNX input model.
Optional resolutions to generate NPU-ready models. Default is [512x512].
"""
Comment thread
liujij marked this conversation as resolved.

@classmethod
def _default_config(cls, accelerator_spec):
return {
"model_type": PassConfigParam(
type_=str,
required=True,
description="SD submodel type.",
),
"resolutions": PassConfigParam(
type_=list[str],
default_value=["512x512"],
required=False,
description="List of resolutions (e.g. ['512x512', '1024x1024']) Default is [512x512].",
),
}

def _run_for_config(
self,
model: ONNXModelHandler,
config: BasePassConfig,
output_model_path: str,
) -> ONNXModelHandler:
try:
from model_generate import generate_model
except ImportError as e:
raise ImportError(
"model_generate is required for VitisGenerateModelSD. Please install the model_generate package."
) from e

if not isinstance(model, ONNXModelHandler):
raise TypeError(
f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}"
)
model_type = config.model_type
Comment thread
liujij marked this conversation as resolved.

output_dir = Path(output_model_path)
if output_dir.suffix == ".onnx":
output_dir = output_dir.parent
output_dir.mkdir(parents=True, exist_ok=True)

logger.info(
"[VitisGenerateModelSD] output_dir=%s, model_type=%s",
output_dir,
model_type,
)

onnx_input_path = self.resolve_onnx_input_path(model)
logger.info("[VitisGenerateModelSD] ONNX input path: %s", onnx_input_path)

resolutions = getattr(config, "resolutions", None)
extra_options = {"model_type": model_type}
if resolutions:
logger.info(
"[VitisGenerateModelSD] Using resolutions: %s",
resolutions,
)
extra_options["resolutions"] = ",".join(resolutions)

generate_model(
mode="sd",
input_model=str(onnx_input_path),
output_dir=str(output_dir),
extra_options=extra_options,
)

self._ensure_model_onnx(output_dir)

return ONNXModelHandler(
model_path=str(output_dir),
onnx_file_name="model.onnx",
)

def resolve_onnx_input_path(self, model: ONNXModelHandler) -> Path:
p = Path(model.model_path)
if p.is_file():
return p
if p.is_dir():
name = getattr(model, "onnx_file_name", None)
if name:
f = p / name
if f.exists():
return f
raise FileNotFoundError(f"Specified onnx_file_name does not exist under {p}: {name}")

default_model_path = p / "model.onnx"
if default_model_path.exists():
return default_model_path

onnx_files = sorted(path for path in p.glob("*.onnx") if path.is_file())
if len(onnx_files) == 1:
return onnx_files[0]
if len(onnx_files) > 1:
candidates = ", ".join(path.name for path in onnx_files)
raise ValueError(
f"Multiple .onnx model files found under {p}: {candidates}. Please specify one using the onnx_file_name argument."
)
else:
raise FileNotFoundError(f"No .onnx file found under {p}")
raise FileNotFoundError(f"Model path does not exist: {p}")

def _ensure_model_onnx(self, output_dir: Path) -> None:
"""Copy actual generate_sd_model output to output_dir/model.onnx if needed."""
model_onnx = output_dir / "model.onnx"
if model_onnx.exists():
return
optimized = output_dir / "optimized.onnx"
dd_replaced = output_dir / "dd" / "replaced.onnx"
if dd_replaced.exists():
shutil.copy2(dd_replaced, model_onnx)
logger.info("[VitisGenerateModelSD] Wrote model.onnx from dd/replaced.onnx")
elif optimized.exists():
shutil.copy2(optimized, model_onnx)
logger.info("[VitisGenerateModelSD] Wrote model.onnx from optimized.onnx")
else:
raise FileNotFoundError(
f"[VitisGenerateModelSD] No optimized.onnx or dd/replaced.onnx found under {output_dir}. Please check the output directory.",
)
222 changes: 222 additions & 0 deletions test/passes/vitis_ai/test_vitis_generate_model_sd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# -------------------------------------------------------------------------
# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
# -------------------------------------------------------------------------

from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

import pytest

from olive.model import ONNXModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.vitis_ai.vitis_generate_model_sd import VitisGenerateModelSD
from test.utils import ONNX_MODEL_PATH, get_onnx_model

pytest.importorskip("model_generate", reason="model_generate is not installed; skipping all SD model generation tests")

_PATCH_GEN = "model_generate.generate_model"


def _make_pass(**kwargs):
cfg = {"model_type": "unet", "resolutions": [], **kwargs}
return create_pass_from_dict(VitisGenerateModelSD, cfg, disable_search=True)


def _generate_writes_placeholder(**kwargs):
"""Mock generate_model: leave output Olive's _ensure_model_onnx can satisfy."""
out = Path(kwargs["output_dir"])
out.mkdir(parents=True, exist_ok=True)
(out / "optimized.onnx").write_bytes(b"placeholder")


def test_run_raises_on_missing_model_generate(tmp_path):
p = _make_pass()
with patch.dict("sys.modules", {"model_generate": None}):

Check warning

Code scanning / lintrunner

RUFF/SIM117 Warning test

Use a single with statement with multiple contexts instead of nested with statements.
See https://docs.astral.sh/ruff/rules/multiple-with-statements
with pytest.raises(ImportError, match="model_generate is required for VitisGenerateModelSD"):
p.run(get_onnx_model(), str(tmp_path / "out"))


def test_run_includes_resolutions_in_extra_options(tmp_path):
gen = MagicMock(side_effect=_generate_writes_placeholder)
with patch(_PATCH_GEN, gen):
p = create_pass_from_dict(
VitisGenerateModelSD,
{
"model_type": "unet",
"resolutions": ["512x512", "768x768"],
},
disable_search=True,
)
p.run(get_onnx_model(), str(tmp_path / "sd_out"))

gen.assert_called_once()
kwargs = gen.call_args.kwargs
assert kwargs["mode"] == "sd"
assert kwargs["extra_options"]["model_type"] == "unet"
assert kwargs["extra_options"]["resolutions"] == "512x512,768x768"


def test_run_default_resolutions_passed_when_using_defaults(tmp_path):
gen = MagicMock(side_effect=_generate_writes_placeholder)
with patch(_PATCH_GEN, gen):
p = create_pass_from_dict(
VitisGenerateModelSD,
{"model_type": "unet"},
disable_search=True,
)
p.run(get_onnx_model(), str(tmp_path / "out"))

assert gen.call_args.kwargs["extra_options"].get("resolutions") == "512x512"


def test_run_omits_resolutions_when_empty_list(tmp_path):
gen = MagicMock(side_effect=_generate_writes_placeholder)
with patch(_PATCH_GEN, gen):
p = create_pass_from_dict(
VitisGenerateModelSD,
{"model_type": "unet", "resolutions": []},
disable_search=True,
)
p.run(get_onnx_model(), str(tmp_path / "out"))

assert "resolutions" not in gen.call_args.kwargs["extra_options"]


def test_ensure_model_onnx_copies_optimized(tmp_path):
def write_optimized(**kwargs):
out = Path(kwargs["output_dir"])
(out / "optimized.onnx").write_text("from_optimized", encoding="utf-8")

gen = MagicMock(side_effect=write_optimized)
with patch(_PATCH_GEN, gen):
p = create_pass_from_dict(
VitisGenerateModelSD,
{"model_type": "unet", "resolutions": []},
disable_search=True,
)
p.run(get_onnx_model(), str(tmp_path / "out"))

assert (tmp_path / "out" / "model.onnx").read_text(encoding="utf-8") == "from_optimized"


def test_ensure_model_onnx_prefers_dd_replaced_over_optimized(tmp_path):
def write_both(**kwargs):
out = Path(kwargs["output_dir"])
(out / "optimized.onnx").write_text("from_optimized", encoding="utf-8")
dd = out / "dd"
dd.mkdir(parents=True)
(dd / "replaced.onnx").write_text("from_dd", encoding="utf-8")

gen = MagicMock(side_effect=write_both)
with patch(_PATCH_GEN, gen):
p = create_pass_from_dict(
VitisGenerateModelSD,
{"model_type": "unet", "resolutions": []},
disable_search=True,
)
p.run(get_onnx_model(), str(tmp_path / "out"))

assert (tmp_path / "out" / "model.onnx").read_text(encoding="utf-8") == "from_dd"


def test_ensure_model_onnx_skips_copy_when_model_onnx_exists(tmp_path):
def write_only_original(**kwargs):
out = Path(kwargs["output_dir"])
(out / "model.onnx").write_text("original", encoding="utf-8")
(out / "optimized.onnx").write_text("optimized", encoding="utf-8")

gen = MagicMock(side_effect=write_only_original)
with patch(_PATCH_GEN, gen):
p = create_pass_from_dict(
VitisGenerateModelSD,
{"model_type": "unet", "resolutions": []},
disable_search=True,
)
p.run(get_onnx_model(), str(tmp_path / "out"))

assert (tmp_path / "out" / "model.onnx").read_text(encoding="utf-8") == "original"


def test_ensure_model_onnx_raises_when_no_candidate_files(tmp_path):
gen = MagicMock()
with patch(_PATCH_GEN, gen):
p = create_pass_from_dict(
VitisGenerateModelSD,
{"model_type": "unet", "resolutions": []},
disable_search=True,
)
with pytest.raises(FileNotFoundError, match=r"No optimized\.onnx or dd/replaced\.onnx"):
p.run(get_onnx_model(), str(tmp_path / "out"))


def test_resolve_onnx_input_path_single_file():
p = _make_pass()
h = ONNXModelHandler(model_path=str(ONNX_MODEL_PATH))
assert p.resolve_onnx_input_path(h) == Path(ONNX_MODEL_PATH)


def test_resolve_onnx_input_path_dir_with_model_onnx(tmp_path):
(tmp_path / "model.onnx").write_bytes(b"x")
p = _make_pass()
h = ONNXModelHandler(model_path=str(tmp_path))
assert p.resolve_onnx_input_path(h) == tmp_path / "model.onnx"


def test_resolve_onnx_input_path_dir_with_onnx_file_name(tmp_path):
(tmp_path / "custom.onnx").write_bytes(b"x")
p = _make_pass()
h = ONNXModelHandler(model_path=str(tmp_path), onnx_file_name="custom.onnx")
assert p.resolve_onnx_input_path(h) == tmp_path / "custom.onnx"


def test_resolve_onnx_input_path_dir_onnx_file_name_missing_raises(tmp_path):
p = _make_pass()
h = SimpleNamespace(model_path=str(tmp_path), onnx_file_name="missing.onnx")
with pytest.raises(FileNotFoundError, match="Specified onnx_file_name"):
p.resolve_onnx_input_path(h)


def test_resolve_onnx_input_path_dir_single_unnamed_onnx(tmp_path):
(tmp_path / "only.onnx").write_bytes(b"x")
p = _make_pass()
h = ONNXModelHandler(model_path=str(tmp_path))
assert p.resolve_onnx_input_path(h) == tmp_path / "only.onnx"


def test_resolve_onnx_input_path_dir_multiple_onnx_raises(tmp_path):
(tmp_path / "a.onnx").write_bytes(b"x")
(tmp_path / "b.onnx").write_bytes(b"y")
p = _make_pass()
h = SimpleNamespace(model_path=str(tmp_path))
with pytest.raises(ValueError, match=r"Multiple \.onnx model files found"):
p.resolve_onnx_input_path(h)


def test_resolve_onnx_input_path_dir_no_onnx_raises(tmp_path):
p = _make_pass()
h = SimpleNamespace(model_path=str(tmp_path))
with pytest.raises(FileNotFoundError, match=r"No \.onnx file found"):
p.resolve_onnx_input_path(h)


def test_resolve_onnx_input_path_missing_path_raises(tmp_path):
p = _make_pass()
missing = tmp_path / "nope"
h = SimpleNamespace(model_path=str(missing))
with pytest.raises(FileNotFoundError, match="Model path does not exist"):
p.resolve_onnx_input_path(h)


def test_run_requires_onnx_model_handler(tmp_path):
gen = MagicMock()
with patch(_PATCH_GEN, gen):
p = create_pass_from_dict(
VitisGenerateModelSD,
{"model_type": "unet", "resolutions": []},
disable_search=True,
)
bad = MagicMock()
with pytest.raises(TypeError, match="ONNXModelHandler"):
p.run(bad, str(tmp_path / "out"))
Loading