Skip to content

Commit 3a03eb2

Browse files
committed
fix: add torchvision dependency and apply black formatting
1 parent 07a655b commit 3a03eb2

15 files changed

Lines changed: 685 additions & 668 deletions

examples/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,15 @@ def main():
243243
action="store_true",
244244
help="Use WeightedRandomSampler for balanced class sampling (recommended for imbalanced datasets)",
245245
)
246-
246+
247247
# Image size arguments
248248
parser.add_argument(
249249
"--rgb_size",
250250
type=int,
251251
default=224,
252252
help="RGB image size (single value for square images, e.g., 224 for 224x224). Default matches ImageNet pretraining.",
253253
)
254-
254+
255255
# Normalization arguments
256256
parser.add_argument(
257257
"--rgb_norm_method",
@@ -412,7 +412,7 @@ def main():
412412
model_kwargs = {}
413413
if args.model_variant is not None:
414414
model_kwargs["model_variant"] = args.model_variant
415-
415+
416416
classifier = RGBClassifier(
417417
model_type=args.model_type,
418418
num_classes=args.num_classes,

neon_tree_classification/core/datamodule.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,12 @@ def train_dataloader(self) -> DataLoader:
556556
# Compute sampler if balanced sampling is enabled
557557
sampler = None
558558
shuffle = True
559-
559+
560560
if self.use_balanced_sampler:
561561
print("⚖️ Using WeightedRandomSampler for balanced class sampling")
562562
sampler = self._create_weighted_sampler()
563563
shuffle = False # Can't use shuffle with sampler
564-
564+
565565
return DataLoader(
566566
self.train_dataset,
567567
batch_size=self.batch_size,
@@ -612,10 +612,10 @@ def test_dataloader(self) -> Optional[DataLoader]:
612612
def _create_weighted_sampler(self) -> WeightedRandomSampler:
613613
"""
614614
Create WeightedRandomSampler for balanced class sampling.
615-
615+
616616
Computes sample weights inversely proportional to class frequency,
617617
so rare classes are sampled more often and common classes less often.
618-
618+
619619
Returns:
620620
WeightedRandomSampler for training dataset
621621
"""
@@ -640,29 +640,30 @@ def _create_weighted_sampler(self) -> WeightedRandomSampler:
640640

641641
# Count class frequencies
642642
class_counts = sample_labels.value_counts().to_dict()
643-
643+
644644
# Compute weight for each class (inverse frequency)
645645
num_samples = len(sample_labels)
646646
class_weights = {
647647
cls: num_samples / count for cls, count in class_counts.items()
648648
}
649-
649+
650650
# Assign weight to each sample based on its class
651651
sample_weights = [class_weights[label] for label in sample_labels]
652652
sample_weights = torch.DoubleTensor(sample_weights)
653-
653+
654654
# Create sampler
655655
sampler = WeightedRandomSampler(
656656
weights=sample_weights,
657657
num_samples=len(sample_weights),
658-
replacement=True # Sample with replacement to oversample rare classes
658+
replacement=True, # Sample with replacement to oversample rare classes
659659
)
660-
660+
661661
print(f" Created sampler for {len(sample_weights)} samples")
662-
print(f" Sample weight range: {sample_weights.min():.3f} - {sample_weights.max():.3f}")
663-
664-
return sampler
662+
print(
663+
f" Sample weight range: {sample_weights.min():.3f} - {sample_weights.max():.3f}"
664+
)
665665

666+
return sampler
666667

667668
def get_class_weights(self) -> torch.Tensor:
668669
"""
@@ -723,32 +724,32 @@ def get_class_weights(self) -> torch.Tensor:
723724
def _create_genus_label_mapping(self) -> Dict[str, int]:
724725
"""
725726
Create genus-level label mapping from species names in the CSV.
726-
727+
727728
Extracts genus (first word) from species_name column.
728-
729+
729730
Returns:
730731
Dictionary mapping genus name to integer index
731732
"""
732733
import warnings
733-
734+
734735
# Load CSV to extract species names
735736
df = pd.read_csv(self.csv_path)
736-
737+
737738
# Apply any filters that were specified
738739
if self.dataset_params.get("species_filter"):
739740
df = df[df["species"].isin(self.dataset_params["species_filter"])]
740741
if self.dataset_params.get("site_filter"):
741742
df = df[df["site"].isin(self.dataset_params["site_filter"])]
742743
if self.dataset_params.get("year_filter"):
743744
df = df[df["year"].isin(self.dataset_params["year_filter"])]
744-
745+
745746
# Extract genus from species_name (first word)
746747
df["genus"] = df["species_name"].apply(lambda x: str(x).split()[0])
747-
748+
748749
# Get unique genera and create mapping
749750
unique_genera = sorted(df["genus"].unique())
750751
label_to_idx = {genus: idx for idx, genus in enumerate(unique_genera)}
751-
752+
752753
# Validate genus names and warn about edge cases
753754
non_alpha_genera = [g for g in unique_genera if not g.isalpha()]
754755
if non_alpha_genera:
@@ -758,7 +759,7 @@ def _create_genus_label_mapping(self) -> Dict[str, int]:
758759
f"Run 'python processing/misc/inspect_labels.py' to review. "
759760
f"To exclude, use: species_filter=[...]"
760761
)
761-
762+
762763
# Check for known family names
763764
known_families = {"Pinaceae", "Rosaceae", "Fabaceae", "Asteraceae"}
764765
found_families = set(unique_genera) & known_families
@@ -769,7 +770,7 @@ def _create_genus_label_mapping(self) -> Dict[str, int]:
769770
f"These represent unidentified species within that family. "
770771
f"See docs/taxonomic_levels.md for more information."
771772
)
772-
773+
773774
return label_to_idx
774775

775776
def get_dataset_info(self) -> Dict[str, Any]:

neon_tree_classification/core/dataset.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,26 +250,32 @@ def _validate_species_consistency(self) -> None:
250250
# If the first mapping key is a species code (all uppercase, short), it's species-level
251251
# If it's a genus name (capitalized, longer), it's genus-level
252252
sample_label = next(iter(mapping_labels)) if mapping_labels else ""
253-
is_genus_mapping = sample_label and sample_label[0].isupper() and sample_label[1:].islower()
254-
253+
is_genus_mapping = (
254+
sample_label and sample_label[0].isupper() and sample_label[1:].islower()
255+
)
256+
255257
if is_genus_mapping:
256258
# Genus-level mapping: validate that all species have extractable genus
257259
if "species_name" not in self.data.columns:
258260
raise ValueError(
259261
"Genus-level mapping detected but 'species_name' column not found in data. "
260262
"Cannot extract genus from species names."
261263
)
262-
264+
263265
# Extract genera from species names and check they're all in mapping
264-
data_genera = set(self.data["species_name"].apply(lambda x: str(x).split()[0]).unique())
266+
data_genera = set(
267+
self.data["species_name"].apply(lambda x: str(x).split()[0]).unique()
268+
)
265269
missing_genera = data_genera - mapping_labels
266270
if missing_genera:
267271
raise ValueError(
268272
f"Genera extracted from dataset not found in external label mapping: {sorted(missing_genera)}. "
269273
f"External mapping has: {sorted(mapping_labels)}"
270274
)
271-
272-
print(f"✓ Genus-level validation passed: All {len(data_genera)} genera found in mapping")
275+
276+
print(
277+
f"✓ Genus-level validation passed: All {len(data_genera)} genera found in mapping"
278+
)
273279
else:
274280
# Species-level mapping: check species codes
275281
missing_in_mapping = data_species - mapping_labels

neon_tree_classification/inference/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
from .utils import load_label_mapping, format_predictions
2525

2626
__all__ = [
27-
'TreeClassifier',
28-
'preprocess_image',
29-
'prepare_tensor',
30-
'load_label_mapping',
31-
'format_predictions',
27+
"TreeClassifier",
28+
"preprocess_image",
29+
"prepare_tensor",
30+
"load_label_mapping",
31+
"format_predictions",
3232
]
3333

34-
__version__ = '1.0.0'
34+
__version__ = "1.0.0"

0 commit comments

Comments
 (0)