Skip to content

Commit 31e457f

Browse files
committed
fix: address code review issues from Copilot
- Remove sys.path manipulation from predictor.py (use package imports directly) - Remove unused OrderedDict import in create_label_mappings.py - Update preprocessing defaults to 224x224 and imagenet normalization - Update preprocess_image_batch and resize_image defaults to match - Fix normalize_rgb docstring to accurately describe both normalization modes - Update model_registry input_size to 224x224 and add norm_method field - Make TreeClassifier norm_method configurable (default: imagenet) - Fix predictor to use self.norm_method instead of hardcoded '0_1' - Update from_checkpoint to use (224, 224) and imagenet defaults - Add rgb_norm_method param to RGBClassifier; normalize() now reflects it - Validate numeric_to_label_dict in upload_to_huggingface.py - Fix docs: species_filter is inclusion filter, not exclusion - Fix docs: add --csv_path to inspect_labels.py example commands - Fix warning message: clarify species_filter is an inclusion filter
1 parent 73d29cb commit 31e457f

8 files changed

Lines changed: 66 additions & 42 deletions

File tree

docs/taxonomic_levels.md

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ datamodule = NeonCrownDataModule(
6363
### Step 1: Run Label Inspection
6464

6565
```bash
66-
python processing/misc/inspect_labels.py
66+
python processing/misc/inspect_labels.py --csv_path path/to/your/labels.csv
6767
```
6868

6969
This will show:
@@ -103,20 +103,22 @@ Pinus 6,600 samples 19 species (Pines)
103103
If you want taxonomically pure genus-level training:
104104

105105
```python
106-
# Option A: Filter specific species codes
106+
# Option A: Include only specific species codes (all others are excluded)
107107
datamodule = NeonCrownDataModule(
108108
...,
109109
taxonomic_level="genus",
110-
species_filter=["PINACE"], # Exclude Pinaceae (will filter BEFORE genus extraction)
110+
species_filter=["PSMEM", "TSHE"], # Include only these USDA codes
111111
)
112112

113-
# Option B: Filter after inspecting
114-
# See inspect_labels.py output for USDA codes to exclude
115-
species_to_exclude = ["PINACE", "2PLANT", "2PLANT-S"] # Example
113+
# Option B: Build an inclusion list after inspecting
114+
# See inspect_labels.py output for USDA codes present in your data
115+
# species_filter keeps only rows WHERE species IS IN the list
116+
all_codes = [...] # full list from inspect_labels.py
117+
species_to_include = [c for c in all_codes if c not in ["PINACE", "2PLANT", "2PLANT-S"]]
116118
datamodule = NeonCrownDataModule(
117119
...,
118120
taxonomic_level="genus",
119-
species_filter=species_to_exclude,
121+
species_filter=species_to_include,
120122
)
121123
```
122124

@@ -175,13 +177,16 @@ trainer.fit(model, datamodule)
175177
### With Filtering
176178

177179
```python
178-
# Clean genus-level training (exclude edge cases)
180+
# Clean genus-level training (include only true genera, omit edge cases)
181+
# species_filter keeps only rows where species code is in the list
182+
all_codes = [...] # get from inspect_labels.py output
183+
clean_codes = [c for c in all_codes if c not in ["PINACE"]] # drop Pinaceae
179184
datamodule = NeonCrownDataModule(
180185
csv_path="data/metadata/combined_dataset.csv",
181186
hdf5_path="data/combined_dataset.h5",
182187
modalities=["rgb"],
183188
taxonomic_level="genus",
184-
species_filter=["PINACE"], # Exclude Pinaceae family
189+
species_filter=clean_codes, # include all except Pinaceae
185190
batch_size=64,
186191
)
187192
# Now training on 59 true genera only
@@ -262,7 +267,7 @@ These represent unidentified species within that family.
262267
See docs/taxonomic_levels.md for more information.
263268
```
264269

265-
**These are informational** - training will proceed normally. Filter if desired using `species_filter`.
270+
**These are informational** - training will proceed normally. To exclude them, build an inclusion list with all other codes and pass it to `species_filter` (which keeps only species in the list).
266271

267272
## FAQ
268273

@@ -277,7 +282,7 @@ See docs/taxonomic_levels.md for more information.
277282
**Q: What about Pinaceae?**
278283
- It's a family name, not genus, but only 26 samples (0.05%)
279284
- Keep it (recommended): Represents "unidentified conifer" class
280-
- Filter it: Use `species_filter=["PINACE"]` if you need taxonomic purity
285+
- Exclude it: Build an inclusion list of all codes except `"PINACE"` and pass to `species_filter`
281286

282287
**Q: How do I know how many classes I have?**
283288
```python
@@ -304,7 +309,7 @@ Expected accuracy ranges on NEON combined dataset (RGB only, ResNet50):
304309

305310
## Additional Resources
306311

307-
- **Data inspection**: `python processing/misc/inspect_labels.py`
312+
- **Data inspection**: `python processing/misc/inspect_labels.py --csv_path path/to/labels.csv`
308313
- **Training examples**: `examples/train.py`
309314
- **Model architectures**: `docs/training.md`
310315
- **Data processing**: `docs/processing.md`

neon_tree_classification/core/datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,8 +756,8 @@ def _create_genus_label_mapping(self) -> Dict[str, int]:
756756
warnings.warn(
757757
f"Found non-alphabetic genus names: {non_alpha_genera}. "
758758
f"These may be unidentified species or family names. "
759-
f"Run 'python processing/misc/inspect_labels.py' to review. "
760-
f"To exclude, use: species_filter=[...]"
759+
f"Run 'python processing/misc/inspect_labels.py --csv_path <path>' to review. "
760+
f"To include only specific species, use: species_filter=[...]"
761761
)
762762

763763
# Check for known family names

neon_tree_classification/inference/model_registry.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
"num_classes": 167,
1717
"architecture": "resnet",
1818
"modality": "rgb",
19-
"input_size": (128, 128),
19+
"input_size": (224, 224),
20+
"norm_method": "imagenet",
2021
"accuracy": 75.88, # Test accuracy percentage
2122
"parameters": "11.2M",
2223
"url": None, # To be added when uploaded to HuggingFace
@@ -28,7 +29,8 @@
2829
"num_classes": 60,
2930
"architecture": "resnet",
3031
"modality": "rgb",
31-
"input_size": (128, 128),
32+
"input_size": (224, 224),
33+
"norm_method": "imagenet",
3234
"accuracy": 72.24, # Test accuracy percentage
3335
"parameters": "11.2M",
3436
"url": None, # To be added when uploaded to HuggingFace

neon_tree_classification/inference/predictor.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,6 @@
88
import warnings
99
from pathlib import Path
1010
from typing import Union, List, Dict, Optional, Tuple
11-
import sys
12-
13-
# Add project root to path for imports
14-
project_root = Path(__file__).parent.parent.parent
15-
if str(project_root) not in sys.path:
16-
sys.path.insert(0, str(project_root))
1711

1812
from neon_tree_classification.models.rgb_models import create_rgb_model
1913
from .preprocessing import preprocess_image, preprocess_image_batch
@@ -61,7 +55,8 @@ def __init__(
6155
label_mapping: Dict,
6256
taxonomic_level: str,
6357
device: str = None,
64-
input_size: Tuple[int, int] = (128, 128),
58+
input_size: Tuple[int, int] = (224, 224),
59+
norm_method: str = "imagenet",
6560
):
6661
"""
6762
Initialize tree classifier.
@@ -72,11 +67,13 @@ def __init__(
7267
taxonomic_level: 'species' or 'genus'
7368
device: Device for inference ('cpu', 'cuda', 'mps'). Auto-detected if None.
7469
input_size: Input image size (width, height)
70+
norm_method: Normalization method ('imagenet' or '0_1')
7571
"""
7672
self.model = model
7773
self.label_mapping = label_mapping
7874
self.taxonomic_level = taxonomic_level
7975
self.input_size = input_size
76+
self.norm_method = norm_method
8077

8178
# Auto-detect device if not specified
8279
if device is None:
@@ -176,7 +173,8 @@ def from_checkpoint(
176173
label_mapping=label_mapping,
177174
taxonomic_level=taxonomic_level,
178175
device=device,
179-
input_size=(128, 128),
176+
input_size=(224, 224),
177+
norm_method="imagenet",
180178
)
181179

182180
@classmethod
@@ -240,7 +238,7 @@ def predict(
240238
image_input,
241239
target_size=self.input_size,
242240
normalize=True,
243-
norm_method="0_1",
241+
norm_method=self.norm_method,
244242
return_tensor=True,
245243
add_batch_dim=True,
246244
device=self.device,
@@ -299,7 +297,7 @@ def predict_batch(
299297
batch,
300298
target_size=self.input_size,
301299
normalize=True,
302-
norm_method="0_1",
300+
norm_method=self.norm_method,
303301
device=self.device,
304302
)
305303

neon_tree_classification/inference/preprocessing.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def load_image(
8282
)
8383

8484

85-
def resize_image(image: Image.Image, target_size: tuple = (128, 128)) -> Image.Image:
85+
def resize_image(image: Image.Image, target_size: tuple = (224, 224)) -> Image.Image:
8686
"""
8787
Resize image to target size.
8888
@@ -97,14 +97,17 @@ def resize_image(image: Image.Image, target_size: tuple = (128, 128)) -> Image.I
9797

9898

9999
def normalize_rgb(
100-
image: Union[Image.Image, np.ndarray], method: str = "0_1"
100+
image: Union[Image.Image, np.ndarray], method: str = "imagenet"
101101
) -> np.ndarray:
102102
"""
103-
Normalize RGB image to 0-1 range.
103+
Normalize RGB image.
104104
105105
Args:
106106
image: PIL Image or numpy array (H, W, 3) in 0-255 range
107-
method: Normalization method ('0_1' or 'imagenet')
107+
method: Normalization method:
108+
- '0_1': scales pixel values to [0, 1]
109+
- 'imagenet': scales to [0, 1] then standardizes using
110+
ImageNet mean/std (produces values outside [0, 1])
108111
109112
Returns:
110113
Normalized numpy array (H, W, 3) as float32
@@ -161,9 +164,9 @@ def prepare_tensor(
161164

162165
def preprocess_image(
163166
image_input: Union[str, Path, Image.Image, np.ndarray, torch.Tensor],
164-
target_size: tuple = (128, 128),
167+
target_size: tuple = (224, 224),
165168
normalize: bool = True,
166-
norm_method: str = "0_1",
169+
norm_method: str = "imagenet",
167170
return_tensor: bool = True,
168171
add_batch_dim: bool = True,
169172
device: str = "cpu",
@@ -221,9 +224,9 @@ def preprocess_image(
221224
# Convenience functions for batch processing
222225
def preprocess_image_batch(
223226
image_inputs: list,
224-
target_size: tuple = (128, 128),
227+
target_size: tuple = (224, 224),
225228
normalize: bool = True,
226-
norm_method: str = "0_1",
229+
norm_method: str = "imagenet",
227230
device: str = "cpu",
228231
) -> torch.Tensor:
229232
"""

neon_tree_classification/models/lightning_modules.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def __init__(
397397
class_weights: Optional[torch.Tensor] = None,
398398
log_images: bool = False,
399399
idx_to_label: Optional[Dict[int, str]] = None,
400+
rgb_norm_method: str = "imagenet",
400401
**model_kwargs,
401402
):
402403
"""
@@ -413,6 +414,7 @@ def __init__(
413414
log_images: Whether to log sample images during validation
414415
idx_to_label: Optional label mapping {0: "Species1", 1: "Species2", ...}
415416
for DeepForest CropModel compatibility
417+
rgb_norm_method: Normalization method used during training ('imagenet' or '0_1')
416418
**model_kwargs: Additional arguments for model creation
417419
"""
418420
# Create RGB model
@@ -432,6 +434,7 @@ def __init__(
432434

433435
self.log_images = log_images
434436
self.logged_images_this_epoch = False
437+
self.rgb_norm_method = rgb_norm_method
435438

436439
# Set label_dict for DeepForest CropModel compatibility
437440
if idx_to_label is not None:
@@ -445,17 +448,24 @@ def _extract_modality_data(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor
445448
return batch["rgb"]
446449

447450
def normalize(self):
448-
"""Return normalization transform for DeepForest CropModel compatibility.
451+
"""Return normalization transform matching the training configuration.
449452
450-
Returns ImageNet normalization transform as used in training.
451-
This method is required for DeepForest CropModel integration.
453+
Required for DeepForest CropModel integration. Returns a transform
454+
consistent with the rgb_norm_method used during training.
452455
453456
Returns:
454457
torchvision.transforms.Normalize object
455458
"""
456-
return transforms.Normalize(
457-
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
458-
)
459+
if self.rgb_norm_method == "imagenet":
460+
return transforms.Normalize(
461+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
462+
)
463+
elif self.rgb_norm_method == "0_1":
464+
# Scale to [0,1]: equivalent to dividing by 255 in ToTensor,
465+
# represented as zero-mean, unit-std (no-op standardization)
466+
return transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
467+
else:
468+
raise ValueError(f"Unknown rgb_norm_method: {self.rgb_norm_method}")
459469

460470
def set_label_dict(self, idx_to_label: Dict[int, str]):
461471
"""Set label dictionaries from idx_to_label mapping.

scripts/create_label_mappings.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from the training CSV and saves them as JSON files for use in inference.
77
88
Usage:
9-
python scripts/create_label_mappings.py
9+
python scripts/create_label_mappings.py --csv_path path/to/labels.csv
1010
"""
1111

1212
import json
1313
import pandas as pd
1414
from pathlib import Path
15-
from collections import OrderedDict
1615
import sys
1716

1817
# Add project root to path

scripts/upload_to_huggingface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ def upload_to_huggingface(
235235
)
236236
sys.exit(1)
237237

238+
if not checkpoint_data["numeric_to_label_dict"]:
239+
print(
240+
"❌ Checkpoint missing numeric_to_label_dict! "
241+
"Was the model trained with idx_to_label?"
242+
)
243+
sys.exit(1)
244+
238245
num_classes = len(checkpoint_data["label_dict"])
239246
print(f"✅ Found {num_classes} classes in label_dict")
240247

0 commit comments

Comments
 (0)