Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions olive/passes/qairt/encapsulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
MAX_GENIE_CONTEXT_LENGTH = 4096


def _deep_merge(base: dict, overrides: dict) -> dict:
"""Recursively merge *overrides* into *base*, returning a new dict.

Nested dicts are merged rather than replaced, so only the keys present in
*overrides* are changed; all other keys from *base* are preserved.
"""
result = dict(base)
for k, v in overrides.items():
if k in result and isinstance(result[k], dict) and isinstance(v, dict):
result[k] = _deep_merge(result[k], v)
else:
result[k] = v
return result


class QairtEncapsulation(Pass):
"""Encapsulates a QAIRT DLC model with an onnx protobuf."""

Expand All @@ -49,6 +64,21 @@
required=False,
description="Opset name and version to be added in the generated context model",
),
"genie_overrides": PassConfigParam(
type_=dict,
default_value=None,
required=False,
description=(
"Deep-merged into the GenAIConfig before the Genie DLC is produced. "
"Use Python field names (underscores). Nested dicts are merged recursively — "
"only the specified keys are overridden; all other GenAIBuilder defaults are "
"preserved. Any field on GenAIConfig is valid: kv_dim, rope_theta, n_heads, "
"n_layer, n_embd, allow_async_init, enable_graph_switching, "
"positional_encoding (nested dict), etc. Note: top-level rope_theta and "
"rope_scaling are not forwarded by the Genie factory — use "
"positional_encoding.rope_theta to override RoPE theta in the DLC."
),
),
}

def _run_for_config(
Expand Down Expand Up @@ -76,6 +106,13 @@

container: qairt_genai.LLMContainer = qairt_genai.LLMContainer.load(model.model_path)

if config.genie_overrides:
gen_ai_cfg = container._gen_ai_config

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _gen_ai_config of a client class (protected-access)
See protected-access.
current = gen_ai_cfg.model_dump(mode="json", by_alias=False, exclude_none=True)
merged = _deep_merge(current, config.genie_overrides)
container._gen_ai_config = gen_ai_cfg.model_validate(merged)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _gen_ai_config of a client class (protected-access)
See protected-access.
logger.info("Applied genie_overrides to GenAIConfig: %s", list(config.genie_overrides.keys()))

# Input/Output metadata
container.inputs = [("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])]
container.outputs = [("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab_size"])]
Expand Down
191 changes: 191 additions & 0 deletions test/passes/qairt/test_encapsulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,194 @@
assert len(provider_options) == 1
assert "qnn" in provider_options[0]
assert "QNN" not in provider_options[0]


# ---------------------------------------------------------------------------
# _deep_merge unit tests
# ---------------------------------------------------------------------------


def test_deep_merge_flat():
"""Flat keys in overrides replace or add keys in base."""
from olive.passes.qairt.encapsulation import _deep_merge

result = _deep_merge({"a": 1, "b": 2}, {"b": 99, "c": 3})
assert result == {"a": 1, "b": 99, "c": 3}


def test_deep_merge_nested_dicts_are_merged_not_replaced():
"""Nested dicts are recursively merged, preserving keys not in overrides."""
from olive.passes.qairt.encapsulation import _deep_merge

base = {"positional_encoding": {"type": "rope", "rope_dim": 64, "rope_theta": 10000.0}}
overrides = {"positional_encoding": {"rope_theta": 500000.0}}
result = _deep_merge(base, overrides)
assert result["positional_encoding"] == {"type": "rope", "rope_dim": 64, "rope_theta": 500000.0}


def test_deep_merge_nested_override_replaces_non_dict():
"""A dict override replaces a non-dict base value at the same key."""
from olive.passes.qairt.encapsulation import _deep_merge

result = _deep_merge({"a": 42}, {"a": {"nested": 1}})
assert result == {"a": {"nested": 1}}


def test_deep_merge_base_unmodified():
"""_deep_merge does not mutate base."""
from olive.passes.qairt.encapsulation import _deep_merge

base = {"a": {"b": 1}}
overrides = {"a": {"b": 2}}
_deep_merge(base, overrides)
assert base["a"]["b"] == 1


# ---------------------------------------------------------------------------
# genie_overrides integration tests
# ---------------------------------------------------------------------------


def test_encapsulation_default_config_includes_genie_overrides(mock_accelerator_spec):
"""genie_overrides is present in _default_config with None default."""
config = QairtEncapsulation._default_config(mock_accelerator_spec) # pylint: disable=protected-access
assert "genie_overrides" in config
assert config["genie_overrides"].default_value is None
assert config["genie_overrides"].required is False


def test_encapsulation_genie_overrides_applied(tmp_path, mock_qairt_model, mock_qairt_modules):
"""When genie_overrides is set, _gen_ai_config is deep-merged before export."""
output_path = tmp_path / "output"
output_path.mkdir(parents=True, exist_ok=True)

model_path = Path(mock_qairt_model.model_path)
(model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096}))
(model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2}))

mock_container = MagicMock()
mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])]
mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])]

# Represent the existing GenAIConfig state after LLMContainer.load()
initial_gen_ai_state = {
"context_length": 4096,
"n_vocab": 32000,
"bos_token": 1,
"eos_token": 2,
"tokenizer_path": str(tmp_path / "tokenizer.json"),
"kv_dim": None,
"positional_encoding": {"type": "rope", "rope_dim": 64},
}
mock_container._gen_ai_config.model_dump.return_value = initial_gen_ai_state

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning test

Access to a protected member _gen_ai_config of a client class (protected-access)
See protected-access.

def mock_export(output_dir, export_format):
Path(output_dir).mkdir(parents=True, exist_ok=True)
(Path(output_dir) / "model.dlc").write_text("dummy dlc")

mock_container.export.side_effect = mock_export
mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container

def mock_save_func(model_def, path):
import onnx
from onnx import TensorProto

inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"])
out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"])
node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"])
graph = onnx.helper.make_graph([node], "g", [inp], [out])
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)])
onnx.save(model, path)

overrides = {"kv_dim": 128, "positional_encoding": {"rope_theta": 500000.0}}

with (
patch("olive.passes.qairt.encapsulation.helper") as mock_helper,
patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func),
patch("olive.passes.qairt.encapsulation.checker"),
):
mock_helper.make_node.return_value = MagicMock()
mock_helper.make_attribute.return_value = MagicMock()
mock_helper.make_tensor_value_info.return_value = MagicMock()
mock_helper.make_graph.return_value = MagicMock()
mock_helper.make_opsetid.return_value = MagicMock()
mock_helper.make_model.return_value = MagicMock()

encap_pass = create_pass_from_dict(
QairtEncapsulation,
{"backend": "CPU", "genie_overrides": overrides},
disable_search=True,
)

encap_pass.run(mock_qairt_model, str(output_path))

# model_dump was called to capture current state
mock_container._gen_ai_config.model_dump.assert_called_once_with(mode="json", by_alias=False, exclude_none=True)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning test

Access to a protected member _gen_ai_config of a client class (protected-access)
See protected-access.
# model_validate was called with the deep-merged result
expected_merged = {
**initial_gen_ai_state,
"kv_dim": 128,
"positional_encoding": {"type": "rope", "rope_dim": 64, "rope_theta": 500000.0},
}
mock_container._gen_ai_config.model_validate.assert_called_once_with(expected_merged)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning test

Access to a protected member _gen_ai_config of a client class (protected-access)
See protected-access.
# _gen_ai_config was reassigned to the validated result
assert (
mock_container._gen_ai_config

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning test

Access to a protected member _gen_ai_config of a client class (protected-access)
See protected-access.
is not mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value._gen_ai_config

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning test

Access to a protected member _gen_ai_config of a client class (protected-access)
See protected-access.
)


def test_encapsulation_no_genie_overrides_leaves_gen_ai_config_untouched(
tmp_path, mock_qairt_model, mock_qairt_modules
):
"""When genie_overrides is None, _gen_ai_config.model_dump is never called."""
output_path = tmp_path / "output"
output_path.mkdir(parents=True, exist_ok=True)

model_path = Path(mock_qairt_model.model_path)
(model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096}))
(model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2}))

mock_container = MagicMock()
mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])]
mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])]

def mock_export(output_dir, export_format):
Path(output_dir).mkdir(parents=True, exist_ok=True)
(Path(output_dir) / "model.dlc").write_text("dummy dlc")

mock_container.export.side_effect = mock_export
mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container

def mock_save_func(model_def, path):
import onnx
from onnx import TensorProto

inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"])
out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"])
node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"])
graph = onnx.helper.make_graph([node], "g", [inp], [out])
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)])
onnx.save(model, path)

with (
patch("olive.passes.qairt.encapsulation.helper") as mock_helper,
patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func),
patch("olive.passes.qairt.encapsulation.checker"),
):
mock_helper.make_node.return_value = MagicMock()
mock_helper.make_attribute.return_value = MagicMock()
mock_helper.make_tensor_value_info.return_value = MagicMock()
mock_helper.make_graph.return_value = MagicMock()
mock_helper.make_opsetid.return_value = MagicMock()
mock_helper.make_model.return_value = MagicMock()

encap_pass = create_pass_from_dict(
QairtEncapsulation,
{"backend": "CPU"},
disable_search=True,
)

encap_pass.run(mock_qairt_model, str(output_path))

mock_container._gen_ai_config.model_dump.assert_not_called()

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning test

Access to a protected member _gen_ai_config of a client class (protected-access)
See protected-access.
Loading