@@ -129,14 +129,27 @@ bool ArgMaxPlugin::supportsFormatCombination(
129129 ss << " ArgMaxPlugin accepts only two input, but here pos is " << pos;
130130 CHECK (pos < 2 , ss.str ());
131131
132- bool typeOk = inOut[0 ].desc .type == DataType::kFLOAT ;
133- typeOk = typeOk || inOut[0 ].desc .type == DataType::kHALF ;
134- typeOk = typeOk || inOut[0 ].desc .type == DataType::kBF16 ;
135- typeOk = typeOk || inOut[0 ].desc .type == DataType::kINT8 ;
136- // here support int8, and enqueue() will recieve int8 input
137- // or it will drop back to float/half to call enqueue()
132+ // pos=1 is the output tensor: always INT32 kLINEAR
133+ if (pos == 1 ) {
134+ return inOut[1 ].desc .type == DataType::kINT32
135+ && inOut[1 ].desc .format == PluginFormat::kLINEAR ;
136+ }
137+
138+ // pos=0 is the input tensor
139+ auto typ = inOut[0 ].desc .type ;
140+ auto fmt = inOut[0 ].desc .format ;
141+
142+ bool typeOk = typ == DataType::kFLOAT
143+ || typ == DataType::kHALF
144+ || typ == DataType::kBF16
145+ || typ == DataType::kINT8 ;
138146
139- bool formatOK = inOut[0 ].desc .format == PluginFormat::kLINEAR ;
147+ bool formatOK = fmt == PluginFormat::kLINEAR
148+ || fmt == PluginFormat::kHWC
149+ || (fmt == PluginFormat::kHWC4 && typ == DataType::kINT8 )
150+ || (fmt == PluginFormat::kHWC4 && typ == DataType::kFLOAT )
151+ || (fmt == PluginFormat::kHWC4 && typ == DataType::kHALF )
152+ || (fmt == PluginFormat::kHWC8 && typ == DataType::kHALF );
140153
141154 return formatOK && typeOk;
142155}
@@ -170,7 +183,7 @@ int32_t ArgMaxPlugin::getOutputShapes(DimsExprs const* inputs, int32_t nbInputs,
170183size_t ArgMaxPlugin::getWorkspaceSize (DynamicPluginTensorDesc const * inputs, int32_t nbInputs,
171184 DynamicPluginTensorDesc const * outputs, int32_t nbOutputs) const noexcept
172185{
173- return sizeof ( mAxis ) ;
186+ return 0 ;
174187}
175188
176189
@@ -200,29 +213,39 @@ int32_t ArgMaxPlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDes
200213
201214 // cout << "type is: " << static_cast<int32_t>(type) << endl;
202215
216+ auto fmt = inputDesc[0 ].format ;
217+ bool is_nhwc = (fmt == nvinfer1::TensorFormat::kHWC );
218+ bool is_nhwc4 = (fmt == nvinfer1::TensorFormat::kHWC4 );
219+ bool is_nhwc8 = (fmt == nvinfer1::TensorFormat::kHWC8 );
220+ int32_t n_spatial = n_size * m_size; // N * H * W, used by NHWC kernels
221+
203222 if (type == nvinfer1::DataType::kFLOAT ) {
204223 const float * ptr_inp = static_cast <const float *>(inputs[0 ]);
205224 int32_t * ptr_out = static_cast <int32_t *>(outputs[0 ]);
206- argMaxFunc<float >(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
207- // cout << "type is: fp32" << endl;
225+ if (is_nhwc4) argMaxHWC4FP32Func (ptr_inp, ptr_out, n_spatial, dimsize, &stream);
226+ else if (is_nhwc) argMaxHWCFunc<float >(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
227+ else argMaxNCHWFunc<float >(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
208228
209229 } else if (type == nvinfer1::DataType::kHALF ) {
210230 const __half* ptr_inp = static_cast <const __half*>(inputs[0 ]);
211231 int32_t * ptr_out = static_cast <int32_t *>(outputs[0 ]);
212- argMaxFunc<__half>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
213- // cout << "type is: fp16" << endl;
232+ if (is_nhwc4) argMaxHWC4FP16Func (ptr_inp, ptr_out, n_spatial, dimsize, &stream);
233+ else if (is_nhwc8) argMaxHWC8FP16Func (ptr_inp, ptr_out, n_spatial, dimsize, &stream);
234+ else if (is_nhwc) argMaxHWCFunc<__half>(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
235+ else argMaxNCHWFunc<__half>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
214236
215237 } else if (type == nvinfer1::DataType::kBF16 ) {
216- // cout << "type is: bf16" << endl;
217238 const __nv_bfloat16* ptr_inp = static_cast <const __nv_bfloat16*>(inputs[0 ]);
218239 int32_t * ptr_out = static_cast <int32_t *>(outputs[0 ]);
219- argMaxFunc<__nv_bfloat16>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
240+ if (is_nhwc) argMaxHWCFunc<__nv_bfloat16>(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
241+ else argMaxNCHWFunc<__nv_bfloat16>(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
220242
221243 } else if (type == nvinfer1::DataType::kINT8 ) {
222244 const int8_t * ptr_inp = static_cast <const int8_t *>(inputs[0 ]);
223245 int32_t * ptr_out = static_cast <int32_t *>(outputs[0 ]);
224- argMaxFunc<int8_t >(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
225- // cout << "type is: int8" << endl;
246+ if (is_nhwc4) argMaxHWC4Int8Func (ptr_inp, ptr_out, n_spatial, dimsize, &stream);
247+ else if (is_nhwc) argMaxHWCFunc<int8_t >(ptr_inp, ptr_out, n_spatial, dimsize, &stream);
248+ else argMaxNCHWFunc<int8_t >(ptr_inp, ptr_out, n_size, dimsize, m_size, &stream);
226249
227250 } else {
228251 cout << " type is: other" << endl;
@@ -292,7 +315,7 @@ IPluginV3* ArgMaxPluginCreator::createPlugin(char const* name, PluginFieldCollec
292315 const string fieldName (fc->fields [i].name );
293316 if (fieldName == " dim" )
294317 {
295- mAxis = *static_cast <int32_t const *>(fc->fields [i].data );
318+ mAxis = *static_cast <int64_t const *>(fc->fields [i].data );
296319 }
297320 }
298321 auto plugin = new ArgMaxPlugin (mAxis );
0 commit comments