Skip to content

Commit 0229aed

Browse files
committed
fix: handle hf dataset tensor warnings
1 parent 6310613 commit 0229aed

5 files changed

Lines changed: 62 additions & 66 deletions

src/scratch/datasets/causal_langauge_modeling_dataset.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,18 @@ def load_hf_dataset(
9090
)
9191

9292
if shuffle:
93-
data = data.shuffle().with_format("torch")
93+
data = data.shuffle()
9494

9595
if validate:
96-
data = data.filter(validate).with_format("torch")
96+
data = data.filter(validate)
9797

9898
def tokenize_function(examples):
9999
return tokenizer(examples["text"], padding="max_length", truncation=True)
100100

101-
data = data.map(tokenize_function, batched=True).with_format("torch")
101+
data = data.map(tokenize_function, batched=True)
102102

103103
if prepare:
104-
data = data.map(prepare).with_format("torch")
104+
data = data.map(prepare)
105105

106106
return data.with_format("torch")
107107

@@ -162,9 +162,9 @@ def transform(batch: CausalLanguageModelingBatch):
162162
batch["attention_mask"],
163163
batch["labels"],
164164
)
165-
input_ids = torch.tensor(input_ids, dtype=torch.int64)
166-
attention_mask = torch.tensor(attention_mask, dtype=torch.int64)
167-
labels = torch.tensor(labels, dtype=torch.int64)
165+
input_ids = torch.as_tensor(input_ids, dtype=torch.int64)
166+
attention_mask = torch.as_tensor(attention_mask, dtype=torch.int64)
167+
labels = torch.as_tensor(labels, dtype=torch.int64)
168168
return CausalLanguageModelingBatch(
169169
input_ids=input_ids, attention_mask=attention_mask, labels=labels
170170
)
@@ -196,12 +196,12 @@ def wikitext2_dataset(
196196
tokenizer = load_tokenizer(tokenizer_name, max_length=max_length)
197197

198198
def prepare(sample):
199-
input_ids = sample["input_ids"]
200-
input_ids = torch.tensor(input_ids, dtype=torch.int64)
201-
labels = input_ids.clone()
199+
input_ids = np.array(sample["input_ids"], dtype=np.int64)
200+
labels = input_ids.copy()
202201
# Make a lower triangular attention mask
203-
attention_mask = np.tril(np.ones((len(input_ids), len(input_ids))))
204-
attention_mask = torch.tensor(attention_mask, dtype=torch.int64)
202+
attention_mask = np.tril(
203+
np.ones((len(input_ids), len(input_ids)), dtype=np.int64)
204+
)
205205
sample["input_ids"], sample["attention_mask"], sample["labels"] = (
206206
input_ids,
207207
attention_mask,

src/scratch/datasets/image_classification_dataset.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,20 @@ def load_hf_dataset(
9494
The IterableDataset object
9595
"""
9696
data = load_dataset(
97-
dataset_name, split=dataset_split, trust_remote_code=True, streaming=True
98-
).with_format("torch")
97+
dataset_name,
98+
split=dataset_split,
99+
trust_remote_code=True,
100+
streaming=True,
101+
)
99102

100103
if shuffle:
101-
data = data.shuffle().with_format("torch")
104+
data = data.shuffle()
102105

103106
if validate:
104-
data = data.filter(validate).with_format("torch")
107+
data = data.filter(validate)
105108

106109
if prepare:
107-
data = data.map(prepare).with_format("torch")
110+
data = data.map(prepare)
108111

109112
return data.with_format("torch")
110113

@@ -172,15 +175,13 @@ def mnist_dataset(batch_size=32, shuffle=True):
172175

173176
def prepare(sample):
174177
images, labels = sample["image"], sample["label"]
175-
# Ensure the images are float tensors
176-
images = images.to(torch.float32)
177-
# Normalize the images
178-
images = images / 255.0
179-
# Convert labels to one-hot encoding
180-
labels = labels.to(torch.int64) # Ensure labels are int32 tensors
181-
labels = F.one_hot(labels, num_classes=10).to(torch.int32)
182-
183-
sample["image"], sample["label"] = images, labels
178+
images = transforms.ToTensor()(images).to(torch.float32)
179+
labels = F.one_hot(
180+
torch.as_tensor(labels, dtype=torch.int64),
181+
num_classes=10,
182+
).to(torch.int32)
183+
184+
sample["image"], sample["label"] = images.numpy(), labels.numpy()
184185
return sample
185186

186187
train_data, test_data = (
@@ -219,20 +220,21 @@ def tiny_imagenet_dataset(batch_size=32, shuffle=True):
219220

220221
def prepare(sample):
221222
images, labels = sample["image"], sample["label"]
222-
# Ensure the images are float tensors
223-
images = images.clone().detach().to(torch.float32)
224-
# Normalize the images
225-
images = images / 255.0
226-
# Convert labels to one-hot encoding
227-
labels = labels.clone().detach().to(torch.int64) # Ensure labels are int32
228-
labels = F.one_hot(labels, num_classes=200).to(torch.int32)
229-
230-
sample["image"], sample["label"] = images, labels
223+
images = transforms.ToTensor()(images).to(torch.float32)
224+
labels = F.one_hot(
225+
torch.as_tensor(labels, dtype=torch.int64),
226+
num_classes=200,
227+
).to(torch.int32)
228+
229+
sample["image"], sample["label"] = images.numpy(), labels.numpy()
231230
return sample
232231

233232
def validate(sample):
234-
transform = transforms.ToTensor()
235-
img = transform(sample["image"])
233+
img = (
234+
sample["image"]
235+
if isinstance(sample["image"], torch.Tensor)
236+
else transforms.ToTensor()(sample["image"])
237+
)
236238
return (
237239
img.shape == (3, 64, 64)
238240
and torch.isnan(img).sum() == 0

src/scratch/datasets/question_answering_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ def transform(batch):
8888
batch["start_positions"],
8989
batch["end_positions"],
9090
)
91-
input_ids = torch.tensor(input_ids, dtype=torch.long)
92-
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
93-
start_positions = torch.tensor(start_positions, dtype=torch.long)
94-
end_positions = torch.tensor(end_positions, dtype=torch.long)
91+
input_ids = torch.as_tensor(input_ids, dtype=torch.long)
92+
attention_mask = torch.as_tensor(attention_mask, dtype=torch.long)
93+
start_positions = torch.as_tensor(start_positions, dtype=torch.long)
94+
end_positions = torch.as_tensor(end_positions, dtype=torch.long)
9595
return QuestionAnsweringBatch(
9696
input_ids=input_ids,
9797
attention_mask=attention_mask,

src/scratch/datasets/sequence_classification_dataset.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,25 @@ def load_hf_dataset(
8181
The IterableDataset object
8282
"""
8383
data = load_dataset(
84-
dataset_name, split=dataset_split, trust_remote_code=True, streaming=True
84+
dataset_name,
85+
split=dataset_split,
86+
trust_remote_code=True,
87+
streaming=True,
8588
)
8689

8790
if shuffle:
88-
data = data.shuffle().with_format("torch")
91+
data = data.shuffle()
8992

9093
if validate:
91-
data = data.filter(validate).with_format("torch")
94+
data = data.filter(validate)
9295

9396
def tokenize_function(examples):
9497
return tokenizer(examples["text"], padding="max_length", truncation=True)
9598

96-
data = data.map(tokenize_function, batched=True).with_format("torch")
99+
data = data.map(tokenize_function, batched=True)
97100

98101
if prepare:
99-
data = data.map(prepare).with_format("torch")
102+
data = data.map(prepare)
100103

101104
return data.with_format("torch")
102105

@@ -157,8 +160,8 @@ def transform(batch: SequenceClassificationBatch):
157160
batch["input_ids"],
158161
batch["label"],
159162
)
160-
input_ids = torch.tensor(input_ids, dtype=torch.int64)
161-
label = torch.tensor(label, dtype=torch.int64)
163+
input_ids = torch.as_tensor(input_ids, dtype=torch.int64)
164+
label = torch.as_tensor(label, dtype=torch.int64)
162165
label = F.one_hot(label, num_classes=num_classes).to(torch.int32)
163166
return SequenceClassificationBatch(
164167
input_ids=input_ids,
@@ -192,21 +195,12 @@ def imdb_dataset(
192195
tokenizer = load_tokenizer(tokenizer_name, max_length=max_length)
193196

194197
def prepare(sample):
195-
input_ids, labels = (
196-
sample["input_ids"],
197-
sample["label"],
198-
)
199-
input_ids = torch.tensor(input_ids, dtype=torch.int64)
200-
labels = torch.tensor(labels, dtype=torch.int64)
201-
labels = F.one_hot(labels, num_classes=2).to(torch.int32)
202-
203-
(
204-
sample["input_ids"],
205-
sample["label"],
206-
) = (
207-
input_ids,
208-
labels,
209-
)
198+
input_ids, labels = sample["input_ids"], sample["label"]
199+
input_ids = np.array(input_ids, dtype=np.int64)
200+
labels_tensor = torch.as_tensor(labels, dtype=torch.int64)
201+
labels = F.one_hot(labels_tensor, num_classes=2).to(torch.int32).numpy()
202+
203+
sample["input_ids"], sample["label"] = input_ids, labels
210204
return sample
211205

212206
train_data, test_data = (

src/scratch/datasets/token_classification_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def transform(batch):
9191
batch["attention_mask"],
9292
batch["labels"],
9393
)
94-
input_ids = torch.tensor(input_ids, dtype=torch.long)
95-
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
96-
labels = torch.tensor(labels, dtype=torch.long)
94+
input_ids = torch.as_tensor(input_ids, dtype=torch.long)
95+
attention_mask = torch.as_tensor(attention_mask, dtype=torch.long)
96+
labels = torch.as_tensor(labels, dtype=torch.long)
9797
return TokenClassificationBatch(
9898
input_ids=input_ids, attention_mask=attention_mask, labels=labels
9999
)

0 commit comments

Comments
 (0)