Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions flowrl/agent/online/dpmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def solve_normalizer_exp(q: jnp.ndarray, temp: float):

def solve_normalizer_linear(q: jnp.ndarray, temp: float, negative: float=0.0):
num_particles = q.shape[-1]

# convert from sum to 1 to sum to num_particles
temp = temp * num_particles
negative = negative / num_particles

q_sorted = jnp.sort(q, axis=-1)[..., ::-1]
q_cumsum = jnp.cumsum(q_sorted, axis=-1)

Expand All @@ -40,7 +45,7 @@ def solve_normalizer_linear(q: jnp.ndarray, temp: float, negative: float=0.0):

def solve_normalizer_square(q: jnp.ndarray, temp: float):
num_particles = q.shape[-1]
target_sum = temp ** 2
target_sum = temp ** 2 * num_particles

q_sorted = jnp.sort(q, axis=-1)[..., ::-1]
q_cumsum = jnp.cumsum(q_sorted, axis=-1)
Expand Down Expand Up @@ -159,6 +164,7 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar
if reweight == "exp":
nu = solve_normalizer_exp(q_batch, temp())
weights = jnp.exp((q_batch - nu) / temp())
weights = weights * num_particles
elif reweight == "linear":
nu = solve_normalizer_linear(q_batch, temp(), negative=negative_bound)
weights = jnp.maximum((q_batch - nu) / temp(), negative_bound)
Expand All @@ -170,7 +176,6 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar
ent_weights = jnp.maximum(weights, 1e-6)
ent_weights = ent_weights / ent_weights.sum(axis=-1, keepdims=True)
entropy = - jnp.sum(ent_weights * jnp.log(ent_weights+1e-6), axis=-1)
weights = weights * num_particles

_, at, t, eps = actor.add_noise(add_noise_rng, action_batch)

Expand Down