Skip to content

Commit dd154df

Browse files
committed
optimize training pipeline and fix bugs in training code
- Update Comet logger to auto-detect config from .comet.config file - Add setup guard to prevent duplicate DataModule initialization - Implement timestamp-based output directory organization - Fix class weights calculation to use species indices correctly - Enhance logging with training progress indicators Performance improvements: - Eliminate duplicate setup() calls during training initialization - Reduce redundant dataset setup logging
1 parent 3a16bd5 commit dd154df

8 files changed

Lines changed: 479 additions & 1191 deletions

File tree

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ lightning_logs/
1515
results_temp_dir/
1616
.comet.config
1717

18+
# Training outputs
19+
outputs/
20+
test_outputs*/
21+
*.ckpt
22+
1823
# Python packaging
1924
*.egg-info/
2025
build/

examples/train.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
"""
1515

1616
import argparse
17+
import os
18+
from datetime import datetime
1719
import lightning as L
1820
from lightning.pytorch.callbacks import (
1921
ModelCheckpoint,
@@ -23,6 +25,10 @@
2325
from lightning.pytorch.loggers import TensorBoardLogger
2426
import torch
2527

28+
# Optimize CUDA performance for Tensor Cores
29+
if torch.cuda.is_available():
30+
torch.set_float32_matmul_precision("medium")
31+
2632
try:
2733
from lightning.pytorch.loggers import CometLogger
2834

@@ -39,6 +45,7 @@
3945

4046

4147
def main():
48+
4249
parser = argparse.ArgumentParser(description="Train NEON tree species classifier")
4350

4451
# Data arguments
@@ -100,29 +107,39 @@ def main():
100107
parser.add_argument(
101108
"--num_workers", type=int, default=4, help="Number of data loader workers"
102109
)
110+
parser.add_argument(
111+
"--distributed", action="store_true", help="Enable distributed training"
112+
)
103113

104114
# Logging arguments
105115
parser.add_argument(
106116
"--logger", type=str, default="tensorboard", choices=["tensorboard", "comet"]
107117
)
108-
parser.add_argument("--project_name", type=str, default="neon-tree-classification")
109118
parser.add_argument(
110-
"--experiment_name",
119+
"--output_dir",
111120
type=str,
112-
help="Experiment name (auto-generated if not provided)",
121+
help="Directory to save logs, checkpoints, and results (auto-generated if not provided)",
113122
)
114123

115124
args = parser.parse_args()
116125

117-
# Set up experiment name
118-
if args.experiment_name is None:
119-
args.experiment_name = (
120-
f"{args.modality}_{args.model_type}_{args.lr}_{args.batch_size}"
121-
)
126+
# Set up experiment name (auto-generate)
127+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
128+
experiment_name = (
129+
f"{args.modality}_{args.model_type}_{args.lr}_{args.batch_size}_{timestamp}"
130+
)
131+
132+
# Set up output directory (organize by modality and timestamp)
133+
if args.output_dir is None:
134+
args.output_dir = f"./outputs/{args.modality}_{timestamp}"
122135

123136
print(f"🌲 Training {args.modality.upper()} classifier: {args.model_type}")
124137
print(f"📁 Data: {args.data_dir}")
125-
print(f"🧪 Experiment: {args.experiment_name}")
138+
print(f"🧪 Experiment: {experiment_name}")
139+
print(f"💾 Output directory: {args.output_dir}")
140+
141+
# Create output directory if it doesn't exist
142+
os.makedirs(args.output_dir, exist_ok=True)
126143

127144
# Create data module
128145
datamodule = NeonCrownDataModule(
@@ -192,12 +209,10 @@ def main():
192209
"CometML not available. Install with: pip install comet-ml"
193210
)
194211
logger = CometLogger(
195-
project_name=args.project_name,
196-
experiment_name=args.experiment_name,
197-
save_dir="lightning_logs",
212+
save_dir=args.output_dir,
198213
)
199214
else:
200-
logger = TensorBoardLogger(save_dir="lightning_logs", name=args.experiment_name)
215+
logger = TensorBoardLogger(save_dir=args.output_dir, name=experiment_name)
201216

202217
# Set up callbacks
203218
callbacks = [
@@ -228,6 +243,7 @@ def main():
228243
# datamodule.setup() # Already called above for class detection
229244

230245
# Get class weights for imbalanced datasets
246+
print("⚖️ Calculating class weights...")
231247
class_weights = datamodule.get_class_weights()
232248
if class_weights is not None:
233249
classifier.class_weights = class_weights

neon_tree_classification/core/datamodule.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def __init__(
142142
self.label_to_idx = None
143143
self.idx_to_label = None
144144
self.num_classes = None
145+
self._setup_done = False # Guard to prevent duplicate setup
145146

146147
def _create_default_transforms(self) -> Dict[str, Callable]:
147148
"""Create default transform functions."""
@@ -158,6 +159,11 @@ def setup(self, stage: Optional[str] = None) -> None:
158159
Args:
159160
stage: 'fit', 'validate', 'test', or 'predict'
160161
"""
162+
# Guard against duplicate setup
163+
if self._setup_done:
164+
print("⚡ DataModule already set up, skipping duplicate setup")
165+
return
166+
161167
if stage is None or stage in ["fit", "validate"]:
162168
# Create full dataset to analyze splits
163169
full_dataset = NeonCrownDataset(
@@ -209,6 +215,9 @@ def setup(self, stage: Optional[str] = None) -> None:
209215
print(f" Test samples: {len(self.test_dataset)}")
210216
print(f" Num classes: {self.num_classes}")
211217

218+
# Mark setup as complete
219+
self._setup_done = True
220+
212221
def _create_label_mapping(self, dataset: NeonCrownDataset) -> None:
213222
"""Create mapping between string labels and integer indices."""
214223
species_list = dataset.get_species_list()
@@ -392,20 +401,31 @@ def get_class_weights(self) -> torch.Tensor:
392401
if self.train_dataset is None:
393402
raise RuntimeError("Must call setup() before getting class weights")
394403

395-
# Count samples per class in training set
404+
print("🔄 Calculating class weights...")
405+
406+
# More efficient: count from the original data instead of loading samples
407+
# Get species from the training split's underlying data
396408
species_counts = {}
397-
for i in range(len(self.train_dataset)):
398-
sample = self.train_dataset[i]
399-
species = sample["species"]
400-
species_counts[species] = species_counts.get(species, 0) + 1
409+
410+
# Access the pandas DataFrame directly from train dataset
411+
train_data = self.train_dataset.data
412+
for _, row in train_data.iterrows():
413+
species = row["species"]
414+
# Convert to index using our mapping
415+
species_idx = self.label_to_idx[species]
416+
species_counts[species_idx] = species_counts.get(species_idx, 0) + 1
417+
418+
print(f"📊 Found {len(species_counts)} classes in training set")
401419

402420
# Calculate weights (inverse frequency)
403421
total_samples = sum(species_counts.values())
404422
weights = []
405-
for species in sorted(species_counts.keys()):
406-
weight = total_samples / (len(species_counts) * species_counts[species])
423+
# Sort by species index to maintain consistent ordering
424+
for species_idx in sorted(species_counts.keys()):
425+
weight = total_samples / (len(species_counts) * species_counts[species_idx])
407426
weights.append(weight)
408427

428+
print(f"✅ Class weights calculated successfully")
409429
return torch.tensor(weights, dtype=torch.float32)
410430

411431
def get_species_mapping(self) -> Dict[str, int]:

neon_tree_classification/models/lightning_modules.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,7 @@ def configure_optimizers(self):
183183

184184
# Scheduler
185185
if self.hparams.scheduler == "plateau":
186-
scheduler = ReduceLROnPlateau(
187-
optimizer, mode="min", factor=0.5, patience=5, verbose=True
188-
)
186+
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
189187
return {
190188
"optimizer": optimizer,
191189
"lr_scheduler": {

0 commit comments

Comments
 (0)