|
6 | 6 | * LICENSE file in the root directory of this source tree. |
7 | 7 | */ |
8 | 8 |
|
| 9 | +#include <executorch/backends/cadence/generic/operators/op_quantized_conv1d_ncl.h> |
9 | 10 | #include <executorch/backends/cadence/hifi/kernels/kernels.h> |
10 | 11 | #include <executorch/backends/cadence/hifi/operators/operators.h> |
11 | 12 | #include <executorch/runtime/kernel/kernel_includes.h> |
@@ -58,9 +59,9 @@ void xa_opt_quantized_conv1d_ncl_asym8sxsym8s_asym8s( |
58 | 59 | WORD32 kernel_width = weight.size(2); |
59 | 60 | WORD32 out_width = out.size(2); |
60 | 61 | WORD32 out_height = 1; |
61 | | - WORD32 x_stride = stride[1]; |
| 62 | + WORD32 x_stride = 1; |
62 | 63 | WORD32 y_stride = stride[0]; |
63 | | - WORD32 x_padding = padding[1]; |
| 64 | + WORD32 x_padding = 0; |
64 | 65 | WORD32 y_padding = padding[0]; |
65 | 66 | WORD32 dilation_height = 1; |
66 | 67 | WORD32 dilation_width = 1; |
@@ -236,8 +237,8 @@ void xa_opt_quantized_conv1d_ncl_asym8uxsym8u_asym8u( |
236 | 237 | WORD32 kernel_width = weight.size(2); |
237 | 238 | WORD32 out_width = out.size(2); |
238 | 239 | WORD32 out_height = 1; |
239 | | - WORD32 x_stride = stride[1]; |
240 | | - WORD32 x_padding = padding[1]; |
| 240 | + WORD32 x_stride = stride[0]; |
| 241 | + WORD32 x_padding = padding[0]; |
241 | 242 | WORD32 input_zero_bias = -in_zero_point; |
242 | 243 | WORD32 out_multiplier32 = bias_scale * (1. / output_scale) * 2147483648; |
243 | 244 | WORD32 out_shift32 = 0; |
@@ -345,46 +346,114 @@ void quantized_conv1d_ncl_per_tensor_out( |
345 | 346 | const Tensor& bias, |
346 | 347 | IntArrayRef stride, |
347 | 348 | IntArrayRef padding, |
348 | | - __ET_UNUSED IntArrayRef dilation, |
349 | | - __ET_UNUSED int64_t groups, |
| 349 | + IntArrayRef dilation, |
| 350 | + int64_t groups, |
350 | 351 | int64_t in_zero_point, |
351 | 352 | int64_t weight_zero_point, |
352 | 353 | double bias_scale, |
353 | 354 | double output_scale, |
354 | 355 | int64_t output_zero_point, |
355 | | - __ET_UNUSED int64_t out_multiplier, |
356 | | - __ET_UNUSED int64_t out_shift, |
| 356 | + int64_t out_multiplier, |
| 357 | + int64_t out_shift, |
357 | 358 | Tensor& out) { |
358 | | - ScalarType dtype = out.scalar_type(); |
359 | | - |
360 | | - if (dtype == ScalarType::Char) { |
361 | | - xa_opt_quantized_conv1d_ncl_asym8sxsym8s_asym8s( |
| 359 | + // HiFi nnlib kernels only support dilation=1. |
| 360 | + // Fall back to generic implementation for dilation > 1. |
| 361 | + // Note: For 1D convolution, dilation is a single-element array. |
| 362 | + if (dilation[0] != 1) { |
| 363 | + impl::generic::native::quantized_conv1d_ncl_per_tensor_out( |
362 | 364 | ctx, |
363 | 365 | input, |
364 | 366 | weight, |
365 | 367 | bias, |
366 | 368 | stride, |
367 | 369 | padding, |
368 | | - static_cast<int32_t>(in_zero_point), |
369 | | - static_cast<int32_t>(weight_zero_point), |
370 | | - static_cast<float>(bias_scale), |
371 | | - static_cast<float>(output_scale), |
372 | | - static_cast<int32_t>(output_zero_point), |
| 370 | + dilation, |
| 371 | + groups, |
| 372 | + in_zero_point, |
| 373 | + weight_zero_point, |
| 374 | + bias_scale, |
| 375 | + output_scale, |
| 376 | + output_zero_point, |
| 377 | + out_multiplier, |
| 378 | + out_shift, |
373 | 379 | out); |
| 380 | + return; |
| 381 | + } |
| 382 | + |
| 383 | + ScalarType dtype = out.scalar_type(); |
| 384 | + |
| 385 | + if (dtype == ScalarType::Char) { |
| 386 | + // HiFi nnlib conv2d kernel produces incorrect results with stride > 1 |
| 387 | + // on some backends (e.g., Artemis HiFi4). Fall back to generic. |
| 388 | + if (stride[0] > 1) { |
| 389 | + impl::generic::native::quantized_conv1d_ncl_per_tensor_out( |
| 390 | + ctx, |
| 391 | + input, |
| 392 | + weight, |
| 393 | + bias, |
| 394 | + stride, |
| 395 | + padding, |
| 396 | + dilation, |
| 397 | + groups, |
| 398 | + in_zero_point, |
| 399 | + weight_zero_point, |
| 400 | + bias_scale, |
| 401 | + output_scale, |
| 402 | + output_zero_point, |
| 403 | + out_multiplier, |
| 404 | + out_shift, |
| 405 | + out); |
| 406 | + } else { |
| 407 | + xa_opt_quantized_conv1d_ncl_asym8sxsym8s_asym8s( |
| 408 | + ctx, |
| 409 | + input, |
| 410 | + weight, |
| 411 | + bias, |
| 412 | + stride, |
| 413 | + padding, |
| 414 | + static_cast<int32_t>(in_zero_point), |
| 415 | + static_cast<int32_t>(weight_zero_point), |
| 416 | + static_cast<float>(bias_scale), |
| 417 | + static_cast<float>(output_scale), |
| 418 | + static_cast<int32_t>(output_zero_point), |
| 419 | + out); |
| 420 | + } |
374 | 421 | } else if (dtype == ScalarType::Byte) { |
375 | | - xa_opt_quantized_conv1d_ncl_asym8uxsym8u_asym8u( |
376 | | - ctx, |
377 | | - input, |
378 | | - weight, |
379 | | - bias, |
380 | | - stride, |
381 | | - padding, |
382 | | - static_cast<int32_t>(in_zero_point), |
383 | | - static_cast<int32_t>(weight_zero_point), |
384 | | - static_cast<float>(bias_scale), |
385 | | - static_cast<float>(output_scale), |
386 | | - static_cast<int32_t>(output_zero_point), |
387 | | - out); |
| 422 | + // HiFi nnlib conv1d_std kernel does not support depthwise (groups > 1). |
| 423 | + // Fall back to generic implementation. |
| 424 | + if (groups > 1) { |
| 425 | + impl::generic::native::quantized_conv1d_ncl_per_tensor_out( |
| 426 | + ctx, |
| 427 | + input, |
| 428 | + weight, |
| 429 | + bias, |
| 430 | + stride, |
| 431 | + padding, |
| 432 | + dilation, |
| 433 | + groups, |
| 434 | + in_zero_point, |
| 435 | + weight_zero_point, |
| 436 | + bias_scale, |
| 437 | + output_scale, |
| 438 | + output_zero_point, |
| 439 | + out_multiplier, |
| 440 | + out_shift, |
| 441 | + out); |
| 442 | + } else { |
| 443 | + xa_opt_quantized_conv1d_ncl_asym8uxsym8u_asym8u( |
| 444 | + ctx, |
| 445 | + input, |
| 446 | + weight, |
| 447 | + bias, |
| 448 | + stride, |
| 449 | + padding, |
| 450 | + static_cast<int32_t>(in_zero_point), |
| 451 | + static_cast<int32_t>(weight_zero_point), |
| 452 | + static_cast<float>(bias_scale), |
| 453 | + static_cast<float>(output_scale), |
| 454 | + static_cast<int32_t>(output_zero_point), |
| 455 | + out); |
| 456 | + } |
388 | 457 | } else { |
389 | 458 | ET_DCHECK_MSG( |
390 | 459 | false, |
|
0 commit comments