Skip to content

Commit 5cecde9

Browse files
committed
optimize tensorrt, better argmax impl
1 parent b361d6b commit 5cecde9

6 files changed

Lines changed: 320 additions & 43 deletions

File tree

tensorrt/batch_stream.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class BatchStream : public IBatchStream
4646

4747
readDataFile(dataFile, dataRoot);
4848
mSampleSize = std::accumulate(
49-
mDims.d, mDims.d + mDims.nbDims, 1, std::multiplies<int64_t>()) * sizeof(float);
49+
mDims.d, mDims.d + mDims.nbDims, 1, std::multiplies<int64_t>());
5050
mData.resize(mSampleSize * mBatchSize);
5151
}
5252

@@ -140,7 +140,7 @@ class BatchStream : public IBatchStream
140140
Dims3 mDims{};
141141
std::vector<string> mPaths;
142142
std::vector<float> mData;
143-
int mSampleSize{0};
143+
int64_t mSampleSize{0};
144144
};
145145

146146

tensorrt/plugins/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11

22

3-
43
add_library (custom_plugin SHARED argmax_plugin.cu)
54
target_compile_features (custom_plugin PRIVATE cuda_std_14)
65
target_include_directories (custom_plugin PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

tensorrt/plugins/argmax_plugin.cu

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,27 @@ bool ArgMaxPlugin::supportsFormatCombination(
129129
ss << "ArgMaxPlugin accepts only two input, but here pos is " << pos;
130130
CHECK(pos < 2, ss.str());
131131

132-
bool typeOk = inOut[0].desc.type == DataType::kFLOAT;
133-
typeOk = typeOk || inOut[0].desc.type == DataType::kHALF;
134-
typeOk = typeOk || inOut[0].desc.type == DataType::kBF16;
135-
typeOk = typeOk || inOut[0].desc.type == DataType::kINT8;
136-
// here support int8, and enqueue() will recieve int8 input
137-
// or it will drop back to float/half to call enqueue()
132+
// pos=1 is the output tensor: always INT32 kLINEAR
133+
if (pos == 1) {
134+
return inOut[1].desc.type == DataType::kINT32
135+
&& inOut[1].desc.format == PluginFormat::kLINEAR;
136+
}
137+
138+
// pos=0 is the input tensor
139+
auto typ = inOut[0].desc.type;
140+
auto fmt = inOut[0].desc.format;
141+
142+
bool typeOk = typ == DataType::kFLOAT
143+
|| typ == DataType::kHALF
144+
|| typ == DataType::kBF16
145+
|| typ == DataType::kINT8;
138146

139-
bool formatOK = inOut[0].desc.format == PluginFormat::kLINEAR;
147+
bool formatOK = fmt == PluginFormat::kLINEAR
148+
|| fmt == PluginFormat::kHWC
149+
|| (fmt == PluginFormat::kHWC4 && typ == DataType::kINT8)
150+
|| (fmt == PluginFormat::kHWC4 && typ == DataType::kFLOAT)
151+
|| (fmt == PluginFormat::kHWC4 && typ == DataType::kHALF)
152+
|| (fmt == PluginFormat::kHWC8 && typ == DataType::kHALF);
140153

141154
return formatOK && typeOk;
142155
}
@@ -170,7 +183,7 @@ int32_t ArgMaxPlugin::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs,
170183
size_t ArgMaxPlugin::getWorkspaceSize(DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
171184
DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept
172185
{
173-
return sizeof(mAxis);
186+
return 0;
174187
}
175188

176189

@@ -200,29 +213,39 @@ int32_t ArgMaxPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDes
200213

201214
// cout << "type is: " << static_cast<int32_t>(type) << endl;
202215

216+
auto fmt = inputDesc[0].format;
217+
bool is_nhwc = (fmt == nvinfer1::TensorFormat::kHWC);
218+
bool is_nhwc4 = (fmt == nvinfer1::TensorFormat::kHWC4);
219+
bool is_nhwc8 = (fmt == nvinfer1::TensorFormat::kHWC8);
220+
int32_t n_spatial = n_size * m_size; // N * H * W, used by NHWC kernels
221+
203222
if (type == nvinfer1::DataType::kFLOAT) {
204223
const float* ptr_inp = static_cast<const float*>(inputs[0]);
205224
int32_t* ptr_out = static_cast<int32_t*>(outputs[0]);
206-
argMaxFunc<float>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
207-
// cout << "type is: fp32" << endl;
225+
if (is_nhwc4) argMaxHWC4FP32Func(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
226+
else if (is_nhwc) argMaxHWCFunc<float>(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
227+
else argMaxNCHWFunc<float>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
208228

209229
} else if (type == nvinfer1::DataType::kHALF) {
210230
const __half* ptr_inp = static_cast<const __half*>(inputs[0]);
211231
int32_t* ptr_out = static_cast<int32_t*>(outputs[0]);
212-
argMaxFunc<__half>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
213-
// cout << "type is: fp16" << endl;
232+
if (is_nhwc4) argMaxHWC4FP16Func(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
233+
else if (is_nhwc8) argMaxHWC8FP16Func(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
234+
else if (is_nhwc) argMaxHWCFunc<__half>(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
235+
else argMaxNCHWFunc<__half>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
214236

215237
} else if (type == nvinfer1::DataType::kBF16) {
216-
// cout << "type is: bf16" << endl;
217238
const __nv_bfloat16* ptr_inp = static_cast<const __nv_bfloat16*>(inputs[0]);
218239
int32_t* ptr_out = static_cast<int32_t*>(outputs[0]);
219-
argMaxFunc<__nv_bfloat16>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
240+
if (is_nhwc) argMaxHWCFunc<__nv_bfloat16>(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
241+
else argMaxNCHWFunc<__nv_bfloat16>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
220242

221243
} else if (type == nvinfer1::DataType::kINT8) {
222244
const int8_t* ptr_inp = static_cast<const int8_t*>(inputs[0]);
223245
int32_t* ptr_out = static_cast<int32_t*>(outputs[0]);
224-
argMaxFunc<int8_t>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
225-
// cout << "type is: int8" << endl;
246+
if (is_nhwc4) argMaxHWC4Int8Func(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
247+
else if (is_nhwc) argMaxHWCFunc<int8_t>(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
248+
else argMaxNCHWFunc<int8_t>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
226249

227250
} else {
228251
cout << "type is: other" << endl;
@@ -292,7 +315,7 @@ IPluginV3* ArgMaxPluginCreator::createPlugin(char const* name, PluginFieldCollec
292315
const string fieldName(fc->fields[i].name);
293316
if (fieldName == "dim")
294317
{
295-
mAxis = *static_cast<int32_t const*>(fc->fields[i].data);
318+
mAxis = *static_cast<int64_t const*>(fc->fields[i].data);
296319
}
297320
}
298321
auto plugin = new ArgMaxPlugin(mAxis);

0 commit comments

Comments
 (0)