Skip to content

puneethkotha/vit-inference-optimization

Repository files navigation

ViT Inference Optimization

PyTorch CUDA License Python

Production-ready Vision Transformer (ViT-L/16) inference optimization using 4-bit NF4 quantization and FlashAttention-2. Achieves 40.5% latency reduction (58.9ms → 35.0ms) with <2% accuracy loss on 102-class flower classification.

Key Results

Metric Baseline (FP32) Optimized (4-bit + FA2) Improvement
Latency 58.9 ms 35.0 ms 40.5% faster
Model Size 1157 MB ~290 MB 4x smaller
Top-1 Accuracy 90.6% 88.6% -2.0%
Top-5 Accuracy 96.7% 96.7% 0%
Throughput ~17 img/s ~29 img/s 1.7x higher

Problem Statement

Vision Transformers deliver state-of-the-art accuracy but face critical deployment challenges:

  • High Memory Footprint: ViT-L/16 requires >1GB memory, prohibitive for edge devices
  • Slow Inference: FP32 models achieve only ~17 images/second on consumer GPUs
  • Attention Bottleneck: Standard SDPA attention accounts for 30-40% of inference time
  • Production Gap: Research models rarely optimize for real-world latency constraints

This repository bridges that gap with systematic quantization and attention optimization.

Solution

We implement a comprehensive optimization pipeline:

  1. 4-bit NF4 Quantization: Reduces model size 4x while maintaining accuracy
  2. FlashAttention-2 Integration: Accelerates attention computation by 20-30%
  3. Layer-wise Sensitivity Analysis: Identifies critical layers for selective precision
  4. Parameter-Efficient Fine-tuning: LoRA/QLoRA for memory-efficient training

Model Variants

Five configurations tested across quantization and attention mechanisms:

Variant Quantization Attention Top-1 Acc Latency Model Size
Baseline FP32/FP16 SDPA 90.6% 58.9 ms 1157 MB
4-bit + FA2 4-bit NF4 FlashAttention-2 88.6% 35.0 ms ~290 MB
8-bit + FA2 8-bit FlashAttention-2 90.1% Variable ~580 MB
4-bit SDPA 4-bit NF4 SDPA 88.6% 42.1 ms ~290 MB
8-bit SDPA 8-bit SDPA 90.1% 60.8 ms ~580 MB

Benchmarks

Latency Distribution

All measurements on NVIDIA A10G, 100 iterations, warm-up excluded:

Model Mean Median (p50) p95 p99
Baseline FP32 58.9 ms 58.7 ms 59.4 ms 59.8 ms
4-bit + FA2 35.0 ms 34.8 ms 35.6 ms 36.1 ms
4-bit SDPA 42.1 ms 41.9 ms 42.8 ms 43.2 ms
8-bit SDPA 60.8 ms 60.5 ms 61.3 ms 61.7 ms

Memory Footprint

Configuration Model Size (Disk) Peak VRAM (Inference) Reduction
FP32 Baseline 1157 MB ~1.2 GB -
8-bit Quantized ~580 MB ~650 MB 50%
4-bit Quantized ~290 MB ~380 MB 75%

Accuracy vs Speed Trade-off

Accuracy (%)
│
90.6%├─── Baseline ■
     │
90.1%├─── 8-bit + FA2 ●
     │         
88.6%├───────────────── 4-bit SDPA ▲
     │                       4-bit + FA2 ▼
     │
     └─────────────────────────────────→ Latency (ms)
        35         42        59

Layer-wise Quantization Sensitivity

Ablation study results for 4-bit quantization (Top-1 accuracy when preserving specific layer in FP16):

Layer Preserved Accuracy Delta from Baseline
None (Full 4-bit) 88.63% -
Layer 21 89.61% +0.98%
Layer 22 89.00% +0.37%
Layer 19 88.98% +0.35%
Layer 0-1 88.88% +0.25%
First/Last 88.51% -0.12%

Key Insight: Late encoder layers (21-23) are most sensitive to quantization. Preserving layer 21 in FP16 recovers ~1% accuracy.

Technical Stack

Component Technology Version
Framework PyTorch 2.0+
Model ViT-L/16 (HuggingFace) transformers 4.30+
Quantization bitsandbytes (NF4) latest
Attention FlashAttention-2 flash-attn 2.0+
Fine-tuning PEFT (LoRA) latest
Compute CUDA 11.8+ Ampere/Ada

Quickstart

Prerequisites

  • GPU: NVIDIA GPU with SM80+ (A10, A100, RTX 30/40 series)
  • CUDA: 11.8 or higher
  • RAM: 32 GB recommended
  • Storage: 10 GB for models and dataset

Installation

# Clone repository
git clone https://github.com/puneethkotha/vit-inference-optimization.git
cd vit-inference-optimization

# Install core dependencies
pip install torch>=2.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install transformers>=4.30.0 timm bitsandbytes accelerate peft datasets safetensors

# Install FlashAttention-2 (requires compilation)
pip install flash-attn --no-build-isolation

# Verification
pip install matplotlib seaborn scikit-learn pandas evaluate

Quick Inference

import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image

# Load optimized 4-bit + FlashAttention-2 model
model = AutoModelForImageClassification.from_pretrained(
    "path/to/quant_4bit",
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2"
)
processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224")

# Run inference
image = Image.open("flower.jpg")
inputs = processor(images=image, return_tensors="pt").to("cuda")

with torch.inference_mode():
    outputs = model(**inputs)
    predictions = outputs.logits.softmax(dim=-1)
    
print(f"Top prediction: {predictions.argmax().item()}")

Training Pipeline

End-to-end workflow across notebooks:

# 1. Fine-tune base model with LoRA
jupyter notebook Model_Finetuning+Prep.ipynb

# 2. Evaluate all variants
jupyter notebook Evaluation.ipynb

# 3. Run production inference
jupyter notebook Inferencing.ipynb

Configuration

Quantization Settings

4-bit NF4 configuration:

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

8-bit configuration:

bnb_config = BitsAndBytesConfig(load_in_8bit=True)

LoRA Hyperparameters

from peft import LoraConfig

lora_config = LoraConfig(
    r=8,                        # Rank
    lora_alpha=32,              # Scaling factor
    target_modules=["query", "key", "value", "dense"],
    lora_dropout=0.1,
    bias="none",
    task_type="IMAGE_CLASSIFICATION"
)

FlashAttention-2 Activation

model = AutoModelForImageClassification.from_pretrained(
    model_id,
    attn_implementation="flash_attention_2",  # Enable FA2
    device_map="auto"
)

Methodology

1. Baseline Fine-tuning

  • Dataset: 102-class flower classification (Oxford Flowers)
  • Training: 6 epochs, LoRA with r=8, α=32
  • Precision: Mixed FP16/FP32
  • Hardware: NVIDIA A100 (40GB)

2. Quantization

Apply post-training quantization:

  • 4-bit: NF4 with double quantization
  • 8-bit: Standard int8 quantization
  • Preservation: Optional FP16 for sensitive layers

3. Attention Optimization

Compare attention implementations:

  • SDPA: PyTorch scaled dot-product attention
  • FlashAttention-2: Memory-efficient attention with tiling

4. Benchmarking Protocol

  • Warm-up: 10 iterations discarded
  • Measurement: 100 iterations averaged
  • Metrics: Latency (mean/p50/p95/p99), throughput, accuracy
  • Hardware: Consistent A10G instance

5. Ablation Studies

Systematic layer-wise sensitivity analysis:

  • Test all 24 encoder layers individually
  • Measure accuracy impact of FP16 preservation
  • Identify optimization opportunities

Repository Structure

vit-inference-optimization/
├── Model_Finetuning+Prep.ipynb      # LoRA fine-tuning and quantization
├── Evaluation.ipynb                  # Comprehensive benchmarking
├── Inferencing.ipynb                 # Production inference demo
├── Midpoint Check/                   # Initial exploration
│   └── Midpoint Project Checkpoint.ipynb
├── Finetuned_Models/                 # Model checkpoints
│   ├── baseline_model.zip            # FP16 baseline
│   ├── quant_4bit.zip                # 4-bit quantized
│   └── quant_8bit.zip                # 8-bit quantized
├── Images/                           # Benchmark visualizations
│   ├── Comparison_1.png
│   ├── Comparison_2.png
│   ├── Sensitivity_bar.png
│   └── Effect_of_keeping_first:last_layers.png
├── ablation_results.txt              # Layer sensitivity data
├── Presentation.pdf                  # Project overview
├── Project Report.pdf                # Detailed methodology
└── README.md                         # This file

Troubleshooting

FlashAttention-2 Installation Fails

Symptom: Compilation errors during pip install flash-attn

Solution:

# Ensure correct CUDA toolkit version
nvcc --version

# Install with specific CUDA version
pip install flash-attn --no-build-isolation --no-cache-dir

# Fallback: Use SDPA instead
model = AutoModelForImageClassification.from_pretrained(
    model_id,
    attn_implementation="sdpa"  # No FA2 required
)

Out of Memory During Inference

Symptom: CUDA OOM error

Solution:

  • Use 4-bit quantization instead of 8-bit
  • Reduce batch size to 1
  • Enable torch.cuda.empty_cache() between batches
  • Use smaller input resolution (224x224 instead of 384x384)

Accuracy Drop Too Large

Symptom: >3% accuracy degradation

Solution:

  1. Preserve sensitive layers in FP16:
# Keep layer 21 in FP16
quantization_config.llm_int8_skip_modules = ["encoder.layer.21"]
  1. Use 8-bit instead of 4-bit
  2. Fine-tune quantized model with QLoRA

Model Loading Errors

Symptom: RuntimeError: Error loading checkpoint

Solution:

# Verify checkpoint integrity
unzip -t quant_4bit.zip

# Clear cache and reload
rm -rf ~/.cache/huggingface/
python -c "from transformers import AutoModel; AutoModel.from_pretrained(...)"

GPU Compatibility Issues

Symptom: FlashAttention-2 not available on GPU

Check:

import torch
print(torch.cuda.get_device_capability())  # Must be (8, 0) or higher

Workaround: Use SDPA fallback for older GPUs (SM75/Turing).

Contributing

Contributions welcome! Areas of interest:

  • Model Support: Extend to ViT-B/32, DeiT, Swin Transformer
  • Quantization: Test alternative methods (GPTQ, AWQ, SmoothQuant)
  • Optimization: Integrate TensorRT, ONNX Runtime
  • Benchmarks: Add different hardware platforms (V100, T4, RTX 3090)

Development Setup

git clone https://github.com/puneethkotha/vit-inference-optimization.git
cd vit-inference-optimization
pip install -e .
pre-commit install

Running Tests

pytest tests/
python -m torch.utils.benchmark <script>

Citation

If you use this work, please cite:

@misc{kotha2024vit-inference-optimization,
  author = {Kotha, Puneeth and Pamnani, Jayraj},
  title = {ViT Inference Optimization with Quantization and FlashAttention-2},
  year = {2024},
  publisher = {GitHub},
  url = {https://github.com/puneethkotha/vit-inference-optimization}
}

References

Papers

  1. Vision Transformer: Dosovitskiy et al., "An Image is Worth 16x16 Words", ICLR 2021
  2. QLoRA: Dettmers et al., "Efficient Finetuning of Quantized LLMs", NeurIPS 2023
  3. FlashAttention-2: Dao et al., "Faster Attention with Better Parallelism", ICLR 2024
  4. bitsandbytes: Dettmers et al., "8-bit Optimizers via Block-wise Quantization", ICLR 2022

Libraries

License

This project is released under the MIT License. See LICENSE for details.

Acknowledgments

Contact

Maintainers:

  • Puneeth Kotha - GitHub
  • Jayraj Pamnani

For questions or collaboration: Open an issue

About

ViT-L/16 inference optimization - 4-bit NF4 quantization, FlashAttention-2 vs SDPA benchmarking, 40.5% latency reduction

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors