@@ -556,12 +556,12 @@ def train_dataloader(self) -> DataLoader:
556556 # Compute sampler if balanced sampling is enabled
557557 sampler = None
558558 shuffle = True
559-
559+
560560 if self .use_balanced_sampler :
561561 print ("⚖️ Using WeightedRandomSampler for balanced class sampling" )
562562 sampler = self ._create_weighted_sampler ()
563563 shuffle = False # Can't use shuffle with sampler
564-
564+
565565 return DataLoader (
566566 self .train_dataset ,
567567 batch_size = self .batch_size ,
@@ -612,10 +612,10 @@ def test_dataloader(self) -> Optional[DataLoader]:
612612 def _create_weighted_sampler (self ) -> WeightedRandomSampler :
613613 """
614614 Create WeightedRandomSampler for balanced class sampling.
615-
615+
616616 Computes sample weights inversely proportional to class frequency,
617617 so rare classes are sampled more often and common classes less often.
618-
618+
619619 Returns:
620620 WeightedRandomSampler for training dataset
621621 """
@@ -640,29 +640,30 @@ def _create_weighted_sampler(self) -> WeightedRandomSampler:
640640
641641 # Count class frequencies
642642 class_counts = sample_labels .value_counts ().to_dict ()
643-
643+
644644 # Compute weight for each class (inverse frequency)
645645 num_samples = len (sample_labels )
646646 class_weights = {
647647 cls : num_samples / count for cls , count in class_counts .items ()
648648 }
649-
649+
650650 # Assign weight to each sample based on its class
651651 sample_weights = [class_weights [label ] for label in sample_labels ]
652652 sample_weights = torch .DoubleTensor (sample_weights )
653-
653+
654654 # Create sampler
655655 sampler = WeightedRandomSampler (
656656 weights = sample_weights ,
657657 num_samples = len (sample_weights ),
658- replacement = True # Sample with replacement to oversample rare classes
658+ replacement = True , # Sample with replacement to oversample rare classes
659659 )
660-
660+
661661 print (f" Created sampler for { len (sample_weights )} samples" )
662- print (f" Sample weight range: { sample_weights . min ():.3f } - { sample_weights . max ():.3f } " )
663-
664- return sampler
662+ print (
663+ f" Sample weight range: { sample_weights . min ():.3f } - { sample_weights . max ():.3f } "
664+ )
665665
666+ return sampler
666667
667668 def get_class_weights (self ) -> torch .Tensor :
668669 """
@@ -723,32 +724,32 @@ def get_class_weights(self) -> torch.Tensor:
723724 def _create_genus_label_mapping (self ) -> Dict [str , int ]:
724725 """
725726 Create genus-level label mapping from species names in the CSV.
726-
727+
727728 Extracts genus (first word) from species_name column.
728-
729+
729730 Returns:
730731 Dictionary mapping genus name to integer index
731732 """
732733 import warnings
733-
734+
734735 # Load CSV to extract species names
735736 df = pd .read_csv (self .csv_path )
736-
737+
737738 # Apply any filters that were specified
738739 if self .dataset_params .get ("species_filter" ):
739740 df = df [df ["species" ].isin (self .dataset_params ["species_filter" ])]
740741 if self .dataset_params .get ("site_filter" ):
741742 df = df [df ["site" ].isin (self .dataset_params ["site_filter" ])]
742743 if self .dataset_params .get ("year_filter" ):
743744 df = df [df ["year" ].isin (self .dataset_params ["year_filter" ])]
744-
745+
745746 # Extract genus from species_name (first word)
746747 df ["genus" ] = df ["species_name" ].apply (lambda x : str (x ).split ()[0 ])
747-
748+
748749 # Get unique genera and create mapping
749750 unique_genera = sorted (df ["genus" ].unique ())
750751 label_to_idx = {genus : idx for idx , genus in enumerate (unique_genera )}
751-
752+
752753 # Validate genus names and warn about edge cases
753754 non_alpha_genera = [g for g in unique_genera if not g .isalpha ()]
754755 if non_alpha_genera :
@@ -758,7 +759,7 @@ def _create_genus_label_mapping(self) -> Dict[str, int]:
758759 f"Run 'python processing/misc/inspect_labels.py' to review. "
759760 f"To exclude, use: species_filter=[...]"
760761 )
761-
762+
762763 # Check for known family names
763764 known_families = {"Pinaceae" , "Rosaceae" , "Fabaceae" , "Asteraceae" }
764765 found_families = set (unique_genera ) & known_families
@@ -769,7 +770,7 @@ def _create_genus_label_mapping(self) -> Dict[str, int]:
769770 f"These represent unidentified species within that family. "
770771 f"See docs/taxonomic_levels.md for more information."
771772 )
772-
773+
773774 return label_to_idx
774775
775776 def get_dataset_info (self ) -> Dict [str , Any ]:
0 commit comments