From 93e328b61b86581f874dd862f27494799d64a337 Mon Sep 17 00:00:00 2001 From: hationgma Date: Thu, 26 Mar 2026 16:15:42 +0000 Subject: [PATCH 1/2] fix solve square normalizer --- flowrl/agent/online/dpmd.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/flowrl/agent/online/dpmd.py b/flowrl/agent/online/dpmd.py index c0211f0..4d4e74e 100644 --- a/flowrl/agent/online/dpmd.py +++ b/flowrl/agent/online/dpmd.py @@ -25,6 +25,11 @@ def solve_normalizer_exp(q: jnp.ndarray, temp: float): def solve_normalizer_linear(q: jnp.ndarray, temp: float, negative: float=0.0): num_particles = q.shape[-1] + + # convert from sum to 1 to sum to num_particles + temp = temp * num_particles + negative = negative / num_particles + q_sorted = jnp.sort(q, axis=-1)[..., ::-1] q_cumsum = jnp.cumsum(q_sorted, axis=-1) @@ -159,6 +164,7 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar if reweight == "exp": nu = solve_normalizer_exp(q_batch, temp()) weights = jnp.exp((q_batch - nu) / temp()) + weights = weights * num_particles elif reweight == "linear": nu = solve_normalizer_linear(q_batch, temp(), negative=negative_bound) weights = jnp.maximum((q_batch - nu) / temp(), negative_bound) @@ -170,7 +176,6 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar ent_weights = jnp.maximum(weights, 1e-6) ent_weights = ent_weights / ent_weights.sum(axis=-1, keepdims=True) entropy = - jnp.sum(ent_weights * jnp.log(ent_weights+1e-6), axis=-1) - weights = weights * num_particles _, at, t, eps = actor.add_noise(add_noise_rng, action_batch) From fcb3a8462d79f0ca2bc7c87bd4791d049cbddf44 Mon Sep 17 00:00:00 2001 From: hationgma Date: Thu, 26 Mar 2026 16:15:57 +0000 Subject: [PATCH 2/2] fix solve linear and square normalizer --- flowrl/agent/online/dpmd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flowrl/agent/online/dpmd.py b/flowrl/agent/online/dpmd.py index 4d4e74e..f76d782 100644 --- a/flowrl/agent/online/dpmd.py +++ b/flowrl/agent/online/dpmd.py @@ -45,7 +45,7 @@ def solve_normalizer_linear(q: jnp.ndarray, temp: float, negative: float=0.0): def solve_normalizer_square(q: jnp.ndarray, temp: float): num_particles = q.shape[-1] - target_sum = temp ** 2 + target_sum = temp ** 2 * num_particles q_sorted = jnp.sort(q, axis=-1)[..., ::-1] q_cumsum = jnp.cumsum(q_sorted, axis=-1)