-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathOnnxExportExtensions.cs
More file actions
167 lines (155 loc) · 12.3 KB
/
OnnxExportExtensions.cs
File metadata and controls
167 lines (155 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Google.Protobuf;
using Microsoft.ML.Data;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper;
namespace Microsoft.ML
{
public static class OnnxExportExtensions
{
private static ModelProto ConvertToOnnxProtobufCore(IHostEnvironment env, OnnxContextImpl ctx, ITransformer transform, IDataView inputData, string[] outputColumnNamesToKeep = null)
{
var outputData = transform.Transform(inputData);
LinkedList<ITransformCanSaveOnnx> transforms = null;
using (var ch = env.Start("ONNX conversion"))
{
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms);
// We pass in the output names to keep, but this next call expects a list of ones to drop. Invert the list.
var outputColumnNamesToDrop = new HashSet<string>();
if (outputColumnNamesToKeep != null)
{
for (int i = 0; i < sink.Schema.Count; ++i)
{
if (!outputColumnNamesToKeep.Contains(sink.Schema[i].Name))
{
outputColumnNamesToDrop.Add(sink.Schema[i].Name);
}
}
}
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, outputColumnNamesToDrop);
}
}
/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnxProtobuf(ModelOperationsCatalog, ITransformer, IDataView, string[])"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="outputColumns">List of output columns we want to keep.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
[BestFriend]
internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, string[] outputColumns = null) =>
ConvertToOnnxProtobuf(catalog, transform, inputData.Schema, outputColumns);
/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnxProtobuf(ModelOperationsCatalog, ITransformer, IDataView, int)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="opSetVersion">The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
[BestFriend]
internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion) =>
ConvertToOnnxProtobuf(catalog, transform, inputData.Schema, opSetVersion);
/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnxProtobuf(ModelOperationsCatalog, ITransformer, DataViewSchema, string[])"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputSchema">The schema of the input to the transformer.</param>
/// <param name="outputColumns">List of output columns we want to keep.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
[BestFriend]
internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog catalog, ITransformer transform, DataViewSchema inputSchema, string[] outputColumns = null)
{
var env = catalog.GetEnvironment();
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable);
return ConvertToOnnxProtobufCore(env, ctx, transform, new EmptyDataView(env, inputSchema), outputColumns);
}
/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnxProtobuf(ModelOperationsCatalog, ITransformer, DataViewSchema, int)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputSchema">The schema of the input to the transformer.</param>
/// <param name="opSetVersion">The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
[BestFriend]
internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog catalog, ITransformer transform, DataViewSchema inputSchema, int opSetVersion)
{
var env = catalog.GetEnvironment();
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable, opSetVersion);
return ConvertToOnnxProtobufCore(env, ctx, transform, new EmptyDataView(env, inputSchema));
}
/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnx(ModelOperationsCatalog, ITransformer, IDataView, Stream)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, inputData).WriteTo(stream);
/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnx(ModelOperationsCatalog, ITransformer, IDataView, int, Stream)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="opSetVersion">The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12</param>
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, inputData, opSetVersion).WriteTo(stream);
/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnx(ModelOperationsCatalog, ITransformer, IDataView, Stream)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <param name="outputColumns">List of output columns we want to keep.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, Stream stream, params string[] outputColumns) =>
ConvertToOnnxProtobuf(catalog, transform, inputData, outputColumns).WriteTo(stream);
/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnx(ModelOperationsCatalog, ITransformer, DataViewSchema, Stream)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputSchema">The schema of the input to the transformer.</param>
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, DataViewSchema inputSchema, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, inputSchema).WriteTo(stream);
/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnx(ModelOperationsCatalog, ITransformer, DataViewSchema, int, Stream)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputSchema">The schema of the input to the transformer.</param>
/// <param name="opSetVersion">The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12</param>
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, DataViewSchema inputSchema, int opSetVersion, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, inputSchema, opSetVersion).WriteTo(stream);
/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnx(ModelOperationsCatalog, ITransformer, DataViewSchema, Stream)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputSchema">The schema of the input to the transformer.</param>
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <param name="outputColumns">List of output columns we want to keep.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, DataViewSchema inputSchema, Stream stream, params string[] outputColumns) =>
ConvertToOnnxProtobuf(catalog, transform, inputSchema, outputColumns).WriteTo(stream);
}
}