Production-ready Vision Transformer (ViT-L/16) inference optimization using 4-bit NF4 quantization and FlashAttention-2. Achieves 40.5% latency reduction (58.9ms → 35.0ms) with <2% accuracy loss on 102-class flower classification.
| Metric | Baseline (FP32) | Optimized (4-bit + FA2) | Improvement |
|---|---|---|---|
| Latency | 58.9 ms | 35.0 ms | 40.5% faster |
| Model Size | 1157 MB | ~290 MB | 4x smaller |
| Top-1 Accuracy | 90.6% | 88.6% | -2.0% |
| Top-5 Accuracy | 96.7% | 96.7% | 0% |
| Throughput | ~17 img/s | ~29 img/s | 1.7x higher |
Vision Transformers deliver state-of-the-art accuracy but face critical deployment challenges:
- High Memory Footprint: ViT-L/16 requires >1GB memory, prohibitive for edge devices
- Slow Inference: FP32 models achieve only ~17 images/second on consumer GPUs
- Attention Bottleneck: Standard SDPA attention accounts for 30-40% of inference time
- Production Gap: Research models rarely optimize for real-world latency constraints
This repository bridges that gap with systematic quantization and attention optimization.
We implement a comprehensive optimization pipeline:
- 4-bit NF4 Quantization: Reduces model size 4x while maintaining accuracy
- FlashAttention-2 Integration: Accelerates attention computation by 20-30%
- Layer-wise Sensitivity Analysis: Identifies critical layers for selective precision
- Parameter-Efficient Fine-tuning: LoRA/QLoRA for memory-efficient training
Five configurations tested across quantization and attention mechanisms:
| Variant | Quantization | Attention | Top-1 Acc | Latency | Model Size |
|---|---|---|---|---|---|
| Baseline | FP32/FP16 | SDPA | 90.6% | 58.9 ms | 1157 MB |
| 4-bit + FA2 | 4-bit NF4 | FlashAttention-2 | 88.6% | 35.0 ms | ~290 MB |
| 8-bit + FA2 | 8-bit | FlashAttention-2 | 90.1% | Variable | ~580 MB |
| 4-bit SDPA | 4-bit NF4 | SDPA | 88.6% | 42.1 ms | ~290 MB |
| 8-bit SDPA | 8-bit | SDPA | 90.1% | 60.8 ms | ~580 MB |
All measurements on NVIDIA A10G, 100 iterations, warm-up excluded:
| Model | Mean | Median (p50) | p95 | p99 |
|---|---|---|---|---|
| Baseline FP32 | 58.9 ms | 58.7 ms | 59.4 ms | 59.8 ms |
| 4-bit + FA2 | 35.0 ms | 34.8 ms | 35.6 ms | 36.1 ms |
| 4-bit SDPA | 42.1 ms | 41.9 ms | 42.8 ms | 43.2 ms |
| 8-bit SDPA | 60.8 ms | 60.5 ms | 61.3 ms | 61.7 ms |
| Configuration | Model Size (Disk) | Peak VRAM (Inference) | Reduction |
|---|---|---|---|
| FP32 Baseline | 1157 MB | ~1.2 GB | - |
| 8-bit Quantized | ~580 MB | ~650 MB | 50% |
| 4-bit Quantized | ~290 MB | ~380 MB | 75% |
Accuracy (%)
│
90.6%├─── Baseline ■
│
90.1%├─── 8-bit + FA2 ●
│
88.6%├───────────────── 4-bit SDPA ▲
│ 4-bit + FA2 ▼
│
└─────────────────────────────────→ Latency (ms)
35 42 59
Ablation study results for 4-bit quantization (Top-1 accuracy when preserving specific layer in FP16):
| Layer Preserved | Accuracy | Delta from Baseline |
|---|---|---|
| None (Full 4-bit) | 88.63% | - |
| Layer 21 | 89.61% | +0.98% |
| Layer 22 | 89.00% | +0.37% |
| Layer 19 | 88.98% | +0.35% |
| Layer 0-1 | 88.88% | +0.25% |
| First/Last | 88.51% | -0.12% |
Key Insight: Late encoder layers (21-23) are most sensitive to quantization. Preserving layer 21 in FP16 recovers ~1% accuracy.
| Component | Technology | Version |
|---|---|---|
| Framework | PyTorch | 2.0+ |
| Model | ViT-L/16 (HuggingFace) | transformers 4.30+ |
| Quantization | bitsandbytes (NF4) | latest |
| Attention | FlashAttention-2 | flash-attn 2.0+ |
| Fine-tuning | PEFT (LoRA) | latest |
| Compute | CUDA 11.8+ | Ampere/Ada |
- GPU: NVIDIA GPU with SM80+ (A10, A100, RTX 30/40 series)
- CUDA: 11.8 or higher
- RAM: 32 GB recommended
- Storage: 10 GB for models and dataset
# Clone repository
git clone https://github.com/puneethkotha/vit-inference-optimization.git
cd vit-inference-optimization
# Install core dependencies
pip install torch>=2.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install transformers>=4.30.0 timm bitsandbytes accelerate peft datasets safetensors
# Install FlashAttention-2 (requires compilation)
pip install flash-attn --no-build-isolation
# Verification
pip install matplotlib seaborn scikit-learn pandas evaluateimport torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
# Load optimized 4-bit + FlashAttention-2 model
model = AutoModelForImageClassification.from_pretrained(
"path/to/quant_4bit",
device_map="auto",
trust_remote_code=True,
attn_implementation="flash_attention_2"
)
processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224")
# Run inference
image = Image.open("flower.jpg")
inputs = processor(images=image, return_tensors="pt").to("cuda")
with torch.inference_mode():
outputs = model(**inputs)
predictions = outputs.logits.softmax(dim=-1)
print(f"Top prediction: {predictions.argmax().item()}")End-to-end workflow across notebooks:
# 1. Fine-tune base model with LoRA
jupyter notebook Model_Finetuning+Prep.ipynb
# 2. Evaluate all variants
jupyter notebook Evaluation.ipynb
# 3. Run production inference
jupyter notebook Inferencing.ipynb4-bit NF4 configuration:
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)8-bit configuration:
bnb_config = BitsAndBytesConfig(load_in_8bit=True)from peft import LoraConfig
lora_config = LoraConfig(
r=8, # Rank
lora_alpha=32, # Scaling factor
target_modules=["query", "key", "value", "dense"],
lora_dropout=0.1,
bias="none",
task_type="IMAGE_CLASSIFICATION"
)model = AutoModelForImageClassification.from_pretrained(
model_id,
attn_implementation="flash_attention_2", # Enable FA2
device_map="auto"
)- Dataset: 102-class flower classification (Oxford Flowers)
- Training: 6 epochs, LoRA with r=8, α=32
- Precision: Mixed FP16/FP32
- Hardware: NVIDIA A100 (40GB)
Apply post-training quantization:
- 4-bit: NF4 with double quantization
- 8-bit: Standard int8 quantization
- Preservation: Optional FP16 for sensitive layers
Compare attention implementations:
- SDPA: PyTorch scaled dot-product attention
- FlashAttention-2: Memory-efficient attention with tiling
- Warm-up: 10 iterations discarded
- Measurement: 100 iterations averaged
- Metrics: Latency (mean/p50/p95/p99), throughput, accuracy
- Hardware: Consistent A10G instance
Systematic layer-wise sensitivity analysis:
- Test all 24 encoder layers individually
- Measure accuracy impact of FP16 preservation
- Identify optimization opportunities
vit-inference-optimization/
├── Model_Finetuning+Prep.ipynb # LoRA fine-tuning and quantization
├── Evaluation.ipynb # Comprehensive benchmarking
├── Inferencing.ipynb # Production inference demo
├── Midpoint Check/ # Initial exploration
│ └── Midpoint Project Checkpoint.ipynb
├── Finetuned_Models/ # Model checkpoints
│ ├── baseline_model.zip # FP16 baseline
│ ├── quant_4bit.zip # 4-bit quantized
│ └── quant_8bit.zip # 8-bit quantized
├── Images/ # Benchmark visualizations
│ ├── Comparison_1.png
│ ├── Comparison_2.png
│ ├── Sensitivity_bar.png
│ └── Effect_of_keeping_first:last_layers.png
├── ablation_results.txt # Layer sensitivity data
├── Presentation.pdf # Project overview
├── Project Report.pdf # Detailed methodology
└── README.md # This file
Symptom: Compilation errors during pip install flash-attn
Solution:
# Ensure correct CUDA toolkit version
nvcc --version
# Install with specific CUDA version
pip install flash-attn --no-build-isolation --no-cache-dir
# Fallback: Use SDPA instead
model = AutoModelForImageClassification.from_pretrained(
model_id,
attn_implementation="sdpa" # No FA2 required
)Symptom: CUDA OOM error
Solution:
- Use 4-bit quantization instead of 8-bit
- Reduce batch size to 1
- Enable
torch.cuda.empty_cache()between batches - Use smaller input resolution (224x224 instead of 384x384)
Symptom: >3% accuracy degradation
Solution:
- Preserve sensitive layers in FP16:
# Keep layer 21 in FP16
quantization_config.llm_int8_skip_modules = ["encoder.layer.21"]- Use 8-bit instead of 4-bit
- Fine-tune quantized model with QLoRA
Symptom: RuntimeError: Error loading checkpoint
Solution:
# Verify checkpoint integrity
unzip -t quant_4bit.zip
# Clear cache and reload
rm -rf ~/.cache/huggingface/
python -c "from transformers import AutoModel; AutoModel.from_pretrained(...)"Symptom: FlashAttention-2 not available on GPU
Check:
import torch
print(torch.cuda.get_device_capability()) # Must be (8, 0) or higherWorkaround: Use SDPA fallback for older GPUs (SM75/Turing).
Contributions welcome! Areas of interest:
- Model Support: Extend to ViT-B/32, DeiT, Swin Transformer
- Quantization: Test alternative methods (GPTQ, AWQ, SmoothQuant)
- Optimization: Integrate TensorRT, ONNX Runtime
- Benchmarks: Add different hardware platforms (V100, T4, RTX 3090)
git clone https://github.com/puneethkotha/vit-inference-optimization.git
cd vit-inference-optimization
pip install -e .
pre-commit installpytest tests/
python -m torch.utils.benchmark <script>If you use this work, please cite:
@misc{kotha2024vit-inference-optimization,
author = {Kotha, Puneeth and Pamnani, Jayraj},
title = {ViT Inference Optimization with Quantization and FlashAttention-2},
year = {2024},
publisher = {GitHub},
url = {https://github.com/puneethkotha/vit-inference-optimization}
}- Vision Transformer: Dosovitskiy et al., "An Image is Worth 16x16 Words", ICLR 2021
- QLoRA: Dettmers et al., "Efficient Finetuning of Quantized LLMs", NeurIPS 2023
- FlashAttention-2: Dao et al., "Faster Attention with Better Parallelism", ICLR 2024
- bitsandbytes: Dettmers et al., "8-bit Optimizers via Block-wise Quantization", ICLR 2022
- PyTorch - Deep learning framework
- HuggingFace Transformers - Pre-trained models
- bitsandbytes - Quantization library
- PEFT - Parameter-efficient fine-tuning
- FlashAttention - Optimized attention kernels
- timm - Image models
This project is released under the MIT License. See LICENSE for details.
- Project developed as part of HPML coursework
- Flower dataset from Oxford Visual Geometry Group
- Compute resources provided by NYU HPC
Maintainers:
- Puneeth Kotha - GitHub
- Jayraj Pamnani
For questions or collaboration: Open an issue