@@ -134,6 +134,9 @@ def _compute_loss(self, batched_outputs, y_one_hot):
134134 - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
135135 """
136136 mask_loss , iou_regression_loss = 0.0 , 0.0
137+ batch_size = len (batched_outputs )
138+ if batch_size == 0 :
139+ raise RuntimeError ("Got empty batch outputs in loss computation." )
137140
138141 # Loop over the batch.
139142 for batch_output , targets in zip (batched_outputs , y_one_hot ):
@@ -163,6 +166,9 @@ def _compute_loss(self, batched_outputs, y_one_hot):
163166 mask_loss = mask_loss + torch .mean (dice_scores )
164167 iou_regression_loss = iou_regression_loss + iou_score
165168
169+ # Normalize by batch size so that loss/metric are comparable across batch sizes.
170+ mask_loss = mask_loss / batch_size
171+ iou_regression_loss = iou_regression_loss / batch_size
166172 loss = mask_loss + iou_regression_loss
167173
168174 return loss , mask_loss , iou_regression_loss
@@ -448,6 +454,7 @@ def _validate_impl(self, forward_context):
448454
449455 val_iteration = 0
450456 metric_val , loss_val , model_iou_val = 0.0 , 0.0 , 0.0
457+ mask_loss_val , iou_loss_val = 0.0 , 0.0
451458
452459 with torch .no_grad ():
453460 for x , y in self .val_loader :
@@ -459,19 +466,23 @@ def _validate_impl(self, forward_context):
459466
460467 loss_val += loss .item ()
461468 metric_val += metric .item ()
469+ mask_loss_val += mask_loss .item ()
470+ iou_loss_val += iou_regression_loss .item ()
462471 model_iou_val += model_iou .item ()
463472 val_iteration += 1
464473
465474 loss_val /= len (self .val_loader )
466475 metric_val /= len (self .val_loader )
476+ mask_loss_val /= len (self .val_loader )
477+ iou_loss_val /= len (self .val_loader )
467478 model_iou_val /= len (self .val_loader )
468479 print ()
469480 print (f"The Average Dice Score for the Current Epoch is { 1 - metric_val } " )
470481
471482 if self .logger is not None :
472483 self .logger .log_validation (
473484 self ._iteration , metric_val , loss_val , x , y ,
474- sampled_binary_y , mask_loss , iou_regression_loss , model_iou_val
485+ sampled_binary_y , mask_loss_val , iou_loss_val , model_iou_val
475486 )
476487
477488 return metric_val
0 commit comments