From 96d08e9c5da356aabb57e7e11e68ce20863d49e6 Mon Sep 17 00:00:00 2001 From: Kyle Romero Date: Tue, 19 May 2026 23:35:05 +0000 Subject: [PATCH] Add genie_overrides to QairtEncapsulation for GenAIConfig customization Introduce a genie_overrides PassConfigParam that deep-merges user-supplied fields into the GenAIConfig before LLMContainer.export() bakes them into the Genie DLC. This allows callers to override any GenAIConfig field (engine config, positional encoding, etc.) without modifying QairtGenAIBuilder or QairtPipelinePass. Nested dicts are merged recursively so only the specified keys are changed; all other values set by the upstream builder pass are preserved. --- olive/passes/qairt/encapsulation.py | 37 +++++ test/passes/qairt/test_encapsulation.py | 191 ++++++++++++++++++++++++ 2 files changed, 228 insertions(+) diff --git a/olive/passes/qairt/encapsulation.py b/olive/passes/qairt/encapsulation.py index a6fa9ebae..306535fc7 100644 --- a/olive/passes/qairt/encapsulation.py +++ b/olive/passes/qairt/encapsulation.py @@ -24,6 +24,21 @@ MAX_GENIE_CONTEXT_LENGTH = 4096 +def _deep_merge(base: dict, overrides: dict) -> dict: + """Recursively merge *overrides* into *base*, returning a new dict. + + Nested dicts are merged rather than replaced, so only the keys present in + *overrides* are changed; all other keys from *base* are preserved. + """ + result = dict(base) + for k, v in overrides.items(): + if k in result and isinstance(result[k], dict) and isinstance(v, dict): + result[k] = _deep_merge(result[k], v) + else: + result[k] = v + return result + + class QairtEncapsulation(Pass): """Encapsulates a QAIRT DLC model with an onnx protobuf.""" @@ -49,6 +64,21 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon required=False, description="Opset name and version to be added in the generated context model", ), + "genie_overrides": PassConfigParam( + type_=dict, + default_value=None, + required=False, + description=( + "Deep-merged into the GenAIConfig before the Genie DLC is produced. " + "Use Python field names (underscores). Nested dicts are merged recursively — " + "only the specified keys are overridden; all other GenAIBuilder defaults are " + "preserved. Any field on GenAIConfig is valid: kv_dim, rope_theta, n_heads, " + "n_layer, n_embd, allow_async_init, enable_graph_switching, " + "positional_encoding (nested dict), etc. Note: top-level rope_theta and " + "rope_scaling are not forwarded by the Genie factory — use " + "positional_encoding.rope_theta to override RoPE theta in the DLC." + ), + ), } def _run_for_config( @@ -76,6 +106,13 @@ def _run_for_config( container: qairt_genai.LLMContainer = qairt_genai.LLMContainer.load(model.model_path) + if config.genie_overrides: + gen_ai_cfg = container._gen_ai_config + current = gen_ai_cfg.model_dump(mode="json", by_alias=False, exclude_none=True) + merged = _deep_merge(current, config.genie_overrides) + container._gen_ai_config = gen_ai_cfg.model_validate(merged) + logger.info("Applied genie_overrides to GenAIConfig: %s", list(config.genie_overrides.keys())) + # Input/Output metadata container.inputs = [("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])] container.outputs = [("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab_size"])] diff --git a/test/passes/qairt/test_encapsulation.py b/test/passes/qairt/test_encapsulation.py index 0fb45a8b7..2ef566bfa 100644 --- a/test/passes/qairt/test_encapsulation.py +++ b/test/passes/qairt/test_encapsulation.py @@ -894,3 +894,194 @@ def test_create_genai_config_provider_options_key_lowercase(tmp_path): assert len(provider_options) == 1 assert "qnn" in provider_options[0] assert "QNN" not in provider_options[0] + + +# --------------------------------------------------------------------------- +# _deep_merge unit tests +# --------------------------------------------------------------------------- + + +def test_deep_merge_flat(): + """Flat keys in overrides replace or add keys in base.""" + from olive.passes.qairt.encapsulation import _deep_merge + + result = _deep_merge({"a": 1, "b": 2}, {"b": 99, "c": 3}) + assert result == {"a": 1, "b": 99, "c": 3} + + +def test_deep_merge_nested_dicts_are_merged_not_replaced(): + """Nested dicts are recursively merged, preserving keys not in overrides.""" + from olive.passes.qairt.encapsulation import _deep_merge + + base = {"positional_encoding": {"type": "rope", "rope_dim": 64, "rope_theta": 10000.0}} + overrides = {"positional_encoding": {"rope_theta": 500000.0}} + result = _deep_merge(base, overrides) + assert result["positional_encoding"] == {"type": "rope", "rope_dim": 64, "rope_theta": 500000.0} + + +def test_deep_merge_nested_override_replaces_non_dict(): + """A dict override replaces a non-dict base value at the same key.""" + from olive.passes.qairt.encapsulation import _deep_merge + + result = _deep_merge({"a": 42}, {"a": {"nested": 1}}) + assert result == {"a": {"nested": 1}} + + +def test_deep_merge_base_unmodified(): + """_deep_merge does not mutate base.""" + from olive.passes.qairt.encapsulation import _deep_merge + + base = {"a": {"b": 1}} + overrides = {"a": {"b": 2}} + _deep_merge(base, overrides) + assert base["a"]["b"] == 1 + + +# --------------------------------------------------------------------------- +# genie_overrides integration tests +# --------------------------------------------------------------------------- + + +def test_encapsulation_default_config_includes_genie_overrides(mock_accelerator_spec): + """genie_overrides is present in _default_config with None default.""" + config = QairtEncapsulation._default_config(mock_accelerator_spec) # pylint: disable=protected-access + assert "genie_overrides" in config + assert config["genie_overrides"].default_value is None + assert config["genie_overrides"].required is False + + +def test_encapsulation_genie_overrides_applied(tmp_path, mock_qairt_model, mock_qairt_modules): + """When genie_overrides is set, _gen_ai_config is deep-merged before export.""" + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock() + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + # Represent the existing GenAIConfig state after LLMContainer.load() + initial_gen_ai_state = { + "context_length": 4096, + "n_vocab": 32000, + "bos_token": 1, + "eos_token": 2, + "tokenizer_path": str(tmp_path / "tokenizer.json"), + "kv_dim": None, + "positional_encoding": {"type": "rope", "rope_dim": 64}, + } + mock_container._gen_ai_config.model_dump.return_value = initial_gen_ai_state + + def mock_export(output_dir, export_format): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export.side_effect = mock_export + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + overrides = {"kv_dim": 128, "positional_encoding": {"rope_theta": 500000.0}} + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU", "genie_overrides": overrides}, + disable_search=True, + ) + + encap_pass.run(mock_qairt_model, str(output_path)) + + # model_dump was called to capture current state + mock_container._gen_ai_config.model_dump.assert_called_once_with(mode="json", by_alias=False, exclude_none=True) + # model_validate was called with the deep-merged result + expected_merged = { + **initial_gen_ai_state, + "kv_dim": 128, + "positional_encoding": {"type": "rope", "rope_dim": 64, "rope_theta": 500000.0}, + } + mock_container._gen_ai_config.model_validate.assert_called_once_with(expected_merged) + # _gen_ai_config was reassigned to the validated result + assert ( + mock_container._gen_ai_config + is not mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value._gen_ai_config + ) + + +def test_encapsulation_no_genie_overrides_leaves_gen_ai_config_untouched( + tmp_path, mock_qairt_model, mock_qairt_modules +): + """When genie_overrides is None, _gen_ai_config.model_dump is never called.""" + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock() + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + def mock_export(output_dir, export_format): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export.side_effect = mock_export + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU"}, + disable_search=True, + ) + + encap_pass.run(mock_qairt_model, str(output_path)) + + mock_container._gen_ai_config.model_dump.assert_not_called()