Skip to content

Commit 37127f0

Browse files
committed
gemm_gpu update
1 parent 035938c commit 37127f0

2 files changed

Lines changed: 21 additions & 3 deletions

File tree

include/t81/linalg/gemm_gpu.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ inline constexpr bool backend_available(Backend backend) noexcept {
5757
return false;
5858
}
5959

60+
#if !defined(T81LIB_DOXYGEN)
61+
Backend get_current_backend() noexcept;
62+
#endif
63+
6064
#if T81LIB_USE_METAL
6165
bool metal_available() noexcept;
6266
#endif
@@ -105,6 +109,15 @@ void addcmul(const TensorMetadata &input,
105109
TensorMetadata &out,
106110
Backend backend = Backend::Auto);
107111

112+
#if !defined(T81LIB_DOXYGEN)
113+
void gemm_ternary(const TensorMetadata &A,
114+
const TensorMetadata &B,
115+
TensorMetadata &C,
116+
float alpha,
117+
float beta,
118+
Backend backend = Backend::Auto);
119+
#endif
120+
108121
#if T81LIB_USE_CUDA
109122
void cuda_gemm_ternary(std::span<const core::limb> A,
110123
std::span<const core::limb> B,
@@ -154,7 +167,7 @@ inline void gemm_ternary_dispatch(std::span<const core::limb> A,
154167
int K,
155168
float alpha,
156169
float beta,
157-
Backend backend) {
170+
Backend backend = Backend::Auto) {
158171
if (M < 0 || N < 0 || K < 0) {
159172
throw std::invalid_argument("gemm_ternary dimensions must be non-negative");
160173
}

python/bindings.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,14 +756,19 @@ PYBIND11_MODULE(t81lib, module) {
756756
const int K_limbs = K / core::limb::TRITS;
757757
const auto a_handle = handle_for_packed_object(A_obj, M, K_limbs);
758758
const auto b_handle = handle_for_packed_object(B_obj, K_limbs, N);
759-
const auto c_handle = extract_tensor_handle(C_obj);
759+
auto c_handle = extract_tensor_handle(C_obj);
760760
if (static_cast<int>(c_handle.meta.sizes.size()) != 2 ||
761761
static_cast<int>(c_handle.meta.sizes[0]) != M ||
762762
static_cast<int>(c_handle.meta.sizes[1]) != N) {
763763
throw py::value_error("C tensor shape must match (M, N)");
764764
}
765765
t81::linalg::detail::gemm_ternary(
766-
a_handle.meta, b_handle.meta, c_handle.meta, alpha, beta, Backend::Auto);
766+
a_handle.meta,
767+
b_handle.meta,
768+
c_handle.meta,
769+
alpha,
770+
beta,
771+
t81::linalg::Backend::Auto);
767772
},
768773
py::arg("A"),
769774
py::arg("B"),

0 commit comments

Comments
 (0)