Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mipcandy/common/optim/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, *, lambda_bce: float = 1, lambda_soft_dice: float = 1,

def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
outputs = outputs.sigmoid()
labels = labels.float()
labels = labels.to(dtype=outputs.dtype)
bce = nn.functional.binary_cross_entropy(outputs, labels)
dice = soft_dice(outputs, labels, smooth=self.smooth)
metrics = {"soft dice": dice.item(), "bce loss": bce.item()}
Expand Down
4 changes: 3 additions & 1 deletion mipcandy/presets/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerT
outputs = list(torch.unbind(outputs, dim=1))
labels = self.prepare_deep_supervision_targets(labels, [m.shape[2:] for m in outputs])
loss, metrics = toolbox.criterion(outputs, labels)
loss.backward()
self._do_backward(loss, toolbox)
if toolbox.scaler:
toolbox.scaler.unscale_(toolbox.optimizer)
nn.utils.clip_grad_norm_(toolbox.model.parameters(), 12)
return loss.item(), metrics

Expand Down
43 changes: 35 additions & 8 deletions mipcandy/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TrainerToolbox(object):
scheduler: optim.lr_scheduler.LRScheduler
criterion: nn.Module
ema: nn.Module | None = None
scaler: torch.amp.GradScaler | None = None


@dataclass
Expand Down Expand Up @@ -85,11 +86,14 @@ def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: Trainer
**training_arguments) -> None:
if self._unrecoverable:
return
torch.save({
state_dicts = {
"optimizer": toolbox.optimizer.state_dict(),
"scheduler": toolbox.scheduler.state_dict(),
"criterion": toolbox.criterion.state_dict()
}, f"{self.experiment_folder()}/state_dicts.pth")
}
if toolbox.scaler:
state_dicts["scaler"] = toolbox.scaler.state_dict()
torch.save(state_dicts, f"{self.experiment_folder()}/state_dicts.pth")
with open(f"{self.experiment_folder()}/state_orb.json", "w") as f:
dump({"tracker": asdict(tracker), "training_arguments": training_arguments}, f)

Expand All @@ -116,6 +120,9 @@ def load_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_m
toolbox.optimizer.load_state_dict(state_dicts["optimizer"])
toolbox.scheduler.load_state_dict(state_dicts["scheduler"])
toolbox.criterion.load_state_dict(state_dicts["criterion"])
if "scaler" in state_dicts:
toolbox.scaler = torch.amp.GradScaler(self._device_type())
Comment thread
perctrix marked this conversation as resolved.
toolbox.scaler.load_state_dict(state_dicts["scaler"])
return toolbox

def recover_from(self, experiment_id: str) -> Self:
Expand Down Expand Up @@ -391,6 +398,12 @@ def empty_cache(self) -> None:

# Training methods

def _do_backward(self, loss: torch.Tensor, toolbox: TrainerToolbox) -> None:
if toolbox.scaler:
toolbox.scaler.scale(loss).backward()
else:
loss.backward()

def sanity_check(self, template_model: nn.Module, example_shape: AmbiguousShape) -> SanityCheckResult:
try:
return sanity_check(template_model, example_shape, device=self._device)
Expand All @@ -402,12 +415,23 @@ def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerT
str, float]]:
raise NotImplementedError

def _device_type(self) -> str:
return self._device.type if isinstance(self._device, torch.device) else str(self._device).split(":")[0]

def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
str, float]]:
toolbox.optimizer.zero_grad()
loss, metrics = self.backward(images, labels, toolbox)
toolbox.optimizer.step()
toolbox.scheduler.step()
with torch.amp.autocast(self._device_type(), enabled=toolbox.scaler is not None):
loss, metrics = self.backward(images, labels, toolbox)
Comment thread
perctrix marked this conversation as resolved.
Comment thread
perctrix marked this conversation as resolved.
if toolbox.scaler:
old_scale = toolbox.scaler.get_scale()
toolbox.scaler.step(toolbox.optimizer)
toolbox.scaler.update()
if old_scale <= toolbox.scaler.get_scale():
toolbox.scheduler.step()
Comment thread
perctrix marked this conversation as resolved.
else:
toolbox.optimizer.step()
toolbox.scheduler.step()
if toolbox.ema:
toolbox.ema.update_parameters(toolbox.model)
return loss, metrics
Expand Down Expand Up @@ -440,7 +464,7 @@ def train_epoch(self, toolbox: TrainerToolbox) -> dict[str, list[float]]:
def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, compile_model: bool = True,
ema: bool = True, seed: int | None = None, early_stop_tolerance: int = 5,
val_score_prediction: bool = True, val_score_prediction_degree: int = 5, save_preview: bool = True,
preview_quality: float = .75) -> None:
preview_quality: float = .75, amp: bool = False) -> None:
training_arguments = self.filter_train_params(**locals())
self.init_experiment()
if note:
Expand Down Expand Up @@ -468,6 +492,9 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co
toolbox = (self.load_toolbox if self.recovery() else self.build_toolbox)(
num_epochs, example_shape, compile_model, ema
)
if amp and not toolbox.scaler:
toolbox.scaler = torch.amp.GradScaler(self._device_type())
self.log("Mixed precision training enabled")
Comment thread
perctrix marked this conversation as resolved.
checkpoint_path = lambda v: f"{self.experiment_folder()}/checkpoint_{v}.pth"
es_tolerance = early_stop_tolerance
self._frontend.on_experiment_created(self._experiment_id, self._trainer_variant, model_name, note,
Expand Down Expand Up @@ -550,7 +577,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co
def filter_train_params(**kwargs) -> dict[str, Setting]:
return {k: v for k, v in kwargs.items() if k in (
"note", "num_checkpoints", "compile_model", "ema", "seed", "early_stop_tolerance", "val_score_prediction",
"val_score_prediction_degree", "save_preview", "preview_quality"
"val_score_prediction_degree", "save_preview", "preview_quality", "amp"
)}

def train_with_settings(self, num_epochs: int, **kwargs) -> None:
Expand Down Expand Up @@ -580,7 +607,7 @@ def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float
worst_score = float("+inf")
metrics = {}
num_cases = len(self._validation_dataloader)
with torch.no_grad(), Progress(
with torch.no_grad(), torch.amp.autocast(self._device_type(), enabled=toolbox.scaler is not None), Progress(
*Progress.get_default_columns(), SpinnerColumn(), console=self._console
) as progress:
task = progress.add_task(f"Validating", total=num_cases)
Expand Down