diff --git a/docs/source/features/quantization.rst b/docs/source/features/quantization.rst index 65b03a5..2761e15 100644 --- a/docs/source/features/quantization.rst +++ b/docs/source/features/quantization.rst @@ -81,8 +81,8 @@ Uses standard PyTorch quantization APIs (``GenericTinyMLQATFxModule`` / quantization_weight_bitwidth: 8 quantization_activation_bitwidth: 8 -For more details on the underlying wrappers, see -:ref:`quantization-wrapper-architecture` below. +For more details on the underlying wrappers, see the +`tinyml-modeloptimization documentation `_. **TI Style Optimised Quantization (quantization: 2)** @@ -331,7 +331,7 @@ Example: Full Quantization Workflow variables: 1 training: - model_name: 'CLS_4k_NPU' + model_name: 'ArcFault_model_400_t' training_epochs: 30 batch_size: 256 quantization: 2 @@ -410,337 +410,108 @@ Even without NPU, integer operations are faster: Float32: ~5000 µs INT8: ~2000 µs -.. _quantization-wrapper-architecture: - -Quantization Wrapper Architecture ----------------------------------- - -Under the hood, Tiny ML Tensorlab uses quantization wrapper classes from -the ``tinyml-modeloptimization`` package. Understanding the wrapper -architecture helps when customizing quantization or debugging. - -**Class Hierarchy:** - -.. code-block:: text - - TinyMLQuantFxBaseModule (base class) - ├── TINPUTinyMLQuantFxModule - │ ├── TINPUTinyMLQATFxModule (quantization: 2, QAT) - │ └── TINPUTinyMLPTQFxModule (quantization: 2, PTQ) - │ - └── GenericTinyMLQuantFxModule - ├── GenericTinyMLQATFxModule (quantization: 1, QAT) - └── GenericTinyMLPTQFxModule (quantization: 1, PTQ) - -**TINPUTinyML wrappers** (``quantization: 2``) incorporate the constraints -of TI NPU Hardware accelerator. They perform extensive graph transformations -including 13+ layer pattern replacements to produce NPU-compatible integer -operations. Key characteristics: - -* Enforces power-of-2 scale factors (mandatory for 8-bit quantization) -* Transforms convolution, pooling, linear, and batch normalization layers - to NPU-compatible patterns -* Implements the NPU BNORM sequence: - ``Add (bias) → Mul (scale) → Div (2^n, right shift) → Floor → Clip`` -* All operations in integer domain, no dequantization step - -**GenericTinyML wrappers** (``quantization: 1``) use standard PyTorch -quantization APIs with minimal modifications, relying on ONNX Runtime for -optimization. Key characteristics: - -* Flexible scaling (no power-of-2 constraint) -* Only 1 pattern replacement (permute + unsqueeze) -* Uses PyTorch's native quantized operations -* Relies on ONNX Runtime optimization for deployment - -.. note:: - - When using the toolchain via YAML configs, you do not need to interact - with these wrapper classes directly. Setting ``quantization: 1`` or - ``quantization: 2`` in the config selects the appropriate wrapper - automatically. - -NPU Hardware Constraints +QAT Training Performance ------------------------ -When using TI style optimised quantization (``quantization: 2``), the -following hardware constraints are enforced automatically by the TINPU -wrapper: - -**Channel Alignment:** - -Input and output channels must be multiples of 4. The NPU processes data -in SIMD fashion with 4-channel vectors. - -.. list-table:: - :header-rows: 1 - :widths: 25 35 40 - - * - Layer Type - - Channel Requirement - - Notes - * - FCONV (First Conv) - - Input: exactly 1, Output: multiple of 4 - - First layer in the network - * - GCONV (Generic Conv) - - Input and Output: multiple of 4 - - General convolution layers - * - DWCONV (Depthwise Conv) - - Input/Output: multiple of 4 - - Depthwise separable layers - * - PWCONV (Pointwise Conv) - - Input/Output: multiple of 4 - - 1x1 convolution layers - * - FC (Fully Connected) - - Input: multiple of 4 - - Dense/linear layers - -**Power-of-2 Scaling:** - -For 8-bit quantization, scale factors must be powers of 2. This enables -efficient implementation as bit shifts in hardware, avoiding expensive -division operations. For sub-8-bit quantization (4-bit, 2-bit), -non-power-of-2 scales are supported and may provide better accuracy. - -**Bitwidth Constraints:** - -.. list-table:: - :header-rows: 1 - :widths: 20 30 50 - - * - Parameter - - Allowed Values - - Notes - * - Weight bitwidth - - 2, 4, or 8 bits (signed) - - Determines model compression ratio - * - Activation bitwidth - - 8 bits (signed or unsigned) - - Fixed at 8 bits for NPU acceleration - * - Bias - - 16-bit (2b/4b weights), 24-bit (8b weights) - - Automatically computed - * - Scale - - 8-bit unsigned (2b/4b), power-of-2 shift (8b) - - Automatically computed - -**Supported NPU Layer Patterns:** - -The NPU accelerates the following layer types: FCONV, GCONV, DWCONV, -PWCONV, FC, AVGPOOL, MAXPOOL. Each layer includes a BNORM sequence -(bias → scale → shift → floor → clip) that maps directly to NPU hardware -units. - -.. warning:: +Quantization-Aware Training is significantly slower than float training. +This section explains why, and which factors dominate the overhead. - Models with layers that do not meet NPU constraints will fall back to - CPU execution for those layers. Use ``quantization: 1`` (Generic) for - models that cannot satisfy these constraints. +**FakeQuantize Nodes in the Forward Pass** -Using Quantization Wrappers Directly -------------------------------------- +``prepare_qat_fx`` rewrites the model's FX graph by inserting a +``FakeQuantize`` module at every weight tensor and every activation +output. For an N-layer model this adds at least ``2N + 1`` extra +operations to both the forward and backward pass, on every batch. -For advanced users who want to use the quantization wrappers outside the -Tiny ML Tensorlab toolchain (e.g., in custom PyTorch training scripts), -the wrappers can be imported and used directly. +Each ``FakeQuantize`` node performs, per batch: -**TINPU QAT Example:** +1. **Observer forward** — runs a ``torch.min`` / ``torch.max`` reduction + over the full activation or weight tensor to update running statistics. +2. **Scale computation** (``_calculate_qparams``) — derives ``scale`` and + ``zero_point`` from the stored statistics. +3. **Power-of-2 scale snapping** — TI's NPU requires power-of-2 scales. + ``ceil2_tensor`` computes ``torch.pow(2, torch.ceil(torch.log2(x)))`` + and also calls ``x.data.abs().sum()`` which **forces a device-to-host + synchronisation** — the same class of GPU pipeline stall that the + deferred ``.item()`` optimisation eliminates for metric logging, but + here it occurs in every layer, every batch. +4. **Fake-quantize operation** — ``torch.fake_quantize_per_tensor_affine`` + performs ``round(x / scale) * scale`` with STE gradient propagation. -.. code-block:: python +**Soft-Quantize Variants (4-bit and 2-bit)** - from tinyml_torchmodelopt.quantization import TINPUTinyMLQATFxModule +Lower bit widths use ``SoftSigmoidFakeQuantize`` (4-bit) or +``SoftTanhFakeQuantize`` (2-bit), which run the standard ``FakeQuantize`` +forward AND then a second full quantize-dequantize pass with +sigmoid- or tanh-based differentiable rounding over the flattened +activation tensor. This roughly triples the per-node cost compared +to standard 8-bit ``FakeQuantize``. - # Create and pretrain your model - model = MyNeuralNetwork() - model.load_state_dict(torch.load('pretrained.pth')) +**Backward Pass Complexity** - # Wrap with TINPU quantization - model = TINPUTinyMLQATFxModule(model, total_epochs=epochs) +Every ``FakeQuantize`` node adds autograd nodes to the computation +graph. The soft-quantize variants additionally record ``floor``, +``detach``, ``sigmoid`` / ``tanh``, ``clone``, and STE propagation +nodes. The backward graph is substantially larger than the float +model's graph. - # Train the wrapped model (your usual training loop) - model.train() - for e in range(epochs): - for images, target in train_loader: - output = model(images) - # loss, backward(), optimizer step as usual +**Per-Epoch Module Traversals** - model.eval() +The QAT wrapper overrides ``model.train()`` to perform three full +module-tree traversals every epoch: - # Convert to integer operations - model = model.convert() +1. ``self.apply(enable_observer)`` or ``self.apply(disable_observer)`` +2. ``self.apply(update_bn_stats)`` or ``self.apply(freeze_bn_stats)`` +3. ``for m in self.modules()`` to update soft-quantize temperatures - # Export to ONNX - dummy_input = torch.rand((1, 1, 256, 1)) - model.export(dummy_input, 'model_int8.onnx', input_names=['input']) +**torch.compile Ordering** -**Generic QAT Example:** +In the current training flow, ``torch.compile`` is applied to the float +model *before* ``prepare_qat_fx`` rewrites the graph. The FX graph +transformation discards the compiled version, so the QAT model runs +in eager mode while the float model benefits from fused kernels. +This is likely the single largest factor in the speed difference. -.. code-block:: python - - from tinyml_torchmodelopt.quantization import GenericTinyMLQATFxModule - - # Create and pretrain your model - model = MyNeuralNetwork() - model.load_state_dict(torch.load('pretrained.pth')) - - # Wrap with Generic quantization - model = GenericTinyMLQATFxModule(model, total_epochs=epochs) - - # Train, convert, and export (same API as TINPU) - # ... - model = model.convert() - model.export(dummy_input, 'model_int8.onnx', input_names=['input']) - -**PTQ (Post-Training Quantization):** - -For PTQ, replace the QAT module with the PTQ variant. PTQ only requires -a calibration pass (forward pass on representative data) instead of full -retraining: - -.. code-block:: python - - from tinyml_torchmodelopt.quantization import TINPUTinyMLPTQFxModule - - model = TINPUTinyMLPTQFxModule(model, total_epochs=1) - - # Calibration pass (no backward, no optimizer) - model.eval() - with torch.no_grad(): - for images, _ in calibration_loader: - model(images) - - model = model.convert() - model.export(dummy_input, 'model_int8.onnx', input_names=['input']) - -**Evaluating Exported ONNX Models:** - -After exporting, you can evaluate the quantized ONNX model using ONNX -Runtime: - -.. code-block:: python - - import onnxruntime as ort - - ort_session_options = ort.SessionOptions() - ort_session_options.graph_optimization_level = ( - ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED - ) - - ort_session = ort.InferenceSession('model_int8.onnx', ort_session_options) - prediction = ort_session.run(None, {'input': example_input.numpy()}) - -Wrapper API Reference ---------------------- - -All quantization wrappers inherit from ``TinyMLQuantFxBaseModule``, which -accepts the following parameters: - -.. list-table:: +.. list-table:: QAT Overhead Summary :header-rows: 1 - :widths: 30 10 60 - - * - Parameter - - Type - - Description - * - ``model`` - - nn.Module - - The PyTorch model to quantize - * - ``qconfig_type`` - - dict/None - - QConfig mapping for quantization. ``None`` uses wrapper defaults. - * - ``example_inputs`` - - Tensor - - Example input with batch size 1 - * - ``is_qat`` - - bool - - Toggle between QAT (True) and PTQ (False) - * - ``backend`` - - str - - Backend: ``'qnnpack'`` (Linux) or ``'fbgemm'`` (Windows). - Automatically selected. - * - ``total_epochs`` - - int - - Total number of quantized training epochs - * - ``num_batch_norm_update_epochs`` - - bool/int - - BatchNorm freezing control (see below) - * - ``num_observer_update_epochs`` - - bool/int - - Observer freezing control (see below) - * - ``bias_calibration_factor`` - - float - - Bias calibration factor (0.0 = disabled) - * - ``verbose`` - - bool - - Enable verbose logging - * - ``float_ops`` - - bool - - Use float bias for Conv/Linear layers. Increases accuracy but - disables BNORM on TINPU hardware. - -**BatchNorm Freezing (``num_batch_norm_update_epochs``):** - -* ``None`` (default): Freezes BatchNorm statistics at the midpoint of - training -* ``False``: Never freezes BatchNorm (may cause overfitting) -* Integer value: Freezes after the specified epoch. Best results with - half to 3/4 of total epochs. - -**Observer Freezing (``num_observer_update_epochs``):** - -* ``False`` (default): Observers remain active throughout training -* Integer value: Freezes observers after the specified epoch - -.. tip:: - - For best QAT results, set ``num_batch_norm_update_epochs`` to - approximately half of ``total_epochs``. This allows the model to - learn quantization-aware representations before freezing statistics. - -Model Surgery -------------- - -The ``tinyml-modeloptimization`` package includes model surgery utilities -that use ``torch.fx`` to replace unsupported modules with efficient -alternatives. This is useful for adapting existing models to meet NPU -constraints. - -**Basic Usage:** - -.. code-block:: python - - from tinyml_torchmodelopt.surgery import convert_to_lite_fx - - # Replace unsupported layers with default replacements - model = convert_to_lite_fx(model) - -**Custom Replacements:** - -You can define custom replacement rules: - -.. code-block:: python - - import copy - from tinyml_torchmodelopt.surgery import ( - convert_to_lite_fx, get_replacement_dict_default - ) - - # Get and modify the default replacement dictionary - replacement_dict = copy.deepcopy(get_replacement_dict_default()) - replacement_dict.update({torch.nn.GELU: torch.nn.ReLU}) - - # Apply with custom replacements - model = convert_to_lite_fx(model, replacement_dict=replacement_dict) - -The replacement value can also be a function for complex transformations: - -.. code-block:: python - - replacement_dict.update({'my_layer': my_replacement_function}) - model = convert_to_lite_fx(model, replacement_dict=replacement_dict) - -Model surgery is applied automatically during the quantization pipeline -when needed. Direct usage is only necessary for custom workflows. + :widths: 40 20 40 + + * - Factor + - Frequency + - Impact + * - ``torch.compile`` only applies to float model + - All batches + - High — float gets fused kernels, QAT runs eager + * - ``2N+1`` FakeQuantize forward + backward ops + - Per batch, per layer + - High — doubles+ the computation graph + * - Observer min/max tensor reductions + - Per batch, per layer + - Medium — full-tensor reduction per observer + * - ``ceil2_tensor`` ``.sum()`` GPU syncs + - Per batch, per layer + - Medium — forces ``2N+1`` pipeline stalls + * - Soft-round sigmoid/tanh pass (4-bit / 2-bit) + - Per batch, per layer + - High — triples per-node cost + * - ``model.train()`` triple module traversal + - Per epoch + - Low — amortised over batches + +**Key Source Files** + +* ``tinyml-modeloptimization/torchmodelopt/.../quantization/base/fx/quant_base.py`` + — ``TinyMLQuantFxBaseModule``: wraps the model, drives ``train()``/``freeze()`` + lifecycle, epoch counter, temperature schedule. +* ``tinyml-modeloptimization/torchmodelopt/.../quantization/base/fx/fake_quant_types.py`` + — ``SoftSigmoidFakeQuantize``, ``SoftTanhFakeQuantize``: the most + expensive per-batch ops. +* ``tinyml-modeloptimization/torchmodelopt/.../quantization/base/fx/functional_utils.py`` + — ``ceil2_tensor``, ``_propagate_quant_ste``: power-of-2 scale snapping + with the ``.sum()`` sync. +* ``tinyml-modeloptimization/torchmodelopt/.../quantization/base/fx/observer_types.py`` + — ``SimplePerChannelWeightObserver``, ``SimpleActivationObserver``: + per-batch statistics with ``power2_scale`` call. Next Steps ---------- diff --git a/tinyml-modelmaker/DEVICE_TASK_SUPPORT.md b/tinyml-modelmaker/DEVICE_TASK_SUPPORT.md index 6ca6f17..f26dcfd 100644 --- a/tinyml-modelmaker/DEVICE_TASK_SUPPORT.md +++ b/tinyml-modelmaker/DEVICE_TASK_SUPPORT.md @@ -154,14 +154,18 @@ These devices support **all** timeseries tasks (classification, regression, anom - Pattern recognition in sensor data **Available Models:** -- TimeSeries_Generic_13k_t (13K parameters) -- TimeSeries_Generic_6k_t (6K parameters) -- TimeSeries_Generic_4k_t (4K parameters) -- TimeSeries_Generic_1k_t (1K parameters) -- TimeSeries_Generic_100_t (100 parameters) -- TimeSeries_Generic_55k_t (55K parameters) -- Res_Add_TimeSeries_Generic_3k_t (Residual addition, 3K parameters) -- Res_Cat_TimeSeries_Generic_3k_t (Residual concatenation, 3K parameters) +- CLS_100_NPU (100 parameters) +- CLS_500_NPU (500 parameters) +- CLS_1k_NPU (1K parameters) +- CLS_2k_NPU (2K parameters) +- CLS_4k_NPU (4K parameters) +- CLS_6k_NPU (6K parameters) +- CLS_8k_NPU (8K parameters) +- CLS_13k_NPU (13K parameters) +- CLS_20k_NPU (20K parameters) +- CLS_55k_NPU (55K parameters) +- CLS_ResAdd_3k (Residual addition, 3K parameters) +- CLS_ResCat_3k (Residual concatenation, 3K parameters) **Key Features:** - Multiple model sizes for different memory constraints @@ -180,11 +184,17 @@ These devices support **all** timeseries tasks (classification, regression, anom - Sensor calibration **Available Models:** -- TimeSeries_Generic_Regr_13k_t (13K parameters, CNN-based) -- TimeSeries_Generic_Regr_10k_t (10K parameters) -- TimeSeries_Generic_Regr_4k_t (4K parameters, CNN-based) -- TimeSeries_Generic_Regr_3k_t (3K parameters, MLP-based) -- TimeSeries_Generic_Regr_1k_t (1K parameters) +- REGR_1k (1K parameters) +- REGR_2k (2K parameters) +- REGR_3k (3K parameters, MLP-based) +- REGR_4k (4K parameters, CNN-based) +- REGR_10k (10K parameters) +- REGR_13k (13K parameters, CNN-based) +- REGR_500_NPU (500 parameters, NPU) +- REGR_2k_NPU (2K parameters, NPU) +- REGR_6k_NPU (6K parameters, NPU) +- REGR_8k_NPU (8K parameters, NPU) +- REGR_20k_NPU (20K parameters, NPU) **Key Features:** - Multiple architectures (CNN, MLP) @@ -203,12 +213,18 @@ These devices support **all** timeseries tasks (classification, regression, anom - Security monitoring **Available Models:** -- TimeSeries_Generic_AD_17k_t (17K parameters) -- TimeSeries_Generic_AD_16k_t (16K parameters) -- TimeSeries_Generic_AD_4k_t (4K parameters) -- TimeSeries_Generic_AD_1k_t (1K parameters) -- TimeSeries_Generic_Linear_AD (Linear model) -- Ondevice_Trainable_TimeSeries_Generic_Linear_AD (On-device trainable) +- AD_1k (1K parameters) +- AD_4k (4K parameters) +- AD_16k (16K parameters) +- AD_17k (17K parameters) +- AD_Linear (Linear model) +- AD_500_NPU (500 parameters, NPU) +- AD_2k_NPU (2K parameters, NPU) +- AD_6k_NPU (6K parameters, NPU) +- AD_8k_NPU (8K parameters, NPU) +- AD_10k_NPU (10K parameters, NPU) +- AD_20k_NPU (20K parameters, NPU) +- Ondevice_Trainable_AD_Linear (On-device trainable) **Key Features:** - Unsupervised and semi-supervised approaches @@ -227,10 +243,18 @@ These devices support **all** timeseries tasks (classification, regression, anom - Trend prediction **Available Models:** -- TimeSeries_Generic_Forecasting_13k_t (13K parameters, CNN-based) -- TimeSeries_Generic_Forecasting_3k_t (3K parameters, MLP-based) -- TimeSeries_Generic_Forecasting_LSTM10 (LSTM with hidden size 10) -- TimeSeries_Generic_Forecasting_LSTM8 (LSTM with hidden size 8) +- FCST_3k (3K parameters, MLP-based) +- FCST_13k (13K parameters, CNN-based) +- FCST_LSTM8 (LSTM with hidden size 8) +- FCST_LSTM10 (LSTM with hidden size 10) +- FCST_500_NPU (500 parameters, NPU) +- FCST_1k_NPU (1K parameters, NPU) +- FCST_2k_NPU (2K parameters, NPU) +- FCST_4k_NPU (4K parameters, NPU) +- FCST_6k_NPU (6K parameters, NPU) +- FCST_8k_NPU (8K parameters, NPU) +- FCST_10k_NPU (10K parameters, NPU) +- FCST_20k_NPU (20K parameters, NPU) **Key Features:** - Multiple forecasting horizons