Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/cellflow/solvers/_genot.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,26 +293,31 @@ def _get_predict_fn(self, kwargs_frozen: frozen_dict.FrozenDict) -> Callable:

kwargs = dict(kwargs_frozen)

def vf(t: float, x: jnp.ndarray, args: tuple[dict[str, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
params = self.vf_state.params
x_0, condition, encoder_noise = args
def vf(
t: float, x: jnp.ndarray, args: tuple[Any, jnp.ndarray, dict[str, jnp.ndarray], jnp.ndarray]
) -> jnp.ndarray:
params, x_0, condition, encoder_noise = args
return self.vf_state.apply_fn({"params": params}, t, x, x_0, condition, encoder_noise, train=False)[0]

def solve_ode(
latent: jnp.ndarray, x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray
params: Any,
latent: jnp.ndarray,
x: jnp.ndarray,
condition: dict[str, jnp.ndarray],
encoder_noise: jnp.ndarray,
) -> jnp.ndarray:
term = diffrax.ODETerm(vf)
sol = diffrax.diffeqsolve(
term,
t0=0.0,
t1=1.0,
y0=latent,
args=(x, condition, encoder_noise),
args=(params, x, condition, encoder_noise),
**kwargs,
)
return sol.ys[0]

fn = jax.jit(jax.vmap(solve_ode, in_axes=[0, 0, None, None]))
fn = jax.jit(jax.vmap(solve_ode, in_axes=[None, 0, 0, None, None]))
self._predict_fn_cache[kwargs_frozen] = fn
return fn

Expand All @@ -337,7 +342,7 @@ def _predict_jit(
latent = self.latent_noise_fn(rng_genot, (x.shape[0],))

predict_fn = self._get_predict_fn(kwargs_frozen)
return predict_fn(latent, x, condition, encoder_noise)
return predict_fn(self.vf_state.params, latent, x, condition, encoder_noise)

@property
def is_trained(self) -> bool:
Expand Down
15 changes: 8 additions & 7 deletions src/cellflow/solvers/_otfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,25 @@ def _get_predict_fn(self, kwargs_frozen: frozen_dict.FrozenDict) -> Callable:

kwargs = dict(kwargs_frozen)

def vf(t: jnp.ndarray, x: jnp.ndarray, args: tuple[dict[str, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
params = self.vf_state_inference.params
condition, encoder_noise = args
def vf(t: jnp.ndarray, x: jnp.ndarray, args: tuple[Any, dict[str, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
params, condition, encoder_noise = args
return self.vf_state_inference.apply_fn({"params": params}, t, x, condition, encoder_noise, train=False)[0]

def solve_ode(x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray) -> jnp.ndarray:
def solve_ode(
params: Any, x: jnp.ndarray, condition: dict[str, jnp.ndarray], encoder_noise: jnp.ndarray
) -> jnp.ndarray:
ode_term = diffrax.ODETerm(vf)
result = diffrax.diffeqsolve(
ode_term,
t0=0.0,
t1=1.0,
y0=x,
args=(condition, encoder_noise),
args=(params, condition, encoder_noise),
**kwargs,
)
return result.ys[0]

fn = jax.jit(jax.vmap(solve_ode, in_axes=[0, None, None]))
fn = jax.jit(jax.vmap(solve_ode, in_axes=[None, 0, None, None]))
self._predict_fn_cache[kwargs_frozen] = fn
return fn

Expand All @@ -238,7 +239,7 @@ def _predict_jit(
encoder_noise = jnp.zeros(noise_dim) if use_mean else jax.random.normal(rng, noise_dim)

predict_fn = self._get_predict_fn(kwargs_frozen)
return predict_fn(x, condition, encoder_noise)
return predict_fn(self.vf_state_inference.params, x, condition, encoder_noise)

def predict(
self,
Expand Down
67 changes: 66 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import cellflow
from cellflow._compat import ConstantNoiseFlow
from cellflow.solvers import _otfm
from cellflow.solvers import _genot, _otfm
from cellflow.training import CellFlowTrainer, ComputationCallback, Metrics
from cellflow.utils import match_linear

Expand Down Expand Up @@ -227,3 +227,68 @@ def test_predict_kwargs_iter(self, dataloader, valid_loader):
diff_2 = end_2 - start_2

assert diff_2 - diff_1 > 0.4

@pytest.mark.parametrize("solver_class", ["otfm", "genot"])
def test_validation_metrics_change_when_predict_cache_is_reused(
self,
dataloader,
valid_loader,
solver_class,
):
opt = optax.adam(1e-3)
if solver_class == "otfm":
vf = cellflow.networks.ConditionalVelocityField(
output_dim=5,
max_combination_length=2,
condition_embedding_dim=12,
hidden_dims=(32, 32),
decoder_dims=(32, 32),
)
solver = _otfm.OTFlowMatching(
vf=vf,
match_fn=match_linear,
probability_path=ConstantNoiseFlow(0.0),
optimizer=opt,
conditions=cond,
rng=vf_rng,
)
else:
vf = cellflow.networks.GENOTConditionalVelocityField(
output_dim=5,
max_combination_length=2,
condition_embedding_dim=12,
hidden_dims=(32, 32),
decoder_dims=(32, 32),
)
solver = _genot.GENOT(
vf=vf,
data_match_fn=match_linear,
probability_path=ConstantNoiseFlow(0.0),
optimizer=opt,
source_dim=5,
target_dim=5,
conditions=cond,
rng=vf_rng,
)

trainer = CellFlowTrainer(solver=solver, predict_kwargs={"max_steps": 3, "throw": False})
metrics_callback = Metrics(metrics=["e_distance"])

valid_source_data, valid_true_data, valid_pred_data = trainer._validation_step(valid_loader)
metric_before = metrics_callback.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, solver)[
"val_e_distance_mean"
]

batch = dataloader.sample(None)
metric_diffs = []
for i in range(3):
solver.step_fn(jax.random.PRNGKey(i), batch)
valid_source_data, valid_true_data, valid_pred_data = trainer._validation_step(valid_loader)
metric_after = metrics_callback.on_log_iteration(
valid_source_data, valid_true_data, valid_pred_data, solver
)["val_e_distance_mean"]
metric_diffs.append(abs(metric_after - metric_before))
metric_before = metric_after

assert len(solver._predict_fn_cache) == 1
assert any(diff > 0 for diff in metric_diffs)
Loading