Skip to content

Commit 6db53f0

Browse files
committed
begin metal implementation
1 parent 64d06d4 commit 6db53f0

11 files changed

Lines changed: 409 additions & 3 deletions

File tree

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
build/
2+
build-*/
23
dist/
34
*.o
45
*.obj
@@ -14,6 +15,13 @@ pipx_home/
1415
pipx_logs/
1516
t81lib.egg-info/
1617

18+
# CMake artifacts
19+
CMakeCache.txt
20+
CMakeFiles/
21+
cmake_install.cmake
22+
CTestTestfile.cmake
23+
Makefile
24+
1725
# Python runtime artifacts
1826
__pycache__/
1927
*.py[cod]

AGENTS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,5 @@ This file helps AI agents discover and understand how to work with this reposito
6363
- Hardened the SIMD detection helpers in `include/t81/core/detail/simd.hpp` with CPUID/xgetbv fallbacks, documented the `add_trytes_*` overflow semantics, and made NEON runtime checks opt-out via `T81_DISABLE_NEON`.
6464
- Added the `compression-first` GGUF export profile (metadata + CLI flags), plus `scripts/gguf_benchmark.py` and CLI docs that walk FP16 to ternary GGUF before/after measurements.
6565
- Added `examples/ternary_phi3_ptq_qat_demo.ipynb` to showcase Phi-3-mini PTQ/QAT size, latency, and perplexity comparisons in one compact notebook.
66+
- Added Metal pack/quantize kernels (`src/linalg/pack_kernel.metal`, `src/linalg/pack_metal.mm`) plus `include/t81/linalg/pack_gpu.hpp` and Python binding dispatch so PTQ packing can run on Apple Metal when enabled.
67+
- Documented GGUF helper APIs (`read_gguf`, `repack_gguf`, `dequantize_gguf`) plus the experimental TQ1_1 note in the GGUF and Python docs.

docs/python-api.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ This page is the landing spot for the auto-generated Python reference. It is pro
1313
| `t81.convert` / `t81.gguf` | call the conversion/GGUF helpers programmatically | `from t81 import convert, gguf` |
1414
| `t81.hardware` | explore ternary hardware emulation helpers | `from t81 import hardware` |
1515

16+
## GGUF helpers (quick reference)
17+
18+
The `t81.gguf` module exposes streaming and compatibility utilities beyond the CLI wrappers:
19+
20+
- `t81.gguf.write_gguf` to emit GGUF bundles from converted models.
21+
- `t81.gguf.read_gguf` to stream tensor payloads and metadata without loading the full file.
22+
- `t81.gguf.repack_gguf` to re-quantize existing float16/float32 GGUF bundles into TQ1_0/TQ2_0.
23+
- `t81.gguf.dequantize_gguf` (plus `t81.dequantize_gguf_to_float`) to rewrite ternary bundles into float GGUF files for broader runtime compatibility.
24+
1625
## Generating the docs
1726

1827
1. Install the tooling (ideally in a virtual environment):

docs/python-cookbook.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,22 @@ t81-gguf --input model.t81 --validate
6363
```
6464

6565
This recipe shows how Python experiments (scripts, notebooks) complement the CLI docs in `docs/references/cli-usage.md`.
66+
67+
## 4. Inspect, repack, or dequantize GGUF bundles in Python
68+
69+
```python
70+
import numpy as np
71+
from t81 import gguf
72+
73+
# Stream metadata and tensors without loading the full file into RAM.
74+
payload, metadata = gguf.read_gguf("model-tq1.gguf", return_metadata=True)
75+
print(metadata.get("general.architecture"))
76+
77+
# Repack a float GGUF into ternary (float tensors only).
78+
gguf.repack_gguf("model-f16.gguf", "model-tq1.gguf", quant="TQ1_0", threshold=0.45)
79+
80+
# Convert a ternary GGUF back to float for runtimes without TQ support.
81+
gguf.dequantize_gguf("model-tq1.gguf", "model-f16.gguf", dtype=np.float16)
82+
```
83+
84+
Use `dequantize_gguf_to_float` when you always want float32 output, and set `T81_ENABLE_TQ1_1=1` before using the experimental `tq1_1-draft` profile.

docs/references/gguf.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,13 @@ t81 convert meta-llama/Llama-3.2-3B-Instruct llama3.2-3b-t81.gguf \
1818
### Export profiles
1919

2020
For a no-knobs compression-first export, use the `compression-first` profile via the CLI (`--gguf-profile` or `--profile`). It stamps `t81.profile=compression-first` in metadata and pins the GGUF quant scheme to TQ1_0 for maximum compression.
21+
22+
### Experimental TQ1_1 profile
23+
24+
`tq1_1-draft` is available for header-size testing only. It requires `T81_ENABLE_TQ1_1=1` and writes payloads that are not yet loadable by llama.cpp, so use it for experiments rather than production GGUF bundles.
25+
26+
### Repacking + dequantizing existing GGUF files
27+
28+
`t81.gguf.repack_gguf` re-quantizes an existing GGUF file (float tensors only) and preserves the metadata, so you can take a float32 or float16 bundle and emit a ternary one without running the full conversion pipeline. For compatibility with runtimes that do not support ternary types, `t81.gguf.dequantize_gguf` (and the convenience `t81.dequantize_gguf_to_float`) converts TQ1_0/TQ2_0 payloads into float32 or float16 GGUF files.
29+
30+
If you need to inspect a GGUF without loading everything into RAM, `t81.gguf.read_gguf` streams metadata and tensor payloads from the file handle and can return raw bytes instead of dequantized tensors.

include/t81/linalg/pack_gpu.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <cstddef>
4+
#include <cstdint>
5+
#include <span>
6+
7+
#ifndef T81LIB_USE_METAL
8+
#define T81LIB_USE_METAL 0
9+
#endif
10+
11+
namespace t81::linalg::detail {
12+
13+
#if T81LIB_USE_METAL
14+
void metal_quantize_to_trits(std::span<const float> src,
15+
std::span<std::int8_t> dst,
16+
float threshold);
17+
18+
void metal_pack_dense_matrix(std::span<const float> src,
19+
std::span<std::uint8_t> dst,
20+
int rows,
21+
int cols,
22+
float threshold);
23+
#endif
24+
25+
} // namespace t81::linalg::detail

python/CMakeLists.txt

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ endif()
6464

6565
if(USE_METAL)
6666
target_sources(t81lib_python PRIVATE
67-
${CMAKE_CURRENT_SOURCE_DIR}/../src/linalg/gemm_metal.mm)
67+
${CMAKE_CURRENT_SOURCE_DIR}/../src/linalg/gemm_metal.mm
68+
${CMAKE_CURRENT_SOURCE_DIR}/../src/linalg/pack_metal.mm)
6869

6970
set(METAL_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/../src/linalg/gemm_kernel.metal)
7071
set(METAL_AIR ${CMAKE_CURRENT_BINARY_DIR}/gemm_kernel.air)
@@ -79,6 +80,20 @@ if(USE_METAL)
7980
add_custom_target(gemm_metal_shader DEPENDS ${METAL_LIB})
8081
add_dependencies(t81lib_python gemm_metal_shader)
8182
target_compile_definitions(t81lib_python PRIVATE GEMM_METAL_LIBRARY_PATH=\"${METAL_LIB}\")
83+
84+
set(PACK_METAL_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/../src/linalg/pack_kernel.metal)
85+
set(PACK_METAL_AIR ${CMAKE_CURRENT_BINARY_DIR}/pack_kernel.air)
86+
set(PACK_METAL_LIB ${CMAKE_CURRENT_BINARY_DIR}/pack_kernel.metallib)
87+
88+
add_custom_command(OUTPUT ${PACK_METAL_LIB}
89+
COMMAND xcrun metal -c -o ${PACK_METAL_AIR} ${PACK_METAL_SOURCE}
90+
COMMAND xcrun metallib -o ${PACK_METAL_LIB} ${PACK_METAL_AIR}
91+
DEPENDS ${PACK_METAL_SOURCE}
92+
COMMENT "Compiling Metal pack shader")
93+
94+
add_custom_target(pack_metal_shader DEPENDS ${PACK_METAL_LIB})
95+
add_dependencies(t81lib_python pack_metal_shader)
96+
target_compile_definitions(t81lib_python PRIVATE PACK_METAL_LIBRARY_PATH=\"${PACK_METAL_LIB}\")
8297
endif()
8398
8499
set_target_properties(t81lib_python PROPERTIES

python/bindings.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <t81/io/format.hpp>
2525
#include <t81/linalg/gemm.hpp>
2626
#include <t81/linalg/gemm_gpu.hpp>
27+
#include <t81/linalg/pack_gpu.hpp>
2728
#include <t81/tensor_metadata.hpp>
2829
#include <t81/sparse/simple.hpp>
2930
#include <t81/t81lib.hpp>
@@ -426,6 +427,17 @@ namespace {
426427
const std::size_t total = static_cast<std::size_t>(std::max<py::ssize_t>(info.size, 0));
427428
const auto src = static_cast<const float *>(info.ptr);
428429
const auto dst = static_cast<std::int8_t *>(output.request().ptr);
430+
#if T81LIB_USE_METAL
431+
if (t81::linalg::detail::metal_available()) {
432+
try {
433+
std::span<const float> src_span{src, total};
434+
std::span<std::int8_t> dst_span{dst, total};
435+
t81::linalg::detail::metal_quantize_to_trits(src_span, dst_span, threshold);
436+
return output;
437+
} catch (const std::exception &) {
438+
}
439+
}
440+
#endif
429441
for (std::size_t index = 0; index < total; ++index) {
430442
dst[index] = quantize_trit(src[index], threshold);
431443
}
@@ -463,8 +475,25 @@ namespace {
463475
py::array_t<std::uint8_t> packed(
464476
{static_cast<std::size_t>(rows), static_cast<std::size_t>(limbs_per_row), limb_bytes});
465477
const auto *src = static_cast<const float *>(info.ptr);
466-
auto *dst = static_cast<std::uint8_t *>(packed.request().ptr);
478+
auto packed_info = packed.request(true);
479+
auto *dst = static_cast<std::uint8_t *>(packed_info.ptr);
467480
const std::size_t row_stride = static_cast<std::size_t>(limbs_per_row) * limb_bytes;
481+
#if T81LIB_USE_METAL
482+
if (t81::linalg::detail::metal_available()) {
483+
try {
484+
const std::size_t total_src =
485+
static_cast<std::size_t>(rows) * static_cast<std::size_t>(cols);
486+
const std::size_t total_dst =
487+
static_cast<std::size_t>(std::max<py::ssize_t>(packed_info.size, 0));
488+
std::span<const float> src_span{src, total_src};
489+
std::span<std::uint8_t> dst_span{dst, total_dst};
490+
t81::linalg::detail::metal_pack_dense_matrix(src_span, dst_span,
491+
rows, cols, threshold);
492+
return packed;
493+
} catch (const std::exception &) {
494+
}
495+
}
496+
#endif
468497
for (int row = 0; row < rows; ++row) {
469498
const auto *row_ptr =
470499
src + static_cast<std::size_t>(row) * static_cast<std::size_t>(cols);

src/linalg/pack_kernel.metal

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#include <metal_stdlib>
2+
using namespace metal;
3+
4+
struct PackParams {
5+
uint rows;
6+
uint cols;
7+
uint limbs_per_row;
8+
uint trits_per_limb;
9+
uint limb_bytes;
10+
float threshold;
11+
};
12+
13+
struct QuantParams {
14+
uint count;
15+
float threshold;
16+
};
17+
18+
static inline int quantize_trit(float value, float threshold) {
19+
float clamped = clamp(value, -1.0f, 1.0f);
20+
if (clamped >= threshold) {
21+
return 1;
22+
}
23+
if (clamped <= -threshold) {
24+
return -1;
25+
}
26+
return 0;
27+
}
28+
29+
kernel void quantize_trits_kernel(device const float *src [[buffer(0)]],
30+
device char *dst [[buffer(1)]],
31+
constant QuantParams &params [[buffer(2)]],
32+
uint gid [[thread_position_in_grid]]) {
33+
if (gid >= params.count) {
34+
return;
35+
}
36+
const int trit = quantize_trit(src[gid], params.threshold);
37+
dst[gid] = static_cast<char>(trit);
38+
}
39+
40+
kernel void pack_dense_matrix_kernel(device const float *src [[buffer(0)]],
41+
device uchar *dst [[buffer(1)]],
42+
constant PackParams &params [[buffer(2)]],
43+
uint gid [[thread_position_in_grid]]) {
44+
const uint total_limbs = params.rows * params.limbs_per_row;
45+
if (gid >= total_limbs) {
46+
return;
47+
}
48+
49+
const uint row = gid / params.limbs_per_row;
50+
const uint limb = gid % params.limbs_per_row;
51+
const uint base_col = limb * params.trits_per_limb;
52+
const uint out_offset = (row * params.limbs_per_row + limb) * params.limb_bytes;
53+
54+
for (uint tryte_idx = 0; tryte_idx < params.limb_bytes; ++tryte_idx) {
55+
const uint trit_base = tryte_idx * 3u;
56+
int t0 = 0;
57+
int t1 = 0;
58+
int t2 = 0;
59+
60+
uint col = base_col + trit_base;
61+
if (col < params.cols) {
62+
t0 = quantize_trit(src[row * params.cols + col], params.threshold);
63+
}
64+
col = base_col + trit_base + 1u;
65+
if (col < params.cols) {
66+
t1 = quantize_trit(src[row * params.cols + col], params.threshold);
67+
}
68+
col = base_col + trit_base + 2u;
69+
if (col < params.cols) {
70+
t2 = quantize_trit(src[row * params.cols + col], params.threshold);
71+
}
72+
73+
const int tryte = t0 + 3 * t1 + 9 * t2 + 13;
74+
dst[out_offset + tryte_idx] = static_cast<uchar>(tryte);
75+
}
76+
}

0 commit comments

Comments
 (0)