From 26cc12405bd4554f8a5f3d441252cc30af5129a3 Mon Sep 17 00:00:00 2001 From: "d.savchenkov" Date: Thu, 18 Jun 2026 15:56:22 +0300 Subject: [PATCH] [quantization] Implement PTQ wrapper for Gemma4VisionModel with static export support Replace the skeleton Gemma4VisionModel wrapper with a complete implementation TICO-DCO-1.0-Signed-off-by: d.savchenkov --- .../gemma4/test_quant_vision_model.py | 367 ++++++++++++++++++ .../gemma4/test_quantize_vision_model.py | 225 +++++++++++ .../debug/wrapper_smoke/cases/gemma4.py | 110 ++++++ .../examples/gemma4/quantize_vision_model.py | 200 ++++++++++ .../wrapq/wrappers/gemma4/export_adapters.py | 37 ++ .../wrappers/gemma4/quant_vision_attention.py | 4 +- .../wrappers/gemma4/quant_vision_model.py | 318 ++++++++++++++- tico/quantization/wrapq/wrappers/registry.py | 4 +- 8 files changed, 1249 insertions(+), 16 deletions(-) create mode 100644 test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py create mode 100644 test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_model.py create mode 100644 tico/quantization/wrapq/examples/gemma4/quantize_vision_model.py diff --git a/test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py b/test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py new file mode 100644 index 00000000..c016f095 --- /dev/null +++ b/test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py @@ -0,0 +1,367 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the Gemma4 vision model PTQ wrapper.""" + +import unittest + +import torch + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.mode import Mode + + +_SKIP_MSG = "required transformers Gemma4 vision modules are not installed" + + +def _has_gemma4_vision() -> bool: + """Return whether the installed transformers package provides Gemma4 vision.""" + try: + from transformers.models.gemma4.configuration_gemma4 import ( # noqa: F401 + Gemma4VisionConfig, + ) + from transformers.models.gemma4.modeling_gemma4 import ( # noqa: F401 + Gemma4VisionModel, + ) + except Exception: + return False + return True + + +def _make_vision_config(**overrides): + """Create a tiny Gemma4 vision config for synthetic vision model tests.""" + from transformers.models.gemma4.configuration_gemma4 import Gemma4VisionConfig + + kwargs = dict( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + patch_size=4, + position_embedding_size=8, + pooling_kernel_size=2, + attention_dropout=0.0, + max_position_embeddings=128, + rms_norm_eps=1e-6, + use_clipped_linears=False, + rope_parameters={"rope_type": "default", "rope_theta": 100.0}, + standardize=True, + ) + kwargs.update(overrides) + cfg = Gemma4VisionConfig(**kwargs) + if hasattr(cfg, "_attn_implementation"): + cfg._attn_implementation = "eager" + else: + setattr(cfg, "_attn_implementation", "eager") + return cfg + + +def _vision_position_ids(batch_size: int, num_patches: int) -> torch.Tensor: + """Create deterministic 2-D pixel position ids for a tiny patch sequence.""" + side = int(num_patches**0.5) + coords = torch.arange(num_patches) + xy = torch.stack((coords % side, coords // side), dim=-1) + return xy.unsqueeze(0).expand(batch_size, -1, -1).long() + + +@unittest.skipUnless(_has_gemma4_vision(), _SKIP_MSG) +class TestQuantGemma4VisionModel(unittest.TestCase): + """Validate Gemma4 vision model wrapper behavior.""" + + def setUp(self): + """Create deterministic test inputs.""" + torch.manual_seed(2026) + self.cfg = _make_vision_config() + self.batch_size = 1 + self.num_patches = 16 + self.patch_size = self.cfg.patch_size + self.position_embedding_size = self.cfg.position_embedding_size + + @staticmethod + def _make_vision_model(cfg=None): + """Create a floating-point Gemma4 vision model.""" + from transformers.models.gemma4.modeling_gemma4 import Gemma4VisionModel + + cfg = cfg if cfg is not None else _make_vision_config() + return Gemma4VisionModel(cfg).eval() + + def _sample_inputs(self, batch_size=None): + """Create synthetic vision model inputs. + + The HF Gemma4VisionModel expects pre-flattened patches: + pixel_values: (B, num_patches, 3*patch_size^2) + pixel_position_ids: (B, num_patches, 2) + """ + batch_size = batch_size or self.batch_size + patch_dim = 3 * self.patch_size**2 + + pixel_values = torch.randn(batch_size, self.num_patches, patch_dim) + pixel_position_ids = _vision_position_ids(batch_size, self.num_patches) + + return { + "pixel_values": pixel_values, + "pixel_position_ids": pixel_position_ids, + } + + def test_00_prepare_wraps_vision_model_when_registered(self): + """Check that registry-based prepare wraps Gemma4VisionModel.""" + from tico.quantization import prepare + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper + + fp_model = self._make_vision_model() + prepared = prepare(fp_model, PTQConfig()) + + self.assertIsInstance(prepared, PTQWrapper) + self.assertIsInstance(prepared.wrapped, QuantGemma4VisionModel) + + def test_no_quant_forward_matches_hf_vision_model(self): + """Check that the wrapper matches Hugging Face eager vision model output.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + cfg = _make_vision_config() + fp_model = self._make_vision_model(cfg) + q_model = QuantGemma4VisionModel(fp_model).eval() + sample = self._sample_inputs() + + with torch.no_grad(): + quant_out = q_model(**sample, return_dict=True) + fp_out = fp_model(**sample, return_dict=True) + + # HF model strips batch dim via hidden_states[pooler_mask], so output + # may be 2D (num_soft_tokens, hidden_size). The wrapper preserves batch dim. + self.assertEqual( + quant_out.last_hidden_state.shape[-1], fp_out.last_hidden_state.shape[-1] + ) + + def test_mode_transitions(self): + """Check lifecycle transitions: NO_QUANT → CALIB → QUANT.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + fp_model = self._make_vision_model() + q_model = QuantGemma4VisionModel(fp_model).eval() + + self.assertIs(q_model._mode, Mode.NO_QUANT) + + q_model.enable_calibration() + self.assertIs(q_model._mode, Mode.CALIB) + + sample = self._sample_inputs() + with torch.no_grad(): + _ = q_model(**sample, return_dict=True) + + q_model.freeze_qparams() + self.assertIs(q_model._mode, Mode.QUANT) + + def test_observers_are_collected(self): + """Check that _all_observers returns expected observers.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + fp_model = self._make_vision_model() + q_model = QuantGemma4VisionModel(fp_model).eval() + + all_obs = list(q_model._all_observers()) + self.assertGreaterEqual(len(all_obs), 3) + + def test_quant_mode_output_is_finite(self): + """In QUANT mode the output should be finite and have the correct shape.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + fp_model = self._make_vision_model() + q_model = QuantGemma4VisionModel(fp_model).eval() + q_model.enable_calibration() + + sample = self._sample_inputs() + with torch.no_grad(): + _ = q_model(**sample, return_dict=True) + q_model.freeze_qparams() + + with torch.no_grad(): + output = q_model(**sample, return_dict=True) + + self.assertTrue(torch.isfinite(output.last_hidden_state).all()) + # Check output has the right hidden_size dimension + self.assertEqual(output.last_hidden_state.shape[-1], self.cfg.hidden_size) + + def test_config_attributes_are_stored(self): + """Check that config attributes are stored on the wrapper.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + cfg = _make_vision_config(hidden_size=64, patch_size=8) + fp_model = self._make_vision_model(cfg) + q_model = QuantGemma4VisionModel(fp_model).eval() + + self.assertEqual(q_model.config.hidden_size, 64) + self.assertEqual(q_model.config.patch_size, 8) + + def test_standardize_buffers_are_registered(self): + """Check that std_bias and std_scale are registered when standardize=True.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + cfg = _make_vision_config(standardize=True) + fp_model = self._make_vision_model(cfg) + q_model = QuantGemma4VisionModel(fp_model).eval() + + self.assertTrue(hasattr(q_model, "std_bias")) + self.assertTrue(hasattr(q_model, "std_scale")) + self.assertIsInstance(q_model.std_bias, torch.Tensor) + self.assertIsInstance(q_model.std_scale, torch.Tensor) + self.assertEqual(q_model.std_bias.shape[0], cfg.hidden_size) + self.assertEqual(q_model.std_scale.shape[0], cfg.hidden_size) + + def test_standardize_false_no_buffers(self): + """Check that std_bias/std_scale are NOT registered when standardize=False.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + cfg = _make_vision_config(standardize=False) + fp_model = self._make_vision_model(cfg) + q_model = QuantGemma4VisionModel(fp_model).eval() + + self.assertFalse(hasattr(q_model, "std_bias")) + self.assertFalse(hasattr(q_model, "std_scale")) + + def test_as_export_module_requires_quant_mode(self): + """as_export_module should assert that mode is QUANT.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + fp_model = self._make_vision_model() + q_model = QuantGemma4VisionModel(fp_model).eval() + + # Should fail in NO_QUANT mode + with self.assertRaises(AssertionError): + q_model.as_export_module(mode="prefill", pixel_position_ids=None) + + def test_as_export_module_requires_standardize(self): + """as_export_module should assert that config.standardize is True.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + cfg = _make_vision_config(standardize=True) + fp_model = self._make_vision_model(cfg) + q_model = QuantGemma4VisionModel(fp_model).eval() + q_model.enable_calibration() + + sample = self._sample_inputs() + with torch.no_grad(): + _ = q_model(**sample, return_dict=True) + q_model.freeze_qparams() + + # Should succeed with standardize=True + export_module = q_model.as_export_module( + mode="prefill", + pixel_position_ids=sample["pixel_position_ids"], + ) + self.assertIsNotNone(export_module) + + def test_forward_export_via_as_export_module(self): + """Test the export flow via as_export_module which sets up export adapters.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + cfg = _make_vision_config(standardize=True) + fp_model = self._make_vision_model(cfg) + q_model = QuantGemma4VisionModel(fp_model).eval() + q_model.enable_calibration() + + sample = self._sample_inputs() + with torch.no_grad(): + _ = q_model(**sample, return_dict=True) + q_model.freeze_qparams() + + # as_export_module sets up export adapters + export_module = q_model.as_export_module( + mode="prefill", + pixel_position_ids=sample["pixel_position_ids"], + ) + + # Test forward (adapter delegates to wrapped_model.forward_export) + with torch.no_grad(): + output = export_module(**sample) + + self.assertTrue(torch.isfinite(output.last_hidden_state).all()) + + def test_as_export_module_creates_export_adapter_attributes(self): + """as_export_module should create patch_embedder_export and pooler_export attributes.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + cfg = _make_vision_config(standardize=True) + fp_model = self._make_vision_model(cfg) + q_model = QuantGemma4VisionModel(fp_model).eval() + q_model.enable_calibration() + + sample = self._sample_inputs() + with torch.no_grad(): + _ = q_model(**sample, return_dict=True) + q_model.freeze_qparams() + + # Before as_export_module, no export adapter attributes + self.assertFalse(hasattr(q_model, "patch_embedder_export")) + self.assertFalse(hasattr(q_model, "pooler_export")) + + q_model.as_export_module( + mode="prefill", + pixel_position_ids=sample["pixel_position_ids"], + ) + + # After as_export_module, export adapter attributes should exist + self.assertTrue(hasattr(q_model, "patch_embedder_export")) + self.assertTrue(hasattr(q_model, "pooler_export")) + + # Original wrappers should still be intact (not mutated) + from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper + + self.assertIsInstance(q_model.patch_embedder, PTQWrapper) + self.assertIsInstance(q_model.pooler, PTQWrapper) + + def test_submodules_are_wrapped(self): + """Check that patch_embedder, encoder, and pooler are wrapped.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper + + fp_model = self._make_vision_model() + q_model = QuantGemma4VisionModel(fp_model).eval() + + self.assertIsInstance(q_model.patch_embedder, PTQWrapper) + self.assertIsInstance(q_model.encoder, PTQWrapper) + self.assertIsInstance(q_model.pooler, PTQWrapper) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_model.py b/test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_model.py new file mode 100644 index 00000000..6f7a7380 --- /dev/null +++ b/test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_model.py @@ -0,0 +1,225 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Smoke tests for Gemma4 vision model prepare-calibrate-convert flow.""" + +import copy +import os +import unittest + +import torch + +from tico.quantization import convert, prepare +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper + + +IS_INTERNAL_TEST = os.environ.get("RUN_INTERNAL_TESTS", "0") == "1" +_SKIP_MSG = "required transformers Gemma4 vision modules are not installed" + + +def _has_gemma4_vision() -> bool: + """Return whether the installed transformers package provides Gemma4 vision.""" + try: + from transformers.models.gemma4.configuration_gemma4 import ( # noqa: F401 + Gemma4VisionConfig, + ) + from transformers.models.gemma4.modeling_gemma4 import ( # noqa: F401 + Gemma4VisionModel, + ) + except Exception: + return False + return True + + +def _make_vision_config(**overrides): + """Create a tiny Gemma4 vision config for synthetic smoke tests.""" + from transformers.models.gemma4.configuration_gemma4 import Gemma4VisionConfig + + kwargs = dict( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + patch_size=4, + position_embedding_size=8, + pooling_kernel_size=2, + attention_dropout=0.0, + max_position_embeddings=128, + rms_norm_eps=1e-6, + use_clipped_linears=False, + rope_parameters={"rope_type": "default", "rope_theta": 100.0}, + standardize=True, + ) + kwargs.update(overrides) + cfg = Gemma4VisionConfig(**kwargs) + if not hasattr(cfg, "_attn_implementation"): + setattr(cfg, "_attn_implementation", "eager") + else: + cfg._attn_implementation = "eager" + return cfg + + +def _vision_position_ids(batch_size: int, num_patches: int) -> torch.Tensor: + """Create deterministic 2-D pixel position ids for a tiny patch sequence.""" + side = int(num_patches**0.5) + coords = torch.arange(num_patches) + xy = torch.stack((coords % side, coords // side), dim=-1) + return xy.unsqueeze(0).expand(batch_size, -1, -1).long() + + +@unittest.skipIf( + not IS_INTERNAL_TEST, + "Internal smoke test — set RUN_INTERNAL_TESTS=1 to enable it.", +) +@unittest.skipUnless(_has_gemma4_vision(), _SKIP_MSG) +class TestGemma4VisionModelSmoke(unittest.TestCase): + """Exercise Gemma4 vision model wrapper parity and PTQ flow.""" + + def setUp(self): + """Create deterministic tiny Gemma4 vision model modules.""" + torch.manual_seed(2026) + from transformers.models.gemma4.modeling_gemma4 import Gemma4VisionModel + + self.cfg = _make_vision_config() + self.fp_model = Gemma4VisionModel(self.cfg).eval() + self.fp_ref = copy.deepcopy(self.fp_model).eval() + # For 16 patches with pooling_kernel_size=2: output_length = 16 / 4 = 4 + self.num_patches = 16 + self.patch_size = self.cfg.patch_size + self.batch_size = 1 + + def _sample(self): + """Create one synthetic Gemma4 vision model sample. + + The HF Gemma4VisionModel expects pre-flattened patches: + pixel_values: (B, num_patches, 3*patch_size^2) + pixel_position_ids: (B, num_patches, 2) + """ + patch_dim = 3 * self.patch_size**2 + return { + "pixel_values": torch.randn(self.batch_size, self.num_patches, patch_dim), + "pixel_position_ids": _vision_position_ids( + self.batch_size, self.num_patches + ), + } + + def test_no_quant_vision_model_matches_reference(self): + """The wrapper should match the floating-point module before quantization.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + wrapped = QuantGemma4VisionModel(self.fp_model, qcfg=PTQConfig()).eval() + sample = self._sample() + + with torch.no_grad(): + quant_out = wrapped(**sample, return_dict=True) + fp_out = self.fp_ref(**sample, return_dict=True) + + self.assertEqual( + quant_out.last_hidden_state.shape, fp_out.last_hidden_state.shape + ) + self.assertTrue( + torch.allclose( + quant_out.last_hidden_state, + fp_out.last_hidden_state, + atol=1e-5, + rtol=1e-5, + ) + ) + + def test_prepare_convert_vision_model_flow(self): + """Quantize Gemma4 vision model and validate a synthetic output.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + + prepared = prepare(self.fp_model, PTQConfig()) + self.assertIsInstance(prepared, PTQWrapper) + self.assertIsInstance(prepared.wrapped, QuantGemma4VisionModel) + + with torch.no_grad(): + for _ in range(3): + prepared(**self._sample(), return_dict=True) + + quantized = convert(prepared) + self.assertIs(quantized._mode, Mode.QUANT) + + sample = self._sample() + with torch.no_grad(): + quant_out = quantized(**sample, return_dict=True) + fp_out = self.fp_ref(**sample, return_dict=True) + + self.assertEqual( + quant_out.last_hidden_state.shape, fp_out.last_hidden_state.shape + ) + self.assertTrue(torch.isfinite(quant_out.last_hidden_state).all()) + + def test_as_export_module_flow(self): + """Test the as_export_module flow for Circle export.""" + from tico.quantization.wrapq.wrappers.gemma4.export_adapters import ( + Gemma4VisionModelPrefillExportAdapter, + ) + + prepared = prepare(self.fp_model, PTQConfig()) + + with torch.no_grad(): + for _ in range(3): + prepared(**self._sample(), return_dict=True) + + quantized = convert(prepared) + + sample = self._sample() + export_module = quantized.wrapped.as_export_module( + mode="prefill", + pixel_position_ids=sample["pixel_position_ids"], + ) + + # as_export_module returns Gemma4VisionModelPrefillExportAdapter + self.assertIsInstance(export_module, Gemma4VisionModelPrefillExportAdapter) + + # Verify forward works (adapter delegates to wrapped_model.forward_export) + with torch.no_grad(): + output = export_module(**sample) + + self.assertTrue(torch.isfinite(output.last_hidden_state).all()) + + def test_vision_model_with_standardize_false(self): + """Test vision model when standardize=False.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_vision_model import ( + QuantGemma4VisionModel, + ) + from transformers.models.gemma4.modeling_gemma4 import Gemma4VisionModel + + cfg = _make_vision_config(standardize=False) + fp_model = Gemma4VisionModel(cfg).eval() + wrapped = QuantGemma4VisionModel(fp_model, qcfg=PTQConfig()).eval() + + sample = self._sample() + with torch.no_grad(): + quant_out = wrapped(**sample, return_dict=True) + fp_out = self.fp_ref(**sample, return_dict=True) + + self.assertEqual( + quant_out.last_hidden_state.shape, fp_out.last_hidden_state.shape + ) + self.assertTrue(torch.isfinite(quant_out.last_hidden_state).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py b/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py index b0781fad..a9310ef2 100644 --- a/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py +++ b/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py @@ -561,6 +561,7 @@ def _make_vision_config() -> Any: rms_norm_eps=1e-6, use_clipped_linears=False, rope_parameters={"rope_type": "default", "rope_theta": 100.0}, + standardize=True, ) return _set_eager_attention(cfg) @@ -923,6 +924,114 @@ def export_input( return ForwardInput((hidden, pixel_position_ids, padding_positions), {}) +class Gemma4VisionModelCase(Gemma4BaseCase): + """Smoke case for one tiny Gemma4 vision model.""" + + name = "gemma4_vision_model" + description = ( + "Quantize one tiny Gemma4 vision model (patch_embedder + encoder + pooler)." + ) + tags = ("gemma4", "e2b", "vision", "model") + max_mean_abs_diff = 3.0 + # seq_len=36 and output_length=4 so that k=2 (36 / 3^2 = 4, sqrt(4) = 2). + seq_len = 36 + + def build(self, cfg: Mapping[str, Any]) -> tuple[torch.nn.Module, torch.nn.Module]: + """Build a tiny Gemma4 vision model and reference copy.""" + from transformers.models.gemma4.modeling_gemma4 import Gemma4VisionModel + + torch.manual_seed(123) + self.vision_cfg = _make_vision_config() + module = Gemma4VisionModel(self.vision_cfg).eval() + return module, clone_module(module) + + def _sample(self) -> ForwardInput: + """Create one synthetic Gemma4 vision model input. + + The HF Gemma4VisionModel expects pre-flattened patches: + pixel_values: (B, num_patches, 3*patch_size^2) + pixel_position_ids: (B, num_patches, 2) + """ + batch_size = 1 + patch_size = self.vision_cfg.patch_size + patch_dim = 3 * patch_size**2 + pixel_values = torch.randn(batch_size, self.seq_len, patch_dim) + pixel_position_ids = _pixel_position_ids(batch_size, self.seq_len) + return ForwardInput( + (), + { + "pixel_values": pixel_values, + "pixel_position_ids": pixel_position_ids, + "return_dict": True, + }, + ) + + def forward(self, module: torch.nn.Module, sample: ForwardInput) -> Any: + """Run a Gemma4 vision model without sharing mutable sample state.""" + cloned = _clone_forward_input(sample) + output = module(*cloned.args, **dict(cloned.kwargs)) + if hasattr(output, "last_hidden_state"): + return output.last_hidden_state + return output + + def reference_forward( + self, reference: torch.nn.Module, sample: ForwardInput + ) -> Any: + """Run the original Gemma4 vision model without sharing mutable sample state.""" + cloned = _clone_forward_input(sample) + output = reference(*cloned.args, **dict(cloned.kwargs)) + if hasattr(output, "last_hidden_state"): + return output.last_hidden_state + return output + + def calibration_inputs( + self, + prepared: torch.nn.Module, + cfg: Mapping[str, Any], + ) -> list[ForwardInput]: + """Create Gemma4 vision model calibration samples.""" + return [self._sample() for _ in range(3)] + + def eval_input( + self, + prepared: torch.nn.Module, + cfg: Mapping[str, Any], + ) -> ForwardInput: + """Create the Gemma4 vision model evaluation sample.""" + return self._sample() + + def export_module( + self, quantized: torch.nn.Module, cfg: Mapping[str, Any] + ) -> torch.nn.Module: + """Export the wrapped vision model in prefill mode. + + Passes ``pixel_position_ids`` so the pooler's export adapter can + precompute the pooling weight matrix and output mask at construction + time. + """ + wrapped = getattr(quantized, "wrapped", quantized) + if hasattr(wrapped, "as_export_module"): + pixel_pos_ids = _pixel_position_ids(1, self.seq_len) + return wrapped.as_export_module( + mode="prefill", + pixel_position_ids=pixel_pos_ids, + ).eval() + return quantized + + def export_input( + self, eval_sample: ForwardInput, cfg: Mapping[str, Any] + ) -> ForwardInput: + """Create static export inputs expected by the vision model adapter. + + The export adapter's forward() takes pixel_values and pixel_position_ids. + """ + cloned = _clone_forward_input(eval_sample) + kwargs = dict(cloned.kwargs) + pixel_values = kwargs["pixel_values"] + pixel_position_ids = kwargs["pixel_position_ids"] + return ForwardInput((pixel_values, pixel_position_ids), {}) + + GEMMA4_CASES = ( Gemma4TextMLPCase(), Gemma4TextAttentionCase(), @@ -937,4 +1046,5 @@ def export_input( Gemma4VisionAttentionCase(), Gemma4VisionEncoderLayerCase(), Gemma4VisionPoolerCase(), + Gemma4VisionModelCase(), ) diff --git a/tico/quantization/wrapq/examples/gemma4/quantize_vision_model.py b/tico/quantization/wrapq/examples/gemma4/quantize_vision_model.py new file mode 100644 index 00000000..545959b6 --- /dev/null +++ b/tico/quantization/wrapq/examples/gemma4/quantize_vision_model.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example: PTQ quantization of Gemma4VisionModel. + +The Gemma4 vision model encodes image pixels into visual soft tokens through +a pipeline of patch embedding, transformer encoding, spatial pooling, and +optional standardization. It accepts: + +- ``pixel_values``: Pre-flattened image patches of shape ``(B, num_patches, 3*patch_size^2)`` +- ``pixel_position_ids``: 2D grid coordinates of shape ``(B, num_patches, 2)`` + +The output is a ``BaseModelOutputWithPast`` whose ``last_hidden_state`` contains +visual soft tokens after pooling and standardization. + +This script demonstrates the full PTQ flow: + +1. Create a tiny Gemma4VisionModel with random weights (no download needed). +2. Prepare the model for quantization. +3. Calibrate with synthetic data. +4. Convert to a fake-quantized model. +5. Compare FP vs. quantized outputs. +6. Export and convert to Circle format. +""" + +import copy +import sys + +import torch + +import tico +import tico.quantization +import tico.quantization.config.ptq +from tico.quantization.evaluation.metric import compute_peir +from tico.quantization.evaluation.utils import plot_two_outputs +from tico.quantization.wrapq.utils.version import has_transformers_for + +torch.manual_seed(123) + + +# Check if transformers is available +if not has_transformers_for("gemma4"): + print( + "Error: transformers package with Gemma4 support not installed. " + "Cannot test Gemma4VisionModel." + ) + sys.exit(1) + +from transformers.models.gemma4.configuration_gemma4 import Gemma4VisionConfig +from transformers.models.gemma4.modeling_gemma4 import Gemma4VisionModel + + +def _pixel_position_ids(batch_size: int, num_patches: int) -> torch.Tensor: + """Create deterministic 2D pixel position ids for a patch grid. + + The vision model requires ``pixel_position_ids`` with shape ``(B, num_patches, 2)`` + where the last dimension encodes ``(x, y)`` patch coordinates. We build a + simple square grid layout. + """ + side = int(num_patches**0.5) + coords = torch.arange(num_patches) + xy = torch.stack((coords % side, coords // side), dim=-1) + return xy.unsqueeze(0).expand(batch_size, -1, -1).long() + + +def generate_calibration_data( + batch_size: int, + num_patches: int, + patch_size: int, + num_samples: int = 20, +) -> list[dict]: + """Generate calibration data for PTQ. + + Each sample is a dict of keyword arguments matching the vision model's forward + signature: ``pixel_values``, ``pixel_position_ids``. + """ + patch_dim = 3 * patch_size**2 + calibration_data = [] + for _ in range(num_samples): + sample = { + "pixel_values": torch.randn(batch_size, num_patches, patch_dim), + "pixel_position_ids": _pixel_position_ids(batch_size, num_patches), + "return_dict": True, + } + calibration_data.append(sample) + return calibration_data + + +def main(): + # Create the vision model with a tiny config (no download needed). + cfg = Gemma4VisionConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + patch_size=4, + position_embedding_size=8, + pooling_kernel_size=2, + attention_dropout=0.0, + max_position_embeddings=128, + rms_norm_eps=1e-6, + use_clipped_linears=False, + rope_parameters={"rope_type": "default", "rope_theta": 100.0}, + standardize=True, + ) + if not hasattr(cfg, "_attn_implementation"): + setattr(cfg, "_attn_implementation", "eager") + else: + cfg._attn_implementation = "eager" + + model = Gemma4VisionModel(cfg) + orig_model = copy.deepcopy(model) + model.eval() + + # Generate calibration data + batch_size = 1 + num_patches = 16 + patch_size = cfg.patch_size + + calibration_data = generate_calibration_data( + batch_size=batch_size, + num_patches=num_patches, + patch_size=patch_size, + num_samples=20, + ) + + # Configure PTQ + ptq_config = tico.quantization.config.ptq.PTQConfig() + + # Prepare the model for quantization + prepared_model = tico.quantization.prepare( + model, ptq_config, inplace=True # Transform the model in place + ) + + # Calibrate the model (collect statistics) + print("Calibrating...") + with torch.no_grad(): + for sample in calibration_data: + prepared_model(**sample) + + # Convert to quantized model + print("Converting to quantized model...") + quantized_model = tico.quantization.convert(prepared_model, inplace=True) + + # Compute PEIR (Peak Error-to-Interval Ratio) between quantized model and original model + eval_sample = calibration_data[0] + with torch.no_grad(): + quant_out = quantized_model(**eval_sample).last_hidden_state + fp_out = orig_model(**eval_sample).last_hidden_state + + print(f"\n┌───────────── Quantization Error Summary ─────────────") + print(f"│ FP output shape : {tuple(fp_out.shape)}") + print(f"│ Quant output shape : {tuple(quant_out.shape)}") + print(f"│ Mean |diff| : {(quant_out - fp_out).abs().mean().item():.6f}") + print(f"│ PEIR : {compute_peir(fp_out, quant_out) * 100:.6f} %") + print(f"└──────────────────────────────────────────────────────") + print(plot_two_outputs(fp_out, quant_out)) + + # Export and convert to Circle format. + # The vision model's as_export_module requires pixel_position_ids to + # precompute the pooler's static weight matrix. + wrapped = getattr(quantized_model, "wrapped", quantized_model) + if hasattr(wrapped, "as_export_module"): + pixel_pos_ids = _pixel_position_ids(batch_size, num_patches) + export_module = wrapped.as_export_module( + mode="prefill", + pixel_position_ids=pixel_pos_ids, + ).eval() + + example_inputs = ( + torch.randn(batch_size, num_patches, 3 * patch_size**2), # pixel_values + _pixel_position_ids(batch_size, num_patches), # pixel_position_ids + ) + + print("\nConverting to Circle format...") + circle_model = tico.convert(export_module, example_inputs) + + filename = "gemma4_vision_model.q.circle" + circle_model.save(filename) + print(f"Circle model saved as '{filename}'") + else: + print("Note: as_export_module not available; skipping Circle export.") + + +if __name__ == "__main__": + main() diff --git a/tico/quantization/wrapq/wrappers/gemma4/export_adapters.py b/tico/quantization/wrapq/wrappers/gemma4/export_adapters.py index fb16c63c..ee8a7d30 100644 --- a/tico/quantization/wrapq/wrappers/gemma4/export_adapters.py +++ b/tico/quantization/wrapq/wrappers/gemma4/export_adapters.py @@ -322,3 +322,40 @@ def __init__(self, wrapped_conditional_generation_model: nn.Module): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Return vocabulary logits for the final hidden state.""" return self.lm_head(self.norm(hidden_states)) + + +class Gemma4VisionModelPrefillExportAdapter(nn.Module): + """Export adapter for the Gemma4 vision model with static-shape contract. + + This adapter wraps a ``QuantGemma4VisionModel`` that has been prepared + for export via ``as_export_module()``. Calling ``forward()`` delegates + to the wrapped model's ``forward_export()`` method, which uses the + export-friendly submodule adapters (``patch_embedder_export``, + ``pooler_export``) when they are available. + + Input contract: + ``pixel_values`` has shape ``(1, num_patches, 3*patch_size^2)``. + ``pixel_position_ids`` has shape ``(1, num_patches, 2)``. + + Output contract: + Returns ``BaseModelOutputWithPast`` with ``last_hidden_state`` + containing visual soft tokens. + """ + + def __init__( + self, + wrapped_model: nn.Module, + ): + super().__init__() + self.wrapped_model = wrapped_model + + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_position_ids: torch.LongTensor, + ): + """Run the vision model export path via the wrapped model's forward_export.""" + return self.wrapped_model.forward_export( + pixel_values=pixel_values, + pixel_position_ids=pixel_position_ids, + ) diff --git a/tico/quantization/wrapq/wrappers/gemma4/quant_vision_attention.py b/tico/quantization/wrapq/wrappers/gemma4/quant_vision_attention.py index 6e4e83b8..1352aee0 100644 --- a/tico/quantization/wrapq/wrappers/gemma4/quant_vision_attention.py +++ b/tico/quantization/wrapq/wrappers/gemma4/quant_vision_attention.py @@ -359,8 +359,10 @@ def _build_attention_mask( dtype=dtype, device=device, ) + # Use ``== False`` instead of ``~`` to avoid ``aten::bitwise_not`` + # which is not supported by the Circle conversion pipeline. additive = additive.masked_fill( - ~keep_mask, + keep_mask == False, float(self.qcfg.attention_mask_fill_value), ) return self._fq(additive, self.obs_attn_mask) diff --git a/tico/quantization/wrapq/wrappers/gemma4/quant_vision_model.py b/tico/quantization/wrapq/wrappers/gemma4/quant_vision_model.py index b970147f..4ad431d7 100644 --- a/tico/quantization/wrapq/wrappers/gemma4/quant_vision_model.py +++ b/tico/quantization/wrapq/wrappers/gemma4/quant_vision_model.py @@ -18,7 +18,11 @@ import torch.nn as nn from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.mode import Mode from tico.quantization.wrapq.utils.utils import join_name +from tico.quantization.wrapq.wrappers.llama.export_adapters import ( + register_fake_quant_meta_kernels_for_dynamic_export, +) from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase from tico.quantization.wrapq.wrappers.registry import try_register @@ -26,7 +30,20 @@ @try_register("transformers.models.gemma4.modeling_gemma4.Gemma4VisionModel") class QuantGemma4VisionModel(QuantModuleBase): - """PTQ wrapper skeleton for the Gemma4 vision model.""" + """PTQ wrapper for the Gemma4 vision model. + + This wrapper supports two modes: + 1. Runtime mode (forward): Supports dynamic tensor shapes and conditional branching + (config.standardize can be True or False). Not exportable. + 2. Export mode (forward_export): Static tensor shapes, no conditional branching. + Assumes config.standardize=True. Exportable via torch.export. + + The vision model encodes image pixels into visual soft tokens through: + - Patch embedder: Projects pixels to patch embeddings with position encoding + - Encoder: Processes embeddings through transformer layers + - Pooler: Reduces spatial dimension to fixed number of soft tokens + - Standardization: Applies learned std_bias and std_scale (if enabled) + """ def __init__( self, @@ -38,6 +55,10 @@ def __init__( super().__init__(qcfg, fp_name=fp_name) self.module = fp_model self.config = fp_model.config + + # Wrap submodules with PTQWrapper + # Note: These will use specialized wrappers (QuantGemma4VisionPatchEmbedder, + # QuantGemma4VisionEncoder, QuantGemma4VisionPooler) if registered self.patch_embedder = PTQWrapper( fp_model.patch_embedder, qcfg=qcfg.child("patch_embedder") if qcfg else None, @@ -53,30 +74,301 @@ def __init__( qcfg=qcfg.child("pooler") if qcfg else None, fp_name=join_name(fp_name, "pooler"), ) + + # Register std_bias and std_scale as buffers if standardize is enabled + if self.config.standardize: + self.register_buffer( + "std_bias", + fp_model.std_bias.clone() + if hasattr(fp_model, "std_bias") + else torch.empty(self.config.hidden_size), + persistent=False, + ) + self.register_buffer( + "std_scale", + fp_model.std_scale.clone() + if hasattr(fp_model, "std_scale") + else torch.empty(self.config.hidden_size), + persistent=False, + ) + + # Observers + self.obs_minus_bias = self._make_obs("minus_bias") + self.obs_strip_padding = self._make_obs("strip_padding") self.obs_last_hidden_state = self._make_obs("last_hidden_state") + self.obs_std_bias = ( + self._make_obs("std_bias") if self.config.standardize else None + ) + self.obs_std_scale = ( + self._make_obs("std_scale") if self.config.standardize else None + ) + + def enable_calibration(self) -> None: + """Enable calibration and collect static weight ranges.""" + super().enable_calibration() + # Collect std_bias and std_scale statistics if standardize is enabled + if ( + self.config.standardize + and self.obs_std_bias is not None + and self.obs_std_scale is not None + ): + self.obs_std_bias.collect(self.std_bias) + self.obs_std_scale.collect(self.std_scale) def forward( self, pixel_values: torch.Tensor, - pixel_position_ids: Optional[torch.Tensor] = None, + pixel_position_ids: torch.Tensor, **kwargs, ): - """Run Gemma4 vision model. + """Run Gemma4 vision model with dynamic shapes and conditional branching. + + This forward method supports: + - Dynamic output_length computation from pixel_values shape + - Conditional standardization based on config.standardize + + This method is NOT exportable via torch.export due to dynamic operations. + Use forward_export() for export. - TODO: Replace fallback pieces with a fully static implementation once the - patch embedder and pooler wrappers are complete. + Args: + pixel_values: Image pixels with shape (batch, channels, height, width) + or list of (1, channels, height, width) for variable sizes. + pixel_position_ids: Patch positions with shape (batch_size, max_patches, 2). + Padding patches are indicated by (-1, -1). + **kwargs: Additional arguments passed to encoder. + + Returns: + BaseModelOutputWithPast with last_hidden_state containing visual soft tokens. """ - outputs = self.module( - pixel_values=pixel_values, + from transformers.models.gemma4.modeling_gemma4 import BaseModelOutputWithPast + + # Compute output_length dynamically + pooling_kernel_size = self.config.pooling_kernel_size + output_length = pixel_values.shape[-2] // ( + pooling_kernel_size * pooling_kernel_size + ) + + # Create padding mask from pixel_position_ids + padding_positions = (pixel_position_ids == -1).all(dim=-1) + + # Patch embedder + inputs_embeds = self.patch_embedder( + pixel_values, pixel_position_ids, padding_positions + ) + + # Encoder + output = self.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~padding_positions, # encoder expects True=valid pixel_position_ids=pixel_position_ids, - return_dict=True, **kwargs, ) - outputs.last_hidden_state = self._fq( - outputs.last_hidden_state, self.obs_last_hidden_state + + # The encoder may return a BaseModelOutputWithPast (HF original) or a + # plain tensor (QuantGemma4VisionEncoder wrapper). Handle both cases. + if isinstance(output, torch.Tensor): + encoder_hidden = output + else: + encoder_hidden = output.last_hidden_state + + # Pooler + hidden_states, pooler_mask = self.pooler( + hidden_states=encoder_hidden, + pixel_position_ids=pixel_position_ids, + padding_positions=padding_positions, + output_length=output_length, + ) + + # Strip padding tokens + hidden_states = hidden_states[pooler_mask] + hidden_states = self._fq(hidden_states, self.obs_strip_padding) + + # Standardization (conditional based on config) + if self.config.standardize: + std_bias = self.std_bias + std_scale = self.std_scale + if self._mode is Mode.QUANT: + assert self.obs_std_bias is not None + assert self.obs_std_scale is not None + std_bias = self.obs_std_bias.fake_quant(std_bias) + std_scale = self.obs_std_scale.fake_quant(std_scale) + hidden_states = hidden_states - std_bias.float() + hidden_states = self._fq(hidden_states, self.obs_minus_bias) + hidden_states = hidden_states * std_scale.float() + + # Cast to input dtype + hidden_states = hidden_states.to(inputs_embeds.dtype) + + # Quantize output + hidden_states = self._fq(hidden_states, self.obs_last_hidden_state) + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + def forward_export( + self, + pixel_values: torch.Tensor, + pixel_position_ids: torch.Tensor, + ): + """Run Gemma4 vision model with static shapes for torch.export. + + This forward method assumes: + - config.standardize is True (std_bias and std_scale are always applied) + - output_length is precomputed and fixed + - No conditional branching + - as_export_module() has been called to set up export adapters + + This method IS exportable via torch.export. + + Args: + pixel_values: Image pixels with shape (batch, channels, height, width). + pixel_position_ids: Patch positions with shape (batch_size, max_patches, 2). + + Returns: + BaseModelOutputWithPast with last_hidden_state containing visual soft tokens. + """ + from transformers.models.gemma4.modeling_gemma4 import BaseModelOutputWithPast + + # Create padding mask from pixel_position_ids + padding_positions = self.padding_positions + + # Patch embedder (use export adapter if available, otherwise original) + patch_embedder = self.patch_embedder_export + inputs_embeds = patch_embedder( + pixel_values, pixel_position_ids, padding_positions + ) + + # Encoder (no export adapter yet — uses original wrapper) + output = self.encoder( + inputs_embeds=inputs_embeds, + # Use ``== False`` instead of ``~`` to avoid ``aten::bitwise_not`` + # which is not supported by the Circle conversion pipeline. + attention_mask=(padding_positions == False), + pixel_position_ids=pixel_position_ids, + return_dict=True, + ) + + # The encoder may return a BaseModelOutputWithPast (HF original) or a + # plain tensor (QuantGemma4VisionEncoder wrapper). Handle both cases. + encoder_hidden = output + + # Pooler (use export adapter if available, otherwise original) + pooler = self.pooler_export + hidden_states, pooler_mask = pooler( + hidden_states=encoder_hidden, + pixel_position_ids=pixel_position_ids, + padding_positions=padding_positions, + output_length=self.output_length, ) - return outputs + + # Strip padding tokens + hidden_states = hidden_states[pooler_mask] + hidden_states = self._fq(hidden_states, self.obs_strip_padding) + + # Standardization (always applied in export mode) + assert self.obs_std_bias is not None + assert self.obs_std_scale is not None + std_bias = self.obs_std_bias.fake_quant(self.std_bias) + std_scale = self.obs_std_scale.fake_quant(self.std_scale) + hidden_states = hidden_states - std_bias.float() + hidden_states = self._fq(hidden_states, self.obs_minus_bias) + hidden_states = hidden_states * std_scale.float() + + # Cast to input dtype + hidden_states = hidden_states.to(inputs_embeds.dtype) + + # Quantize output + hidden_states = self._fq(hidden_states, self.obs_last_hidden_state) + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + def as_export_module( + self, + mode: str = "prefill", + *, + pixel_position_ids: torch.Tensor, + **kwargs, + ) -> nn.Module: + """Prepare the model for torch.export by precomputing static tensors. + + This method: + 1. Asserts that config.standardize is True (required for export) + 2. Asserts that the model is in QUANT mode + 3. Recursively converts submodules to their export adapters + 4. Registers output_length as a buffer for static export + + Submodule export adapters are stored as separate attributes + (e.g. ``patch_embedder_export``, ``pooler_export``) so that the + original wrapper attributes are not mutated. ``forward_export()`` + uses these export adapter attributes when they exist. + + Args: + mode: Export mode (only "prefill" is supported). + pixel_position_ids: Patch position ids tensor with shape + ``(1, num_patches, 2)``. Required by the pooler's + ``as_export_module()`` to precompute pooling weights. + **kwargs: Additional arguments (unused). + + Returns: + Gemma4VisionModelPrefillExportAdapter wrapping this module. + """ + # Assert standardize is True for export + assert ( + self.config.standardize + ), "Gemma4VisionModel export requires config.standardize=True" + + # Assert QUANT mode + assert self._mode is Mode.QUANT, "Must be in QUANT mode for export" + + # Make sure that all observers are calibrated + for obs in self._all_observers(): + assert obs.has_qparams, f"Observer {obs.name} has not been calibrated" + + # Store output_length for use in forward_export + pooling_kernel_size = self.config.pooling_kernel_size + max_patches = pixel_position_ids.shape[-2] + self.output_length = max_patches // (pooling_kernel_size * pooling_kernel_size) + assert ( + self.output_length * (pooling_kernel_size * pooling_kernel_size) + == max_patches + ), "max_patches must be divisible by pooling_kernel_size^2" + + # Recursively convert submodules to their export adapters. + # Store as separate attributes to avoid mutating the original wrappers. + # forward_export() will use these via getattr(..., self.). + self.patch_embedder_export = self.patch_embedder.as_export_module(mode=mode) + + # Encoder: no as_export_module yet — will use original wrapper + # in forward_export via getattr fallback. + + # Pooler: requires pixel_position_ids to precompute pooling weights + assert pixel_position_ids is not None, ( + "pixel_position_ids is required by the pooler's as_export_module() " + "to precompute pooling weights for static export." + ) + self.pooler_export = self.pooler.as_export_module( + mode=mode, + output_length=self.output_length, + pixel_position_ids=pixel_position_ids, + ) + + # Precompute padding mask from pixel_position_ids + padding_positions = (pixel_position_ids == -1).all(dim=-1) + self.register_buffer("padding_positions", padding_positions) + + register_fake_quant_meta_kernels_for_dynamic_export() + + from tico.quantization.wrapq.wrappers.gemma4.export_adapters import ( + Gemma4VisionModelPrefillExportAdapter, + ) + + return Gemma4VisionModelPrefillExportAdapter(wrapped_model=self) def _all_observers(self) -> Iterable: - """Return observers owned directly by this wrapper.""" - return (self.obs_last_hidden_state,) + """Return all observers owned by this wrapper.""" + yield self.obs_minus_bias + yield self.obs_last_hidden_state + yield self.obs_strip_padding + if self.obs_std_bias is not None: + yield self.obs_std_bias + if self.obs_std_scale is not None: + yield self.obs_std_scale diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index 45827cbe..0bcd33bf 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -76,8 +76,8 @@ "tico.quantization.wrapq.wrappers.gemma4.quant_vision_mlp", "tico.quantization.wrapq.wrappers.gemma4.quant_vision_attention", "tico.quantization.wrapq.wrappers.gemma4.quant_vision_encoder_layer", - # "tico.quantization.wrapq.wrappers.gemma4.quant_vision_encoder", - # "tico.quantization.wrapq.wrappers.gemma4.quant_vision_model", + "tico.quantization.wrapq.wrappers.gemma4.quant_vision_encoder", + "tico.quantization.wrapq.wrappers.gemma4.quant_vision_model", # "tico.quantization.wrapq.wrappers.gemma4.quant_multimodal_embedder", # "tico.quantization.wrapq.wrappers.gemma4.quant_model", # "tico.quantization.wrapq.wrappers.gemma4.quant_for_conditional_generation",