|
24 | 24 | #include <t81/io/format.hpp> |
25 | 25 | #include <t81/linalg/gemm.hpp> |
26 | 26 | #include <t81/linalg/gemm_gpu.hpp> |
| 27 | +#include <t81/linalg/pack_gpu.hpp> |
27 | 28 | #include <t81/tensor_metadata.hpp> |
28 | 29 | #include <t81/sparse/simple.hpp> |
29 | 30 | #include <t81/t81lib.hpp> |
@@ -426,6 +427,17 @@ namespace { |
426 | 427 | const std::size_t total = static_cast<std::size_t>(std::max<py::ssize_t>(info.size, 0)); |
427 | 428 | const auto src = static_cast<const float *>(info.ptr); |
428 | 429 | const auto dst = static_cast<std::int8_t *>(output.request().ptr); |
| 430 | +#if T81LIB_USE_METAL |
| 431 | + if (t81::linalg::detail::metal_available()) { |
| 432 | + try { |
| 433 | + std::span<const float> src_span{src, total}; |
| 434 | + std::span<std::int8_t> dst_span{dst, total}; |
| 435 | + t81::linalg::detail::metal_quantize_to_trits(src_span, dst_span, threshold); |
| 436 | + return output; |
| 437 | + } catch (const std::exception &) { |
| 438 | + } |
| 439 | + } |
| 440 | +#endif |
429 | 441 | for (std::size_t index = 0; index < total; ++index) { |
430 | 442 | dst[index] = quantize_trit(src[index], threshold); |
431 | 443 | } |
@@ -463,8 +475,25 @@ namespace { |
463 | 475 | py::array_t<std::uint8_t> packed( |
464 | 476 | {static_cast<std::size_t>(rows), static_cast<std::size_t>(limbs_per_row), limb_bytes}); |
465 | 477 | const auto *src = static_cast<const float *>(info.ptr); |
466 | | - auto *dst = static_cast<std::uint8_t *>(packed.request().ptr); |
| 478 | + auto packed_info = packed.request(true); |
| 479 | + auto *dst = static_cast<std::uint8_t *>(packed_info.ptr); |
467 | 480 | const std::size_t row_stride = static_cast<std::size_t>(limbs_per_row) * limb_bytes; |
| 481 | +#if T81LIB_USE_METAL |
| 482 | + if (t81::linalg::detail::metal_available()) { |
| 483 | + try { |
| 484 | + const std::size_t total_src = |
| 485 | + static_cast<std::size_t>(rows) * static_cast<std::size_t>(cols); |
| 486 | + const std::size_t total_dst = |
| 487 | + static_cast<std::size_t>(std::max<py::ssize_t>(packed_info.size, 0)); |
| 488 | + std::span<const float> src_span{src, total_src}; |
| 489 | + std::span<std::uint8_t> dst_span{dst, total_dst}; |
| 490 | + t81::linalg::detail::metal_pack_dense_matrix(src_span, dst_span, |
| 491 | + rows, cols, threshold); |
| 492 | + return packed; |
| 493 | + } catch (const std::exception &) { |
| 494 | + } |
| 495 | + } |
| 496 | +#endif |
468 | 497 | for (int row = 0; row < rows; ++row) { |
469 | 498 | const auto *row_ptr = |
470 | 499 | src + static_cast<std::size_t>(row) * static_cast<std::size_t>(cols); |
|
0 commit comments