Skip to content

[quantization] Implement PTQ wrapper for Gemma4VisionModel with static export support#793

Merged
dvsav merged 1 commit into
Samsung:mainfrom
dvsav:vision
Jun 25, 2026
Merged

[quantization] Implement PTQ wrapper for Gemma4VisionModel with static export support#793
dvsav merged 1 commit into
Samsung:mainfrom
dvsav:vision

Conversation

@dvsav

@dvsav dvsav commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

What

This PR replaces the skeleton QuantGemma4VisionModel wrapper with a full PTQ implementation that decomposes the forward pass into individual submodules (patch embedder, encoder, pooler), adds a static-shape forward_export() path for torch.export, and activates the wrapper in the registry.

Why

The previous QuantGemma4VisionModel was a skeleton that delegated the entire forward pass to self.module() (the original Hugging Face model). This meant:

  • No per-submodule quantization (patch embedder, encoder, pooler were not individually wrapped)
  • No activation observers between submodule boundaries
  • No static-shape export path for Circle conversion
  • The wrapper was commented out in the registry, so it was not usable

The Gemma4 E2B static runtime requires the vision model to be fully quantized with per-submodule observers and a forward_export() path that avoids dynamic operations (conditional branching, dynamic shapes) incompatible with torch.export and Circle conversion.

Key Design Decisions

  1. Separate forward() and forward_export() methods: The runtime forward() supports dynamic shapes and conditional config.standardize branching. The export forward_export() assumes config.standardize=True and uses precomputed static tensors. This follows the same pattern as the existing Llama and Qwen text decoder wrappers.

  2. Separate export adapter attributes: as_export_module() stores submodule export adapters as self.patch_embedder_export and self.pooler_export rather than mutating the original self.patch_embedder and self.pooler wrappers. This preserves the original wrappers for potential re-export with different parameters.

  3. pixel_position_ids required for as_export_module(): The pooler's export adapter needs pixel_position_ids at construction time to precompute static pooling weights (replacing dynamic F.one_hot and torch.div with a static matmul). This is enforced via an assertion.

  4. keep_mask == False instead of ~keep_mask: Replaced aten::bitwise_not with == False comparison in both quant_vision_attention.py and forward_export() to avoid an operator not supported by the Circle conversion pipeline.

  5. register_fake_quant_meta_kernels_for_dynamic_export(): Called during as_export_module() to register fake quantize meta kernels needed for torch.export with dynamic shapes in the encoder path.

  6. Encoder returns plain tensor: The QuantGemma4VisionEncoder wrapper returns a plain tensor rather than BaseModelOutputWithPast. Both forward() and forward_export() handle this with isinstance(output, torch.Tensor) checks.

Changes

  • tico/quantization/wrapq/wrappers/gemma4/quant_vision_model.py — Replaced skeleton with full implementation: decomposed forward() into patch_embedder → encoder → pooler → strip_padding → standardization pipeline; added forward_export() for static-shape export; added as_export_module() with recursive submodule export adapter conversion; registered std_bias/std_scale as buffers; added observers for minus_bias, strip_padding, std_bias, std_scale; added enable_calibration() to collect std_bias/std_scale statistics

  • tico/quantization/wrapq/wrappers/gemma4/export_adapters.py — Added Gemma4VisionModelPrefillExportAdapter that wraps a QuantGemma4VisionModel and delegates forward() to wrapped_model.forward_export()

  • tico/quantization/wrapq/wrappers/gemma4/quant_vision_attention.py — Fixed ~keep_maskkeep_mask == False to avoid aten::bitwise_not which is unsupported by the Circle conversion pipeline

  • tico/quantization/wrapq/wrappers/registry.py — Activated quant_vision_encoder and quant_vision_model entries (uncommented from _CORE_MODULES)

  • test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py — Added 13 unit tests covering: prepare wrapping, no-quant forward parity, mode transitions, observer collection, quant mode output finiteness, config attribute storage, standardize buffer registration, standardize=False path, as_export_module preconditions, forward_export via as_export_module, export adapter attribute creation, submodule wrapping

  • test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_model.py — Added 4 smoke tests covering: no-quant reference parity, prepare-convert flow, as_export_module flow with Gemma4VisionModelPrefillExportAdapter, standardize=False path

  • tico/quantization/recipes/debug/wrapper_smoke/cases/gemma4.py — Added Gemma4VisionModelCase to GEMMA4_CASES with build(), calibration_inputs(), eval_input(), export_module(), and export_input() methods

  • tico/quantization/wrapq/examples/gemma4/quantize_vision_model.py — Added example script demonstrating full PTQ flow: create model, prepare, calibrate, convert, compare FP vs quantized, export to Circle

Tests

  • Unit tests (test_quant_vision_model.py): 13 tests covering wrapper lifecycle, forward parity, mode transitions, observer collection, export module preconditions, and export adapter attribute creation. All pass.
  • Smoke tests (test_quantize_vision_model.py): 4 tests covering end-to-end prepare-calibrate-convert flow, reference parity, and as_export_module flow. Gated behind RUN_INTERNAL_TESTS=1.
  • Smoke case (Gemma4VisionModelCase): Registered in GEMMA4_CASES for the wrapper smoke test framework.
  • Full Gemma4 test suite: 89 passed, 21 skipped (internal tests), 0 failures.

Unit Tests

$ python -m pytest test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py -v
=========================================================================== test session starts ============================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 13 items                                                                                                                                                         

test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_00_prepare_wraps_vision_model_when_registered PASSED            [  7%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_as_export_module_creates_export_adapter_attributes PASSED       [ 15%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_as_export_module_requires_quant_mode PASSED                     [ 23%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_as_export_module_requires_standardize PASSED                    [ 30%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_config_attributes_are_stored PASSED                             [ 38%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_forward_export_via_as_export_module PASSED                      [ 46%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_mode_transitions PASSED                                         [ 53%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_no_quant_forward_matches_hf_vision_model PASSED                 [ 61%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_observers_are_collected PASSED                                  [ 69%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_quant_mode_output_is_finite PASSED                              [ 76%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_standardize_buffers_are_registered PASSED                       [ 84%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_standardize_false_no_buffers PASSED                             [ 92%]
test/quantization/wrapq/wrappers/gemma4/test_quant_vision_model.py::TestQuantGemma4VisionModel::test_submodules_are_wrapped PASSED                                   [100%]

====================================================================== 13 passed, 2 warnings in 5.24s ======================================================================

Internal Tests

$ RUN_INTERNAL_TESTS=1 python -m pytest test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_model.py -v
=========================================================================== test session starts ============================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 4 items                                                                                                                                                          

test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_model.py::TestGemma4VisionModelSmoke::test_as_export_module_flow PASSED                                 [ 25%]
test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_model.py::TestGemma4VisionModelSmoke::test_no_quant_vision_model_matches_reference PASSED               [ 50%]
test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_model.py::TestGemma4VisionModelSmoke::test_prepare_convert_vision_model_flow PASSED                     [ 75%]
test/quantization/wrapq/wrappers/gemma4/test_quantize_vision_model.py::TestGemma4VisionModelSmoke::test_vision_model_with_standardize_false PASSED                   [100%]

====================================================================== 4 passed, 2 warnings in 5.36s =======================================================================

Smoke Tests

python -m tico.quantization.examples.inspect \
    --config tico/quantization/examples/configs/wrapper_smoke.yaml \
    --mode wrapper-smoke \
    --case gemma4_vision_model \
    --export circle \
    --output-dir ./out/wrapper_smoke

[QuantCheck] WARNING: 68 nodes without qparam detected (see logs).
┌───────────── Wrapper Smoke Summary ─────────────
│ Case             : gemma4_vision_model
│ Status           : PASS
│ Mean |diff|      : 0.206352
│ Max |diff|       : 1.385050
│ PEIR             : 0.040837
│ Shape match      : True
│ Quant finite     : True
└─────────────────────────────────────────────────
Artifacts:
  - circle: out/wrapper_smoke/gemma4_vision_model.q.circle
    ┌────────────────────────────────────────────┐
32.7┤                                            │
    │                                            │
    │                                         •  │
26.5┤                                    •       │
    │                                  •         │
    │                               •••          │
    │                                            │
20.3┤                            •               │
    │                          •                 │
    │                       •••                  │
14.1┤                    • ••                    │
    │                   ••                       │
    │                 ••                         │
 7.9┤                                            │
    │            •••                             │
    │           •                                │
    │         •                                  │
 1.6┤     •                                      │
    │                                            │
    │  •                                         │
-4.6┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -4.6        4.7       14.1      23.4      32.7

NOTE: 68 nodes without qparam are caused mostly by QuantGemma4VisionEncoder that is not implemented yet.

Example Script

  • tico/quantization/wrapq/examples/gemma4/quantize_vision_model.py — Demonstrates the complete PTQ pipeline for Gemma4VisionModel:
    1. Creates a tiny Gemma4VisionModel with standardize=True and pooling_kernel_size=2
    2. Prepares the model with PTQConfig()
    3. Calibrates with 20 synthetic samples
    4. Converts to a fake-quantized model
    5. Computes PEIR between FP and quantized outputs
    6. Exports via as_export_module() with pixel_position_ids for pooler precomputation
    7. Converts to Circle format and saves as gemma4_vision_model.q.circle
$ python tico/quantization/wrapq/examples/gemma4/quantize_vision_model.py
Calibrating...
Converting to quantized model...

┌───────────── Quantization Error Summary ─────────────
│ FP output shape    : (4, 32)
│ Quant output shape : (4, 32)
│ Mean |diff|        : 0.078656
│ PEIR               : 1.128772 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
22.0┤                                            │
    │                                       •••  │
    │                                     •••    │
18.2┤                                   •        │
    │                                ••          │
    │                             ••••           │
14.3┤                           •••              │
    │                        •••                 │
10.5┤                     ••••                   │
    │                    •••                     │
    │                 •                          │
 6.7┤              ••••                          │
    │             ••                             │
    │         ••••                               │
 2.9┤       ••                                   │
    │    •••                                     │
    │  ••                                        │
-0.9┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -0.9        4.8       10.5      16.2      22.0 


Converting to Circle format...
/home/d.savchenkov/myenv/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  getattr_node = gm.graph.get_attr(lifted_node)
/home/d.savchenkov/myenv/lib/python3.10/site-packages/torch/fx/graph.py:1801: UserWarning: Node wrapped_model_encoder_wrapped_layers_0_wrapped_self_attn_wrapped_lifted_tensor_0 target wrapped_model.encoder.wrapped.layers.0.wrapped.self_attn.wrapped.lifted_tensor_0 lifted_tensor_0 of wrapped_model.encoder.wrapped.layers.0.wrapped.self_attn.wrapped does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(
/home/d.savchenkov/myenv/lib/python3.10/site-packages/torch/fx/graph.py:1801: UserWarning: Node wrapped_model_encoder_wrapped_layers_0_wrapped_self_attn_wrapped_lifted_tensor_1 target wrapped_model.encoder.wrapped.layers.0.wrapped.self_attn.wrapped.lifted_tensor_1 lifted_tensor_1 of wrapped_model.encoder.wrapped.layers.0.wrapped.self_attn.wrapped does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(
[QuantCheck] WARNING: 68 nodes without qparam detected (see logs).
Circle model saved as 'gemma4_vision_model.q.circle'

NOTE: UserWarnings are caused by QuantGemma4VisionEncoder that is not implemented yet.

@dvsav dvsav force-pushed the vision branch 2 times, most recently from 34c1362 to 907f90f Compare June 24, 2026 09:38
@dvsav dvsav requested review from Torrero and mhs4670go June 24, 2026 09:43
mhs4670go
mhs4670go previously approved these changes Jun 24, 2026

@mhs4670go mhs4670go left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

from transformers.models.gemma4.modeling_gemma4 import BaseModelOutputWithPast

# Create padding mask from pixel_position_ids
padding_positions = (pixel_position_ids == -1).all(dim=-1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe here we can precompute padding_positions for static input image?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Good point, thank you! Implemented.

…c export support

Replace the skeleton Gemma4VisionModel wrapper with a complete implementation

TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>

@Torrero Torrero left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dvsav dvsav merged commit 219adb7 into Samsung:main Jun 25, 2026
7 checks passed
@dvsav dvsav deleted the vision branch June 25, 2026 12:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants