@@ -94,17 +94,20 @@ def load_hf_dataset(
9494 The IterableDataset object
9595 """
9696 data = load_dataset (
97- dataset_name , split = dataset_split , trust_remote_code = True , streaming = True
98- ).with_format ("torch" )
97+ dataset_name ,
98+ split = dataset_split ,
99+ trust_remote_code = True ,
100+ streaming = True ,
101+ )
99102
100103 if shuffle :
101- data = data .shuffle (). with_format ( "torch" )
104+ data = data .shuffle ()
102105
103106 if validate :
104- data = data .filter (validate ). with_format ( "torch" )
107+ data = data .filter (validate )
105108
106109 if prepare :
107- data = data .map (prepare ). with_format ( "torch" )
110+ data = data .map (prepare )
108111
109112 return data .with_format ("torch" )
110113
@@ -172,15 +175,13 @@ def mnist_dataset(batch_size=32, shuffle=True):
172175
173176 def prepare (sample ):
174177 images , labels = sample ["image" ], sample ["label" ]
175- # Ensure the images are float tensors
176- images = images .to (torch .float32 )
177- # Normalize the images
178- images = images / 255.0
179- # Convert labels to one-hot encoding
180- labels = labels .to (torch .int64 ) # Ensure labels are int32 tensors
181- labels = F .one_hot (labels , num_classes = 10 ).to (torch .int32 )
182-
183- sample ["image" ], sample ["label" ] = images , labels
178+ images = transforms .ToTensor ()(images ).to (torch .float32 )
179+ labels = F .one_hot (
180+ torch .as_tensor (labels , dtype = torch .int64 ),
181+ num_classes = 10 ,
182+ ).to (torch .int32 )
183+
184+ sample ["image" ], sample ["label" ] = images .numpy (), labels .numpy ()
184185 return sample
185186
186187 train_data , test_data = (
@@ -219,20 +220,21 @@ def tiny_imagenet_dataset(batch_size=32, shuffle=True):
219220
220221 def prepare (sample ):
221222 images , labels = sample ["image" ], sample ["label" ]
222- # Ensure the images are float tensors
223- images = images .clone ().detach ().to (torch .float32 )
224- # Normalize the images
225- images = images / 255.0
226- # Convert labels to one-hot encoding
227- labels = labels .clone ().detach ().to (torch .int64 ) # Ensure labels are int32
228- labels = F .one_hot (labels , num_classes = 200 ).to (torch .int32 )
229-
230- sample ["image" ], sample ["label" ] = images , labels
223+ images = transforms .ToTensor ()(images ).to (torch .float32 )
224+ labels = F .one_hot (
225+ torch .as_tensor (labels , dtype = torch .int64 ),
226+ num_classes = 200 ,
227+ ).to (torch .int32 )
228+
229+ sample ["image" ], sample ["label" ] = images .numpy (), labels .numpy ()
231230 return sample
232231
233232 def validate (sample ):
234- transform = transforms .ToTensor ()
235- img = transform (sample ["image" ])
233+ img = (
234+ sample ["image" ]
235+ if isinstance (sample ["image" ], torch .Tensor )
236+ else transforms .ToTensor ()(sample ["image" ])
237+ )
236238 return (
237239 img .shape == (3 , 64 , 64 )
238240 and torch .isnan (img ).sum () == 0
0 commit comments