Skip to content

Commit 5a622bb

Browse files
luisquintanillaCESARDELATORRE
authored andcommitted
Improve ONNX Object Detection Flow (#531)
* Initial commit * Refactored for flow * Renamed image read method for clarity * Rearranged methods based on the order used * Removing unused variables * restored old variable names * Updated based on feedback * Added comments based on feedback
1 parent e2e301c commit 5a622bb

6 files changed

Lines changed: 245 additions & 232 deletions

File tree

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
using Microsoft.ML.Data;
1+
using System.Collections.Generic;
2+
using System.IO;
3+
using System.Linq;
4+
using Microsoft.ML.Data;
25

3-
namespace ObjectDetection
6+
namespace ObjectDetection.DataStructures
47
{
58
public class ImageNetData
69
{
@@ -9,11 +12,13 @@ public class ImageNetData
912

1013
[LoadColumn(1)]
1114
public string Label;
12-
}
1315

14-
public class ImageNetDataProbability : ImageNetData
15-
{
16-
public string PredictedLabel;
17-
public float Probability { get; set; }
16+
public static IEnumerable<ImageNetData> ReadFromFile(string imageFolder)
17+
{
18+
return Directory
19+
.GetFiles(imageFolder)
20+
.Where(filePath => Path.GetExtension(filePath) != ".md")
21+
.Select(filePath => new ImageNetData { ImagePath = filePath, Label = Path.GetFileName(filePath) });
22+
}
1823
}
1924
}
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
using Microsoft.ML.Data;
22

3-
namespace ObjectDetection
3+
namespace ObjectDetection.DataStructures
44
{
55
public class ImageNetPrediction
66
{
7-
[ColumnName(OnnxModelScorer.TinyYoloModelSettings.ModelOutput)]
7+
[ColumnName("grid")]
88
public float[] PredictedLabels;
99
}
1010
}

samples/csharp/getting-started/DeepLearning_ObjectDetection_Onnx/ObjectDetectionConsoleApp/OnnxModelScorer.cs

Lines changed: 19 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
using System;
22
using System.Collections.Generic;
3-
using System.Drawing;
4-
using System.Drawing.Drawing2D;
5-
using System.IO;
63
using System.Linq;
74
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
using ObjectDetection.DataStructures;
7+
using ObjectDetection.YoloParser;
88

99
namespace ObjectDetection
1010
{
@@ -15,13 +15,12 @@ class OnnxModelScorer
1515
private readonly MLContext mlContext;
1616

1717
private IList<YoloBoundingBox> _boundingBoxes = new List<YoloBoundingBox>();
18-
private readonly YoloWinMlParser _parser = new YoloWinMlParser();
1918

20-
public OnnxModelScorer(string imagesFolder, string modelLocation)
19+
public OnnxModelScorer(string imagesFolder, string modelLocation, MLContext mlContext)
2120
{
2221
this.imagesFolder = imagesFolder;
2322
this.modelLocation = modelLocation;
24-
mlContext = new MLContext();
23+
this.mlContext = mlContext;
2524
}
2625

2726
public struct ImageNetSettings
@@ -32,7 +31,7 @@ public struct ImageNetSettings
3231

3332
public struct TinyYoloModelSettings
3433
{
35-
// for checking TIny yolo2 Model input and output parameter names,
34+
// for checking Tiny yolo2 Model input and output parameter names,
3635
//you can use tools like Netron,
3736
// which is installed by Visual Studio AI Tools
3837

@@ -43,139 +42,46 @@ public struct TinyYoloModelSettings
4342
public const string ModelOutput = "grid";
4443
}
4544

46-
public void Score()
47-
{
48-
var model = LoadModel(modelLocation);
49-
50-
PredictDataUsingModel(imagesFolder, model);
51-
}
52-
53-
private PredictionEngine<ImageNetData, ImageNetPrediction> LoadModel(string modelLocation)
45+
private ITransformer LoadModel(string modelLocation)
5446
{
5547
Console.WriteLine("Read model");
5648
Console.WriteLine($"Model location: {modelLocation}");
5749
Console.WriteLine($"Default parameters: image size=({ImageNetSettings.imageWidth},{ImageNetSettings.imageHeight})");
5850

59-
var data = CreateEmptyDataView();
51+
// Create IDataView from empty list to obtain input data schema
52+
var data = mlContext.Data.LoadFromEnumerable(new List<ImageNetData>());
6053

54+
// Define scoring pipeline
6155
var pipeline = mlContext.Transforms.LoadImages(outputColumnName: "image", imageFolder: "", inputColumnName: nameof(ImageNetData.ImagePath))
6256
.Append(mlContext.Transforms.ResizeImages(outputColumnName: "image", imageWidth: ImageNetSettings.imageWidth, imageHeight: ImageNetSettings.imageHeight, inputColumnName: "image"))
6357
.Append(mlContext.Transforms.ExtractPixels(outputColumnName: "image"))
6458
.Append(mlContext.Transforms.ApplyOnnxModel(modelFile: modelLocation, outputColumnNames: new[] { TinyYoloModelSettings.ModelOutput }, inputColumnNames: new[] { TinyYoloModelSettings.ModelInput }));
6559

60+
// Fit scoring pipeline
6661
var model = pipeline.Fit(data);
6762

68-
var predictionEngine = mlContext.Model.CreatePredictionEngine<ImageNetData, ImageNetPrediction>(model);
69-
70-
return predictionEngine;
63+
return model;
7164
}
7265

73-
protected void PredictDataUsingModel(string imagesFolder, PredictionEngine<ImageNetData, ImageNetPrediction> model)
66+
private IEnumerable<float[]> PredictDataUsingModel(IDataView testData, ITransformer model)
7467
{
7568
Console.WriteLine($"Images location: {imagesFolder}");
7669
Console.WriteLine("");
7770
Console.WriteLine("=====Identify the objects in the images=====");
7871
Console.WriteLine("");
7972

80-
var testData = GetImagesData(imagesFolder);
81-
82-
foreach (var sample in testData)
83-
{
84-
var probs = model.Predict(sample).PredictedLabels;
85-
_boundingBoxes = _parser.ParseOutputs(probs);
86-
var filteredBoxes = _parser.FilterBoundingBoxes(_boundingBoxes, 5, .5F);
73+
IDataView scoredData = model.Transform(testData);
8774

75+
IEnumerable<float[]> probabilities = scoredData.GetColumn<float[]>(TinyYoloModelSettings.ModelOutput);
8876

89-
var outputDirectory = Path.Combine(Directory.GetParent(sample.ImagePath).FullName, "output");
90-
var filename = new FileInfo(sample.ImagePath).Name;
91-
92-
DrawBoundingBox(imagesFolder, outputDirectory, filename, filteredBoxes);
93-
94-
Console.WriteLine(".....The objects in the image {0} are detected as below....", sample.Label);
95-
foreach (var box in filteredBoxes)
96-
{
97-
Console.WriteLine(box.Label + " and its Confidence score: " + box.Confidence);
98-
}
99-
Console.WriteLine("");
100-
}
77+
return probabilities;
10178
}
10279

103-
private static IEnumerable<ImageNetData> GetImagesData(string folder)
104-
{
105-
List<ImageNetData> imagesList = new List<ImageNetData>();
106-
string[] filePaths = Directory.GetFiles(folder).Where(filePath => Path.GetExtension(filePath) != ".md").ToArray();
107-
foreach (var filePath in filePaths)
108-
{
109-
ImageNetData imagedata = new ImageNetData { ImagePath = filePath, Label = Path.GetFileName(filePath) };
110-
imagesList.Add(imagedata);
111-
}
112-
return imagesList;
113-
}
114-
115-
private IDataView CreateEmptyDataView()
80+
public IEnumerable<float[]> Score(IDataView data)
11681
{
117-
//Create empty DataView. We just need the schema to call fit()
118-
List<ImageNetData> list = new List<ImageNetData>();
119-
IEnumerable<ImageNetData> enumerableData = list;
120-
var dv = mlContext.Data.LoadFromEnumerable(enumerableData);
121-
return dv;
122-
}
82+
var model = LoadModel(modelLocation);
12383

124-
public void DrawBoundingBox(string inputImageLocation, string outputImageLocation, string imageName, IList<YoloBoundingBox> filteredBoundingBoxes)
125-
{
126-
Image image = Image.FromFile(Path.Combine(inputImageLocation, imageName));
127-
128-
var originalImageHeight = image.Height;
129-
var originalImageWidth = image.Width;
130-
131-
foreach (var box in filteredBoundingBoxes)
132-
{
133-
// Get Bounding Box Dimensions
134-
var x = (uint)Math.Max(box.Dimensions.X, 0);
135-
var y = (uint)Math.Max(box.Dimensions.Y, 0);
136-
var width = (uint)Math.Min(originalImageWidth - x, box.Dimensions.Width);
137-
var height = (uint)Math.Min(originalImageHeight - y, box.Dimensions.Height);
138-
139-
// Resize To Image
140-
x = (uint)originalImageWidth * x / 416;
141-
y = (uint)originalImageHeight * y / 416;
142-
width = (uint)originalImageWidth * width / 416;
143-
height = (uint)originalImageHeight * height / 416;
144-
145-
// Bounding Box Text
146-
string text = $"{box.Label} ({(box.Confidence * 100).ToString("0")}%)";
147-
148-
using (Graphics thumbnailGraphic = Graphics.FromImage(image))
149-
{
150-
thumbnailGraphic.CompositingQuality = CompositingQuality.HighQuality;
151-
thumbnailGraphic.SmoothingMode = SmoothingMode.HighQuality;
152-
thumbnailGraphic.InterpolationMode = InterpolationMode.HighQualityBicubic;
153-
154-
// Define Text Options
155-
Font drawFont = new Font("Arial", 12, FontStyle.Bold);
156-
SizeF size = thumbnailGraphic.MeasureString(text, drawFont);
157-
SolidBrush fontBrush = new SolidBrush(Color.Black);
158-
Point atPoint = new Point((int)x, (int)y - (int)size.Height - 1);
159-
160-
// Define BoundingBox options
161-
Pen pen = new Pen(box.BoxColor, 3.2f);
162-
SolidBrush colorBrush = new SolidBrush(box.BoxColor);
163-
164-
// Draw text on image
165-
thumbnailGraphic.FillRectangle(colorBrush, (int)x, (int)(y - size.Height - 1), (int)size.Width, (int)size.Height);
166-
thumbnailGraphic.DrawString(text, drawFont, fontBrush, atPoint);
167-
168-
// Draw bounding box on image
169-
thumbnailGraphic.DrawRectangle(pen, x, y, width, height);
170-
}
171-
}
172-
173-
if (!Directory.Exists(outputImageLocation))
174-
{
175-
Directory.CreateDirectory(outputImageLocation);
176-
}
177-
178-
image.Save(Path.Combine(outputImageLocation, imageName));
84+
return PredictDataUsingModel(data, model);
17985
}
18086
}
18187
}

samples/csharp/getting-started/DeepLearning_ObjectDetection_Onnx/ObjectDetectionConsoleApp/Program.cs

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
using System;
22
using System.IO;
3+
using System.Collections.Generic;
4+
using System.Drawing;
5+
using System.Drawing.Drawing2D;
6+
using System.Linq;
7+
using Microsoft.ML;
8+
using ObjectDetection.YoloParser;
9+
using ObjectDetection.DataStructures;
310

411
namespace ObjectDetection
512
{
@@ -11,11 +18,41 @@ public static void Main()
1118
string assetsPath = GetAbsolutePath(assetsRelativePath);
1219
var modelFilePath = Path.Combine(assetsPath, "Model", "TinyYolo2_model.onnx");
1320
var imagesFolder = Path.Combine(assetsPath, "images");
21+
var outputFolder = Path.Combine(assetsPath, "images", "output");
22+
23+
// Initialize MLContext
24+
MLContext mlContext = new MLContext();
1425

1526
try
1627
{
17-
var modelScorer = new OnnxModelScorer(imagesFolder, modelFilePath);
18-
modelScorer.Score();
28+
// Load Data
29+
IEnumerable<ImageNetData> images = ImageNetData.ReadFromFile(imagesFolder);
30+
IDataView imageDataView = mlContext.Data.LoadFromEnumerable(images);
31+
32+
// Create instance of model scorer
33+
var modelScorer = new OnnxModelScorer(imagesFolder, modelFilePath, mlContext);
34+
35+
// Use model to score data
36+
IEnumerable<float[]> probabilities = modelScorer.Score(imageDataView);
37+
38+
// Post-process model output
39+
YoloWinMlParser parser = new YoloWinMlParser();
40+
41+
var boundingBoxes =
42+
probabilities
43+
.Select(probability => parser.ParseOutputs(probability))
44+
.Select(boxes => parser.FilterBoundingBoxes(boxes, 5, .5F));
45+
46+
// Draw bounding boxes for detected objects in each of the images
47+
for (var i = 0; i < images.Count(); i++)
48+
{
49+
string imageFileName = images.ElementAt(i).Label;
50+
IList<YoloBoundingBox> detectedObjects = boundingBoxes.ElementAt(i);
51+
52+
DrawBoundingBox(imagesFolder, outputFolder, imageFileName, detectedObjects);
53+
54+
LogDetectedObjects(imageFileName, detectedObjects);
55+
}
1956
}
2057
catch (Exception ex)
2158
{
@@ -35,6 +72,75 @@ public static string GetAbsolutePath(string relativePath)
3572

3673
return fullPath;
3774
}
75+
76+
private static void DrawBoundingBox(string inputImageLocation, string outputImageLocation, string imageName, IList<YoloBoundingBox> filteredBoundingBoxes)
77+
{
78+
Image image = Image.FromFile(Path.Combine(inputImageLocation, imageName));
79+
80+
var originalImageHeight = image.Height;
81+
var originalImageWidth = image.Width;
82+
83+
foreach (var box in filteredBoundingBoxes)
84+
{
85+
// Get Bounding Box Dimensions
86+
var x = (uint)Math.Max(box.Dimensions.X, 0);
87+
var y = (uint)Math.Max(box.Dimensions.Y, 0);
88+
var width = (uint)Math.Min(originalImageWidth - x, box.Dimensions.Width);
89+
var height = (uint)Math.Min(originalImageHeight - y, box.Dimensions.Height);
90+
91+
// Resize To Image
92+
x = (uint)originalImageWidth * x / OnnxModelScorer.ImageNetSettings.imageWidth;
93+
y = (uint)originalImageHeight * y / OnnxModelScorer.ImageNetSettings.imageHeight;
94+
width = (uint)originalImageWidth * width / OnnxModelScorer.ImageNetSettings.imageWidth;
95+
height = (uint)originalImageHeight * height / OnnxModelScorer.ImageNetSettings.imageHeight;
96+
97+
// Bounding Box Text
98+
string text = $"{box.Label} ({(box.Confidence * 100).ToString("0")}%)";
99+
100+
using (Graphics thumbnailGraphic = Graphics.FromImage(image))
101+
{
102+
thumbnailGraphic.CompositingQuality = CompositingQuality.HighQuality;
103+
thumbnailGraphic.SmoothingMode = SmoothingMode.HighQuality;
104+
thumbnailGraphic.InterpolationMode = InterpolationMode.HighQualityBicubic;
105+
106+
// Define Text Options
107+
Font drawFont = new Font("Arial", 12, FontStyle.Bold);
108+
SizeF size = thumbnailGraphic.MeasureString(text, drawFont);
109+
SolidBrush fontBrush = new SolidBrush(Color.Black);
110+
Point atPoint = new Point((int)x, (int)y - (int)size.Height - 1);
111+
112+
// Define BoundingBox options
113+
Pen pen = new Pen(box.BoxColor, 3.2f);
114+
SolidBrush colorBrush = new SolidBrush(box.BoxColor);
115+
116+
// Draw text on image
117+
thumbnailGraphic.FillRectangle(colorBrush, (int)x, (int)(y - size.Height - 1), (int)size.Width, (int)size.Height);
118+
thumbnailGraphic.DrawString(text, drawFont, fontBrush, atPoint);
119+
120+
// Draw bounding box on image
121+
thumbnailGraphic.DrawRectangle(pen, x, y, width, height);
122+
}
123+
}
124+
125+
if (!Directory.Exists(outputImageLocation))
126+
{
127+
Directory.CreateDirectory(outputImageLocation);
128+
}
129+
130+
image.Save(Path.Combine(outputImageLocation, imageName));
131+
}
132+
133+
private static void LogDetectedObjects(string imageName, IList<YoloBoundingBox> boundingBoxes)
134+
{
135+
Console.WriteLine($".....The objects in the image {imageName} are detected as below....");
136+
137+
foreach (var box in boundingBoxes)
138+
{
139+
Console.WriteLine($"{box.Label} and its Confidence score: {box.Confidence}");
140+
}
141+
142+
Console.WriteLine("");
143+
}
38144
}
39145
}
40146

samples/csharp/getting-started/DeepLearning_ObjectDetection_Onnx/ObjectDetectionConsoleApp/YoloParser/YoloBoundingBox.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
using ObjectDetection.YoloParser;
2-
using System.Drawing;
1+
using System.Drawing;
32

4-
namespace ObjectDetection
3+
namespace ObjectDetection.YoloParser
54
{
6-
class YoloBoundingBox
5+
public class BoundingBoxDimensions : DimensionsBase { }
6+
7+
public class YoloBoundingBox
78
{
89
public BoundingBoxDimensions Dimensions { get; set; }
910

@@ -18,6 +19,5 @@ public RectangleF Rect
1819

1920
public Color BoxColor { get; set; }
2021
}
21-
22-
class BoundingBoxDimensions : DimensionsBase { }
22+
2323
}

0 commit comments

Comments
 (0)