-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
110 lines (92 loc) · 3.4 KB
/
main.py
File metadata and controls
110 lines (92 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import annotations
import os
import datetime
import torch
from torch import optim
from transformer import TransformerLM
import data
import lm
def main():
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print(f"Using {device} device")
seq_len = 128
batch_size = 64
data_path = "data" + os.sep
n_layers = 6
n_heads = 6
embed_size = 192
mlp_hidden_size = embed_size * 4
learning_rate = 5e-4
gradient_clipping = 1.0
num_batches_to_train = 50000
tokenizer, tokenized_data = data.load_data(data_path)
# NOTE: are data items are longer by one than the sequence length,
# They will be shortened by 1 when converted to training examples.
data_iter = iter(data.RandomOrderDataIterator(tokenized_data, seq_len + 1))
model: torch.nn.Module = TransformerLM(
n_layers,
n_heads,
embed_size,
seq_len,
tokenizer.vocab_size(),
mlp_hidden_size,
with_residuals=True,
).to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95))
model.train()
num_batches = 0
training_time = datetime.datetime.now()
training_time = (str(training_time)[:19].
replace(":", "-").
replace(" ", "_"))
for batch in data.batch_items(data_iter, batch_size):
try:
if num_batches >= num_batches_to_train:
break
num_batches = num_batches + 1
batch_x, batch_y = lm.batch_to_labeled_samples(batch)
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
logits = model(batch_x)
loss = lm.compute_loss(logits, batch_y)
# parameters update
model.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
optimizer.step()
num_batches += 1
if num_batches % 10 == 0:
print(f"Seen {num_batches} batches. last loss is: {loss.item()}")
if num_batches % 100 == 0:
for _ in range(1):
model.eval()
sampled = tokenizer.detokenize(
model.better_sample_continuation(
prefix=tokenizer.tokenize("Hello"),
max_tokens_to_generate=500,
temperature=10,
topK=5))
model.train()
print(f"Model sample: '''{sampled}'''")
print("")
if num_batches % 500 == 0:
torch.save(model.state_dict(),
"model " + training_time +
"-batch-" + str(num_batches) + ".pth")
except KeyboardInterrupt:
torch.save(model.state_dict(),
"model " + training_time +
"-batch-" + str(num_batches) + ".pth")
print("Interrupted by user -- current weights were saved on batch", num_batches)
break
torch.save(model.state_dict(),
"model " + training_time +
"-batch-" + str(num_batches) + ".pth")
if __name__ == '__main__':
main()