@@ -31,7 +31,9 @@ def mean_intersection_over_union(y_true, y_pred, class_names):
3131 print ('IoU for {} is: {:.2f}%' .format (class_names [c ], iou * 100 ))
3232 total_iou += iou
3333
34- print ('\n Mean IoU is: {:.2f}%' .format (100 * total_iou / len (class_names )))
34+ mIOU = 100 * total_iou / len (class_names )
35+ print ('\n Mean IoU is: {:.2f}%' .format (mIOU ))
36+ return mIOU
3537
3638def main ():
3739 args = argument_parser ()
@@ -96,13 +98,16 @@ def main():
9698 prediction = model .predict_on_batch (X_preprocessed )
9799 prediction_max = np .argmax (prediction , axis = - 1 )
98100
101+ this_m_iou = mean_intersection_over_union (y_preprocessed , prediction , CLASS_NAMES )
102+
99103 feature_image = X_preprocessed .reshape (IMAGE_WIDTH , IMAGE_HEIGHT , IMAGE_DEPTH )
100104 label_image = y_preprocessed_max .reshape (IMAGE_WIDTH , IMAGE_HEIGHT )
101105 prediction_image = prediction_max .reshape (IMAGE_WIDTH , IMAGE_HEIGHT )
102106
103107 plot_feature_label_prediction_path = os .path .join ("plots" , "predictions" , "prediction_event_{}.pdf" .format (count ))
104108 plot_feature_label_prediction (feature_image , label_image , prediction_image ,
105- 'Feature' , 'Label' , 'Model prediction' , CLASS_NAMES , plot_feature_label_prediction_path )
109+ 'Feature' , 'Label' , 'Prediction (mIOU: {:.1f})' .format (this_m_iou ),
110+ CLASS_NAMES , plot_feature_label_prediction_path )
106111
107112 # Calculate Statistics
108113 samples = np .zeros ((NUM_TESTING , IMAGE_WIDTH , IMAGE_HEIGHT , IMAGE_DEPTH ))
@@ -118,7 +123,7 @@ def main():
118123 count += 1
119124
120125 predictions = model .predict_on_batch (samples )
121- mean_intersection_over_union (targets , predictions , CLASS_NAMES )
126+ mIOU = mean_intersection_over_union (targets , predictions , CLASS_NAMES )
122127
123128 # Print the test accuracy
124129 score = model .evaluate (samples , targets , verbose = 0 )
0 commit comments