Skip to content

Commit b7975a6

Browse files
authored
Reduce the peak memory usage of skillmodels (#76) and refactor the likelihood functions modules (#77)
* Add a version of mem consumption test that fails if there are increases in the repo. * Decorate kalman_update, _calculate_sigma_points, with jax.checkpoint, can use `prevent_cse=False` due to being inside lax.scan, as per https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html\#jax.checkpoint. * Use checkpoint() on kalman_predict as well, requires changing `transition_func` to be its first argument and using `jnp.array` in tests. * Provision for testing on GPUs, timing information is not useful yet.
1 parent bbd4ce4 commit b7975a6

19 files changed

Lines changed: 3026 additions & 1071 deletions

.github/workflows/main.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ jobs:
2828
- uses: actions/checkout@v4
2929
- uses: prefix-dev/setup-pixi@v0.8.0
3030
with:
31-
pixi-version: v0.28.2
31+
pixi-version: v0.29.0
3232
cache: true
3333
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
34-
environments: test
34+
environments: test-cpu
3535
activate-environment: true
3636
- name: Run pytest
3737
shell: bash -l {0}
38-
run: pixi run -e test tests-with-cov
38+
run: pixi run -e test-cpu tests-with-cov
3939
- name: Upload coverage report
4040
if: runner.os == 'Linux' && matrix.python-version == '3.12'
4141
uses: codecov/codecov-action@v4

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ entry for a suggested citation. The suggested citation will be updated once the
7373
becomes part of a published paper.
7474

7575
```
76-
@Unpublished{Gabler2018,
76+
@Unpublished{Gabler2024,
7777
Title = {A Python Library to Estimate Nonlinear Dynamic Latent Factor Models},
7878
Author = {Janos Gabler},
79-
Year = {2018},
79+
Year = {2024},
8080
Url = {https://github.com/OpenSourceEconomics/skillmodels}
8181
}
8282
```

docs/source/getting_started/tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"import yaml\n",
2121
"\n",
2222
"from skillmodels.config import TEST_DIR\n",
23-
"from skillmodels.likelihood_function import get_maximization_inputs"
23+
"from skillmodels.maximization_inputs import get_maximization_inputs"
2424
]
2525
},
2626
{

docs/source/how_to_guides/how_to_visualize_pairwise_factor_distribution.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"import yaml\n",
2323
"\n",
2424
"from skillmodels.config import TEST_DIR\n",
25-
"from skillmodels.likelihood_function import get_maximization_inputs\n",
25+
"from skillmodels.maximization_inputs import get_maximization_inputs\n",
2626
"from skillmodels.simulate_data import simulate_dataset\n",
2727
"from skillmodels.visualize_factor_distributions import (\n",
2828
" bivariate_density_contours,\n",

pixi.lock

Lines changed: 2710 additions & 767 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,33 @@ scipy = "<=1.13"
106106
# Development Dependencies (pypi)
107107
# --------------------------------------------------------------------------------------
108108

109-
[tool.pixi.pypi-dependencies]
110-
jax = { version = ">=0.4.20", extras = ["cpu"] }
109+
[tool.pixi.target.unix.dependencies]
110+
jax = ">=0.4.20"
111111
jaxlib = ">=0.4.20"
112+
113+
# Development Dependencies (pypi)
114+
# --------------------------------------------------------------------------------------
115+
116+
[tool.pixi.pypi-dependencies]
112117
pdbp = "*"
113118
skillmodels = {path = ".", editable = true}
114119

120+
[tool.pixi.target.win-64.pypi-dependencies]
121+
jax = { version = ">=0.4.20", extras = ["cpu"] }
122+
jaxlib = ">=0.4.20"
123+
115124
# Features and Tasks
116125
# --------------------------------------------------------------------------------------
117126

127+
[tool.pixi.feature.cuda]
128+
platforms = ["linux-64"]
129+
system-requirements = {cuda = "12"}
130+
131+
[tool.pixi.feature.cuda.target.linux-64.dependencies]
132+
cuda-nvcc = ">=12"
133+
jax = ">=0.4.20"
134+
jaxlib = { version = ">=0.4.20", build = "cuda12*" }
135+
118136
[tool.pixi.feature.test.dependencies]
119137
pytest = "*"
120138
pytest-cov = "*"
@@ -128,6 +146,10 @@ pytest-memray = "*"
128146
tests = "pytest tests"
129147
tests-with-cov = "pytest tests --cov-report=xml --cov=./"
130148
mem = "pytest -x -s --pdb --memray --fail-on-increase tests/test_likelihood_regression.py::test_likelihood_contributions_large_nobs"
149+
mem-on-clean-repo = "git status --porcelain && git diff-index --quiet HEAD -- && git rev-parse HEAD && pytest -x -s --pdb --memray --fail-on-increase tests/test_likelihood_regression.py::test_likelihood_contributions_large_nobs"
150+
151+
[tool.pixi.feature.cuda.tasks]
152+
mem-cuda = "pytest -x -s --pdb --memray --fail-on-increase tests/test_likelihood_regression.py::test_likelihood_contributions_large_nobs"
131153

132154
[tool.pixi.feature.mypy.dependencies]
133155
mypy = "*"
@@ -141,8 +163,10 @@ mypy = "mypy src"
141163
# --------------------------------------------------------------------------------------
142164

143165
[tool.pixi.environments]
166+
cuda = {features = ["cuda"], solve-group = "cuda"}
144167
mypy = {features = ["mypy"], solve-group = "default"}
145-
test = {features = ["test"], solve-group = "default"}
168+
test-cpu = {features = ["test"], solve-group = "default"}
169+
test-gpu = {features = ["test", "cuda"], solve-group = "cuda"}
146170

147171

148172
# ======================================================================================

src/skillmodels/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
contextlib.suppress(Exception)
77

88
from skillmodels.filtered_states import get_filtered_states
9-
from skillmodels.likelihood_function import get_maximization_inputs
9+
from skillmodels.maximization_inputs import get_maximization_inputs
1010
from skillmodels.simulate_data import simulate_dataset
1111

1212
__all__ = ["get_maximization_inputs", "simulate_dataset", "get_filtered_states"]

src/skillmodels/filtered_states.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import jax.numpy as jnp
22
import numpy as np
33

4-
from skillmodels.likelihood_function import get_maximization_inputs
4+
from skillmodels.maximization_inputs import get_maximization_inputs
55
from skillmodels.params_index import get_params_index
66
from skillmodels.parse_params import create_parsing_info, parse_params
77
from skillmodels.process_debug_data import create_state_ranges

src/skillmodels/kalman_filters.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import functools
2+
13
import jax
24
import jax.numpy as jnp
35

@@ -9,6 +11,7 @@
911
# ======================================================================================
1012

1113

14+
@functools.partial(jax.checkpoint, prevent_cse=False)
1215
def kalman_update(
1316
states,
1417
upper_chols,
@@ -152,12 +155,13 @@ def calculate_sigma_scaling_factor_and_weights(n_states, kappa=2):
152155
return scaling_factor, weights
153156

154157

158+
@functools.partial(jax.checkpoint, static_argnums=0, prevent_cse=False)
155159
def kalman_predict(
160+
transition_func,
156161
states,
157162
upper_chols,
158163
sigma_scaling_factor,
159164
sigma_weights,
160-
transition_info,
161165
trans_coeffs,
162166
shock_sds,
163167
anchoring_scaling_factors,
@@ -167,6 +171,7 @@ def kalman_predict(
167171
"""Make a unscented Kalman predict.
168172
169173
Args:
174+
transition_func (Callable): The transition function.
170175
states (jax.numpy.array): Array of shape (n_obs, n_mixtures, n_states) with
171176
pre-update states estimates.
172177
upper_chols (jax.numpy.array): Array of shape (n_obs, n_mixtures, n_states,
@@ -177,9 +182,6 @@ def kalman_predict(
177182
the sigma_point algorithm chosen.
178183
sigma_weights (jax.numpy.array): 1d array of length n_sigma with non-negative
179184
sigma weights.
180-
transition_info (dict): Dict with the entries "func" (the actual transition
181-
function) and "columns" (a dictionary mapping factors that are needed
182-
as individual columns to positions in the factor array).
183185
trans_coeffs (tuple): Tuple of 1d jax.numpy.arrays with transition parameters.
184186
anchoring_scaling_factors (jax.numpy.array): Array of shape (2, n_fac) with
185187
the scaling factors for anchoring. The first row corresponds to the input
@@ -203,7 +205,7 @@ def kalman_predict(
203205
)
204206
transformed = transform_sigma_points(
205207
sigma_points,
206-
transition_info,
208+
transition_func,
207209
trans_coeffs,
208210
anchoring_scaling_factors,
209211
anchoring_constants,
@@ -225,6 +227,7 @@ def kalman_predict(
225227
return predicted_states, predicted_covs
226228

227229

230+
@functools.partial(jax.checkpoint, prevent_cse=False)
228231
def _calculate_sigma_points(states, upper_chols, scaling_factor, observed_factors):
229232
"""Calculate the array of sigma_points for the unscented transform.
230233
@@ -272,7 +275,7 @@ def _calculate_sigma_points(states, upper_chols, scaling_factor, observed_factor
272275

273276
def transform_sigma_points(
274277
sigma_points,
275-
transition_info,
278+
transition_func,
276279
trans_coeffs,
277280
anchoring_scaling_factors,
278281
anchoring_constants,
@@ -281,9 +284,7 @@ def transform_sigma_points(
281284
282285
Args:
283286
sigma_points (jax.numpy.array) of shape n_obs, n_mixtures, n_sigma, n_fac.
284-
transition_info (dict): Dict with the entries "func" (the actual transition
285-
function) and "columns" (a dictionary mapping factors that are needed
286-
as individual columns to positions in the factor array).
287+
transition_func (Callable): The transition function.
287288
trans_coeffs (tuple): Tuple of 1d jax.numpy.arrays with transition parameters.
288289
anchoring_scaling_factors (jax.numpy.array): Array of shape (2, n_states) with
289290
the scaling factors for anchoring. The first row corresponds to the input
@@ -303,9 +304,7 @@ def transform_sigma_points(
303304

304305
anchored = flat_sigma_points * anchoring_scaling_factors[0] + anchoring_constants[0]
305306

306-
transition_function = transition_info["func"]
307-
308-
transformed_anchored = transition_function(trans_coeffs, anchored)
307+
transformed_anchored = transition_func(trans_coeffs, anchored)
309308

310309
n_observed = transformed_anchored.shape[-1]
311310

0 commit comments

Comments
 (0)