diff --git a/test/quantization/wrapq/wrappers/gemma4/test_quant_multimodal_embedder.py b/test/quantization/wrapq/wrappers/gemma4/test_quant_multimodal_embedder.py new file mode 100644 index 00000000..3e845a90 --- /dev/null +++ b/test/quantization/wrapq/wrappers/gemma4/test_quant_multimodal_embedder.py @@ -0,0 +1,300 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the Gemma4 multimodal embedder PTQ wrapper.""" + +import unittest + +import torch + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.gemma4.quant_multimodal_embedder import ( + QuantGemma4MultimodalEmbedder, +) +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper + + +_SKIP_MSG = "required transformers Gemma4 modules are not installed" + + +def _has_gemma4() -> bool: + """Return whether the installed transformers package provides Gemma4 support.""" + try: + from transformers.models.gemma4.modeling_gemma4 import ( # noqa: F401 + Gemma4MultimodalEmbedder, + ) + except Exception: + return False + return True + + +def _make_vision_config(**overrides): + """Create a tiny Gemma4 vision config for synthetic tests.""" + from transformers.models.gemma4.configuration_gemma4 import Gemma4VisionConfig + + kwargs = dict( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + patch_size=4, + position_embedding_size=8, + pooling_kernel_size=2, + attention_dropout=0.0, + max_position_embeddings=128, + rms_norm_eps=1e-6, + use_clipped_linears=False, + rope_parameters={"rope_type": "default", "rope_theta": 100.0}, + standardize=True, + ) + kwargs.update(overrides) + cfg = Gemma4VisionConfig(**kwargs) + if hasattr(cfg, "_attn_implementation"): + cfg._attn_implementation = "eager" + else: + setattr(cfg, "_attn_implementation", "eager") + return cfg + + +def _make_text_config(**overrides): + """Create a tiny Gemma4 text config for synthetic tests.""" + from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + + kwargs = dict( + vocab_size=256, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + num_global_key_value_heads=2, + head_dim=32, + global_head_dim=32, + max_position_embeddings=128, + layer_types=["full_attention"], + rope_parameters={ + "full_attention": {"rope_type": "default", "rope_theta": 10000.0} + }, + attention_bias=False, + attention_dropout=0.0, + use_cache=False, + enable_moe_block=False, + ) + kwargs.update(overrides) + cfg = Gemma4TextConfig(**kwargs) + if hasattr(cfg, "_attn_implementation"): + cfg._attn_implementation = "eager" + else: + setattr(cfg, "_attn_implementation", "eager") + return cfg + + +def _make_multimodal_embedder(vision_cfg=None, text_cfg=None): + """Create a tiny Gemma4MultimodalEmbedder for testing.""" + from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder + + vision_cfg = vision_cfg if vision_cfg is not None else _make_vision_config() + text_cfg = text_cfg if text_cfg is not None else _make_text_config() + return Gemma4MultimodalEmbedder(vision_cfg, text_cfg).eval() + + +@unittest.skipUnless(_has_gemma4(), _SKIP_MSG) +class TestQuantGemma4MultimodalEmbedder(unittest.TestCase): + """Validate Gemma4 multimodal embedder wrapper behavior.""" + + def setUp(self): + """Create deterministic inputs.""" + torch.manual_seed(2026) + self.vision_cfg = _make_vision_config() + self.text_cfg = _make_text_config() + self.batch_size = 1 + self.seq_len = 16 + self.multimodal_hidden_size = self.vision_cfg.hidden_size + self.text_hidden_size = self.text_cfg.hidden_size + + def _sample_inputs(self): + """Create synthetic inputs.""" + inputs_embeds = torch.randn( + self.batch_size, self.seq_len, self.multimodal_hidden_size + ) + return (inputs_embeds,) + + # ------------------------------------------------------------------ + # NO_QUANT mode + # ------------------------------------------------------------------ + + def test_no_quant_forward_matches_fp(self): + """In NO_QUANT mode the wrapper should match the floating-point module.""" + fp_module = _make_multimodal_embedder(self.vision_cfg, self.text_cfg) + q_module = QuantGemma4MultimodalEmbedder(fp_module).eval() + + self.assertIs(q_module._mode, Mode.NO_QUANT) + + inputs = self._sample_inputs() + with torch.no_grad(): + q_out = q_module(*inputs) + fp_out = fp_module(*inputs) + + # Shapes must match + self.assertEqual(q_out.shape, fp_out.shape) + + # Values must be close (the wrapper delegates to original operations) + self.assertTrue(torch.allclose(q_out, fp_out, atol=1e-5, rtol=1e-5)) + + def test_no_quant_output_shape(self): + """Check that the output has the expected shape (B, seq_len, text_hidden_size).""" + fp_module = _make_multimodal_embedder(self.vision_cfg, self.text_cfg) + q_module = QuantGemma4MultimodalEmbedder(fp_module).eval() + + inputs = self._sample_inputs() + with torch.no_grad(): + output = q_module(*inputs) + + expected_shape = (self.batch_size, self.seq_len, self.text_hidden_size) + self.assertEqual(output.shape, expected_shape) + + # ------------------------------------------------------------------ + # Mode transitions + # ------------------------------------------------------------------ + + def test_mode_transitions(self): + """Check the calibration lifecycle: NO_QUANT → CALIB → QUANT.""" + fp_module = _make_multimodal_embedder(self.vision_cfg, self.text_cfg) + q_module = QuantGemma4MultimodalEmbedder(fp_module).eval() + + self.assertIs(q_module._mode, Mode.NO_QUANT) + + q_module.enable_calibration() + self.assertIs(q_module._mode, Mode.CALIB) + + inputs = self._sample_inputs() + with torch.no_grad(): + _ = q_module(*inputs) + + q_module.freeze_qparams() + self.assertIs(q_module._mode, Mode.QUANT) + + def test_observers_are_collected(self): + """Check that _all_observers returns no direct observers (delegated to sub-wrappers).""" + fp_module = _make_multimodal_embedder(self.vision_cfg, self.text_cfg) + q_module = QuantGemma4MultimodalEmbedder(fp_module).eval() + + all_obs = list(q_module._all_observers()) + # The multimodal embedder has no own observers; quantization is + # handled by the sub-wrappers (QuantGemma4RMSNorm, QuantLinear). + self.assertEqual(len(all_obs), 0) + + # ------------------------------------------------------------------ + # Calibration and fake quantization + # ------------------------------------------------------------------ + + def test_quant_mode_output_is_finite(self): + """In QUANT mode the output should be finite and have the correct shape.""" + fp_module = _make_multimodal_embedder(self.vision_cfg, self.text_cfg) + q_module = QuantGemma4MultimodalEmbedder(fp_module).eval() + q_module.enable_calibration() + + inputs = self._sample_inputs() + with torch.no_grad(): + _ = q_module(*inputs) + q_module.freeze_qparams() + + with torch.no_grad(): + output = q_module(*inputs) + + expected_shape = (self.batch_size, self.seq_len, self.text_hidden_size) + self.assertEqual(output.shape, expected_shape) + self.assertTrue(torch.isfinite(output).all()) + + # ------------------------------------------------------------------ + # Submodule wrapping + # ------------------------------------------------------------------ + + def test_submodules_are_wrapped(self): + """Check that embedding_pre_projection_norm and embedding_projection are wrapped.""" + fp_module = _make_multimodal_embedder(self.vision_cfg, self.text_cfg) + q_module = QuantGemma4MultimodalEmbedder(fp_module).eval() + + self.assertIsInstance(q_module.embedding_pre_projection_norm, PTQWrapper) + self.assertIsInstance(q_module.embedding_projection, PTQWrapper) + + # ------------------------------------------------------------------ + # Config attributes + # ------------------------------------------------------------------ + + def test_config_attributes_are_stored(self): + """Check that multimodal_hidden_size and text_hidden_size are accessible.""" + fp_module = _make_multimodal_embedder(self.vision_cfg, self.text_cfg) + q_module = QuantGemma4MultimodalEmbedder(fp_module).eval() + + self.assertEqual(fp_module.multimodal_hidden_size, self.multimodal_hidden_size) + self.assertEqual(fp_module.text_hidden_size, self.text_hidden_size) + + # ------------------------------------------------------------------ + # as_export_module + # ------------------------------------------------------------------ + + def test_as_export_module_returns_self(self): + """as_export_module should return self (this wrapper is already exportable).""" + fp_module = _make_multimodal_embedder(self.vision_cfg, self.text_cfg) + q_module = QuantGemma4MultimodalEmbedder(fp_module).eval() + q_module.enable_calibration() + + inputs = self._sample_inputs() + with torch.no_grad(): + _ = q_module(*inputs) + q_module.freeze_qparams() + + export_module = q_module.as_export_module(mode="prefill") + self.assertIs(export_module, q_module) + + def test_as_export_module_forward_matches_quant_forward(self): + """Export module forward should produce the same output as quant forward.""" + fp_module = _make_multimodal_embedder(self.vision_cfg, self.text_cfg) + q_module = QuantGemma4MultimodalEmbedder(fp_module).eval() + q_module.enable_calibration() + + inputs = self._sample_inputs() + with torch.no_grad(): + _ = q_module(*inputs) + q_module.freeze_qparams() + + export_module = q_module.as_export_module(mode="prefill") + + with torch.no_grad(): + quant_out = q_module(*inputs) + export_out = export_module(*inputs) + + self.assertTrue(torch.allclose(quant_out, export_out, atol=1e-6)) + + # ------------------------------------------------------------------ + # prepare integration + # ------------------------------------------------------------------ + + def test_prepare_wraps_multimodal_embedder_when_registered(self): + """Check that registry-based prepare wraps Gemma4MultimodalEmbedder.""" + from tico.quantization import prepare + + fp_module = _make_multimodal_embedder(self.vision_cfg, self.text_cfg) + prepared = prepare(fp_module, PTQConfig()) + + self.assertIsInstance(prepared, PTQWrapper) + self.assertIsInstance(prepared.wrapped, QuantGemma4MultimodalEmbedder) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/quantization/wrapq/wrappers/gemma4/test_quantize_multimodal_embedder.py b/test/quantization/wrapq/wrappers/gemma4/test_quantize_multimodal_embedder.py new file mode 100644 index 00000000..201eb9c9 --- /dev/null +++ b/test/quantization/wrapq/wrappers/gemma4/test_quantize_multimodal_embedder.py @@ -0,0 +1,200 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Smoke tests for Gemma4 multimodal embedder prepare-calibrate-convert flow.""" + +import copy +import os +import unittest + +import torch + +from tico.quantization import convert, prepare +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper + + +IS_INTERNAL_TEST = os.environ.get("RUN_INTERNAL_TESTS", "0") == "1" +_SKIP_MSG = "required transformers Gemma4 modules are not installed" + + +def _has_gemma4() -> bool: + """Return whether the installed transformers package provides Gemma4 support.""" + try: + from transformers.models.gemma4.modeling_gemma4 import ( # noqa: F401 + Gemma4MultimodalEmbedder, + ) + except Exception: + return False + return True + + +def _make_vision_config(): + """Create a tiny Gemma4 vision config for synthetic smoke tests.""" + from transformers.models.gemma4.configuration_gemma4 import Gemma4VisionConfig + + cfg = Gemma4VisionConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + patch_size=4, + position_embedding_size=8, + pooling_kernel_size=2, + attention_dropout=0.0, + max_position_embeddings=128, + rms_norm_eps=1e-6, + use_clipped_linears=False, + rope_parameters={"rope_type": "default", "rope_theta": 100.0}, + standardize=True, + ) + if not hasattr(cfg, "_attn_implementation"): + setattr(cfg, "_attn_implementation", "eager") + else: + cfg._attn_implementation = "eager" + return cfg + + +def _make_text_config(): + """Create a tiny Gemma4 text config for synthetic smoke tests.""" + from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + + cfg = Gemma4TextConfig( + vocab_size=256, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + num_global_key_value_heads=2, + head_dim=32, + global_head_dim=32, + max_position_embeddings=128, + layer_types=["full_attention"], + rope_parameters={ + "full_attention": {"rope_type": "default", "rope_theta": 10000.0} + }, + attention_bias=False, + attention_dropout=0.0, + use_cache=False, + enable_moe_block=False, + ) + if not hasattr(cfg, "_attn_implementation"): + setattr(cfg, "_attn_implementation", "eager") + else: + cfg._attn_implementation = "eager" + return cfg + + +@unittest.skipIf( + not IS_INTERNAL_TEST, + "Internal smoke test — set RUN_INTERNAL_TESTS=1 to enable it.", +) +@unittest.skipUnless(_has_gemma4(), _SKIP_MSG) +class TestGemma4MultimodalEmbedderSmoke(unittest.TestCase): + """Exercise Gemma4 multimodal embedder wrapper parity and PTQ flow.""" + + def setUp(self): + """Create deterministic tiny Gemma4 multimodal embedder modules.""" + torch.manual_seed(2026) + from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder + + self.vision_cfg = _make_vision_config() + self.text_cfg = _make_text_config() + self.fp_embedder = Gemma4MultimodalEmbedder( + self.vision_cfg, self.text_cfg + ).eval() + self.fp_ref = copy.deepcopy(self.fp_embedder).eval() + self.seq_len = 16 + self.multimodal_hidden_size = self.fp_embedder.multimodal_hidden_size + self.text_hidden_size = self.fp_embedder.text_hidden_size + + def _sample(self): + """Create one synthetic Gemma4 multimodal embedder sample.""" + batch_size = 1 + return { + "inputs_embeds": torch.randn( + batch_size, self.seq_len, self.multimodal_hidden_size + ), + } + + def test_no_quant_multimodal_embedder_matches_reference(self): + """The wrapper should match the floating-point module before quantization.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_multimodal_embedder import ( + QuantGemma4MultimodalEmbedder, + ) + + wrapped = QuantGemma4MultimodalEmbedder( + self.fp_embedder, qcfg=PTQConfig() + ).eval() + sample = self._sample() + + with torch.no_grad(): + quant_out = wrapped(**sample) + fp_out = self.fp_ref(**sample) + + self.assertEqual(quant_out.shape, fp_out.shape) + self.assertTrue(torch.allclose(quant_out, fp_out, atol=1e-5, rtol=1e-5)) + + def test_prepare_convert_multimodal_embedder_flow(self): + """Quantize Gemma4 multimodal embedder and validate a synthetic output.""" + from tico.quantization.wrapq.wrappers.gemma4.quant_multimodal_embedder import ( + QuantGemma4MultimodalEmbedder, + ) + + prepared = prepare(self.fp_embedder, PTQConfig()) + self.assertIsInstance(prepared, PTQWrapper) + self.assertIsInstance(prepared.wrapped, QuantGemma4MultimodalEmbedder) + + with torch.no_grad(): + for _ in range(3): + prepared(**self._sample()) + + quantized = convert(prepared) + self.assertIs(quantized._mode, Mode.QUANT) + + sample = self._sample() + with torch.no_grad(): + quant_out = quantized(**sample) + fp_out = self.fp_ref(**sample) + + self.assertEqual(quant_out.shape, fp_out.shape) + self.assertTrue(torch.isfinite(quant_out).all()) + + def test_as_export_module_flow(self): + """Test the as_export_module flow for Circle export.""" + prepared = prepare(self.fp_embedder, PTQConfig()) + + with torch.no_grad(): + for _ in range(3): + prepared(**self._sample()) + + quantized = convert(prepared) + + export_module = quantized.wrapped.as_export_module(mode="prefill") + + # Verify export module forward works + sample = self._sample() + with torch.no_grad(): + out = export_module(**sample) + + self.assertEqual(out.shape, (1, self.seq_len, self.text_hidden_size)) + self.assertTrue(torch.isfinite(out).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py b/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py index 4de149e4..90aa6d36 100644 --- a/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py +++ b/tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py @@ -1264,6 +1264,72 @@ def export_input( return ForwardInput((pixel_values, pixel_position_ids), {}) +class Gemma4MultimodalEmbedderCase(Gemma4BaseCase): + """Smoke case for one tiny Gemma4 multimodal embedder module.""" + + name = "gemma4_multimodal_embedder" + description = ( + "Quantize one tiny Gemma4 multimodal embedder (RMSNorm + Linear projection)." + ) + tags = ("gemma4", "e2b", "multimodal", "embedder") + max_mean_abs_diff = 2.0 + seq_len = 16 + + def build(self, cfg: Mapping[str, Any]) -> tuple[torch.nn.Module, torch.nn.Module]: + """Build a tiny Gemma4 multimodal embedder and reference copy.""" + from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder + + torch.manual_seed(123) + self.vision_cfg = _make_vision_config() + self.text_cfg = _make_text_config() + module = Gemma4MultimodalEmbedder(self.vision_cfg, self.text_cfg).eval() + return module, clone_module(module) + + def _sample(self) -> ForwardInput: + """Create one synthetic Gemma4 multimodal embedder input.""" + batch_size = 1 + inputs_embeds = torch.randn( + batch_size, self.seq_len, self.vision_cfg.hidden_size + ) + return ForwardInput((inputs_embeds,)) + + def calibration_inputs( + self, + prepared: torch.nn.Module, + cfg: Mapping[str, Any], + ) -> list[ForwardInput]: + """Create Gemma4 multimodal embedder calibration samples.""" + return [self._sample() for _ in range(3)] + + def eval_input( + self, + prepared: torch.nn.Module, + cfg: Mapping[str, Any], + ) -> ForwardInput: + """Create the Gemma4 multimodal embedder evaluation sample.""" + return self._sample() + + def export_module( + self, quantized: torch.nn.Module, cfg: Mapping[str, Any] + ) -> torch.nn.Module: + """Export the wrapped multimodal embedder in prefill mode.""" + wrapped = getattr(quantized, "wrapped", quantized) + if hasattr(wrapped, "as_export_module"): + return wrapped.as_export_module(mode="prefill").eval() + return quantized + + def export_input( + self, eval_sample: ForwardInput, cfg: Mapping[str, Any] + ) -> ForwardInput: + """Create static export inputs expected by the multimodal embedder adapter. + + The export adapter's forward() takes inputs_embeds. + """ + cloned = _clone_forward_input(eval_sample) + inputs_embeds = cloned.args[0] + return ForwardInput((inputs_embeds,), {}) + + GEMMA4_CASES = ( Gemma4TextMLPCase(), Gemma4TextAttentionCase(), @@ -1281,4 +1347,5 @@ def export_input( Gemma4VisionEncoderLayerCase(), Gemma4VisionPoolerCase(), Gemma4VisionModelCase(), + Gemma4MultimodalEmbedderCase(), ) diff --git a/tico/quantization/wrapq/examples/gemma4/quantize_multimodal_embedder.py b/tico/quantization/wrapq/examples/gemma4/quantize_multimodal_embedder.py new file mode 100644 index 00000000..ae5eb0c2 --- /dev/null +++ b/tico/quantization/wrapq/examples/gemma4/quantize_multimodal_embedder.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example: PTQ quantization of Gemma4MultimodalEmbedder. + +The Gemma4 multimodal embedder projects multimodal soft tokens (e.g. visual +features from the vision model) into the text model's hidden space. It applies: + +1. RMS normalization (``embedding_pre_projection_norm``) +2. Linear projection (``embedding_projection``) + +It accepts: + +- ``inputs_embeds``: Soft token embeddings of shape ``(B, seq_len, multimodal_hidden_size)`` + +and produces text-hidden-space embeddings of shape ``(B, seq_len, text_hidden_size)``. + +This script demonstrates the full PTQ flow: + +1. Create a tiny Gemma4MultimodalEmbedder with random weights (no download needed). +2. Prepare the model for quantization. +3. Calibrate with synthetic data. +4. Convert to a fake-quantized model. +5. Compare FP vs. quantized outputs. +6. Export and convert to Circle format. +""" + +import copy +import sys + +import torch + +import tico +import tico.quantization +import tico.quantization.config.ptq +from tico.quantization.evaluation.metric import compute_peir +from tico.quantization.evaluation.utils import plot_two_outputs +from tico.quantization.wrapq.utils.version import has_transformers_for + +torch.manual_seed(123) + + +# Check if transformers is available +if not has_transformers_for("gemma4"): + print( + "Error: transformers package with Gemma4 support not installed. " + "Cannot test Gemma4MultimodalEmbedder." + ) + sys.exit(1) + +from transformers.models.gemma4.configuration_gemma4 import ( + Gemma4TextConfig, + Gemma4VisionConfig, +) +from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder + + +def generate_calibration_data( + batch_size: int, + seq_len: int, + multimodal_hidden_size: int, + num_samples: int = 20, +) -> list[dict]: + """Generate calibration data for PTQ. + + Each sample is a dict of keyword arguments matching the multimodal embedder's + forward signature: ``inputs_embeds``. + """ + calibration_data = [] + for _ in range(num_samples): + sample = { + "inputs_embeds": torch.randn(batch_size, seq_len, multimodal_hidden_size), + } + calibration_data.append(sample) + return calibration_data + + +def main(): + # Create tiny configs for the multimodal embedder (no download needed). + # Gemma4MultimodalEmbedder requires a multimodal_config and a text_config. + vision_cfg = Gemma4VisionConfig( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + patch_size=4, + position_embedding_size=8, + pooling_kernel_size=2, + attention_dropout=0.0, + max_position_embeddings=128, + rms_norm_eps=1e-6, + use_clipped_linears=False, + rope_parameters={"rope_type": "default", "rope_theta": 100.0}, + ) + if not hasattr(vision_cfg, "_attn_implementation"): + setattr(vision_cfg, "_attn_implementation", "eager") + else: + vision_cfg._attn_implementation = "eager" + + text_cfg = Gemma4TextConfig( + vocab_size=256, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + num_global_key_value_heads=2, + head_dim=32, + global_head_dim=32, + max_position_embeddings=128, + layer_types=["full_attention"], + rope_parameters={ + "full_attention": {"rope_type": "default", "rope_theta": 10000.0} + }, + attention_bias=False, + attention_dropout=0.0, + use_cache=False, + enable_moe_block=False, + ) + if not hasattr(text_cfg, "_attn_implementation"): + setattr(text_cfg, "_attn_implementation", "eager") + else: + text_cfg._attn_implementation = "eager" + + model = Gemma4MultimodalEmbedder(vision_cfg, text_cfg) + orig_model = copy.deepcopy(model) + model.eval() + + # Gemma4MultimodalEmbedder( + # (embedding_pre_projection_norm): Gemma4RMSNorm(32, eps=1e-06) + # (embedding_projection): Linear(32, 64, bias=False) + # ) + multimodal_hidden_size = model.multimodal_hidden_size + text_hidden_size = model.text_hidden_size + assert multimodal_hidden_size == 32 + assert text_hidden_size == 64 + + # Generate calibration data + batch_size = 1 + seq_len = 16 + + calibration_data = generate_calibration_data( + batch_size=batch_size, + seq_len=seq_len, + multimodal_hidden_size=multimodal_hidden_size, + num_samples=20, + ) + + # Configure PTQ + ptq_config = tico.quantization.config.ptq.PTQConfig() + + # Prepare the model for quantization + prepared_model = tico.quantization.prepare( + model, ptq_config, inplace=True # Transform the model in place + ) + + # Calibrate the model (collect statistics) + print("Calibrating...") + with torch.no_grad(): + for sample in calibration_data: + prepared_model(**sample) + + # Convert to quantized model + print("Converting to quantized model...") + quantized_model = tico.quantization.convert(prepared_model, inplace=True) + + # Compute PEIR (Peak Error-to-Interval Ratio) between quantized model and original model + eval_sample = calibration_data[0] + with torch.no_grad(): + quant_out = quantized_model(**eval_sample) + fp_out = orig_model(**eval_sample) + + print(f"\n┌───────────── Quantization Error Summary ─────────────") + print(f"│ FP output shape : {tuple(fp_out.shape)}") + print(f"│ Quant output shape : {tuple(quant_out.shape)}") + print(f"│ Mean |diff| : {(quant_out - fp_out).abs().mean().item():.6f}") + print(f"│ PEIR : {compute_peir(fp_out, quant_out) * 100:.6f} %") + print(f"└──────────────────────────────────────────────────────") + print(plot_two_outputs(fp_out, quant_out)) + + # Export and convert to Circle format. + # The multimodal embedder is a simple sequential module (RMSNorm + Linear), + # so as_export_module returns self. + wrapped = getattr(quantized_model, "wrapped", quantized_model) + if hasattr(wrapped, "as_export_module"): + export_module = wrapped.as_export_module(mode="prefill").eval() + + example_inputs = ( + torch.randn(batch_size, seq_len, multimodal_hidden_size), # inputs_embeds + ) + + print("\nConverting to Circle format...") + circle_model = tico.convert(export_module, example_inputs) + + filename = "gemma4_multimodal_embedder.q.circle" + circle_model.save(filename) + print(f"Circle model saved as '{filename}'") + else: + print("Note: as_export_module not available; skipping Circle export.") + + +if __name__ == "__main__": + main() diff --git a/tico/quantization/wrapq/wrappers/gemma4/quant_multimodal_embedder.py b/tico/quantization/wrapq/wrappers/gemma4/quant_multimodal_embedder.py index e38c1368..cd31cd60 100644 --- a/tico/quantization/wrapq/wrappers/gemma4/quant_multimodal_embedder.py +++ b/tico/quantization/wrapq/wrappers/gemma4/quant_multimodal_embedder.py @@ -53,6 +53,10 @@ def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: self.embedding_pre_projection_norm(inputs_embeds) ) + def as_export_module(self, mode: str = "prefill", **kwargs) -> nn.Module: + """Return self for export (this wrapper is already exportable).""" + return self + def _all_observers(self) -> Iterable: """Return observers owned directly by this wrapper.""" return () diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index dfa9f49a..83d219d3 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -78,7 +78,7 @@ "tico.quantization.wrapq.wrappers.gemma4.quant_vision_encoder_layer", "tico.quantization.wrapq.wrappers.gemma4.quant_vision_encoder", "tico.quantization.wrapq.wrappers.gemma4.quant_vision_model", - # "tico.quantization.wrapq.wrappers.gemma4.quant_multimodal_embedder", + "tico.quantization.wrapq.wrappers.gemma4.quant_multimodal_embedder", # "tico.quantization.wrapq.wrappers.gemma4.quant_model", # "tico.quantization.wrapq.wrappers.gemma4.quant_for_conditional_generation", # add future core wrappers here