@@ -316,17 +316,13 @@ def train_ensemble(
316316 when early stopping. Defaults to None.
317317 verbose (bool, optional): Whether to show progress bars for each epoch.
318318 """
319- if isinstance (train_set , Subset ):
320- train_set = train_set .dataset
321- if isinstance (val_set , Subset ):
322- val_set = val_set .dataset
323-
324319 train_loader = DataLoader (train_set , ** data_params )
325320 print (f"Training on { len (train_set ):,} samples" )
326321
327322 if val_set is not None :
328323 data_params .update ({"batch_size" : 16 * data_params ["batch_size" ]})
329324 val_loader = DataLoader (val_set , ** data_params )
325+ print (f"Validating on { len (val_set ):,} samples" )
330326 else :
331327 val_loader = None
332328
@@ -354,7 +350,13 @@ def train_ensemble(
354350
355351 for target , normalizer in normalizer_dict .items ():
356352 if normalizer is not None :
357- sample_target = Tensor (train_set .df [target ].values )
353+ if isinstance (train_set , Subset ):
354+ sample_target = Tensor (
355+ train_set .dataset .df [target ].iloc [train_set .indices ].values
356+ )
357+ else :
358+ sample_target = Tensor (train_set .df [target ].values )
359+
358360 if not restart_params ["resume" ]:
359361 normalizer .fit (sample_target )
360362 print (f"Dummy MAE: { (sample_target - normalizer .mean ).abs ().mean ():.4f} " )
@@ -455,10 +457,6 @@ def results_multitask(
455457 "------------Evaluate model on Test Set------------\n "
456458 "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n "
457459 )
458-
459- if isinstance (test_set , Subset ):
460- test_set = test_set .dataset
461-
462460 test_loader = DataLoader (test_set , ** data_params )
463461 print (f"Testing on { len (test_set ):,} samples" )
464462
0 commit comments