Skip to content

Commit 9634951

Browse files
Refactor validation logging and loss normalization (#1126)
1 parent 736811d commit 9634951

2 files changed

Lines changed: 20 additions & 2 deletions

File tree

micro_sam/training/joint_sam_trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _validate_impl(self, forward_context):
136136

137137
val_iteration = 0
138138
metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
139+
mask_loss_val, iou_loss_val, unetr_loss_val = 0.0, 0.0, 0.0
139140

140141
with torch.no_grad():
141142
for x, y in self.val_loader:
@@ -155,17 +156,23 @@ def _validate_impl(self, forward_context):
155156

156157
loss_val += loss.item()
157158
metric_val += metric.item() + (unetr_metric.item() / 3)
159+
mask_loss_val += mask_loss.item()
160+
iou_loss_val += iou_regression_loss.item()
158161
model_iou_val += model_iou.item()
162+
unetr_loss_val += unetr_loss.item()
159163
val_iteration += 1
160164

161165
loss_val /= len(self.val_loader)
162166
metric_val /= len(self.val_loader)
167+
mask_loss_val /= len(self.val_loader)
168+
iou_loss_val /= len(self.val_loader)
163169
model_iou_val /= len(self.val_loader)
170+
unetr_loss_val /= len(self.val_loader)
164171

165172
if self.logger is not None:
166173
self.logger.log_validation(
167174
self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y,
168-
mask_loss, iou_regression_loss, model_iou_val, unetr_loss
175+
mask_loss_val, iou_loss_val, model_iou_val, unetr_loss_val
169176
)
170177

171178
return metric_val

micro_sam/training/sam_trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def _compute_loss(self, batched_outputs, y_one_hot):
134134
- The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
135135
"""
136136
mask_loss, iou_regression_loss = 0.0, 0.0
137+
batch_size = len(batched_outputs)
138+
if batch_size == 0:
139+
raise RuntimeError("Got empty batch outputs in loss computation.")
137140

138141
# Loop over the batch.
139142
for batch_output, targets in zip(batched_outputs, y_one_hot):
@@ -163,6 +166,9 @@ def _compute_loss(self, batched_outputs, y_one_hot):
163166
mask_loss = mask_loss + torch.mean(dice_scores)
164167
iou_regression_loss = iou_regression_loss + iou_score
165168

169+
# Normalize by batch size so that loss/metric are comparable across batch sizes.
170+
mask_loss = mask_loss / batch_size
171+
iou_regression_loss = iou_regression_loss / batch_size
166172
loss = mask_loss + iou_regression_loss
167173

168174
return loss, mask_loss, iou_regression_loss
@@ -448,6 +454,7 @@ def _validate_impl(self, forward_context):
448454

449455
val_iteration = 0
450456
metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
457+
mask_loss_val, iou_loss_val = 0.0, 0.0
451458

452459
with torch.no_grad():
453460
for x, y in self.val_loader:
@@ -459,19 +466,23 @@ def _validate_impl(self, forward_context):
459466

460467
loss_val += loss.item()
461468
metric_val += metric.item()
469+
mask_loss_val += mask_loss.item()
470+
iou_loss_val += iou_regression_loss.item()
462471
model_iou_val += model_iou.item()
463472
val_iteration += 1
464473

465474
loss_val /= len(self.val_loader)
466475
metric_val /= len(self.val_loader)
476+
mask_loss_val /= len(self.val_loader)
477+
iou_loss_val /= len(self.val_loader)
467478
model_iou_val /= len(self.val_loader)
468479
print()
469480
print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}")
470481

471482
if self.logger is not None:
472483
self.logger.log_validation(
473484
self._iteration, metric_val, loss_val, x, y,
474-
sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val
485+
sampled_binary_y, mask_loss_val, iou_loss_val, model_iou_val
475486
)
476487

477488
return metric_val

0 commit comments

Comments
 (0)