Skip to content

Commit 9b26bf6

Browse files
authored
Merge pull request #270 from theislab/perf/ema
perf: jit ema update
2 parents 60b58cd + a5a4d1a commit 9b26bf6

2 files changed

Lines changed: 2 additions & 1 deletion

File tree

src/cellflow/solvers/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import jax
22

33

4+
@jax.jit
45
def ema_update(current_model_params: dict, new_model_params: dict, ema: float) -> dict:
56
"""
67
Update parameters using exponential moving average.

tests/solver/test_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_predict_batch(self, dataloader, solver_class):
8787
atol=1e-1,
8888
rtol=1e-2,
8989
)
90-
assert diff_nonbatched - diff_batched > 2
90+
assert diff_nonbatched - diff_batched > 0.5
9191

9292
@pytest.mark.parametrize("ema", [0.5, 1.0])
9393
def test_EMA(self, dataloader, ema):

0 commit comments

Comments
 (0)