Skip to content

Commit 3512b6f

Browse files
committed
updates
1 parent fd75045 commit 3512b6f

4 files changed

Lines changed: 149 additions & 16 deletions

File tree

.DS_Store

0 Bytes
Binary file not shown.

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,9 @@ logs/
7171
*.log
7272

7373
wandb/
74-
outputs/
74+
outputs/
75+
76+
*.DS_Store
77+
.DS_Store
78+
.DS_Store
79+
.DS_Store?

notebooks/test.ipynb

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 17,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"data": {
10+
"application/vnd.jupyter.widget-view+json": {
11+
"model_id": "ae8b72691bc94e5b8406681bb5b11d59",
12+
"version_major": 2,
13+
"version_minor": 0
14+
},
15+
"text/plain": [
16+
"Map (num_proc=4): 0%| | 0/63981 [00:00<?, ? examples/s]"
17+
]
18+
},
19+
"metadata": {},
20+
"output_type": "display_data"
21+
},
22+
{
23+
"data": {
24+
"application/vnd.jupyter.widget-view+json": {
25+
"model_id": "f04a9293b4a14796aa1050f40b8f0135",
26+
"version_major": 2,
27+
"version_minor": 0
28+
},
29+
"text/plain": [
30+
"Map (num_proc=4): 0%| | 0/872 [00:00<?, ? examples/s]"
31+
]
32+
},
33+
"metadata": {},
34+
"output_type": "display_data"
35+
},
36+
{
37+
"data": {
38+
"application/vnd.jupyter.widget-view+json": {
39+
"model_id": "0e7c45fd8b974bc1ac58f95e05d69cb6",
40+
"version_major": 2,
41+
"version_minor": 0
42+
},
43+
"text/plain": [
44+
"Map (num_proc=4): 0%| | 0/3368 [00:00<?, ? examples/s]"
45+
]
46+
},
47+
"metadata": {},
48+
"output_type": "display_data"
49+
},
50+
{
51+
"data": {
52+
"application/vnd.jupyter.widget-view+json": {
53+
"model_id": "9968f3bb38ad4d26b78e6d1cc7ae295c",
54+
"version_major": 2,
55+
"version_minor": 0
56+
},
57+
"text/plain": [
58+
"Map (num_proc=4): 0%| | 0/1821 [00:00<?, ? examples/s]"
59+
]
60+
},
61+
"metadata": {},
62+
"output_type": "display_data"
63+
},
64+
{
65+
"data": {
66+
"text/plain": [
67+
"Dataset({\n",
68+
" features: ['idx', 'labels', 'text', 'input_ids', 'attention_mask'],\n",
69+
" num_rows: 63981\n",
70+
"})"
71+
]
72+
},
73+
"execution_count": 17,
74+
"metadata": {},
75+
"output_type": "execute_result"
76+
}
77+
],
78+
"source": [
79+
"from datasets import load_from_disk, Dataset\n",
80+
"from transformers import GPT2Tokenizer\n",
81+
"\n",
82+
"dataset = load_from_disk(\"data/clean\")\n",
83+
"\n",
84+
"tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
85+
"tokenizer.pad_token = tokenizer.eos_token\n",
86+
"\n",
87+
"def tokenize_function(examples):\n",
88+
" return tokenizer(examples[\"text\"], padding=True, truncation=True, max_length=32)\n",
89+
"\n",
90+
"tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4)\n",
91+
"\n",
92+
"tokenized_datasets[\"train\"]\n",
93+
"\n"
94+
]
95+
},
96+
{
97+
"cell_type": "code",
98+
"execution_count": 9,
99+
"metadata": {},
100+
"outputs": [
101+
{
102+
"ename": "AttributeError",
103+
"evalue": "'list' object has no attribute 'schema'",
104+
"output_type": "error",
105+
"traceback": [
106+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
107+
"\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)",
108+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m Dataset(\u001b[43mTable\u001b[49m\u001b[43m(\u001b[49m\u001b[43md\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mtrain\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mtext\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m)\n",
109+
"\u001b[36mFile \u001b[39m\u001b[32m~/projects/SentiSynth/venv311/lib/python3.11/site-packages/datasets/table.py:167\u001b[39m, in \u001b[36mTable.__init__\u001b[39m\u001b[34m(self, table)\u001b[39m\n\u001b[32m 166\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, table: pa.Table):\n\u001b[32m--> \u001b[39m\u001b[32m167\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtable\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 168\u001b[39m \u001b[38;5;28mself\u001b[39m.table = table\n",
110+
"\u001b[36mFile \u001b[39m\u001b[32m~/projects/SentiSynth/venv311/lib/python3.11/site-packages/datasets/table.py:107\u001b[39m, in \u001b[36mIndexedTableMixin.__init__\u001b[39m\u001b[34m(self, table)\u001b[39m\n\u001b[32m 106\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, table: pa.Table):\n\u001b[32m--> \u001b[39m\u001b[32m107\u001b[39m \u001b[38;5;28mself\u001b[39m._schema: pa.Schema = \u001b[43mtable\u001b[49m\u001b[43m.\u001b[49m\u001b[43mschema\u001b[49m\n\u001b[32m 108\u001b[39m \u001b[38;5;28mself\u001b[39m._batches: \u001b[38;5;28mlist\u001b[39m[pa.RecordBatch] = [\n\u001b[32m 109\u001b[39m recordbatch \u001b[38;5;28;01mfor\u001b[39;00m recordbatch \u001b[38;5;129;01min\u001b[39;00m table.to_batches() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(recordbatch) > \u001b[32m0\u001b[39m\n\u001b[32m 110\u001b[39m ]\n\u001b[32m 111\u001b[39m \u001b[38;5;28mself\u001b[39m._offsets: np.ndarray = np.cumsum([\u001b[32m0\u001b[39m] + [\u001b[38;5;28mlen\u001b[39m(b) \u001b[38;5;28;01mfor\u001b[39;00m b \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._batches], dtype=np.int64)\n",
111+
"\u001b[31mAttributeError\u001b[39m: 'list' object has no attribute 'schema'"
112+
]
113+
}
114+
],
115+
"source": [
116+
"Dataset(Table(d[\"train\"][\"text\"]))"
117+
]
118+
}
119+
],
120+
"metadata": {
121+
"kernelspec": {
122+
"display_name": "Python3.11 (sentisynth)",
123+
"language": "python",
124+
"name": "auctionn"
125+
},
126+
"language_info": {
127+
"codemirror_mode": {
128+
"name": "ipython",
129+
"version": 3
130+
},
131+
"file_extension": ".py",
132+
"mimetype": "text/x-python",
133+
"name": "python",
134+
"nbconvert_exporter": "python",
135+
"pygments_lexer": "ipython3",
136+
"version": "3.11.2"
137+
}
138+
},
139+
"nbformat": 4,
140+
"nbformat_minor": 2
141+
}

src/cli/01_train_teacher.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,21 @@
44
import logging
55
from pathlib import Path
66

7-
import numpy as np
8-
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
97
import torch
108
from transformers import DataCollatorWithPadding, IntervalStrategy, TrainingArguments, Trainer
119

1210
from src.models import build_teacher
1311
from src.data import ClassificationDataModule
1412
from utils.wandb_setup import setup_wandb
15-
13+
from utils.metrics import compute_metrics
1614

1715
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
1816
logger = logging.getLogger(__name__)
1917

2018
app = typer.Typer()
2119

2220

23-
def compute_metrics(p):
24-
"""Computes metrics for HF Trainer."""
25-
preds = np.argmax(p.predictions, axis=1)
26-
labels = p.label_ids
27-
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary') # Assuming binary
28-
acc = accuracy_score(labels, preds)
29-
return {
30-
'accuracy': acc,
31-
'f1': f1,
32-
'precision': precision,
33-
'recall': recall
34-
}
21+
3522

3623

3724
@app.command()

0 commit comments

Comments
 (0)