From 6b41505976a0219c2169e9c0dfd0908839c506b3 Mon Sep 17 00:00:00 2001 From: seongwoo Date: Mon, 22 Jun 2026 22:45:41 +0900 Subject: [PATCH 1/3] [quantization] Add a gemma wrapper for TextModel This commit adds a wrapper for gemma text model. TICO-DCO-1.0-Signed-off-by: seongwoo --- .../wrappers/gemma4/test_quant_text_model.py | 372 ++++++++ .../debug/wrapper_smoke/cases/gemma4.py | 231 ++++- .../wrapq/wrappers/gemma4/quant_text_model.py | 793 ++++++++++++++++-- tico/quantization/wrapq/wrappers/registry.py | 2 +- 4 files changed, 1330 insertions(+), 68 deletions(-) create mode 100644 test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py diff --git a/test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py b/test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py new file mode 100644 index 00000000..dd87097c --- /dev/null +++ b/test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py @@ -0,0 +1,372 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the Gemma4 text-model PTQ wrapper.""" + +import unittest + +import torch + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.mode import Mode + + +_SKIP_MSG = "required transformers Gemma4 modules are not installed" + +_GEMMA4_FULL_ROPE_PARAMETERS = { + "rope_type": "proportional", + "partial_rotary_factor": 0.25, + "rope_theta": 1_000_000.0, +} +_GEMMA4_SLIDING_ROPE_PARAMETERS = { + "rope_type": "default", + "rope_theta": 10_000.0, +} + + +def _has_gemma4() -> bool: + """Return whether the installed transformers package provides Gemma4.""" + try: + from transformers.models.gemma4.configuration_gemma4 import ( # noqa: F401 + Gemma4TextConfig, + ) + from transformers.models.gemma4.modeling_gemma4 import ( # noqa: F401 + Gemma4TextModel, + ) + except Exception: + return False + return True + + +def _rope_parameters_for_layer_types(layer_types: list[str]) -> dict[str, dict]: + """Return Gemma4 RoPE parameters matching the requested layer types.""" + params = {} + if "sliding_attention" in layer_types: + params["sliding_attention"] = dict(_GEMMA4_SLIDING_ROPE_PARAMETERS) + if "full_attention" in layer_types: + params["full_attention"] = dict(_GEMMA4_FULL_ROPE_PARAMETERS) + return params + + +def _make_text_config(**overrides): + """Create a tiny dense Gemma4 text config for synthetic text-model tests.""" + from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + + layer_types = list( + overrides.pop("layer_types", ["full_attention", "full_attention"]) + ) + kwargs = dict( + vocab_size=128, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=len(layer_types), + num_attention_heads=4, + num_key_value_heads=2, + num_global_key_value_heads=2, + head_dim=4, + global_head_dim=4, + attention_bias=False, + attention_dropout=0.0, + max_position_embeddings=128, + rms_norm_eps=1e-6, + sliding_window=8, + layer_types=layer_types, + rope_parameters=_rope_parameters_for_layer_types(layer_types), + hidden_size_per_layer_input=0, + attention_k_eq_v=False, + num_kv_shared_layers=0, + enable_moe_block=False, + use_cache=False, + ) + kwargs.update(overrides) + cfg = Gemma4TextConfig(**kwargs) + if not hasattr(cfg, "_attn_implementation"): + setattr(cfg, "_attn_implementation", "eager") + else: + cfg._attn_implementation = "eager" + return cfg + + +def _make_text_model(cfg=None): + """Create a floating-point Gemma4 text model in eval mode.""" + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel + + cfg = cfg if cfg is not None else _make_text_config() + return Gemma4TextModel(cfg).eval() + + +def _assert_close( + testcase: unittest.TestCase, actual: torch.Tensor, expected: torch.Tensor +): + """Assert that two tensors are numerically close for no-quant wrapper parity.""" + testcase.assertEqual(actual.shape, expected.shape) + testcase.assertTrue(torch.allclose(actual, expected, atol=1e-5, rtol=1e-5)) + + +@unittest.skipUnless(_has_gemma4(), _SKIP_MSG) +class TestQuantGemma4TextModel(unittest.TestCase): + """Validate dense Gemma4 text-model wrapper behavior.""" + + def setUp(self): + """Create deterministic test inputs.""" + torch.manual_seed(2026) + + def test_00_prepare_wraps_text_model_when_registered(self): + """Check that registry-based prepare wraps Gemma4TextModel.""" + from tico.quantization import prepare + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper + + prepared = prepare(_make_text_model(_make_text_config()), PTQConfig()) + + self.assertIsInstance(prepared, PTQWrapper) + self.assertIsInstance(prepared.wrapped, QuantGemma4TextModel) + + def test_no_quant_forward_matches_hf_text_model_with_input_ids(self): + """Check that the wrapper matches Hugging Face text-model output.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + + cfg = _make_text_config(layer_types=["sliding_attention", "full_attention"]) + fp_model = _make_text_model(cfg) + qmodel = QuantGemma4TextModel(fp_model).eval() + + input_ids = torch.randint(0, cfg.vocab_size, (2, 5)) + attention_mask = torch.ones_like(input_ids) + + with torch.no_grad(): + quant_out = qmodel( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + ) + fp_out = fp_model( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + ) + + _assert_close(self, quant_out.last_hidden_state, fp_out.last_hidden_state) + + def test_static_attention_mask_mapping_matches_hf_text_model(self): + """Check the static CPU-provided mask mapping path.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, + ) + + cfg = _make_text_config(layer_types=["sliding_attention", "full_attention"]) + fp_model = _make_text_model(cfg) + qmodel = QuantGemma4TextModel(fp_model).eval() + + input_ids = torch.randint(0, cfg.vocab_size, (1, 6)) + inputs_embeds = fp_model.embed_tokens(input_ids) + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0) + mask_kwargs = { + "config": cfg, + "inputs_embeds": inputs_embeds, + "attention_mask": torch.ones_like(input_ids), + "past_key_values": None, + "position_ids": position_ids, + } + mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + with torch.no_grad(): + quant_out = qmodel( + input_ids=input_ids, + attention_mask=mask_mapping, + position_ids=position_ids, + return_shared_kv_states=True, + return_dict=True, + ) + fp_out = fp_model( + input_ids=input_ids, + attention_mask=mask_mapping, + position_ids=position_ids, + return_shared_kv_states=True, + return_dict=True, + ) + + _assert_close(self, quant_out.last_hidden_state, fp_out.last_hidden_state) + self.assertIsNotNone(quant_out.shared_kv_states) + + def test_ple_path_matches_hf_with_input_ids(self): + """Check Hugging Face parity when Gemma4 PLE is enabled.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + + cfg = _make_text_config( + layer_types=["full_attention", "full_attention"], + hidden_size_per_layer_input=8, + ) + fp_model = _make_text_model(cfg) + qmodel = QuantGemma4TextModel(fp_model).eval() + + input_ids = torch.randint(0, cfg.vocab_size, (1, 4)) + + with torch.no_grad(): + quant_out = qmodel(input_ids=input_ids, return_dict=True) + fp_out = fp_model(input_ids=input_ids, return_dict=True) + + _assert_close(self, quant_out.last_hidden_state, fp_out.last_hidden_state) + + def test_ple_path_matches_hf_with_inputs_embeds_and_explicit_per_layer_inputs(self): + """Check multimodal-style PLE entry with explicit token-identity inputs.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + + cfg = _make_text_config( + layer_types=["full_attention", "full_attention"], + hidden_size_per_layer_input=8, + ) + fp_model = _make_text_model(cfg) + qmodel = QuantGemma4TextModel(fp_model).eval() + + input_ids = torch.randint(0, cfg.vocab_size, (1, 4)) + inputs_embeds = fp_model.embed_tokens(input_ids) + per_layer_inputs = fp_model.get_per_layer_inputs(input_ids, inputs_embeds) + + with torch.no_grad(): + quant_out = qmodel( + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + return_dict=True, + ) + fp_out = fp_model( + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + return_dict=True, + ) + + _assert_close(self, quant_out.last_hidden_state, fp_out.last_hidden_state) + + def test_shared_kv_text_model_returns_shared_state_when_requested(self): + """Check full text-model execution with one shared-KV consumer layer.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + + cfg = _make_text_config( + num_hidden_layers=2, + layer_types=["full_attention", "full_attention"], + num_kv_shared_layers=1, + ) + fp_model = _make_text_model(cfg) + qmodel = QuantGemma4TextModel(fp_model).eval() + + input_ids = torch.randint(0, cfg.vocab_size, (1, 5)) + + with torch.no_grad(): + quant_out = qmodel( + input_ids=input_ids, + return_shared_kv_states=True, + return_dict=True, + ) + fp_out = fp_model( + input_ids=input_ids, + return_shared_kv_states=True, + return_dict=True, + ) + + _assert_close(self, quant_out.last_hidden_state, fp_out.last_hidden_state) + self.assertIsNotNone(quant_out.shared_kv_states) + self.assertIn("full_attention", quant_out.shared_kv_states) + + def test_validation_errors_match_expected_contract(self): + """Check user-facing validation errors for invalid input combinations.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + + cfg = _make_text_config(hidden_size_per_layer_input=8) + qmodel = QuantGemma4TextModel(_make_text_model(cfg)).eval() + input_ids = torch.randint(0, cfg.vocab_size, (1, 3)) + inputs_embeds = qmodel.embed_tokens(input_ids) + per_layer_inputs = torch.randn( + 1, + 3, + cfg.num_hidden_layers, + cfg.hidden_size_per_layer_input, + ) + + with self.assertRaisesRegex(ValueError, "exactly one"): + qmodel(input_ids=input_ids, inputs_embeds=inputs_embeds) + + with self.assertRaisesRegex(ValueError, "per_layer_inputs"): + qmodel(input_ids=input_ids, per_layer_inputs=per_layer_inputs) + + qmodel.force_export = True + with self.assertRaisesRegex(NotImplementedError, "static masks"): + qmodel(input_ids=input_ids) + + def test_text_model_wrapper_does_not_own_export_adapter_hook(self): + """Check that TextModel export subgraphs are created outside this wrapper.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + + qmodel = QuantGemma4TextModel(_make_text_model(_make_text_config())).eval() + + self.assertFalse(hasattr(qmodel, "as_export_module")) + + def test_mode_transitions_and_observer_collection(self): + """Check calibration and quantization lifecycle for the text model.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + + cfg = _make_text_config(layer_types=["full_attention", "full_attention"]) + qmodel = QuantGemma4TextModel(_make_text_model(cfg), qcfg=PTQConfig()).eval() + self.assertIs(qmodel._mode, Mode.NO_QUANT) + + qmodel.enable_calibration() + self.assertIs(qmodel._mode, Mode.CALIB) + qmodel(input_ids=torch.randint(0, cfg.vocab_size, (1, 4))) + qmodel.freeze_qparams() + + self.assertIs(qmodel._mode, Mode.QUANT) + self.assertIsNotNone(qmodel.get_observer("inputs_embeds")) + self.assertIsNotNone(qmodel.get_observer("layers.0.self_attn.q_proj.act_in")) + + def test_moe_text_model_is_rejected_for_e2b_scope(self): + """Check that the E2B text-model wrapper rejects MoE configs explicitly.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + + cfg = _make_text_config( + enable_moe_block=True, + num_experts=2, + top_k_experts=1, + moe_intermediate_size=16, + ) + fp_model = _make_text_model(cfg) + + with self.assertRaisesRegex(ValueError, "dense decoder layers only"): + QuantGemma4TextModel(fp_model) + + +if __name__ == "__main__": + unittest.main() diff --git a/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py b/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py index b0781fad..07590885 100644 --- a/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py +++ b/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py @@ -130,6 +130,13 @@ def _attention_mask(seq_len: int, kv_len: int | None = None) -> torch.Tensor: return torch.zeros(1, 1, seq_len, kv_len) +def _causal_mask(seq_len: int, fill_value: float = -120.0) -> torch.Tensor: + """Create an additive causal mask with a large negative upper triangle.""" + mask = torch.zeros(1, 1, seq_len, seq_len) + blocked = torch.full_like(mask, float(fill_value)) + return torch.triu(blocked, diagonal=1) + + def _clone_value(value: Any) -> Any: """Clone tensors nested inside a small smoke-test value.""" if isinstance(value, torch.Tensor): @@ -151,6 +158,62 @@ def _clone_forward_input(sample: ForwardInput) -> ForwardInput: ) +def _sliding_window_causal_mask( + seq_len: int, + sliding_window: int, + *, + batch_size: int = 1, + dtype: torch.dtype = torch.float32, + device: torch.device | None = None, + fill_value: float = -120.0, +) -> torch.Tensor: + """Create a fixed-shape additive causal sliding-window mask. + + A query at position ``q`` can attend to keys in the inclusive interval + ``[max(0, q - sliding_window + 1), q]``. Future keys and keys older than + the configured window receive ``fill_value``. + + Parameters + ---------- + seq_len: + Static query and key/value sequence length. + sliding_window: + Number of visible tokens including the current query token. + batch_size: + Static batch size represented by the returned mask. + dtype: + Floating-point dtype of the additive mask. + device: + Device on which to create the mask. + fill_value: + Additive value assigned to blocked positions. + + Returns + ------- + torch.Tensor + A tensor with shape ``(batch_size, 1, seq_len, seq_len)``. + """ + if seq_len <= 0: + raise ValueError(f"seq_len must be positive, got {seq_len}.") + if sliding_window <= 0: + raise ValueError(f"sliding_window must be positive, got {sliding_window}.") + + query_positions = torch.arange(seq_len, device=device).view(seq_len, 1) + key_positions = torch.arange(seq_len, device=device).view(1, seq_len) + + future_positions = key_positions > query_positions + positions_before_window = key_positions < query_positions - sliding_window + 1 + blocked_positions = future_positions | positions_before_window + + mask = torch.zeros((seq_len, seq_len), dtype=dtype, device=device) + mask.masked_fill_(blocked_positions, float(fill_value)) + return ( + mask.view(1, 1, seq_len, seq_len) + .expand(batch_size, 1, seq_len, seq_len) + .contiguous() + ) + + class Gemma4BaseCase(WrapperSmokeCase): """Base class for Gemma4 E2B wrapper smoke cases.""" @@ -382,7 +445,10 @@ def _sample(self) -> ForwardInput: "position_embeddings": _text_rope( 1, self.seq_len, self.text_cfg.head_dim ), - "attention_mask": _attention_mask(self.seq_len), + "attention_mask": _causal_mask( + self.seq_len, + fill_value=float(self.ptq_config({}).attention_mask_fill_value), + ), "shared_kv_states": {}, }, ) @@ -453,6 +519,109 @@ class Gemma4TextDecoderLayerPrefillCase(Gemma4TextDecoderLayerBaseCase): export_mode = "prefill" +class Gemma4TextDecoderLayerSlidingPrefillCase(Gemma4TextDecoderLayerBaseCase): + """Smoke case for one Gemma4 sliding-attention decoder layer. + + The case creates a two-layer text configuration and selects layer zero. + This keeps layer zero as sliding attention while satisfying Gemma4's + requirement that the final decoder layer use full attention. + + The sliding window is intentionally smaller than the sequence length so + the input covers both future-token masking and left-side window masking. + """ + + name = "gemma4_text_decoder_layer_sliding_prefill" + description = ( + "Quantize one tiny Gemma4 sliding-attention decoder layer with " + "a causal sliding-window mask." + ) + tags = ( + "gemma4", + "e2b", + "text", + "decoder_layer", + "prefill", + "sliding", + ) + + layer_types = ("sliding_attention", "full_attention") + layer_idx = 0 + export_mode = "prefill" + + seq_len = 8 + sliding_window = 4 + mask_fill_value = -120.0 + + def ptq_config(self, cfg: Mapping[str, Any]) -> Any: + """Build a PTQ config matching the sample mask fill value.""" + from tico.quantization.config.ptq import PTQConfig + + return PTQConfig( + model_args={"profile": "reference_eval"}, + attention_mask_fill_value=self.mask_fill_value, + ) + + def build( + self, + cfg: Mapping[str, Any], + ) -> tuple[torch.nn.Module, torch.nn.Module]: + """Build a sliding-attention decoder layer and reference copy.""" + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextDecoderLayer + + torch.manual_seed(123) + self.text_cfg = _make_text_config(layer_types=self.layer_types) + self.text_cfg.sliding_window = self.sliding_window + + module = Gemma4TextDecoderLayer( + self.text_cfg, + layer_idx=self.layer_idx, + ).eval() + + if not module.self_attn.is_sliding: + raise RuntimeError( + "The smoke case did not build a sliding-attention layer." + ) + if module.self_attn.sliding_window != self.sliding_window: + raise RuntimeError( + "The decoder layer does not use the requested sliding window: " + f"expected {self.sliding_window}, " + f"got {module.self_attn.sliding_window}." + ) + + return module, clone_module(module) + + def _sample(self) -> ForwardInput: + """Create one fixed-shape sliding-window prefill sample.""" + batch_size = 1 + hidden = torch.randn( + batch_size, + self.seq_len, + self.text_cfg.hidden_size, + ) + attention_mask = _sliding_window_causal_mask( + self.seq_len, + self.sliding_window, + batch_size=batch_size, + dtype=hidden.dtype, + device=hidden.device, + fill_value=self.mask_fill_value, + ) + + return ForwardInput( + (), + { + "hidden_states": hidden, + "position_embeddings": _text_rope( + batch_size, + self.seq_len, + self.text_cfg.head_dim, + ), + "attention_mask": attention_mask, + "shared_kv_states": {}, + }, + ) + + class Gemma4TextDecoderLayerDecodeCase(Gemma4TextDecoderLayerBaseCase): """Smoke case for one tiny Gemma4 text decoder layer in decode mode.""" @@ -545,6 +714,64 @@ def _sample(self) -> ForwardInput: ) +class Gemma4TextModelCase(Gemma4BaseCase): + """Smoke case for one tiny dense Gemma4 text model.""" + + name = "gemma4_text_model" + description = ( + "Quantize one tiny dense Gemma4 text model with full and sliding attention." + ) + tags = ("gemma4", "e2b", "text", "model") + max_mean_abs_diff = 3.0 + seq_len = 8 + + def ptq_config(self, cfg: Mapping[str, Any]) -> Any: + """Build the PTQ config used by Gemma4 text-model smoke checks.""" + from tico.quantization.config.ptq import PTQConfig + + return PTQConfig(model_args={"profile": "reference_eval"}) + + def build(self, cfg: Mapping[str, Any]) -> tuple[torch.nn.Module, torch.nn.Module]: + """Build a tiny Gemma4 text model and reference copy.""" + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel + + torch.manual_seed(123) + self.text_cfg = _make_text_config( + layer_types=("sliding_attention", "full_attention"), + ) + module = Gemma4TextModel(self.text_cfg).eval() + return module, clone_module(module) + + def _sample(self) -> ForwardInput: + """Create one synthetic Gemma4 text-model input.""" + input_ids = torch.randint(0, self.text_cfg.vocab_size, (1, self.seq_len)) + attention_mask = torch.ones_like(input_ids) + return ForwardInput( + (), + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "return_dict": True, + }, + ) + + def calibration_inputs( + self, + prepared: torch.nn.Module, + cfg: Mapping[str, Any], + ) -> list[ForwardInput]: + """Create Gemma4 text-model calibration samples.""" + return [self._sample() for _ in range(3)] + + def eval_input( + self, + prepared: torch.nn.Module, + cfg: Mapping[str, Any], + ) -> ForwardInput: + """Create the Gemma4 text-model evaluation sample.""" + return self._sample() + + def _make_vision_config() -> Any: """Create a tiny Gemma4 vision config for synthetic smoke tests.""" from transformers.models.gemma4.configuration_gemma4 import Gemma4VisionConfig @@ -930,10 +1157,12 @@ def export_input( Gemma4TextAttentionKEqVCase(), Gemma4TextAttentionSharedKVCase(), Gemma4TextDecoderLayerPrefillCase(), + Gemma4TextDecoderLayerSlidingPrefillCase(), Gemma4TextDecoderLayerDecodeCase(), Gemma4TextDecoderLayerSharedKVCase(), Gemma4TextScaledWordEmbeddingCase(), Gemma4VisionPatchEmbedderCase(), + Gemma4TextModelCase(), Gemma4VisionAttentionCase(), Gemma4VisionEncoderLayerCase(), Gemma4VisionPoolerCase(), diff --git a/tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py b/tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py index c8c4f43d..65c6a92c 100644 --- a/tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py +++ b/tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""PTQ wrapper for the dense Gemma4 E2B text model.""" + from collections import UserDict -from typing import Iterable, Optional +from collections.abc import Mapping +from typing import Any, Iterable, Optional import torch import torch.nn as nn from tico.quantization.config.ptq import PTQConfig -from tico.quantization.wrapq.utils.utils import join_name +from tico.quantization.wrapq.utils.utils import get_model_arg, join_name from tico.quantization.wrapq.wrappers.gemma4.utils import assert_gemma4_e2b_no_moe from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase @@ -28,7 +31,26 @@ @try_register("transformers.models.gemma4.modeling_gemma4.Gemma4TextModel") class QuantGemma4TextModel(QuantModuleBase): - """PTQ wrapper skeleton for the Gemma4 E2B text model.""" + """Quantization wrapper for the dense Gemma4 E2B text model. + + Full and sliding causal masks are built from bounded static templates owned + by this wrapper. The forward path does not call Hugging Face mask factories. + Masked entries use ``PTQConfig.attention_mask_fill_value`` instead of the + minimum finite value of the activation dtype, which keeps affine observer + ranges usable around unmasked attention logits. + + The static template capacity can be selected with + ``PTQConfig.model_args["text"]["max_seq"]``. When it is omitted, the + wrapper uses the smaller of the model context limit and 2048 tokens to avoid + allocating an unexpectedly large quadratic mask. + + The E2B scope intentionally rejects MoE. Cache allocation and writes, + sampling, and CPU/NPU orchestration remain runtime responsibilities. + """ + + force_export: bool = False + _DEFAULT_STATIC_MAX_SEQ = 2048 + _SUPPORTED_LAYER_TYPES = frozenset(("full_attention", "sliding_attention")) def __init__( self, @@ -39,13 +61,22 @@ def __init__( ): assert_gemma4_e2b_no_moe(fp_model) super().__init__(qcfg, fp_name=fp_name) + self.module = fp_model self.config = fp_model.config - self.unique_layer_types = set(self.config.layer_types) - self.hidden_size_per_layer_input = getattr( - fp_model, "hidden_size_per_layer_input", None + self.padding_idx = getattr(fp_model, "padding_idx", None) + self.vocab_size = getattr(fp_model, "vocab_size", self.config.vocab_size) + self.unique_layer_types = tuple(sorted(set(self.config.layer_types))) + self.hidden_size_per_layer_input = int( + getattr(fp_model, "hidden_size_per_layer_input", 0) or 0 ) + unsupported = set(self.unique_layer_types) - self._SUPPORTED_LAYER_TYPES + if unsupported: + raise ValueError( + "Unsupported Gemma4 text layer types: " f"{sorted(unsupported)}." + ) + self.embed_tokens = PTQWrapper( fp_model.embed_tokens, qcfg=qcfg.child("embed_tokens") if qcfg else None, @@ -67,6 +98,9 @@ def __init__( fp_name=join_name(fp_name, "norm"), ) self.rotary_emb = fp_model.rotary_emb + self.gradient_checkpointing = bool( + getattr(fp_model, "gradient_checkpointing", False) + ) self.embed_tokens_per_layer: Optional[nn.Module] = None self.per_layer_model_projection: Optional[nn.Module] = None @@ -89,99 +123,726 @@ def __init__( qcfg=qcfg.child("per_layer_projection_norm") if qcfg else None, fp_name=join_name(fp_name, "per_layer_projection_norm"), ) - self.per_layer_input_scale = fp_model.per_layer_input_scale - self.per_layer_model_projection_scale = ( + self.per_layer_input_scale = float(fp_model.per_layer_input_scale) + self.per_layer_model_projection_scale = float( fp_model.per_layer_model_projection_scale ) - self.obs_inputs_embeds = self._make_obs("inputs_embeds") - self.obs_per_layer_inputs = self._make_obs("per_layer_inputs") + self.static_max_seq = self._resolve_static_max_seq() + self._register_static_templates(fp_model) + + mk = self._make_obs + self.obs_inputs_embeds = mk("inputs_embeds") + self.obs_attention_masks = nn.ModuleDict( + { + layer_type: mk(f"attention_mask_{layer_type}") + for layer_type in self.unique_layer_types + } + ) + self.obs_position_cos = nn.ModuleDict( + { + layer_type: mk(f"position_embeddings_{layer_type}_cos") + for layer_type in self.unique_layer_types + } + ) + self.obs_position_sin = nn.ModuleDict( + { + layer_type: mk(f"position_embeddings_{layer_type}_sin") + for layer_type in self.unique_layer_types + } + ) + + self.obs_per_layer_token_inputs = None + self.obs_per_layer_projection = None + self.obs_per_layer_inputs = None + if self.hidden_size_per_layer_input: + self.obs_per_layer_token_inputs = mk("per_layer_token_inputs") + self.obs_per_layer_projection = mk("per_layer_projection") + self.obs_per_layer_inputs = mk("per_layer_inputs") + + def _resolve_static_max_seq(self) -> int: + """Resolve the capacity of static mask, position, and RoPE templates.""" + configured = get_model_arg(self.qcfg, "text", "max_seq", default=None) + if configured is None: + configured = get_model_arg(self.qcfg, "max_seq", default=None) + + model_capacity = int(self.config.max_position_embeddings) + max_seq = ( + min(model_capacity, self._DEFAULT_STATIC_MAX_SEQ) + if configured is None + else int(configured) + ) + if max_seq <= 0: + raise ValueError(f"Gemma4 text max_seq must be positive, got {max_seq}.") + if max_seq > model_capacity: + raise ValueError( + "Gemma4 text max_seq exceeds max_position_embeddings: " + f"max_seq={max_seq}, model_capacity={model_capacity}." + ) + return max_seq + + @staticmethod + def _mask_template_name(layer_type: str) -> str: + """Return the mask-template buffer name for a layer type.""" + return f"{layer_type}_attention_mask_template" + + @staticmethod + def _cos_template_name(layer_type: str) -> str: + """Return the cosine-template buffer name for a layer type.""" + return f"{layer_type}_cos_template" + + @staticmethod + def _sin_template_name(layer_type: str) -> str: + """Return the sine-template buffer name for a layer type.""" + return f"{layer_type}_sin_template" + + def _build_full_attention_mask_template(self, device: torch.device) -> torch.Tensor: + """Build a bounded full causal mask template.""" + mask = torch.full( + (1, 1, self.static_max_seq, self.static_max_seq), + float(self.qcfg.attention_mask_fill_value), + dtype=torch.float32, + device=device, + ) + return mask.triu_(1) + + def _build_sliding_attention_mask_template( + self, device: torch.device + ) -> torch.Tensor: + """Build a bounded causal sliding-window mask template.""" + window = int(getattr(self.config, "sliding_window", 0) or 0) + if window <= 0: + raise ValueError( + "Gemma4 sliding_attention requires a positive sliding_window, " + f"got {window}." + ) + + query = torch.arange(self.static_max_seq, device=device).view(-1, 1) + key = torch.arange(self.static_max_seq, device=device).view(1, -1) + keep = (key <= query) & (key > query - window) + + mask = torch.full( + (self.static_max_seq, self.static_max_seq), + float(self.qcfg.attention_mask_fill_value), + dtype=torch.float32, + device=device, + ) + mask.masked_fill_(keep, 0.0) + return mask.unsqueeze(0).unsqueeze(0) + + def _register_static_templates(self, fp_model: nn.Module) -> None: + """Register bounded masks, position ids, and per-layer-type RoPE tables.""" + embedding_weight = fp_model.embed_tokens.weight + device = embedding_weight.device + dtype = embedding_weight.dtype + + position_ids = torch.arange( + self.static_max_seq, dtype=torch.long, device=device + ).unsqueeze(0) + self.register_buffer("position_ids_template", position_ids, persistent=False) + + for layer_type in self.unique_layer_types: + if layer_type == "full_attention": + mask = self._build_full_attention_mask_template(device) + elif layer_type == "sliding_attention": + mask = self._build_sliding_attention_mask_template(device) + else: + raise AssertionError(f"Unexpected layer type: {layer_type!r}.") + self.register_buffer( + self._mask_template_name(layer_type), mask, persistent=False + ) + + dummy = torch.empty(0, device=device, dtype=dtype) + with torch.no_grad(): + for layer_type in self.unique_layer_types: + cos, sin = self.rotary_emb(dummy, position_ids, layer_type) + if cos.shape[:2] != (1, self.static_max_seq): + raise RuntimeError( + "Unexpected Gemma4 RoPE template shape: " + f"layer_type={layer_type!r}, shape={tuple(cos.shape)}." + ) + if sin.shape != cos.shape: + raise RuntimeError( + "Gemma4 RoPE sine and cosine shapes differ: " + f"cos={tuple(cos.shape)}, sin={tuple(sin.shape)}." + ) + self.register_buffer( + self._cos_template_name(layer_type), cos, persistent=False + ) + self.register_buffer( + self._sin_template_name(layer_type), sin, persistent=False + ) + + @staticmethod + def _is_torch_export_context() -> bool: + """Return whether execution is inside a torch compile or export context.""" + compiler = getattr(torch, "compiler", None) + is_compiling = getattr(compiler, "is_compiling", None) + return bool(is_compiling()) if callable(is_compiling) else False + + def _requires_static_inputs(self) -> bool: + """Return whether precomputed RoPE tables should be used.""" + return bool(self.force_export or self._is_torch_export_context()) + + @staticmethod + def _unwrap_fp_module(module: nn.Module) -> nn.Module: + """Return the floating-point module hidden behind PTQWrapper layers.""" + wrapped = getattr(module, "wrapped", module) + return getattr(wrapped, "module", wrapped) + + @staticmethod + def _past_seen_tokens(past_key_values: Any) -> int: + """Return the number of tokens already stored in the cache.""" + if past_key_values is None: + return 0 + get_seq_length = getattr(past_key_values, "get_seq_length", None) + return int(get_seq_length()) if callable(get_seq_length) else 0 + + @staticmethod + def _normalize_position_ids(position_ids: torch.Tensor) -> torch.Tensor: + """Normalize position ids to shape ``(B, S)``.""" + if position_ids.dim() == 1: + return position_ids.unsqueeze(0) + if position_ids.dim() != 2: + raise ValueError( + "Gemma4 position_ids must have rank 1 or 2, " + f"got shape={tuple(position_ids.shape)}." + ) + return position_ids + + def _make_position_ids( + self, + *, + inputs_embeds: torch.Tensor, + past_key_values: Any, + position_ids: Optional[torch.Tensor], + ) -> torch.Tensor: + """Create or validate position ids for the current model step.""" + batch_size, seq_len = inputs_embeds.shape[:2] + if position_ids is not None: + position_ids = self._normalize_position_ids(position_ids).to( + device=inputs_embeds.device + ) + if position_ids.size(-1) != seq_len: + raise ValueError( + "position_ids length does not match inputs_embeds: " + f"position_ids={position_ids.size(-1)}, seq_len={seq_len}." + ) + if position_ids.size(0) not in (1, batch_size): + raise ValueError( + "position_ids batch must be 1 or match inputs_embeds: " + f"position_ids={position_ids.size(0)}, batch={batch_size}." + ) + return position_ids.expand(batch_size, -1) + + start = self._past_seen_tokens(past_key_values) + end = start + seq_len + if end > self.static_max_seq: + raise ValueError( + "Gemma4 position range exceeds static_max_seq: " + f"end={end}, static_max_seq={self.static_max_seq}." + ) + return ( + self.position_ids_template[:, start:end] + .to(device=inputs_embeds.device) + .expand(batch_size, -1) + ) + + def _slice_mask_template( + self, + layer_type: str, + *, + q_len: int, + kv_len: int, + past_seen_tokens: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Slice a bounded static mask for the current query and KV spans.""" + row_end = past_seen_tokens + q_len + if row_end > self.static_max_seq or kv_len > self.static_max_seq: + raise ValueError( + "Gemma4 attention span exceeds static_max_seq: " + f"q_end={row_end}, kv_len={kv_len}, " + f"static_max_seq={self.static_max_seq}." + ) + template = getattr(self, self._mask_template_name(layer_type)) + return template[..., past_seen_tokens:row_end, :kv_len].to( + device=device, dtype=dtype + ) + + @staticmethod + def _normalize_explicit_mask_shape( + mask: torch.Tensor, + *, + batch_size: int, + q_len: int, + kv_len: int, + ) -> torch.Tensor: + """Normalize an explicit mask to shape ``(B, 1, Q, K)``.""" + if mask.dim() == 3: + mask = mask.unsqueeze(1) + if mask.dim() != 4: + raise ValueError( + "Explicit Gemma4 masks must have rank 3 or 4, " + f"got shape={tuple(mask.shape)}." + ) + if mask.size(0) not in (1, batch_size): + raise ValueError( + "Explicit mask batch must be 1 or match inputs: " + f"mask_batch={mask.size(0)}, batch={batch_size}." + ) + if mask.size(1) != 1: + raise ValueError( + "Per-head Gemma4 masks are unsupported; expected head dim 1, " + f"got shape={tuple(mask.shape)}." + ) + if mask.size(-1) < kv_len: + raise ValueError( + "Explicit mask is shorter than the KV span: " + f"mask_k={mask.size(-1)}, kv_len={kv_len}." + ) + if mask.size(-1) > kv_len: + mask = mask[..., :kv_len] + if mask.size(-2) < q_len: + raise ValueError( + "Explicit mask has fewer query rows than required: " + f"mask_q={mask.size(-2)}, q_len={q_len}." + ) + if mask.size(-2) > q_len: + mask = mask[..., -q_len:, :] + return mask + + def _bounded_additive_mask( + self, + mask: torch.Tensor, + *, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Map an additive mask to zero or the configured finite fill value.""" + mask = mask.to(device=device, dtype=dtype) + fill = torch.full_like(mask, float(self.qcfg.attention_mask_fill_value)) + return torch.where(mask < 0, fill, torch.zeros_like(mask)) + + def _normalize_attention_mask_for_layer( + self, + layer_type: str, + attention_mask: Optional[torch.Tensor], + *, + inputs_embeds: torch.Tensor, + past_key_values: Any, + ) -> torch.Tensor: + """Build one bounded additive mask directly from static templates.""" + batch_size, q_len = inputs_embeds.shape[:2] + past_seen_tokens = self._past_seen_tokens(past_key_values) + kv_len = past_seen_tokens + q_len + device = inputs_embeds.device + dtype = inputs_embeds.dtype + + causal_mask = self._slice_mask_template( + layer_type, + q_len=q_len, + kv_len=kv_len, + past_seen_tokens=past_seen_tokens, + device=device, + dtype=dtype, + ) + if attention_mask is None: + return causal_mask + + if attention_mask.dim() == 2: + if attention_mask.size(0) != batch_size: + raise ValueError( + "2D attention-mask batch does not match inputs: " + f"mask_batch={attention_mask.size(0)}, batch={batch_size}." + ) + if attention_mask.size(1) == q_len and past_seen_tokens > 0: + prefix = torch.ones( + batch_size, + past_seen_tokens, + device=attention_mask.device, + dtype=attention_mask.dtype, + ) + attention_mask = torch.cat((prefix, attention_mask), dim=-1) + if attention_mask.size(1) != kv_len: + raise ValueError( + "2D attention-mask length does not match KV length: " + f"mask_len={attention_mask.size(1)}, kv_len={kv_len}." + ) + + keep = attention_mask.to(device=device) != 0 + padding_mask = torch.zeros( + batch_size, 1, 1, kv_len, device=device, dtype=dtype + ).masked_fill( + ~keep[:, None, None, :], + float(self.qcfg.attention_mask_fill_value), + ) + return torch.clamp( + causal_mask + padding_mask, + min=float(self.qcfg.attention_mask_fill_value), + max=0.0, + ) + + explicit = self._normalize_explicit_mask_shape( + attention_mask, + batch_size=batch_size, + q_len=q_len, + kv_len=kv_len, + ) + if torch.is_floating_point(explicit): + return self._bounded_additive_mask(explicit, device=device, dtype=dtype) + + keep = explicit.to(device=device).bool() + return torch.zeros(keep.shape, device=device, dtype=dtype).masked_fill( + ~keep, float(self.qcfg.attention_mask_fill_value) + ) + + def _create_attention_mask_mapping( + self, + *, + attention_mask: Any, + inputs_embeds: torch.Tensor, + past_key_values: Any, + ) -> dict[str, torch.Tensor]: + """Create bounded masks for every layer type without HF mask factories.""" + if isinstance(attention_mask, Mapping): + missing = [ + layer_type + for layer_type in self.unique_layer_types + if layer_type not in attention_mask + ] + if missing: + raise KeyError( + "Gemma4 mask mapping is missing layer types: " f"{missing}." + ) + return { + layer_type: self._normalize_attention_mask_for_layer( + layer_type, + attention_mask[layer_type], + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + ) + for layer_type in self.unique_layer_types + } + + # The complete TextModel is not the static NPU export unit. Preserve the + # existing export contract and require CPU-provided masks in that mode. + if self._requires_static_inputs(): + raise NotImplementedError( + "QuantGemma4TextModel static/export mode requires static masks " + "as a dict keyed by layer type." + ) + + return { + layer_type: self._normalize_attention_mask_for_layer( + layer_type, + attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + ) + for layer_type in self.unique_layer_types + } + + def _observe_attention_mask_mapping( + self, mapping: Mapping[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Observe each bounded attention mask.""" + return { + layer_type: self._fq(mask, self.obs_attention_masks[layer_type]) + for layer_type, mask in mapping.items() + } + + def _make_position_embeddings( + self, + *, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Any, + use_static_templates: bool, + ) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: + """Create or slice observed RoPE tables for each layer type.""" + batch_size, seq_len = hidden_states.shape[:2] + if use_static_templates: + start = self._past_seen_tokens(past_key_values) + end = start + seq_len + if end > self.static_max_seq: + raise ValueError( + "Gemma4 RoPE span exceeds static_max_seq: " + f"end={end}, static_max_seq={self.static_max_seq}." + ) + + outputs = {} + for layer_type in self.unique_layer_types: + if use_static_templates: + cos = getattr(self, self._cos_template_name(layer_type))[ + :, start:end, : + ] + sin = getattr(self, self._sin_template_name(layer_type))[ + :, start:end, : + ] + cos = cos.to( + device=hidden_states.device, dtype=hidden_states.dtype + ).expand(batch_size, -1, -1) + sin = sin.to( + device=hidden_states.device, dtype=hidden_states.dtype + ).expand(batch_size, -1, -1) + else: + cos, sin = self.rotary_emb(hidden_states, position_ids, layer_type) + outputs[layer_type] = ( + self._fq(cos, self.obs_position_cos[layer_type]), + self._fq(sin, self.obs_position_sin[layer_type]), + ) + return outputs + + def _reverse_input_ids_from_embeddings( + self, inputs_embeds: torch.Tensor + ) -> torch.Tensor: + """Recover input ids from exact main embeddings for PLE compatibility.""" + embedding = self._unwrap_fp_module(self.embed_tokens) + weight = embedding.weight.to( + device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + scale = getattr(embedding, "embed_scale", None) + if scale is not None: + weight = weight * scale.to( + device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + + with torch.no_grad(): + matches = (inputs_embeds[:, :, None, :] == weight[None, None, :, :]).all( + dim=-1 + ) + indices = matches.nonzero() + expected = inputs_embeds.shape[0] * inputs_embeds.shape[1] + if indices.size(0) != expected: + raise RuntimeError( + "Gemma4 PLE could not recover input_ids from inputs_embeds. " + "Provide input_ids or explicit per_layer_inputs." + ) + try: + return indices[:, 2].view(inputs_embeds.shape[:2]) + except RuntimeError as exc: + raise RuntimeError( + "Gemma4 PLE recovered a non-rectangular input-id layout." + ) from exc + + def get_per_layer_inputs( + self, + input_ids: Optional[torch.Tensor], + inputs_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + """Compute the token-identity component of per-layer embeddings.""" + if not self.hidden_size_per_layer_input: + raise RuntimeError("Per-layer embeddings are disabled for this config.") + if self.embed_tokens_per_layer is None: + raise RuntimeError("Gemma4 PLE embedding is not initialized.") + if input_ids is None: + if inputs_embeds is None: + raise ValueError("inputs_embeds is required when input_ids is None.") + input_ids = self._reverse_input_ids_from_embeddings(inputs_embeds) + + result = self.embed_tokens_per_layer(input_ids).reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + if self.obs_per_layer_token_inputs is None: + raise RuntimeError("Gemma4 PLE token observer is not initialized.") + return self._fq(result, self.obs_per_layer_token_inputs) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Project and combine Gemma4 context-aware per-layer inputs.""" + if not self.hidden_size_per_layer_input: + raise RuntimeError("Per-layer projection is disabled for this config.") + projection = self.per_layer_model_projection + projection_norm = self.per_layer_projection_norm + if projection is None or projection_norm is None: + raise RuntimeError("Gemma4 PLE projection modules are not initialized.") + if self.obs_per_layer_projection is None or self.obs_per_layer_inputs is None: + raise RuntimeError("Gemma4 PLE observers are not initialized.") + + projected = projection(inputs_embeds) * self.per_layer_model_projection_scale + projected = projected.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + projected = projection_norm(projected) + projected = self._fq(projected, self.obs_per_layer_projection) + if per_layer_inputs is None: + return self._fq(projected, self.obs_per_layer_inputs) + combined = (projected + per_layer_inputs) * self.per_layer_input_scale + return self._fq(combined, self.obs_per_layer_inputs) + + @staticmethod + def _unwrap_layer_output(output: Any) -> tuple[torch.Tensor, Optional[Any]]: + """Extract hidden states and optional cache output from a layer result.""" + if not isinstance(output, tuple): + return output, None + if not output: + raise RuntimeError("Gemma4 decoder layer returned an empty tuple.") + return output[0], output[1] if len(output) > 1 else None + + @staticmethod + def _output_cls(): + """Return the Hugging Face Gemma4 text output class.""" + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4TextModelOutputWithPast, + ) + + return Gemma4TextModelOutputWithPast def forward( self, input_ids: Optional[torch.Tensor] = None, - attention_mask=None, + attention_mask: Any = None, position_ids: Optional[torch.Tensor] = None, - past_key_values=None, + past_key_values: Any = None, inputs_embeds: Optional[torch.Tensor] = None, per_layer_inputs: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, + return_dict: bool = True, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, **kwargs, ): - """Run the wrapped text model. - - TODO: Replace HF mask creation with CPU-provided static mask mapping in - the static runtime path. This method remains HF-compatible for wrapper - smoke tests and calibration. - """ - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("Specify exactly one of input_ids or inputs_embeds.") - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = self._fq(inputs_embeds, self.obs_inputs_embeds) + """Run Gemma4 text inference with direct bounded mask construction.""" + if (input_ids is None) == (inputs_embeds is None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds." + ) + if input_ids is not None and per_layer_inputs is not None: + raise ValueError( + "You cannot specify per_layer_inputs if input_ids is provided." + ) + if per_layer_inputs is not None and not self.hidden_size_per_layer_input: + raise ValueError("per_layer_inputs was provided, but PLE is disabled.") - if self.hidden_size_per_layer_input and per_layer_inputs is None: - per_layer_inputs = self.project_per_layer_inputs(inputs_embeds) - if per_layer_inputs is not None: - per_layer_inputs = self._fq(per_layer_inputs, self.obs_per_layer_inputs) + return_shared_kv_states = bool(kwargs.pop("return_shared_kv_states", False)) + output_hidden_states = bool(output_hidden_states) + output_attentions = bool(output_attentions) + position_ids_were_provided = position_ids is not None - if position_ids is None: - position_ids = torch.arange( - inputs_embeds.shape[1], device=inputs_embeds.device - ).unsqueeze(0) + if use_cache is None: + use_cache = bool(getattr(self.config, "use_cache", False)) + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + from transformers.cache_utils import DynamicCache - if not isinstance(attention_mask, dict): - raise NotImplementedError( - "QuantGemma4TextModel expects static attention mask mapping for the first implementation." + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + else: + inputs_embeds = self._fq(inputs_embeds, self.obs_inputs_embeds) + + if self.hidden_size_per_layer_input: + if per_layer_inputs is None: + per_layer_inputs = self.get_per_layer_inputs(input_ids, inputs_embeds) + else: + if self.obs_per_layer_token_inputs is None: + raise RuntimeError("Gemma4 PLE token observer is not initialized.") + per_layer_inputs = self._fq( + per_layer_inputs, self.obs_per_layer_token_inputs + ) + per_layer_inputs = self.project_per_layer_inputs( + inputs_embeds, per_layer_inputs ) - position_embeddings = { - layer_type: self.rotary_emb(inputs_embeds, position_ids, layer_type) - for layer_type in self.unique_layer_types - } + position_ids = self._make_position_ids( + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + position_ids=position_ids, + ) + mask_mapping = self._create_attention_mask_mapping( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + ) + mask_mapping = self._observe_attention_mask_mapping(mask_mapping) hidden_states = inputs_embeds + position_embeddings = self._make_position_embeddings( + hidden_states=hidden_states, + position_ids=position_ids, + past_key_values=past_key_values, + use_static_templates=( + self._requires_static_inputs() and not position_ids_were_provided + ), + ) + shared_kv_states = kwargs.pop("shared_kv_states", UserDict()) - for i, decoder_layer in enumerate(self.layers): + if shared_kv_states is None: + shared_kv_states = UserDict() + + all_hidden_states = () if output_hidden_states else None + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore[operator] + layer_type = self.config.layer_types[i] per_layer_input = ( per_layer_inputs[:, :, i, :] if per_layer_inputs is not None else None ) - hidden_states = decoder_layer( + layer_output = decoder_layer( hidden_states, per_layer_input=per_layer_input, - shared_key_value=shared_kv_states.get(layer_type), + shared_kv_states=shared_kv_states, position_embeddings=position_embeddings[layer_type], - attention_mask=attention_mask[layer_type], - past_key_value=None, - use_cache=False, + attention_mask=mask_mapping[layer_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=bool(use_cache), **kwargs, ) + hidden_states, cache_output = self._unwrap_layer_output(layer_output) + if cache_output is not None and hasattr(cache_output, "get_seq_length"): + past_key_values = cache_output + hidden_states = self.norm(hidden_states) - return hidden_states + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) # type: ignore[operator] - def project_per_layer_inputs(self, inputs_embeds: torch.Tensor) -> torch.Tensor: - """Compute the context-aware per-layer input projection for Gemma4 PLE.""" - if not self.hidden_size_per_layer_input: - raise RuntimeError( - "Per-layer input projection is not enabled for this Gemma4 config." - ) - per_layer_model_projection = self.per_layer_model_projection - per_layer_projection_norm = self.per_layer_projection_norm - if per_layer_model_projection is None or per_layer_projection_norm is None: - raise RuntimeError("Gemma4 PLE projection modules are not initialized.") - per_layer_projection = ( - per_layer_model_projection(inputs_embeds) - * self.per_layer_model_projection_scale - ) - per_layer_projection = per_layer_projection.reshape( - *inputs_embeds.shape[:-1], - self.config.num_hidden_layers, - self.hidden_size_per_layer_input, + attentions = None + if not return_dict: + output: tuple[Any, ...] = (hidden_states,) + if use_cache: + output = output + (past_key_values,) + if output_hidden_states: + output = output + (all_hidden_states,) + if output_attentions: + output = output + (attentions,) + if return_shared_kv_states: + output = output + (shared_kv_states,) + return output + + return self._output_cls()( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=attentions, + shared_kv_states=(shared_kv_states if return_shared_kv_states else None), ) - per_layer_projection = per_layer_projection_norm(per_layer_projection) - return per_layer_projection def _all_observers(self) -> Iterable: """Return observers owned directly by this wrapper.""" - return (self.obs_inputs_embeds, self.obs_per_layer_inputs) + observers = [ + self.obs_inputs_embeds, + *tuple(self.obs_attention_masks.values()), + *tuple(self.obs_position_cos.values()), + *tuple(self.obs_position_sin.values()), + ] + observers.extend( + observer + for observer in ( + self.obs_per_layer_token_inputs, + self.obs_per_layer_projection, + self.obs_per_layer_inputs, + ) + if observer is not None + ) + return tuple(observers) diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index 45827cbe..f915faa0 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -69,7 +69,7 @@ "tico.quantization.wrapq.wrappers.gemma4.quant_text_mlp", "tico.quantization.wrapq.wrappers.gemma4.quant_text_attention", "tico.quantization.wrapq.wrappers.gemma4.quant_text_decoder_layer", - # "tico.quantization.wrapq.wrappers.gemma4.quant_text_model", + "tico.quantization.wrapq.wrappers.gemma4.quant_text_model", # "tico.quantization.wrapq.wrappers.gemma4.quant_for_causal_lm", "tico.quantization.wrapq.wrappers.gemma4.quant_vision_patch_embedder", "tico.quantization.wrapq.wrappers.gemma4.quant_vision_pooler", From ea0b7a367fbd0992f7618b99bcce8bfc60f58fd5 Mon Sep 17 00:00:00 2001 From: seongwoo Date: Wed, 24 Jun 2026 14:30:04 +0900 Subject: [PATCH 2/3] Disallow the reverse flow in QUANT mode. --- .../wrappers/gemma4/test_quant_text_model.py | 43 +++++++++++++++++++ .../wrapq/wrappers/gemma4/quant_text_model.py | 27 +++++++++--- 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py b/test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py index dd87097c..c4fa67c8 100644 --- a/test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py +++ b/test/quantization/wrapq/wrappers/gemma4/test_quant_text_model.py @@ -367,6 +367,49 @@ def test_moe_text_model_is_rejected_for_e2b_scope(self): with self.assertRaisesRegex(ValueError, "dense decoder layers only"): QuantGemma4TextModel(fp_model) + def test_quant_mode_requires_explicit_per_layer_inputs_with_inputs_embeds(self): + """Require explicit PLE token inputs for inputs_embeds in QUANT mode.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_text_model import ( + QuantGemma4TextModel, + ) + + cfg = _make_text_config( + layer_types=["full_attention", "full_attention"], + hidden_size_per_layer_input=8, + ) + fp_model = _make_text_model(cfg) + qmodel = QuantGemma4TextModel(fp_model, qcfg=PTQConfig()).eval() + + input_ids = torch.randint(0, cfg.vocab_size, (1, 4)) + inputs_embeds = fp_model.embed_tokens(input_ids) + per_layer_inputs = fp_model.get_per_layer_inputs(input_ids, inputs_embeds) + + # Calibrate the same inputs_embeds + explicit PLE path. + qmodel.enable_calibration() + with torch.no_grad(): + qmodel( + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + return_dict=True, + ) + qmodel.freeze_qparams() + + with self.assertRaisesRegex(ValueError, "explicit per_layer_inputs"): + qmodel( + inputs_embeds=inputs_embeds, + return_dict=True, + ) + + # The supported explicit path should still work. + with torch.no_grad(): + output = qmodel( + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + return_dict=True, + ) + + self.assertTrue(torch.isfinite(output.last_hidden_state).all()) + if __name__ == "__main__": unittest.main() diff --git a/tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py b/tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py index 65c6a92c..43061b5c 100644 --- a/tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py +++ b/tico/quantization/wrapq/wrappers/gemma4/quant_text_model.py @@ -22,6 +22,7 @@ import torch.nn as nn from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.mode import Mode from tico.quantization.wrapq.utils.utils import get_model_arg, join_name from tico.quantization.wrapq.wrappers.gemma4.utils import assert_gemma4_e2b_no_moe from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper @@ -600,14 +601,14 @@ def _reverse_input_ids_from_embeddings( ) -> torch.Tensor: """Recover input ids from exact main embeddings for PLE compatibility.""" embedding = self._unwrap_fp_module(self.embed_tokens) - weight = embedding.weight.to( - device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) + weight = embedding.weight scale = getattr(embedding, "embed_scale", None) if scale is not None: - weight = weight * scale.to( - device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) + weight = weight * scale.to(device=weight.device, dtype=weight.dtype) + weight = weight.to( + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + ) with torch.no_grad(): matches = (inputs_embeds[:, :, None, :] == weight[None, None, :, :]).all( @@ -722,6 +723,20 @@ def forward( ) if per_layer_inputs is not None and not self.hidden_size_per_layer_input: raise ValueError("per_layer_inputs was provided, but PLE is disabled.") + # In QUANT mode, inputs_embeds may already be fake-quantized before PLE + # construction. Recovering token ids by exact floating-point comparison + # against the raw embedding table is therefore not a valid quantized-path + # contract. Require callers to pass explicit per_layer_inputs instead. + if ( + self._mode is Mode.QUANT + and input_ids is None + and self.hidden_size_per_layer_input + and per_layer_inputs is None + ): + raise ValueError( + "Gemma4 PLE requires explicit per_layer_inputs when " + "inputs_embeds is provided in QUANT mode." + ) return_shared_kv_states = bool(kwargs.pop("return_shared_kv_states", False)) output_hidden_states = bool(output_hidden_states) From 4b0c4c24b28274732496ffabae2a2560dc0e55bc Mon Sep 17 00:00:00 2001 From: seongwoo Date: Wed, 24 Jun 2026 14:39:19 +0900 Subject: [PATCH 3/3] explicit error. --- tico/quantization/recipes/debug/wrapper_smoke/case.py | 2 ++ .../recipes/debug/wrapper_smoke/cases/gemma4.py | 5 +++++ tico/quantization/recipes/debug/wrapper_smoke/runner.py | 6 ++++++ 3 files changed, 13 insertions(+) diff --git a/tico/quantization/recipes/debug/wrapper_smoke/case.py b/tico/quantization/recipes/debug/wrapper_smoke/case.py index 1c64c6c0..31b7bb5d 100644 --- a/tico/quantization/recipes/debug/wrapper_smoke/case.py +++ b/tico/quantization/recipes/debug/wrapper_smoke/case.py @@ -53,6 +53,8 @@ class WrapperSmokeCase: compare_reference_source: str = "reference" inplace_prepare: bool = False inplace_convert: bool = False + supports_circle_export: bool = True + circle_export_unsupported_reason: str | None = None def availability(self) -> CaseAvailability: """Return whether this case can run in the current environment.""" diff --git a/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py b/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py index 07590885..88889d5f 100644 --- a/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py +++ b/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py @@ -724,6 +724,11 @@ class Gemma4TextModelCase(Gemma4BaseCase): tags = ("gemma4", "e2b", "text", "model") max_mean_abs_diff = 3.0 seq_len = 8 + supports_circle_export = False + circle_export_unsupported_reason = ( + "This case validates PTQ numerical parity only. " + "Full Gemma4TextModel Circle export requires a dedicated static adapter." + ) def ptq_config(self, cfg: Mapping[str, Any]) -> Any: """Build the PTQ config used by Gemma4 text-model smoke checks.""" diff --git a/tico/quantization/recipes/debug/wrapper_smoke/runner.py b/tico/quantization/recipes/debug/wrapper_smoke/runner.py index b8b592e7..428b6681 100644 --- a/tico/quantization/recipes/debug/wrapper_smoke/runner.py +++ b/tico/quantization/recipes/debug/wrapper_smoke/runner.py @@ -246,6 +246,12 @@ def run_wrapper_smoke( if export_kind != "circle": result.passed = False result.messages.append(f"unsupported export artifact: {export_kind}") + elif not case.supports_circle_export: + result.passed = False + reason = case.circle_export_unsupported_reason or ( + f"Case '{case.name}' does not support Circle export." + ) + result.messages.append(reason) else: _export_circle(case, quantized, eval_sample, cfg, output_path, result)