|
1 | 1 | #include "infinicore/ops/simple_gla_attention.hpp" |
2 | | - |
3 | | -#include "../../../utils.h" |
4 | 2 | #include "../../utils.hpp" |
5 | | -#include "infinicore/context/context.hpp" |
6 | | -#include <cmath> |
7 | | -#include <cstring> |
8 | | -#include <stdexcept> |
9 | | -#include <vector> |
10 | 3 |
|
11 | 4 | namespace infinicore::op { |
12 | 5 |
|
13 | | -namespace { |
14 | | - |
15 | | -// Read one element from tensor at flat index, convert to float. |
16 | | -template <typename T> |
17 | | -inline float read_float(const std::byte *ptr, size_t idx) { |
18 | | - return static_cast<float>(*reinterpret_cast<const T *>(ptr + idx * sizeof(T))); |
19 | | -} |
20 | | - |
21 | | -inline float read_float_at(const std::byte *ptr, size_t idx, DataType dtype) { |
22 | | - switch (dtype) { |
23 | | - case DataType::F32: |
24 | | - return read_float<float>(ptr, idx); |
25 | | - case DataType::F16: |
26 | | - return _f16_to_f32(*reinterpret_cast<const fp16_t *>(ptr + idx * 2)); |
27 | | - case DataType::BF16: |
28 | | - return _bf16_to_f32(*reinterpret_cast<const bf16_t *>(ptr + idx * 2)); |
29 | | - default: |
30 | | - throw std::runtime_error("simple_gla_attention: unsupported dtype (need F32, F16, or BF16)"); |
31 | | - } |
32 | | -} |
33 | | - |
34 | | -// Write one float to tensor at flat index. |
35 | | -inline void write_float_at(std::byte *ptr, size_t idx, DataType dtype, float val) { |
36 | | - switch (dtype) { |
37 | | - case DataType::F32: |
38 | | - *reinterpret_cast<float *>(ptr + idx * 4) = val; |
39 | | - break; |
40 | | - case DataType::F16: |
41 | | - *reinterpret_cast<fp16_t *>(ptr + idx * 2) = _f32_to_f16(val); |
42 | | - break; |
43 | | - case DataType::BF16: |
44 | | - *reinterpret_cast<bf16_t *>(ptr + idx * 2) = _f32_to_bf16(val); |
45 | | - break; |
46 | | - default: |
47 | | - throw std::runtime_error("simple_gla_attention: unsupported dtype (need F32, F16, or BF16)"); |
48 | | - } |
49 | | -} |
50 | | - |
51 | | -void simple_gla_attention_cpu_impl(Tensor &out, |
52 | | - const Tensor &q, |
53 | | - const Tensor &k, |
54 | | - const Tensor &v, |
55 | | - const Tensor &g_gamma, |
56 | | - float scale) { |
57 | | - const auto &q_shape = q->shape(); |
58 | | - const size_t B = q_shape[0]; |
59 | | - const size_t T = q_shape[1]; |
60 | | - const size_t H = q_shape[2]; |
61 | | - const size_t D = q_shape[3]; |
62 | | - |
63 | | - INFINICORE_ASSERT(k->shape() == q_shape && v->shape() == q_shape); |
64 | | - INFINICORE_ASSERT(g_gamma->shape().size() == 1 && g_gamma->shape()[0] == H); |
65 | | - |
66 | | - const DataType dtype = q->dtype(); |
67 | | - const std::byte *q_ptr = q->data(); |
68 | | - const std::byte *k_ptr = k->data(); |
69 | | - const std::byte *v_ptr = v->data(); |
70 | | - const std::byte *g_ptr = g_gamma->data(); |
71 | | - std::byte *out_ptr = out->data(); |
72 | | - |
73 | | - // Contiguous layout (B, T, H, D): index (b,t,h,d) = b*T*H*D + t*H*D + h*D + d |
74 | | - const size_t stride_b = T * H * D; |
75 | | - const size_t stride_t = H * D; |
76 | | - const size_t stride_h = D; |
77 | | - |
78 | | - // Gate (H,) in float |
79 | | - std::vector<float> gate(H); |
80 | | - for (size_t h = 0; h < H; ++h) { |
81 | | - gate[h] = std::exp(read_float_at(g_ptr, h, g_gamma->dtype())); |
82 | | - } |
83 | | - |
84 | | - // State S: (B, H, D, D) in float, row-major |
85 | | - std::vector<float> S(B * H * D * D, 0.f); |
86 | | - |
87 | | - for (size_t t = 0; t < T; ++t) { |
88 | | - const size_t t_offset = t * stride_t; |
89 | | - |
90 | | - // 1. S = S * gate + outer(k_t, v_t) |
91 | | - // k_t (b,h,d_k), v_t (b,h,d_v) -> kv(b,h,d_k,d_v) = k_t(b,h,d_k) * v_t(b,h,d_v) |
92 | | - for (size_t b = 0; b < B; ++b) { |
93 | | - const size_t b_offset = b * stride_b + t_offset; |
94 | | - for (size_t h = 0; h < H; ++h) { |
95 | | - const float g = gate[h]; |
96 | | - float *S_bh = S.data() + (b * H + h) * (D * D); |
97 | | - |
98 | | - // Scale S by gate |
99 | | - for (size_t i = 0; i < D * D; ++i) { |
100 | | - S_bh[i] *= g; |
101 | | - } |
102 | | - |
103 | | - // Add outer(k_t, v_t) |
104 | | - for (size_t dk = 0; dk < D; ++dk) { |
105 | | - size_t qk_idx = b_offset + h * stride_h + dk; |
106 | | - float k_val = read_float_at(k_ptr, qk_idx, dtype); |
107 | | - for (size_t dv = 0; dv < D; ++dv) { |
108 | | - size_t qv_idx = b_offset + h * stride_h + dv; |
109 | | - float v_val = read_float_at(v_ptr, qv_idx, dtype); |
110 | | - S_bh[dk * D + dv] += k_val * v_val; |
111 | | - } |
112 | | - } |
113 | | - } |
114 | | - } |
115 | | - |
116 | | - // 2. o_t = (q_t * scale) @ S -> (B, H, D) for each (b,h): o[b,h,:] = scale * (q_t[b,h,:] @ S[b,h,:,:]) |
117 | | - for (size_t b = 0; b < B; ++b) { |
118 | | - const size_t b_offset = b * stride_b + t_offset; |
119 | | - for (size_t h = 0; h < H; ++h) { |
120 | | - const float *S_bh = S.data() + (b * H + h) * (D * D); |
121 | | - for (size_t dv = 0; dv < D; ++dv) { |
122 | | - float acc = 0.f; |
123 | | - for (size_t dk = 0; dk < D; ++dk) { |
124 | | - size_t q_idx = b_offset + h * stride_h + dk; |
125 | | - float q_val = read_float_at(q_ptr, q_idx, dtype) * scale; |
126 | | - acc += q_val * S_bh[dk * D + dv]; |
127 | | - } |
128 | | - size_t out_idx = b_offset + h * stride_h + dv; |
129 | | - write_float_at(out_ptr, out_idx, dtype, acc); |
130 | | - } |
131 | | - } |
132 | | - } |
133 | | - } |
134 | | -} |
135 | | - |
136 | | -void simple_gla_attention_cpu_calculate(Tensor &out, const Tensor &q, const Tensor &k, |
137 | | - const Tensor &v, const Tensor &g_gamma, float scale) { |
138 | | - simple_gla_attention_cpu_impl(out, q, k, v, g_gamma, scale); |
139 | | -} |
140 | | - |
141 | | -static bool register_cpu = []() { |
142 | | - SimpleGlaAttention::dispatcher().registerDevice(Device::Type::CPU, &simple_gla_attention_cpu_calculate, |
143 | | - false); |
144 | | - return true; |
145 | | -}(); |
146 | | - |
147 | | -} // namespace |
148 | | - |
149 | | -common::OpDispatcher<SimpleGlaAttention::schema> &SimpleGlaAttention::dispatcher() { |
150 | | - static common::OpDispatcher<SimpleGlaAttention::schema> dispatcher_; |
151 | | - return dispatcher_; |
152 | | -} |
153 | | - |
154 | | -void SimpleGlaAttention::execute(Tensor &out, const Tensor &q, const Tensor &k, const Tensor &v, |
155 | | - const Tensor &g_gamma, float scale) { |
156 | | - INFINICORE_ASSERT_TENSORS_SAME_DEVICE(q, k, v, g_gamma); |
157 | | - infinicore::context::setDevice(q->device()); |
158 | | - auto device_type = infinicore::context::getDevice().getType(); |
159 | | - auto func = dispatcher().lookup(device_type); |
160 | | - if (func == nullptr) { |
161 | | - throw std::runtime_error("simple_gla_attention: no implementation for device type " + std::to_string(static_cast<int>(device_type))); |
162 | | - } |
163 | | - func(out, q, k, v, g_gamma, scale); |
| 6 | +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SimpleGlaAttention); |
| 7 | + |
| 8 | +SimpleGlaAttention::SimpleGlaAttention(Tensor out, |
| 9 | + const Tensor &q, |
| 10 | + const Tensor &k, |
| 11 | + const Tensor &v, |
| 12 | + const Tensor &g_gamma, |
| 13 | + float scale) { |
| 14 | + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, g_gamma); |
| 15 | + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, q, k, v, g_gamma, scale); |
| 16 | +} |
| 17 | + |
| 18 | +void SimpleGlaAttention::execute(Tensor out, |
| 19 | + const Tensor &q, |
| 20 | + const Tensor &k, |
| 21 | + const Tensor &v, |
| 22 | + const Tensor &g_gamma, |
| 23 | + float scale) { |
| 24 | + INFINICORE_GRAPH_OP_RECORD_OR_RUN(SimpleGlaAttention, out, q, k, v, g_gamma, scale); |
164 | 25 | } |
165 | 26 |
|
166 | 27 | Tensor simple_gla_attention(const Tensor &q, |
|
0 commit comments