Shapley values are a theoretically grounded model explanation approach, but their exponential computational cost makes them difficult to use with large deep learning models. ViT-Shapley makes Shapley values practical for vision transformer (ViT) models by learning an amortized explainer model that generates explanations in a single forward pass.
Please see our ICLR spotlight paper for more details, as well as the work that ViT-Shapley builds on (KernelSHAP, FastSHAP).
git clone https://github.com/chanwkimlab/vit-shapley.git && cd vit-shapley
conda create -n vit-shapley python=3.10 -y
conda activate vit-shapley
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
pip install -e . # core deps: timm, pydantic, pyyaml, tqdm
pip install -e ".[dev]" # optional: pytest, ruff, jupyterlab
pip install captum # optional: gradient baselinesStage 1: Classifier Obtain your initial ViT image classifier to explain
|
Stage 2: Surrogate If your model was not trained to acommodate held-out image patches, fine-tune it with random masking
|
Stage 3: Explainer Train an explainer that produces Shapley values in one forward pass
Each stage produces a checkpoint that feeds into the next. The final explainer generates per-patch importance scores (196 values for a 14x14 patch grid) without the exponential cost of exact Shapley computation.
# 1. Classifier
python scripts/train_classifier.py --config configs/classifier.yaml
# 2. Surrogates (one per masking strategy)
python scripts/train_surrogate.py --config configs/surrogate_attn.yaml
python scripts/train_surrogate.py --config configs/surrogate_zero.yaml
# 3. Explainer
python scripts/train_explainer.py --config configs/explainer.yaml
# Evaluate / visualise
python scripts/plot_surrogate_kl.py --config configs/plot_surrogate_kl.yaml
python scripts/visualize_explainer.py --config configs/visualize_explainer.yamlSwitch dataset with --set dataset=pet on any command. All checkpoint paths and output filenames update automatically via ${dataset} self-referencing in the YAML. See Config System for details.
For binary classification datasets like MURA, add --set target_type=binary to all pipeline commands. This switches the loss function (BCEWithLogitsLoss), KL divergence computation (sigmoid-based), and link function (sigmoid instead of softmax).
| Name | Classes | Source |
|---|---|---|
imagenette (default) |
10 | fast.ai subset of ImageNet |
mura |
2 | Stanford MURA X-rays |
pet |
37 | Oxford-IIIT Pet |
MURA (Musculoskeletal Radiographs) is a binary classification dataset of bone X-ray images (negative/positive) across 7 body parts. It requires a Stanford AIMI license — download from Stanford AIMI and place at <data_root>/MURA-v1.1/. Use with --set dataset=mura.
python -m pytest tests/ -v- Config — YAML variables,
--setoverrides,.env, multi-GPU - Evaluation — KL divergence plots and Shapley heatmaps