|
| 1 | +# Advanced Usage |
| 2 | + |
| 3 | +This guide covers advanced features for experienced users who need custom data filtering, specialized training configurations, or want to use the PyTorch Lightning DataModule directly. |
| 4 | + |
| 5 | +## Custom Data Filtering with Lightning DataModule |
| 6 | + |
| 7 | +The `NeonCrownDataModule` provides flexible filtering and splitting options for advanced use cases. |
| 8 | + |
| 9 | +### Basic Configuration |
| 10 | + |
| 11 | +```python |
| 12 | +from neon_tree_classification.core.datamodule import NeonCrownDataModule |
| 13 | + |
| 14 | +# Basic configuration with species/site filtering |
| 15 | +datamodule = NeonCrownDataModule( |
| 16 | + csv_path="_neon_tree_classification_dataset_files/metadata/combined_dataset.csv", |
| 17 | + hdf5_path="_neon_tree_classification_dataset_files/neon_dataset.h5", |
| 18 | + modalities=["rgb"], # Single modality training |
| 19 | + batch_size=32, |
| 20 | + # Filtering options |
| 21 | + species_filter=["PSMEM", "TSHE"], # Train on specific species |
| 22 | + site_filter=["HARV", "OSBS"], # Train on specific sites |
| 23 | + year_filter=[2018, 2019], # Train on specific years |
| 24 | + # Split method options |
| 25 | + split_method="random", # Options: "random", "site", "year" |
| 26 | + val_ratio=0.15, |
| 27 | + test_ratio=0.15 |
| 28 | +) |
| 29 | + |
| 30 | +datamodule.setup("fit") |
| 31 | +``` |
| 32 | + |
| 33 | +### Split Methods |
| 34 | + |
| 35 | +The DataModule supports three splitting strategies: |
| 36 | + |
| 37 | +**1. Random Split** (default) |
| 38 | +```python |
| 39 | +datamodule = NeonCrownDataModule( |
| 40 | + csv_path="path/to/dataset.csv", |
| 41 | + hdf5_path="path/to/dataset.h5", |
| 42 | + split_method="random", |
| 43 | + val_ratio=0.15, |
| 44 | + test_ratio=0.15 |
| 45 | +) |
| 46 | +``` |
| 47 | + |
| 48 | +**2. Site-Based Split** |
| 49 | + |
| 50 | +Useful for testing generalization across geographic locations: |
| 51 | +```python |
| 52 | +datamodule = NeonCrownDataModule( |
| 53 | + csv_path="path/to/dataset.csv", |
| 54 | + hdf5_path="path/to/dataset.h5", |
| 55 | + split_method="site", |
| 56 | + val_ratio=0.15, |
| 57 | + test_ratio=0.15 |
| 58 | +) |
| 59 | +``` |
| 60 | + |
| 61 | +**3. Year-Based Split** |
| 62 | + |
| 63 | +Useful for testing temporal generalization: |
| 64 | +```python |
| 65 | +datamodule = NeonCrownDataModule( |
| 66 | + csv_path="path/to/dataset.csv", |
| 67 | + hdf5_path="path/to/dataset.h5", |
| 68 | + split_method="year", |
| 69 | + val_ratio=0.15, |
| 70 | + test_ratio=0.15 |
| 71 | +) |
| 72 | +``` |
| 73 | + |
| 74 | +### External Test Sets |
| 75 | + |
| 76 | +For domain adaptation or cross-site validation: |
| 77 | + |
| 78 | +```python |
| 79 | +datamodule = NeonCrownDataModule( |
| 80 | + csv_path="_neon_tree_classification_dataset_files/metadata/combined_dataset.csv", |
| 81 | + hdf5_path="_neon_tree_classification_dataset_files/neon_dataset.h5", |
| 82 | + external_test_csv_path="path/to/external_test.csv", |
| 83 | + external_test_hdf5_path="path/to/external_test.h5", # Optional, uses main HDF5 if not provided |
| 84 | + modalities=["rgb"] |
| 85 | +) |
| 86 | + |
| 87 | +datamodule.setup("fit") # Auto-filters species for compatibility |
| 88 | +``` |
| 89 | + |
| 90 | +## Advanced DataLoader Configuration |
| 91 | + |
| 92 | +### Custom Normalization |
| 93 | + |
| 94 | +Each modality supports different normalization methods: |
| 95 | + |
| 96 | +**RGB Normalization:** |
| 97 | +- `"0_1"`: Scale to [0, 1] range (default) |
| 98 | +- `"imagenet"`: ImageNet statistics (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 99 | +- `"per_sample"`: Normalize each sample independently |
| 100 | + |
| 101 | +**HSI Normalization:** |
| 102 | +- `"per_sample"`: Normalize each sample independently (default) |
| 103 | +- `"global"`: Use global dataset statistics |
| 104 | +- `"none"`: No normalization |
| 105 | + |
| 106 | +**LiDAR Normalization:** |
| 107 | +- `"height"`: Normalize by maximum canopy height (default) |
| 108 | +- `"per_sample"`: Normalize each sample independently |
| 109 | +- `"none"`: No normalization |
| 110 | + |
| 111 | +Example: |
| 112 | +```python |
| 113 | +train_loader, test_loader = get_dataloaders( |
| 114 | + config='large', |
| 115 | + modalities=['rgb', 'hsi', 'lidar'], |
| 116 | + batch_size=32, |
| 117 | + rgb_norm_method='imagenet', |
| 118 | + hsi_norm_method='global', |
| 119 | + lidar_norm_method='height' |
| 120 | +) |
| 121 | +``` |
| 122 | + |
| 123 | +### Custom Image Sizes |
| 124 | + |
| 125 | +Adjust the spatial resolution for each modality: |
| 126 | + |
| 127 | +```python |
| 128 | +train_loader, test_loader = get_dataloaders( |
| 129 | + config='large', |
| 130 | + modalities=['rgb', 'hsi', 'lidar'], |
| 131 | + batch_size=32, |
| 132 | + rgb_size=(224, 224), # Larger RGB for fine-grained features |
| 133 | + hsi_size=(16, 16), # Higher HSI resolution |
| 134 | + lidar_size=(16, 16) # Higher LiDAR resolution |
| 135 | +) |
| 136 | +``` |
| 137 | + |
| 138 | +## Direct Dataset Usage |
| 139 | + |
| 140 | +For maximum control, use the `NeonCrownDataset` class directly: |
| 141 | + |
| 142 | +```python |
| 143 | +from neon_tree_classification.core.dataset import NeonCrownDataset |
| 144 | +from torch.utils.data import DataLoader |
| 145 | + |
| 146 | +# Create dataset with custom parameters |
| 147 | +dataset = NeonCrownDataset( |
| 148 | + csv_path="_neon_tree_classification_dataset_files/metadata/large_dataset.csv", |
| 149 | + hdf5_path="_neon_tree_classification_dataset_files/neon_dataset.h5", |
| 150 | + modalities=['rgb', 'hsi'], |
| 151 | + species_filter=['ACRU', 'TSCA'], # Limit to specific species |
| 152 | + site_filter=['HARV', 'MLBS'], # Limit to specific sites |
| 153 | + year_filter=[2018, 2019, 2020], # Limit to specific years |
| 154 | + include_metadata=True, # Include crown_id, species names, etc. |
| 155 | + rgb_size=(128, 128), |
| 156 | + hsi_size=(12, 12), |
| 157 | + rgb_norm_method='imagenet', |
| 158 | + hsi_norm_method='per_sample' |
| 159 | +) |
| 160 | + |
| 161 | +# Create custom DataLoader |
| 162 | +train_loader = DataLoader( |
| 163 | + dataset, |
| 164 | + batch_size=64, |
| 165 | + shuffle=True, |
| 166 | + num_workers=8, |
| 167 | + pin_memory=True |
| 168 | +) |
| 169 | +``` |
| 170 | + |
| 171 | +## Accessing Metadata |
| 172 | + |
| 173 | +Enable metadata in batches to access crown IDs, species names, and site information: |
| 174 | + |
| 175 | +```python |
| 176 | +from scripts.get_dataloaders import get_dataloaders |
| 177 | + |
| 178 | +# Note: get_dataloaders doesn't support include_metadata yet |
| 179 | +# Use NeonCrownDataset directly: |
| 180 | +from neon_tree_classification.core.dataset import NeonCrownDataset |
| 181 | + |
| 182 | +dataset = NeonCrownDataset( |
| 183 | + csv_path="path/to/dataset.csv", |
| 184 | + hdf5_path="path/to/dataset.h5", |
| 185 | + modalities=['rgb'], |
| 186 | + include_metadata=True |
| 187 | +) |
| 188 | + |
| 189 | +# Access metadata in batches |
| 190 | +for batch in DataLoader(dataset, batch_size=32): |
| 191 | + rgb = batch['rgb'] |
| 192 | + labels = batch['species_idx'] |
| 193 | + crown_ids = batch['crown_id'] |
| 194 | + species_names = batch['species'] |
| 195 | + sites = batch['site'] |
| 196 | +``` |
| 197 | + |
| 198 | +## Multi-GPU Training |
| 199 | + |
| 200 | +For distributed training with PyTorch Lightning: |
| 201 | + |
| 202 | +```python |
| 203 | +import pytorch_lightning as pl |
| 204 | +from neon_tree_classification.core.datamodule import NeonCrownDataModule |
| 205 | + |
| 206 | +# Configure DataModule |
| 207 | +datamodule = NeonCrownDataModule( |
| 208 | + csv_path="path/to/dataset.csv", |
| 209 | + hdf5_path="path/to/dataset.h5", |
| 210 | + modalities=["rgb"], |
| 211 | + batch_size=32 # Per-GPU batch size |
| 212 | +) |
| 213 | + |
| 214 | +# Create trainer with multi-GPU support |
| 215 | +trainer = pl.Trainer( |
| 216 | + devices=4, # Number of GPUs |
| 217 | + strategy='ddp', # Distributed Data Parallel |
| 218 | + precision=16, # Mixed precision training |
| 219 | + max_epochs=100 |
| 220 | +) |
| 221 | + |
| 222 | +# Your Lightning module |
| 223 | +trainer.fit(model, datamodule=datamodule) |
| 224 | +``` |
| 225 | + |
| 226 | +## Custom Training Loop |
| 227 | + |
| 228 | +Example of a custom training loop without PyTorch Lightning: |
| 229 | + |
| 230 | +```python |
| 231 | +import torch |
| 232 | +from scripts.get_dataloaders import get_dataloaders |
| 233 | + |
| 234 | +# Get dataloaders |
| 235 | +train_loader, test_loader = get_dataloaders( |
| 236 | + config='large', |
| 237 | + modalities=['rgb'], |
| 238 | + batch_size=64 |
| 239 | +) |
| 240 | + |
| 241 | +# Your model |
| 242 | +model = YourModel().cuda() |
| 243 | +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
| 244 | +criterion = torch.nn.CrossEntropyLoss() |
| 245 | + |
| 246 | +# Training loop |
| 247 | +for epoch in range(100): |
| 248 | + model.train() |
| 249 | + for batch in train_loader: |
| 250 | + rgb = batch['rgb'].cuda() |
| 251 | + labels = batch['species_idx'].cuda() |
| 252 | + |
| 253 | + optimizer.zero_grad() |
| 254 | + outputs = model(rgb) |
| 255 | + loss = criterion(outputs, labels) |
| 256 | + loss.backward() |
| 257 | + optimizer.step() |
| 258 | + |
| 259 | + # Validation |
| 260 | + model.eval() |
| 261 | + correct = 0 |
| 262 | + total = 0 |
| 263 | + with torch.no_grad(): |
| 264 | + for batch in test_loader: |
| 265 | + rgb = batch['rgb'].cuda() |
| 266 | + labels = batch['species_idx'].cuda() |
| 267 | + outputs = model(rgb) |
| 268 | + _, predicted = outputs.max(1) |
| 269 | + total += labels.size(0) |
| 270 | + correct += predicted.eq(labels).sum().item() |
| 271 | + |
| 272 | + accuracy = 100. * correct / total |
| 273 | + print(f'Epoch {epoch}: Accuracy = {accuracy:.2f}%') |
| 274 | +``` |
| 275 | + |
| 276 | +## Performance Tips |
| 277 | + |
| 278 | +1. **Use larger batch sizes**: The dataset fits in memory efficiently due to HDF5 compression |
| 279 | +2. **Increase num_workers**: More workers can significantly speed up data loading |
| 280 | +3. **Enable pin_memory**: Speeds up CPU-to-GPU transfer |
| 281 | +4. **Use persistent_workers**: Reduces worker initialization overhead |
| 282 | + |
| 283 | +```python |
| 284 | +train_loader, test_loader = get_dataloaders( |
| 285 | + config='large', |
| 286 | + modalities=['rgb'], |
| 287 | + batch_size=256, # Larger batch size |
| 288 | + num_workers=16, # More workers (adjust based on CPU cores) |
| 289 | +) |
| 290 | +``` |
0 commit comments