diff --git a/flowrl/agent/online/dpmd.py b/flowrl/agent/online/dpmd.py index c0211f0..f76d782 100644 --- a/flowrl/agent/online/dpmd.py +++ b/flowrl/agent/online/dpmd.py @@ -25,6 +25,11 @@ def solve_normalizer_exp(q: jnp.ndarray, temp: float): def solve_normalizer_linear(q: jnp.ndarray, temp: float, negative: float=0.0): num_particles = q.shape[-1] + + # convert from sum to 1 to sum to num_particles + temp = temp * num_particles + negative = negative / num_particles + q_sorted = jnp.sort(q, axis=-1)[..., ::-1] q_cumsum = jnp.cumsum(q_sorted, axis=-1) @@ -40,7 +45,7 @@ def solve_normalizer_linear(q: jnp.ndarray, temp: float, negative: float=0.0): def solve_normalizer_square(q: jnp.ndarray, temp: float): num_particles = q.shape[-1] - target_sum = temp ** 2 + target_sum = temp ** 2 * num_particles q_sorted = jnp.sort(q, axis=-1)[..., ::-1] q_cumsum = jnp.cumsum(q_sorted, axis=-1) @@ -159,6 +164,7 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar if reweight == "exp": nu = solve_normalizer_exp(q_batch, temp()) weights = jnp.exp((q_batch - nu) / temp()) + weights = weights * num_particles elif reweight == "linear": nu = solve_normalizer_linear(q_batch, temp(), negative=negative_bound) weights = jnp.maximum((q_batch - nu) / temp(), negative_bound) @@ -170,7 +176,6 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar ent_weights = jnp.maximum(weights, 1e-6) ent_weights = ent_weights / ent_weights.sum(axis=-1, keepdims=True) entropy = - jnp.sum(ent_weights * jnp.log(ent_weights+1e-6), axis=-1) - weights = weights * num_particles _, at, t, eps = actor.add_noise(add_noise_rng, action_batch)