|
1 | 1 | #include "../../../devices/metax/metax_common.h" |
2 | 2 | #include "../../../devices/metax/metax_handle.h" |
3 | 3 | #include "../../../devices/metax/metax_kernel_common.h" |
| 4 | + |
4 | 5 | #include "binary_cross_entropy_with_logits_metax.h" |
5 | | -#include <mc_runtime.h> |
| 6 | + |
6 | 7 | #include <type_traits> |
7 | 8 |
|
8 | 9 | namespace op::bce_with_logits::metax { |
@@ -191,7 +192,7 @@ infiniStatus_t Descriptor::calculate( |
191 | 192 | const void *pos_weight, |
192 | 193 | void *stream) const { |
193 | 194 |
|
194 | | - mcStream_t custream = (mcStream_t)stream; |
| 195 | + hcStream_t custream = (hcStream_t)stream; |
195 | 196 | size_t n = _info.num_elements; |
196 | 197 |
|
197 | 198 | // F16/BF16 + 归约需要 float workspace |
@@ -219,7 +220,7 @@ infiniStatus_t Descriptor::calculate( |
219 | 220 | case INFINI_DTYPE_F32: { |
220 | 221 | // 如果是规约操作,计算前需将输出位置清零 |
221 | 222 | if (_reduction != INFINIOP_REDUCTION_NONE) { |
222 | | - mcMemsetAsync(out, 0, sizeof(float), custream); |
| 223 | + hcMemsetAsync(out, 0, sizeof(float), custream); |
223 | 224 | } |
224 | 225 |
|
225 | 226 | bce_logits_kernel<float, float><<<grid, block, 0, custream>>>( |
@@ -255,7 +256,7 @@ infiniStatus_t Descriptor::calculate( |
255 | 256 | out_raw = out; |
256 | 257 | } else { |
257 | 258 | workspace_f = static_cast<float *>(workspace); |
258 | | - mcMemsetAsync(workspace_f, 0, sizeof(float), custream); |
| 259 | + hcMemsetAsync(workspace_f, 0, sizeof(float), custream); |
259 | 260 | out_raw = workspace_f; |
260 | 261 | } |
261 | 262 |
|
@@ -294,7 +295,7 @@ infiniStatus_t Descriptor::calculate( |
294 | 295 | out_raw = out; |
295 | 296 | } else { |
296 | 297 | workspace_f = static_cast<float *>(workspace); |
297 | | - mcMemsetAsync(workspace_f, 0, sizeof(float), custream); |
| 298 | + hcMemsetAsync(workspace_f, 0, sizeof(float), custream); |
298 | 299 | out_raw = workspace_f; |
299 | 300 | } |
300 | 301 |
|
@@ -324,8 +325,8 @@ infiniStatus_t Descriptor::calculate( |
324 | 325 | return INFINI_STATUS_BAD_TENSOR_DTYPE; |
325 | 326 | } |
326 | 327 |
|
327 | | - mcError_t err = mcGetLastError(); |
328 | | - if (err != mcSuccess) { |
| 328 | + hcError_t err = hcGetLastError(); |
| 329 | + if (err != hcSuccess) { |
329 | 330 | return INFINI_STATUS_INTERNAL_ERROR; |
330 | 331 | } |
331 | 332 | return INFINI_STATUS_SUCCESS; |
|
0 commit comments