|
1 | 1 | import logging |
2 | 2 |
|
| 3 | +from accelerate import DistributedType |
3 | 4 | from datasets import load_dataset |
4 | 5 | from torch.utils.data import DataLoader |
5 | | -from transformers import default_data_collator |
| 6 | +from transformers import DataCollatorWithPadding, default_data_collator |
6 | 7 |
|
7 | 8 |
|
8 | 9 | logger = logging.getLogger(__name__) |
@@ -94,7 +95,7 @@ def get_dataloaders(tokenizer, args): |
94 | 95 | text_column_name = "text" if "text" in column_names else column_names[0] |
95 | 96 |
|
96 | 97 | def tokenize_function(examples): |
97 | | - return tokenizer(examples[text_column_name]) |
| 98 | + return tokenizer(examples[text_column_name], truncation=True, max_length=args.max_seq_len) |
98 | 99 |
|
99 | 100 | tokenized_datasets = raw_datasets.map( |
100 | 101 | tokenize_function, |
@@ -157,15 +158,18 @@ def group_texts(examples): |
157 | 158 | val_dataset = lm_datasets["validation"] |
158 | 159 |
|
159 | 160 | # DataLoaders creation: |
| 161 | + data_collator = default_data_collator |
| 162 | + if args.distributed_type == DistributedType.TPU: |
| 163 | + data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=args.max_seq_len) |
160 | 164 | train_dataloader = DataLoader( |
161 | 165 | train_dataset, |
162 | 166 | shuffle=True, |
163 | | - collate_fn=default_data_collator, |
| 167 | + collate_fn=data_collator, |
164 | 168 | batch_size=args.per_device_train_batch_size, |
165 | 169 | ) |
166 | 170 | val_dataloader1 = DataLoader( |
167 | 171 | val_dataset, |
168 | | - collate_fn=default_data_collator, |
| 172 | + collate_fn=data_collator, |
169 | 173 | batch_size=args.per_device_eval_batch_size, |
170 | 174 | ) |
171 | 175 | return train_dataloader, {"val1": val_dataloader1} |
0 commit comments