Skip to content

Commit 07a655b

Browse files
committed
feat: Add ViT models, HSI Hang2020, and prepare for DeepForest integration
BREAKING CHANGES: - Default RGB image size changed from 128x128 to 224x224 - Default RGB normalization changed from 0_1 to imagenet For backward compatibility, explicitly pass rgb_size=(128, 128) and rgb_norm_method='0_1' New Features: - Add Vision Transformer (ViT) support: vit_b_16, vit_b_32, vit_l_16, vit_l_32 - Implement Hang2020 dual-pathway attention architecture for HSI classification - Add model_variant parameter to training script for architecture selection - Add preliminary DeepForest CropModel compatibility methods (WIP): * normalize() method for transforms * label_dict persistence in checkpoints * set_label_dict() and get_label_dict() helpers - Add HuggingFace upload script (experimental, needs further testing) - Add multi-output training support with auxiliary losses (Hang2020) Improvements: - Better experiment naming to prevent collisions in SLURM array jobs - Enhanced test logging with detailed statistics - Add rgb_size and rgb_norm_method CLI arguments for flexibility - Update README with project roadmap Note: Full DeepForest CropModel integration and HuggingFace loading are still in progress and may require additional work. Files changed: 10 files - Added: scripts/upload_to_huggingface.py, sample_plots/test_PSMEM_douglas_fir.png - Modified: train.py, rgb_models.py, hsi_models.py, lightning_modules.py, dataset.py, datamodule.py, README.md, visualization.ipynb
1 parent 7f2f046 commit 07a655b

10 files changed

Lines changed: 1128 additions & 113 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ A comprehensive toolkit for multi-modal tree species classification using NEON e
66

77
This repository aims to provide an end-to-end solution for tree species classification:
88

9-
- [x] **Dataset**: Ready-to-use multi-modal tree crown dataset with 167 species
9+
- [x] **Dataset**: Ready-to-use multi-modal tree crown dataset with 167 species. It's curated using the code in preprocessing directory in this repo.
1010
- [ ] **Data Processing**: Tools for downloading and processing raw NEON data products
11-
- [ ] **Classification Models**: Pre-trained models and training pipelines
11+
- [ ] **Classification Models**: Pre-trained models and training pipelines (Ongoing. ETA End of Feb 2026)
1212
- [ ] **DeepForest Integration**: Automated crown detection and classification workflow
1313

1414
## What's Available Now

examples/train.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ def main():
181181
parser.add_argument(
182182
"--model_type", type=str, default="simple", help="Model architecture type"
183183
)
184+
parser.add_argument(
185+
"--model_variant",
186+
type=str,
187+
default=None,
188+
help="Model variant (e.g., 'vit_b_16', 'vit_l_16' for ViT models)",
189+
)
184190
parser.add_argument(
185191
"--num_classes",
186192
type=int,
@@ -237,6 +243,23 @@ def main():
237243
action="store_true",
238244
help="Use WeightedRandomSampler for balanced class sampling (recommended for imbalanced datasets)",
239245
)
246+
247+
# Image size arguments
248+
parser.add_argument(
249+
"--rgb_size",
250+
type=int,
251+
default=224,
252+
help="RGB image size (single value for square images, e.g., 224 for 224x224). Default matches ImageNet pretraining.",
253+
)
254+
255+
# Normalization arguments
256+
parser.add_argument(
257+
"--rgb_norm_method",
258+
type=str,
259+
default="imagenet",
260+
choices=["none", "0_1", "imagenet"],
261+
help="RGB normalization method: 'imagenet' (recommended for pretrained models), '0_1' (simple [0,1] range), 'none'",
262+
)
240263

241264
# Reproducibility arguments
242265
parser.add_argument(
@@ -293,8 +316,10 @@ def main():
293316
worker_init_fn.base_seed = args.seed
294317

295318
# Set up experiment name (auto-generate)
319+
# Include model_variant and taxonomic_level to avoid collisions in array jobs
296320
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
297-
experiment_name = f"{args.modality}_{args.model_type}_{args.batch_size}_{timestamp}"
321+
model_name = args.model_variant if args.model_variant else args.model_type
322+
experiment_name = f"{args.modality}_{model_name}_{args.taxonomic_level}_{timestamp}"
298323

299324
# Set up output directory with dynamic naming within provided path
300325
if args.output_dir is None:
@@ -324,6 +349,8 @@ def main():
324349
external_test_csv_path=args.external_test_csv, # External test support
325350
external_test_hdf5_path=args.external_test_hdf5, # External test support
326351
modalities=[args.modality],
352+
rgb_size=(args.rgb_size, args.rgb_size), # Image size for RGB
353+
rgb_norm_method=args.rgb_norm_method, # Normalization for RGB (imagenet for pretrained models)
327354
taxonomic_level=args.taxonomic_level, # Species or genus level
328355
use_balanced_sampler=args.use_balanced_sampler, # Balanced sampling
329356
split_method=args.split_method,
@@ -381,6 +408,11 @@ def main():
381408

382409
# Create classifier based on modality
383410
if args.modality == "rgb":
411+
# Prepare model kwargs
412+
model_kwargs = {}
413+
if args.model_variant is not None:
414+
model_kwargs["model_variant"] = args.model_variant
415+
384416
classifier = RGBClassifier(
385417
model_type=args.model_type,
386418
num_classes=args.num_classes,
@@ -389,6 +421,8 @@ def main():
389421
scheduler=args.scheduler,
390422
weight_decay=args.weight_decay,
391423
log_images=True, # Enable image logging for RGB
424+
idx_to_label=datamodule.full_dataset.idx_to_label, # For DeepForest CropModel compatibility
425+
**model_kwargs, # Pass model variant for ViT and other models
392426
)
393427
elif args.modality == "hsi":
394428
classifier = HSIClassifier(

neon_tree_classification/core/datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ def __init__(
6060
species_filter: Optional[List[str]] = None,
6161
site_filter: Optional[List[str]] = None,
6262
year_filter: Optional[List[int]] = None,
63-
rgb_size: Tuple[int, int] = (128, 128),
63+
rgb_size: Tuple[int, int] = (224, 224), # Matches ImageNet pretraining
6464
hsi_size: Tuple[int, int] = (12, 12),
6565
lidar_size: Tuple[int, int] = (12, 12),
6666
rgb_resize_mode: str = "nearest",
6767
hsi_resize_mode: str = "nearest",
6868
lidar_resize_mode: str = "nearest",
69-
rgb_norm_method: str = "0_1",
69+
rgb_norm_method: str = "imagenet", # ImageNet normalization for pretrained models
7070
hsi_norm_method: str = "per_sample",
7171
lidar_norm_method: str = "height",
7272
custom_transforms: Optional[Dict[str, Callable]] = None,

neon_tree_classification/core/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ def __init__(
4444
site_filter: Optional[List[str]] = None,
4545
year_filter: Optional[List[int]] = None,
4646
# Target sizes for training (required for consistent batching)
47-
rgb_size: Tuple[int, int] = (128, 128),
47+
rgb_size: Tuple[int, int] = (224, 224), # Matches ImageNet pretraining
4848
hsi_size: Tuple[int, int] = (12, 12),
4949
lidar_size: Tuple[int, int] = (12, 12),
5050
# Resize methods (optimized for speed)
5151
rgb_resize_mode: str = "nearest", # Fastest for RGB images
5252
hsi_resize_mode: str = "nearest", # Changed to nearest for speed
5353
lidar_resize_mode: str = "nearest", # Changed to nearest for speed
5454
# Normalization methods (performance-first defaults)
55-
rgb_norm_method: str = "0_1", # Simple division, fastest
55+
rgb_norm_method: str = "imagenet", # ImageNet normalization for pretrained models
5656
hsi_norm_method: str = "per_sample", # Per-sample z-score, faster than per_pixel
5757
lidar_norm_method: str = "height", # Simple max scaling, fastest
5858
# Custom transforms (optional, per-modality)

0 commit comments

Comments
 (0)