Skip to content

Commit d1ad1c4

Browse files
committed
minor changes
1 parent 0860ea9 commit d1ad1c4

6 files changed

Lines changed: 35 additions & 17 deletions

File tree

.DS_Store

2 KB
Binary file not shown.

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,9 @@ outputs/
7676
*.DS_Store
7777
.DS_Store
7878
.DS_Store
79-
.DS_Store?
79+
.DS_Store?
80+
81+
.venv/
82+
.env/
83+
84+
.DS_Store

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ To run on `weftdrive`:
5454
1. `export WANDB_API_KEY="..."`
5555
2. `python -m wandb login`
5656
7. Create the logs directory and file: `mkdir -p ~/scratch/paramkapur/logs` and `touch ~/scratch/paramkapur/logs/$(date +%Y%m%d_%H%M).log`
57-
8. Run the training script: `nohup /srv/gpurun.pl python src/cli/01_train_teacher.py configs/teacher/stt2_hf.yaml > ~/scratch/paramkapur/logs/$(date +%Y%m%d_%H%M).log 2>&1 &`
57+
8. Run the training script: `nohup /srv/gpurun.pl python src/cli/01_train_teacher.py configs/teacher/sst2_hf.yaml > ~/scratch/paramkapur/logs/$(date +%Y%m%d_%H%M).log 2>&1 &`
5858

5959

6060
/scratch/paramkapur/data/clean/clean
6161

62+
63+
nohup /srv/gpurun.pl python src/cli/01_train_teacher.py configs/teacher/sst2_hf.yaml > ~/scratch/paramkapur/logs/$(date +%Y%m%d_%H%M).log 2>&1 &

configs/teacher/sst2_hf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ model:
44
use_fast_tokenizer: true
55

66
data:
7-
dataset_path: "~/scratch/paramkapur/data/clean/clean" # Use HF dataset identifier
7+
dataset_path: "./data/clean/" # Use HF dataset identifier
88
max_len: 32
99
train_split: "train"
1010
validation_split: "val"

requirements.txt

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
torch>=1.9.0
1+
torch
22
transformers
3-
datasets>=1.11.0
4-
numpy>=1.19.5
5-
scikit-learn>=0.24.2
6-
pandas>=1.3.0
7-
matplotlib>=3.4.2
8-
seaborn>=0.11.1
9-
tqdm>=4.61.2
10-
datasets>=3.5.0
3+
datasets
4+
numpy
5+
scikit-learn
6+
pandas
7+
matplotlib
8+
seaborn
9+
tqdm
10+
datasets
1111

12-
accelerate==1.6.0
12+
accelerate
1313

1414
# Configuration and Metrics
15-
PyYAML>=6.0.2
16-
scikit-learn>=1.6.1
15+
PyYAML
16+
scikit-learn
1717

1818
# Logging
19-
wandb>=0.19.9
19+
wandb
2020
Cmake
2121
sentencepiece
2222
protobuf

src/utils/metrics.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,20 @@ def compute_metrics(p):
88
labels = p.label_ids
99
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary') # Assuming binary
1010
acc = accuracy_score(labels, preds)
11+
12+
# Calculate confusion matrix metrics
13+
true_positives = np.sum((preds == 1) & (labels == 1))
14+
false_positives = np.sum((preds == 1) & (labels == 0))
15+
true_negatives = np.sum((preds == 0) & (labels == 0))
16+
false_negatives = np.sum((preds == 0) & (labels == 1))
17+
1118
return {
1219
'accuracy': acc,
1320
'f1': f1,
1421
'precision': precision,
15-
'recall': recall
22+
'recall': recall,
23+
'true_positives': int(true_positives),
24+
'false_positives': int(false_positives),
25+
'true_negatives': int(true_negatives),
26+
'false_negatives': int(false_negatives)
1627
}

0 commit comments

Comments
 (0)