@@ -15,7 +15,6 @@ class Loss(BaseModel):
1515 model_config = ConfigDict (arbitrary_types_allowed = True )
1616 reduction : Literal ["mean" , "sum" ]
1717 policy_loss : torch .Tensor
18- kl : torch .Tensor
1918 entropy : torch .Tensor | None
2019 policy_loss_sum : torch .Tensor
2120 probs_corr : torch .Tensor
@@ -126,17 +125,9 @@ def loss_fn(
126125 logprob_diff = old_logprobs - original_logprobs
127126 prob_ratio = torch .exp (logprob_diff )
128127 policy_loss *= torch .clamp (prob_ratio , max = upper_bound ).detach ()
129- if ref_logprobs is not None :
130- kl_div = (
131- torch .exp (ref_logprobs - new_logprobs ) - (ref_logprobs - new_logprobs ) - 1.0
132- )
133- else :
134- kl_div = torch .zeros_like (policy_loss )
135128 policy_loss = policy_loss * weights * assistant_mask
136- kl_div = kl_div * weights * assistant_mask
137129 denominator = assistant_mask .sum () + 1e-6 if reduction == "mean" else 1.0
138130 reduced_policy_loss = policy_loss .sum () / denominator
139- kl = kl_div .sum () / denominator
140131 # Compute reduced entropy for the current step.
141132 if entropies is not None :
142133 shifted_entropies = shift_tensor (entropies , 0.0 )
@@ -146,7 +137,6 @@ def loss_fn(
146137 return Loss (
147138 reduction = reduction ,
148139 policy_loss = reduced_policy_loss ,
149- kl = kl ,
150140 entropy = entropy ,
151141 policy_loss_sum = policy_loss .sum (),
152142 probs_corr = probs_corr ,
0 commit comments