Skip to content

lucidrains/gradnorm-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

45 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GradNorm - Pytorch

A practical implementation of GradNorm, Gradient Normalization for Adaptive Loss Balancing, in Pytorch

Increasingly starting to come across neural network architectures that require more than 3 auxiliary losses, so will build out an installable package that easily handles loss balancing in distributed setting, gradient accumulation, etc. Also open to incorporating any follow up research; just let me know in the issues.

Will be dog-fooded for SoundStream, MagViT2 as well as MetNet3

Appreciation

Install

$ pip install gradnorm-pytorch

Usage

import torch
from torch.optim import Adam

from gradnorm_pytorch import (
    GradNormLossWeighter,
    MockNetworkWithMultipleLosses
)

# a mock network with multiple discriminator losses

network = MockNetworkWithMultipleLosses(
    dim = 512,
    num_losses = 4
)

optim = Adam(network.parameters(), lr = 3e-4)

# backbone shared parameter

backbone_parameter = network.backbone[-1].weight

# grad norm based loss weighter

loss_weighter = GradNormLossWeighter(
    num_losses = 4,
    learning_rate = 1e-4,
    restoring_force_alpha = 0.,                  # 0. is perfectly balanced losses, while anything greater than 1 would account for the relative training rates of each loss. in the paper, they go as high as 3.
    grad_norm_parameters = backbone_parameter
)

# mock input

mock_input = torch.randn(2, 512)
losses = network(mock_input)

# backwards with the loss weights
# will update on each backward based on gradnorm algorithm

loss_weighter.backward(losses)

# the usual

optim.step()
optim.zero_grad()

You can also do it with respect to the gradients flowing through an intermediate activation, say a generated modality

# same as above ...

loss_weighter = GradNormLossWeighter(
    num_losses = 4,
    learning_rate = 1e-4,
    restoring_force_alpha = 0.,
    grad_norm_parameters = None # this is now None and the activations need to be returned on network forward and passed in on backwards
)

# mock input

mock_input = torch.randn(2, 512)
losses, backbone_output_activations = network(mock_input, return_backbone_outputs = True)

# backwards with the loss weights and backbone activations from which gradients backpropagate through from all losses

loss_weighter.backward(losses, backbone_output_activations)

# optimizer

optim.step()
optim.zero_grad()

You can also switch it to basic static loss weighting, in case you want to run experiments against fixed weighting.

loss_weighter = GradNormLossWeighter(
    loss_weights = [1., 10., 5., 2.],
    ...,
    frozen = True
)

# or you can also freeze it on invoking the instance

loss_weighter.backward(..., freeze = True)

To control which loss is subjected to GradNorm, pass in a list[bool] with the loss_mask kwarg

loss_weighter = GradNormLossWeighter(
    loss_mask = [True, True, False, True], # 1st, 2nd, and 4th losses are grad normed
    ...,
)

# you can also override on .backward

loss_weighter.backward(..., loss_mask = [True, True, False, False])

For use with 🤗 Huggingface Accelerate, just pass in the Accelerator instance into the keyword accelerator on initialization

ex.

accelerator = Accelerator()

network = accelerator.prepare(network)

loss_weighter = GradNormLossWeighter(
    ...,
    accelerator = accelerator
)

# backwards will now use accelerator

Todo

  • take care of gradient accumulation
  • handle freezing of some loss weights, but not others
  • handle sets of loss weights
  • allow for a prior weighting, accounted for when calculating gradient targets

Citations

@article{Chen2017GradNormGN,
    title   = {GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks},
    author  = {Zhao Chen and Vijay Badrinarayanan and Chen-Yu Lee and Andrew Rabinovich},
    journal = {ArXiv},
    year    = {2017},
    volume  = {abs/1711.02257},
    url     = {https://api.semanticscholar.org/CorpusID:4703661}
}

About

A practical implementation of GradNorm, Gradient Normalization for Adaptive Loss Balancing, in Pytorch

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages