Skip to content

Commit f18c26d

Browse files
committed
added metrics
1 parent a89b932 commit f18c26d

2 files changed

Lines changed: 33 additions & 14 deletions

File tree

src/cli/02_fine_tune_generator.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from src.data import GeneratorDataModule
1717
from src.models import build_generator
1818
from src.utils.wandb_setup import setup_wandb
19-
19+
from src.utils.metrics import perplexity_metrics
2020

2121
logging.basicConfig(level=logging.INFO,
2222
format="%(asctime)s - %(levelname)s - %(message)s")
@@ -25,18 +25,6 @@
2525
app = typer.Typer()
2626

2727

28-
def perplexity_metrics(eval_pred):
29-
"""
30-
For causal‑LM fine‑tuning we usually care about perplexity rather than
31-
accuracy/F1. `Trainer.evaluate` returns (loss, logits, labels) so we grab
32-
the loss and exponentiate it.
33-
"""
34-
# Depending on HF version eval_pred can be EvalPrediction or a tuple
35-
if isinstance(eval_pred, tuple):
36-
loss = eval_pred[0]
37-
else:
38-
loss = eval_pred.loss
39-
return {"perplexity": math.exp(loss)}
4028

4129

4230
@app.command()

src/utils/metrics.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import numpy as np
1+
import numpy as np
2+
import math
23
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
34

45

@@ -25,3 +26,33 @@ def compute_metrics(p):
2526
'true_negatives': int(true_negatives),
2627
'false_negatives': int(false_negatives)
2728
}
29+
30+
31+
32+
def perplexity_metrics(eval_pred):
33+
"""
34+
Hugging Face passes (logits, labels) – no loss attribute.
35+
We compute CE‑loss ourselves, then PPL = exp(loss).
36+
Padding / ignored positions are ‑100 by HF convention.
37+
"""
38+
# Unpack EvalPrediction → ndarray → torch.Tensor
39+
logits, labels = eval_pred
40+
logits = torch.as_tensor(logits, dtype=torch.float32)
41+
labels = torch.as_tensor(labels, dtype=torch.long)
42+
43+
# Shift so that token t predicts t+1 (standard LM training)
44+
shift_logits = logits[..., :-1, :].contiguous()
45+
shift_labels = labels[..., 1:].contiguous()
46+
47+
# Cross‑entropy over non‑ignored tokens
48+
loss = F.cross_entropy(
49+
shift_logits.view(-1, shift_logits.size(-1)),
50+
shift_labels.view(-1),
51+
ignore_index = -100, # HF Trainer uses -100 for padding
52+
reduction = "mean"
53+
)
54+
55+
return {
56+
"loss": loss.item(),
57+
"perplexity": math.exp(loss.item())
58+
}

0 commit comments

Comments
 (0)