This Python repo hosts the code of the blog post "How to compute Hessian-vector products?" published in the blog post track of ICLR 2024. This blog post received a highlight.
Start by cloning the repo:
$ git clone https://github.com/MatDag/bench_hvp.git
$ cd bench_hvpWe advise working in two separate environments: one environment for the experiments in Jax and another for the PyTorch experiments.
To set up the Jax environment, run:
$ conda create -n bench_hvp_jax python=3.11
$ conda activate bench_hvp_jax
$ pip install -e .
$ pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlIn the blogpost, the experiments were done with version 0.4.21 of Jax.
To set up the Jax environment, run:
$ conda create -n bench_hvp_torch python=3.11
$ conda activate bench_hvp_torch
$ pip install -e .
$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118In the blogpost, the experiments were done with version 2.1.1. of PyTorch.
To reproduce the computational time comparison in Jax or in PyTorch, run the following command where XXX has to be replaced by jax or torch accordingly
$ conda activate bench_hvp_XXX
$ cd bench_hvp
$ python bench_hvp_time.py -f XXX -n 100The raw results are stored in the folder outputs. Then, the figure can be plotted
$ cd ../figures
$ python plot_bench_time.py -f ../outputs/bench_hvp_time_XXX.parquetThe figures are stored in the folder figures.

To reproduce the memory footprint experiment, you can execute the following bash script
$ cd bench_hvp
$ bash bench_hvp_memory.shThe raw results are stored in the folder outputs. Then, the figure can be plotted
$ cd ../figures
$ python plot_bench_memory.py -f ../outputs/bench_hvp_memory_XXX.parquetwhere XXX can be replaced by jax, jax_without_jit or torch.
The figures are stored in the folder figures.


