-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
110 lines (86 loc) · 3.46 KB
/
data.py
File metadata and controls
110 lines (86 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import annotations
from typing import Iterator
import torch
from torch import nn
import torch.nn.functional as F
import random
import glob
import pickle
class CharTokenizer:
def __init__(self):
self.symbols = ["<PAD>"]
self.tokens = set()
self.vocab = list(self.symbols)
self.stoi = {s:i for i, s in enumerate(self.vocab)}
def pad_id(self): return self.stoi["<PAD>"]
def get_id(self, tok: str): return self.stoi[tok]
def vocab_size(self): return len(self.vocab)
def train(self, sequences: list[str]) -> None:
for seq in sequences:
for symbol in self._tokenize_to_symbols(seq):
self.tokens.add(symbol)
self.vocab = list(self.symbols) + list(sorted(self.tokens))
self.stoi = {s:i for i, s in enumerate(self.vocab)}
def _tokenize_to_symbols(self, text: str) -> list[str]:
return list(text)
def tokenize(self, text: str) -> list[int]:
seq: list[str] = self._tokenize_to_symbols(text)
return [self.stoi[s] for s in seq]
def detokenize(self, tokens: list[int], keep_symbols=True) -> str:
strs: list[str] = [self.vocab[t] for t in tokens]
if not keep_symbols:
strs = [s for s in strs if len(s) == 1]
return "".join(strs)
def save(self, path: str) -> None:
# TODO: save it.
with open(path, 'wb') as outfile:
pickle.dump(self, outfile)
@staticmethod
def load(path: str) -> CharTokenizer:
with open(path, "rb") as infile:
tokenizer = pickle.load(infile)
return tokenizer
class RandomOrderDataIterator:
def __init__(self, data: list[list[int]], desired_length: int):
self.desired_length = desired_length
self.data: list[list[int]] = [seq for seq in data if len(seq) > self.desired_length]
def __iter__(self):
if len(self.data) == 0: return
while True:
seq = random.choice(self.data)
idx = random.randint(0, len(seq) - self.desired_length)
yield seq[idx:idx + self.desired_length]
# This both creates the tokenizer and uses it to tokenize the data.
# In a real system you'd like to split it to two separate functions.
# Feel free to separate it to two functions also in this code.
def load_data(path: str) -> [CharTokenizer, list[list[int]]]:
tokenizer = CharTokenizer()
for fname in glob.glob(f"{path}/*.txt"):
with open(fname) as fh:
text = fh.read()
tokenizer.train(text)
data: list[list[int]] = []
for fname in glob.glob(f"{path}/*.txt"):
with open(fname) as fh:
text = fh.read()
data.append(tokenizer.tokenize(text))
return tokenizer, data
def batch_items(data_iter: Iterator[list[int]], batch_size: int = 2) -> Iterator[torch.LongTensor]:
batch = []
for seq in data_iter:
idx = 0
batch.append(seq)
if len(batch) >= batch_size:
yield torch.tensor(batch, dtype=torch.long)
batch = []
if len(batch) > 0:
yield torch.tensor(batch, dtype=torch.long)
class DataFeeder:
def __init__(self, seq_len, data_path):
self.seq_len = seq_len
self.data_path = data_path
self.tokenizer, self.tokenized_data = load_data(self.data_path)
self.data_iter = iter(RandomOrderDataIterator
(self.tokenized_data, self.seq_len))
def get_data_iter(self):
return self.data_iter