Skip to content

Commit f9f5f78

Browse files
committed
Gemm update
1 parent 277a1f7 commit f9f5f78

4 files changed

Lines changed: 86 additions & 66 deletions

File tree

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@ pip install ".[torch]"
7878

7979
On macOS or other PEP 668-enforced environments, activate a virtualenv before running `pip install ".[torch]"` (or use `python3 -m pip install --user ".[torch]" --break-system-packages` if you understand the risks) so pip can install the extra dependencies without hitting the “externally managed environment” error.
8080

81+
### 2a. CLI-friendly Pipx install
82+
83+
If you prefer shell-level access to `t81-convert`, `t81-gguf`, `t81-qat`, and `t81-dequant`, pipx can install the repo and then inject the torch extras:
84+
85+
```bash
86+
pipx install --python python3 /Users/t81dev/Desktop/t81lib
87+
pipx inject t81lib torch transformers accelerate datasets safetensors
88+
```
89+
90+
Pipx doesn’t understand `.[torch]` when pointing at a local directory, so we first install the package from source and then inject the optional dependencies you need (torch, transformers, accelerate, datasets, safetensors). Once that completes, the CLI helpers will run from `~/.local/bin` with the same requirements as `pip install ".[torch]"`. Continue running `pipx uninstall t81lib` and reinject if you upgrade the repo checkout.
91+
8192
### 3. Consume as a subproject
8293

8394
```cmake

include/t81/linalg/gemm.hpp

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -63,71 +63,6 @@ namespace t81::linalg {
6363
return low_value + high_value * radix;
6464
}
6565

66-
inline void gemm_ternary_cpu_impl(std::span<const core::limb> A,
67-
std::span<const core::limb> B,
68-
std::span<float> C,
69-
int M,
70-
int N,
71-
int K,
72-
int K_limbs,
73-
float alpha,
74-
float beta) {
75-
if (M == 0 || N == 0) {
76-
return;
77-
}
78-
79-
constexpr int BlockM = 8;
80-
constexpr int BlockN = 8;
81-
constexpr int BlockK = 4;
82-
const std::size_t N_size = static_cast<std::size_t>(N);
83-
const auto *const a_data = A.data();
84-
const auto *const b_data = B.data();
85-
auto *const c_data = C.data();
86-
87-
for (int ib = 0; ib < M; ib += BlockM) {
88-
const int i_end = std::min(M, ib + BlockM);
89-
for (int jb = 0; jb < N; jb += BlockN) {
90-
const int j_end = std::min(N, jb + BlockN);
91-
std::array<std::array<double, BlockN>, BlockM> accum{};
92-
for (int i = ib; i < i_end; ++i) {
93-
const std::size_t row = static_cast<std::size_t>(i) * N_size;
94-
for (int j = jb; j < j_end; ++j) {
95-
const float existing = c_data[row + static_cast<std::size_t>(j)];
96-
accum[i - ib][j - jb] = static_cast<double>(existing) * beta;
97-
}
98-
}
99-
100-
for (int kb = 0; kb < K_limbs; kb += BlockK) {
101-
const int k_end = std::min(K_limbs, kb + BlockK);
102-
for (int k = kb; k < k_end; ++k) {
103-
const std::size_t b_row = static_cast<std::size_t>(k) * N_size;
104-
for (int j = jb; j < j_end; ++j) {
105-
const core::limb b_value = b_data[b_row + static_cast<std::size_t>(j)];
106-
detail::prefetch_read(b_data + b_row + static_cast<std::size_t>(j) + 1);
107-
for (int i = ib; i < i_end; ++i) {
108-
const std::size_t a_index = static_cast<std::size_t>(i) *
109-
static_cast<std::size_t>(K_limbs) +
110-
static_cast<std::size_t>(k);
111-
const core::limb a_value = a_data[a_index];
112-
const double product = detail::multiply_to_double(a_value, b_value);
113-
accum[i - ib][j - jb] += product * static_cast<double>(alpha);
114-
detail::prefetch_read(a_data + a_index + 1);
115-
}
116-
}
117-
}
118-
}
119-
120-
for (int i = ib; i < i_end; ++i) {
121-
const std::size_t row = static_cast<std::size_t>(i) * N_size;
122-
for (int j = jb; j < j_end; ++j) {
123-
c_data[row + static_cast<std::size_t>(j)] =
124-
static_cast<float>(accum[i - ib][j - jb]);
125-
}
126-
}
127-
}
128-
}
129-
}
130-
13166
} // namespace detail
13267

13368
inline void gemm_ternary(std::span<const core::limb> A,

python/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ pybind11_add_module(t81lib_python MODULE bindings.cpp)
2323

2424
target_sources(t81lib_python PRIVATE
2525
${CMAKE_CURRENT_SOURCE_DIR}/../src/t81/core/gguf_quants.cpp
26-
${CMAKE_CURRENT_SOURCE_DIR}/../src/linalg/gemm_dispatch.cpp)
26+
${CMAKE_CURRENT_SOURCE_DIR}/../src/linalg/gemm_dispatch.cpp
27+
${CMAKE_CURRENT_SOURCE_DIR}/../src/linalg/gemm_cpu.cpp)
2728
target_compile_features(t81lib_python PRIVATE cxx_std_20)
2829

2930
#target_link_libraries(t81lib_python PRIVATE t81lib)

src/linalg/gemm_cpu.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#include <algorithm>
2+
#include <array>
3+
4+
#include "t81/linalg/gemm.hpp"
5+
6+
namespace t81::linalg::detail {
7+
8+
void gemm_ternary_cpu_impl(std::span<const core::limb> A,
9+
std::span<const core::limb> B,
10+
std::span<float> C,
11+
int M,
12+
int N,
13+
int K,
14+
int K_limbs,
15+
float alpha,
16+
float beta) {
17+
if (M == 0 || N == 0) {
18+
return;
19+
}
20+
21+
constexpr int BlockM = 8;
22+
constexpr int BlockN = 8;
23+
constexpr int BlockK = 4;
24+
const std::size_t N_size = static_cast<std::size_t>(N);
25+
const auto *const a_data = A.data();
26+
const auto *const b_data = B.data();
27+
auto *const c_data = C.data();
28+
29+
for (int ib = 0; ib < M; ib += BlockM) {
30+
const int i_end = std::min(M, ib + BlockM);
31+
for (int jb = 0; jb < N; jb += BlockN) {
32+
const int j_end = std::min(N, jb + BlockN);
33+
std::array<std::array<double, BlockN>, BlockM> accum{};
34+
for (int i = ib; i < i_end; ++i) {
35+
const std::size_t row = static_cast<std::size_t>(i) * N_size;
36+
for (int j = jb; j < j_end; ++j) {
37+
const float existing = c_data[row + static_cast<std::size_t>(j)];
38+
accum[i - ib][j - jb] = static_cast<double>(existing) * beta;
39+
}
40+
}
41+
42+
for (int kb = 0; kb < K_limbs; kb += BlockK) {
43+
const int k_end = std::min(K_limbs, kb + BlockK);
44+
for (int k = kb; k < k_end; ++k) {
45+
const std::size_t b_row = static_cast<std::size_t>(k) * N_size;
46+
for (int j = jb; j < j_end; ++j) {
47+
const core::limb b_value = b_data[b_row + static_cast<std::size_t>(j)];
48+
detail::prefetch_read(b_data + b_row + static_cast<std::size_t>(j) + 1);
49+
for (int i = ib; i < i_end; ++i) {
50+
const std::size_t a_index = static_cast<std::size_t>(i) *
51+
static_cast<std::size_t>(K_limbs) +
52+
static_cast<std::size_t>(k);
53+
const core::limb a_value = a_data[a_index];
54+
const double product = detail::multiply_to_double(a_value, b_value);
55+
accum[i - ib][j - jb] += product * static_cast<double>(alpha);
56+
detail::prefetch_read(a_data + a_index + 1);
57+
}
58+
}
59+
}
60+
}
61+
62+
for (int i = ib; i < i_end; ++i) {
63+
const std::size_t row = static_cast<std::size_t>(i) * N_size;
64+
for (int j = jb; j < j_end; ++j) {
65+
c_data[row + static_cast<std::size_t>(j)] =
66+
static_cast<float>(accum[i - ib][j - jb]);
67+
}
68+
}
69+
}
70+
}
71+
}
72+
73+
} // namespace t81::linalg::detail

0 commit comments

Comments
 (0)