Skip to content

Commit 75a81e9

Browse files
authored
fix: restore PR 607 KL loss removal (#639)
PR #619 reverted PR #607 which modified kl divergence, we fix this here.
1 parent 7ffa8b4 commit 75a81e9

3 files changed

Lines changed: 2 additions & 11 deletions

File tree

src/art/loss.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ class Loss(BaseModel):
1515
model_config = ConfigDict(arbitrary_types_allowed=True)
1616
reduction: Literal["mean", "sum"]
1717
policy_loss: torch.Tensor
18-
kl: torch.Tensor
1918
entropy: torch.Tensor | None
2019
policy_loss_sum: torch.Tensor
2120
probs_corr: torch.Tensor
@@ -126,17 +125,9 @@ def loss_fn(
126125
logprob_diff = old_logprobs - original_logprobs
127126
prob_ratio = torch.exp(logprob_diff)
128127
policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach()
129-
if ref_logprobs is not None:
130-
kl_div = (
131-
torch.exp(ref_logprobs - new_logprobs) - (ref_logprobs - new_logprobs) - 1.0
132-
)
133-
else:
134-
kl_div = torch.zeros_like(policy_loss)
135128
policy_loss = policy_loss * weights * assistant_mask
136-
kl_div = kl_div * weights * assistant_mask
137129
denominator = assistant_mask.sum() + 1e-6 if reduction == "mean" else 1.0
138130
reduced_policy_loss = policy_loss.sum() / denominator
139-
kl = kl_div.sum() / denominator
140131
# Compute reduced entropy for the current step.
141132
if entropies is not None:
142133
shifted_entropies = shift_tensor(entropies, 0.0)
@@ -146,7 +137,6 @@ def loss_fn(
146137
return Loss(
147138
reduction=reduction,
148139
policy_loss=reduced_policy_loss,
149-
kl=kl,
150140
entropy=entropy,
151141
policy_loss_sum=policy_loss.sum(),
152142
probs_corr=probs_corr,

src/art/test/test_kl_advantage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def test_kl_advantage_no_effect_when_disabled():
4646

4747
assert loss_no_kl.kl_policy_ref is None
4848
assert loss_without_ref.kl_policy_ref is None
49+
assert loss_no_kl.reduction == "mean"
50+
assert not hasattr(loss_no_kl, "kl")
4951

5052

5153
def test_kl_advantage_enabled():

tests/integration/megatron_oracle_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,6 @@ def _scaled_loss_fn(*args: Any, **kwargs: Any):
682682
return loss.model_copy(
683683
update={
684684
"policy_loss": loss.policy_loss * effective_loss_scale,
685-
"kl": loss.kl * effective_loss_scale,
686685
"policy_loss_sum": loss.policy_loss_sum * effective_loss_scale,
687686
}
688687
)

0 commit comments

Comments
 (0)