diff --git a/olive/passes/qairt/encapsulation.py b/olive/passes/qairt/encapsulation.py index a6fa9ebae..306535fc7 100644 --- a/olive/passes/qairt/encapsulation.py +++ b/olive/passes/qairt/encapsulation.py @@ -24,6 +24,21 @@ MAX_GENIE_CONTEXT_LENGTH = 4096 +def _deep_merge(base: dict, overrides: dict) -> dict: + """Recursively merge *overrides* into *base*, returning a new dict. + + Nested dicts are merged rather than replaced, so only the keys present in + *overrides* are changed; all other keys from *base* are preserved. + """ + result = dict(base) + for k, v in overrides.items(): + if k in result and isinstance(result[k], dict) and isinstance(v, dict): + result[k] = _deep_merge(result[k], v) + else: + result[k] = v + return result + + class QairtEncapsulation(Pass): """Encapsulates a QAIRT DLC model with an onnx protobuf.""" @@ -49,6 +64,21 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon required=False, description="Opset name and version to be added in the generated context model", ), + "genie_overrides": PassConfigParam( + type_=dict, + default_value=None, + required=False, + description=( + "Deep-merged into the GenAIConfig before the Genie DLC is produced. " + "Use Python field names (underscores). Nested dicts are merged recursively — " + "only the specified keys are overridden; all other GenAIBuilder defaults are " + "preserved. Any field on GenAIConfig is valid: kv_dim, rope_theta, n_heads, " + "n_layer, n_embd, allow_async_init, enable_graph_switching, " + "positional_encoding (nested dict), etc. Note: top-level rope_theta and " + "rope_scaling are not forwarded by the Genie factory — use " + "positional_encoding.rope_theta to override RoPE theta in the DLC." + ), + ), } def _run_for_config( @@ -76,6 +106,13 @@ def _run_for_config( container: qairt_genai.LLMContainer = qairt_genai.LLMContainer.load(model.model_path) + if config.genie_overrides: + gen_ai_cfg = container._gen_ai_config + current = gen_ai_cfg.model_dump(mode="json", by_alias=False, exclude_none=True) + merged = _deep_merge(current, config.genie_overrides) + container._gen_ai_config = gen_ai_cfg.model_validate(merged) + logger.info("Applied genie_overrides to GenAIConfig: %s", list(config.genie_overrides.keys())) + # Input/Output metadata container.inputs = [("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])] container.outputs = [("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab_size"])] diff --git a/test/passes/qairt/test_encapsulation.py b/test/passes/qairt/test_encapsulation.py index 0fb45a8b7..2ef566bfa 100644 --- a/test/passes/qairt/test_encapsulation.py +++ b/test/passes/qairt/test_encapsulation.py @@ -894,3 +894,194 @@ def test_create_genai_config_provider_options_key_lowercase(tmp_path): assert len(provider_options) == 1 assert "qnn" in provider_options[0] assert "QNN" not in provider_options[0] + + +# --------------------------------------------------------------------------- +# _deep_merge unit tests +# --------------------------------------------------------------------------- + + +def test_deep_merge_flat(): + """Flat keys in overrides replace or add keys in base.""" + from olive.passes.qairt.encapsulation import _deep_merge + + result = _deep_merge({"a": 1, "b": 2}, {"b": 99, "c": 3}) + assert result == {"a": 1, "b": 99, "c": 3} + + +def test_deep_merge_nested_dicts_are_merged_not_replaced(): + """Nested dicts are recursively merged, preserving keys not in overrides.""" + from olive.passes.qairt.encapsulation import _deep_merge + + base = {"positional_encoding": {"type": "rope", "rope_dim": 64, "rope_theta": 10000.0}} + overrides = {"positional_encoding": {"rope_theta": 500000.0}} + result = _deep_merge(base, overrides) + assert result["positional_encoding"] == {"type": "rope", "rope_dim": 64, "rope_theta": 500000.0} + + +def test_deep_merge_nested_override_replaces_non_dict(): + """A dict override replaces a non-dict base value at the same key.""" + from olive.passes.qairt.encapsulation import _deep_merge + + result = _deep_merge({"a": 42}, {"a": {"nested": 1}}) + assert result == {"a": {"nested": 1}} + + +def test_deep_merge_base_unmodified(): + """_deep_merge does not mutate base.""" + from olive.passes.qairt.encapsulation import _deep_merge + + base = {"a": {"b": 1}} + overrides = {"a": {"b": 2}} + _deep_merge(base, overrides) + assert base["a"]["b"] == 1 + + +# --------------------------------------------------------------------------- +# genie_overrides integration tests +# --------------------------------------------------------------------------- + + +def test_encapsulation_default_config_includes_genie_overrides(mock_accelerator_spec): + """genie_overrides is present in _default_config with None default.""" + config = QairtEncapsulation._default_config(mock_accelerator_spec) # pylint: disable=protected-access + assert "genie_overrides" in config + assert config["genie_overrides"].default_value is None + assert config["genie_overrides"].required is False + + +def test_encapsulation_genie_overrides_applied(tmp_path, mock_qairt_model, mock_qairt_modules): + """When genie_overrides is set, _gen_ai_config is deep-merged before export.""" + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock() + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + # Represent the existing GenAIConfig state after LLMContainer.load() + initial_gen_ai_state = { + "context_length": 4096, + "n_vocab": 32000, + "bos_token": 1, + "eos_token": 2, + "tokenizer_path": str(tmp_path / "tokenizer.json"), + "kv_dim": None, + "positional_encoding": {"type": "rope", "rope_dim": 64}, + } + mock_container._gen_ai_config.model_dump.return_value = initial_gen_ai_state + + def mock_export(output_dir, export_format): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export.side_effect = mock_export + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + overrides = {"kv_dim": 128, "positional_encoding": {"rope_theta": 500000.0}} + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU", "genie_overrides": overrides}, + disable_search=True, + ) + + encap_pass.run(mock_qairt_model, str(output_path)) + + # model_dump was called to capture current state + mock_container._gen_ai_config.model_dump.assert_called_once_with(mode="json", by_alias=False, exclude_none=True) + # model_validate was called with the deep-merged result + expected_merged = { + **initial_gen_ai_state, + "kv_dim": 128, + "positional_encoding": {"type": "rope", "rope_dim": 64, "rope_theta": 500000.0}, + } + mock_container._gen_ai_config.model_validate.assert_called_once_with(expected_merged) + # _gen_ai_config was reassigned to the validated result + assert ( + mock_container._gen_ai_config + is not mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value._gen_ai_config + ) + + +def test_encapsulation_no_genie_overrides_leaves_gen_ai_config_untouched( + tmp_path, mock_qairt_model, mock_qairt_modules +): + """When genie_overrides is None, _gen_ai_config.model_dump is never called.""" + output_path = tmp_path / "output" + output_path.mkdir(parents=True, exist_ok=True) + + model_path = Path(mock_qairt_model.model_path) + (model_path / "config.json").write_text(json.dumps({"model_type": "llama", "hidden_size": 4096})) + (model_path / "generation_config.json").write_text(json.dumps({"eos_token_id": 2})) + + mock_container = MagicMock() + mock_container.inputs = [("input_ids", 7, ["batch_size", "sequence_length"])] + mock_container.outputs = [("logits", 1, ["batch_size", 1, "vocab_size"])] + + def mock_export(output_dir, export_format): + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "model.dlc").write_text("dummy dlc") + + mock_container.export.side_effect = mock_export + mock_qairt_modules["gen_ai_api"].LLMContainer.load.return_value = mock_container + + def mock_save_func(model_def, path): + import onnx + from onnx import TensorProto + + inp = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "seq"]) + out = onnx.helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["batch_size", 1, "vocab"]) + node = onnx.helper.make_node("Identity", inputs=["input_ids"], outputs=["logits"]) + graph = onnx.helper.make_graph([node], "g", [inp], [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 14)]) + onnx.save(model, path) + + with ( + patch("olive.passes.qairt.encapsulation.helper") as mock_helper, + patch("olive.passes.qairt.encapsulation.save", side_effect=mock_save_func), + patch("olive.passes.qairt.encapsulation.checker"), + ): + mock_helper.make_node.return_value = MagicMock() + mock_helper.make_attribute.return_value = MagicMock() + mock_helper.make_tensor_value_info.return_value = MagicMock() + mock_helper.make_graph.return_value = MagicMock() + mock_helper.make_opsetid.return_value = MagicMock() + mock_helper.make_model.return_value = MagicMock() + + encap_pass = create_pass_from_dict( + QairtEncapsulation, + {"backend": "CPU"}, + disable_search=True, + ) + + encap_pass.run(mock_qairt_model, str(output_path)) + + mock_container._gen_ai_config.model_dump.assert_not_called()