Skip to content

Commit 03bff12

Browse files
Fix: validation caching (#292)
* pass current params into cached predict functions * add regression test for cached validation predictions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0a2d28e commit 03bff12

3 files changed

Lines changed: 86 additions & 15 deletions

File tree

src/cellflow/solvers/_genot.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -293,26 +293,31 @@ def _get_predict_fn(self, kwargs_frozen: frozen_dict.FrozenDict) -> Callable:
293293

294294
kwargs = dict(kwargs_frozen)
295295

296-
def vf(t: float, x: jnp.ndarray, args: tuple[dict[str, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
297-
params = self.vf_state.params
298-
x_0, condition, encoder_noise = args
296+
def vf(
297+
t: float, x: jnp.ndarray, args: tuple[Any, jnp.ndarray, dict[str, jnp.ndarray], jnp.ndarray]
298+
) -> jnp.ndarray:
299+
params, x_0, condition, encoder_noise = args
299300
return self.vf_state.apply_fn({"params": params}, t, x, x_0, condition, encoder_noise, train=False)[0]
300301

301302
def solve_ode(
302-
latent: jnp.ndarray, x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray
303+
params: Any,
304+
latent: jnp.ndarray,
305+
x: jnp.ndarray,
306+
condition: dict[str, jnp.ndarray],
307+
encoder_noise: jnp.ndarray,
303308
) -> jnp.ndarray:
304309
term = diffrax.ODETerm(vf)
305310
sol = diffrax.diffeqsolve(
306311
term,
307312
t0=0.0,
308313
t1=1.0,
309314
y0=latent,
310-
args=(x, condition, encoder_noise),
315+
args=(params, x, condition, encoder_noise),
311316
**kwargs,
312317
)
313318
return sol.ys[0]
314319

315-
fn = jax.jit(jax.vmap(solve_ode, in_axes=[0, 0, None, None]))
320+
fn = jax.jit(jax.vmap(solve_ode, in_axes=[None, 0, 0, None, None]))
316321
self._predict_fn_cache[kwargs_frozen] = fn
317322
return fn
318323

@@ -337,7 +342,7 @@ def _predict_jit(
337342
latent = self.latent_noise_fn(rng_genot, (x.shape[0],))
338343

339344
predict_fn = self._get_predict_fn(kwargs_frozen)
340-
return predict_fn(latent, x, condition, encoder_noise)
345+
return predict_fn(self.vf_state.params, latent, x, condition, encoder_noise)
341346

342347
@property
343348
def is_trained(self) -> bool:

src/cellflow/solvers/_otfm.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,24 +198,25 @@ def _get_predict_fn(self, kwargs_frozen: frozen_dict.FrozenDict) -> Callable:
198198

199199
kwargs = dict(kwargs_frozen)
200200

201-
def vf(t: jnp.ndarray, x: jnp.ndarray, args: tuple[dict[str, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
202-
params = self.vf_state_inference.params
203-
condition, encoder_noise = args
201+
def vf(t: jnp.ndarray, x: jnp.ndarray, args: tuple[Any, dict[str, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
202+
params, condition, encoder_noise = args
204203
return self.vf_state_inference.apply_fn({"params": params}, t, x, condition, encoder_noise, train=False)[0]
205204

206-
def solve_ode(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray) -> jnp.ndarray:
205+
def solve_ode(
206+
params: Any, x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray
207+
) -> jnp.ndarray:
207208
ode_term = diffrax.ODETerm(vf)
208209
result = diffrax.diffeqsolve(
209210
ode_term,
210211
t0=0.0,
211212
t1=1.0,
212213
y0=x,
213-
args=(condition, encoder_noise),
214+
args=(params, condition, encoder_noise),
214215
**kwargs,
215216
)
216217
return result.ys[0]
217218

218-
fn = jax.jit(jax.vmap(solve_ode, in_axes=[0, None, None]))
219+
fn = jax.jit(jax.vmap(solve_ode, in_axes=[None, 0, None, None]))
219220
self._predict_fn_cache[kwargs_frozen] = fn
220221
return fn
221222

@@ -238,7 +239,7 @@ def _predict_jit(
238239
encoder_noise = jnp.zeros(noise_dim) if use_mean else jax.random.normal(rng, noise_dim)
239240

240241
predict_fn = self._get_predict_fn(kwargs_frozen)
241-
return predict_fn(x, condition, encoder_noise)
242+
return predict_fn(self.vf_state_inference.params, x, condition, encoder_noise)
242243

243244
def predict(
244245
self,

tests/trainer/test_trainer.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import cellflow
1010
from cellflow._compat import ConstantNoiseFlow
11-
from cellflow.solvers import _otfm
11+
from cellflow.solvers import _genot, _otfm
1212
from cellflow.training import CellFlowTrainer, ComputationCallback, Metrics
1313
from cellflow.utils import match_linear
1414

@@ -227,3 +227,68 @@ def test_predict_kwargs_iter(self, dataloader, valid_loader):
227227
diff_2 = end_2 - start_2
228228

229229
assert diff_2 - diff_1 > 0.4
230+
231+
@pytest.mark.parametrize("solver_class", ["otfm", "genot"])
232+
def test_validation_metrics_change_when_predict_cache_is_reused(
233+
self,
234+
dataloader,
235+
valid_loader,
236+
solver_class,
237+
):
238+
opt = optax.adam(1e-3)
239+
if solver_class == "otfm":
240+
vf = cellflow.networks.ConditionalVelocityField(
241+
output_dim=5,
242+
max_combination_length=2,
243+
condition_embedding_dim=12,
244+
hidden_dims=(32, 32),
245+
decoder_dims=(32, 32),
246+
)
247+
solver = _otfm.OTFlowMatching(
248+
vf=vf,
249+
match_fn=match_linear,
250+
probability_path=ConstantNoiseFlow(0.0),
251+
optimizer=opt,
252+
conditions=cond,
253+
rng=vf_rng,
254+
)
255+
else:
256+
vf = cellflow.networks.GENOTConditionalVelocityField(
257+
output_dim=5,
258+
max_combination_length=2,
259+
condition_embedding_dim=12,
260+
hidden_dims=(32, 32),
261+
decoder_dims=(32, 32),
262+
)
263+
solver = _genot.GENOT(
264+
vf=vf,
265+
data_match_fn=match_linear,
266+
probability_path=ConstantNoiseFlow(0.0),
267+
optimizer=opt,
268+
source_dim=5,
269+
target_dim=5,
270+
conditions=cond,
271+
rng=vf_rng,
272+
)
273+
274+
trainer = CellFlowTrainer(solver=solver, predict_kwargs={"max_steps": 3, "throw": False})
275+
metrics_callback = Metrics(metrics=["e_distance"])
276+
277+
valid_source_data, valid_true_data, valid_pred_data = trainer._validation_step(valid_loader)
278+
metric_before = metrics_callback.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, solver)[
279+
"val_e_distance_mean"
280+
]
281+
282+
batch = dataloader.sample(None)
283+
metric_diffs = []
284+
for i in range(3):
285+
solver.step_fn(jax.random.PRNGKey(i), batch)
286+
valid_source_data, valid_true_data, valid_pred_data = trainer._validation_step(valid_loader)
287+
metric_after = metrics_callback.on_log_iteration(
288+
valid_source_data, valid_true_data, valid_pred_data, solver
289+
)["val_e_distance_mean"]
290+
metric_diffs.append(abs(metric_after - metric_before))
291+
metric_before = metric_after
292+
293+
assert len(solver._predict_fn_cache) == 1
294+
assert any(diff > 0 for diff in metric_diffs)

0 commit comments

Comments
 (0)