Implementation of GradCAM: Visualize the weight's gradient-activation with respect to the prediction the model makes. Currently supports CNN, ViT, and Swin Transformer from timm. The visualization includes: heatmap, RGB channel injected heatmap, overlapped image, and an overview image.
model_type |
Description |
|---|---|
Normal |
Standard CNN-based models (ResNet, etc.) |
ViT |
Vision Transformer models |
DeiT |
vision transformer structured model with one extra distillation token embedding row |
SwinT |
Swin Transformer models |
pip install torch torchvision timm matplotlib pillow numpyGradCAM(
model, # torch.nn.Module — the model to visualize
layer_name, # str — full name of the name intended for GradCAM
img_path=None, # str — path to the input image
img_value=None, # tensor — pre-loaded image tensor (alternative to img_path)
input_shape=(224, 224), # tuple — image resize shape (used when no transform provided)
model_type='Normal', # str — 'Normal', 'ViT', or 'SwinT'
transform=None, # transforms.Compose — custom preprocessing (recommended for ViT/SwinT)
verbose=False, # bool — print debug info about shapes and predictions
)Runs the forward and backward pass to compute the Grad-CAM heatmap.
heatmap_threshold: controls how much of the heatmap is shown. Must be > 1. Higher values suppress more low-activation regions (fewer highlights).
Displays a 2×2 figure: original image, overlapped colormap, raw heatmap, and projected colormap.
save_path: if provided, saves the figure and individual images (original, overlapped, heatmap, colormap) as PNGs.denormalize: tuple of(mean, std)lists to undo normalization before save the visualization images (e.g. for ImageNet-normalized inputs).
from grad_cam_code.grad_cam import *
model = create_model('timm/resnet18.a1_in1k', pretrained=True)
model.eval()
print_layername(model)Note: ResNet ends with an
AdaptiveAvgPoolbefore the classifier — uselayer4.1.conv2instead of last layer of the backbone, or it will raise a dimension error.
Note: It is important to implement
model.eval()to make GradCAM success. This action will make the prediction stable and disable some training actions (e.g.: BatchNorm, Dropout).
from grad_cam_code.grad_cam import *
model = create_model('timm/resnet18.a1_in1k', pretrained=True)
model.eval()
img_path = 'graphs/test_images/test2-pug-dog.png'
cam_vit = GradCAM(model,img_path, layer_name='layer4.1.conv2', model_type='Normal')
cam_vit(heatmap_threshold=20)
cam.imposing_visualization()Just Note ViT ends by taking only the [cls] patch of the backbone (encoder) into classfication header.
from grad_cam_code.grad_cam import *
from timm.data.transforms_factory import create_transform
from timm.data import resolve_data_config
model = create_model('vit_base_patch16_224', pretrained=True)
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
model.eval()
img_path = 'graphs/test_images/test2-pug-dog.png'
cam_vit = GradCAM(model,img_path, layer_name='blocks.10.drop_path2', model_type='ViT', transform = transform, verbose=True)
cam_vit(heatmap_threshold=5)
cam.imposing_visualization()from grad_cam_code.grad_cam import *
from timm.data.transforms_factory import create_transform
from timm.data import resolve_data_config
model = create_model('timm/deit_small_distilled_patch16_224.fb_in1k', pretrained=True)
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
model.eval()
img_path = 'graphs/test_images/test2-pug-dog.png'
cam = GradCAM(model,img_path,layer_name='blocks.10.drop_path2', model_type='DeiT', transform = transform,verbose=True)
cam(heatmap_threshold=8)
# Specify denormalize to undo ImageNet normalization when saving
cam.imposing_visualization(
save_path="img/swt_test",
denormalize=([0.4850, 0.4560, 0.4060], [0.2290, 0.2240, 0.2250])
)from grad_cam_code.grad_cam import *
from timm.data.transforms_factory import create_transform
from timm.data import resolve_data_config
model = create_model('swin_base_patch4_window7_224', pretrained=True)
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
model.eval()
img_path = 'graphs/test_images/test2-pug-dog.png'
cam = GradCAM(model,img_path,layer_name='layers.3.blocks.1.drop_path2', model_type='SwinT', transform = transform,verbose=True)
cam(heatmap_threshold=40)
# Specify denormalize to undo ImageNet normalization when saving
cam.imposing_visualization(
save_path="img/swt_test",
denormalize=([0.4850, 0.4560, 0.4060], [0.2290, 0.2240, 0.2250])
)Grad-CAM computes a class-discriminative localization map by:
- Running a forward pass through the model up to the target layer to get activations
$A^k$ . - Computing the gradient of the predicted class score
$y^c$ with respect to those activations:$\frac{\partial y^c}{\partial A^k}$ . -
Global average pooling the gradients over the spatial dimensions to get per-channel weights
$\alpha_k^c$ . - Computing the weighted combination:
$L^c_{\text{Grad-CAM}} = \text{ReLU}\left(\sum_k \alpha_k^c A^k\right)$ . - Upsampling the resulting heatmap back to the original image size via bilinear interpolation and overlaying it using a jet colormap.
The model_type parameter controls how activations are reshaped:
Normal(CNN-based): expects(B, C, H, W)feature maps (standard CNN conv outputs).ViT/SwinT/DeiT(transformer-based): reshapes the sequence dimension(B, HW, C)into a spatial grid(B, H, W, C).
Theoretical support:
@article{Selvaraju_2019,
title={Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization},
volume={128},
ISSN={1573-1405},
url={http://dx.doi.org/10.1007/s11263-019-01228-7},
DOI={10.1007/s11263-019-01228-7},
number={2},
journal={International Journal of Computer Vision},
publisher={Springer Science and Business Media LLC},
author={Selvaraju, Ramprasaath R. and Cogswell, Michael and Das, Abhishek and Vedantam, Ramakrishna and Parikh, Devi and Batra, Dhruv},
year={2019},
month=oct, pages={336–359} }
Part of the code implementation happened during my research on this paper:
@article{CHEN2024100332,
title = {A vision transformer machine learning model for COVID-19 diagnosis using chest X-ray images},
journal = {Healthcare Analytics},
volume = {5},
pages = {100332},
year = {2024},
issn = {2772-4425},
doi = {https://doi.org/10.1016/j.health.2024.100332},
url = {https://www.sciencedirect.com/science/article/pii/S2772442524000340},
author = {Tianyi Chen and Ian Philippi and Quoc Bao Phan and Linh Nguyen and Ngoc Thang Bui and Carlo daCunha and Tuy Tan Nguyen},
keywords = {Computer-aided diagnosis, Machine learning, Vision transformer, Efficient neural networks, COVID-19, Chest X-ray},
}
