Skip to content

Commit a52e426

Browse files
committed
fix ruff
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 2c4bdf4 commit a52e426

9 files changed

Lines changed: 542 additions & 199 deletions

File tree

include/infinicore/ops/simple_gla_attention.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,23 @@
22

33
#include "../device.hpp"
44
#include "../graph/graph.hpp"
5+
#include "../tensor.hpp"
56
#include "common/op.hpp"
67

78
namespace infinicore::op {
89

10+
INFINICORE_GRAPH_OP_CLASS(SimpleGlaAttention,
11+
Tensor,
12+
const Tensor &,
13+
const Tensor &,
14+
const Tensor &,
15+
const Tensor &,
16+
float);
17+
918
// Simple GLA (recurrent linear) attention with per-head decay.
1019
// Shapes: q, k, v [B, T, H, D], g_gamma [H] (log-decay per head).
1120
// Recurrence: gate = exp(g_gamma); S = S * gate + outer(k_t, v_t); o_t = (q_t * scale) @ S.
12-
// Returns [B, T, H, D].
13-
class SimpleGlaAttention {
14-
public:
15-
using schema = void (*)(Tensor & out, const Tensor &q, const Tensor &k, const Tensor &v,
16-
const Tensor &g_gamma, float scale);
17-
static void execute(Tensor & out, const Tensor &q, const Tensor &k, const Tensor &v,
18-
const Tensor &g_gamma, float scale);
19-
static common::OpDispatcher<schema> &dispatcher();
20-
};
21-
21+
// Returns [B, T, H, D]. Implementation lives in InfiniOP (CPU reference; NVIDIA uses simple_gla_prefill kernels).
2222
Tensor simple_gla_attention(const Tensor &q,
2323
const Tensor &k,
2424
const Tensor &v,
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#ifndef __INFINIOP_SIMPLE_GLA_ATTENTION_API_H__
2+
#define __INFINIOP_SIMPLE_GLA_ATTENTION_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
// Full-sequence Simple GLA attention forward (reference CPU + NVIDIA via prefill kernels).
7+
// q, k, v: [B, T, H, D] (F32/F16/BF16), g_gamma: [H] (F32), out: [B, T, H, D] (same dtype as q)
8+
typedef struct InfiniopDescriptor *infiniopSimpleGLAAttentionDescriptor_t;
9+
10+
__INFINI_C __export infiniStatus_t infiniopCreateSimpleGLAAttentionDescriptor(
11+
infiniopHandle_t handle,
12+
infiniopSimpleGLAAttentionDescriptor_t *desc_ptr,
13+
infiniopTensorDescriptor_t out_desc,
14+
infiniopTensorDescriptor_t q_desc,
15+
infiniopTensorDescriptor_t k_desc,
16+
infiniopTensorDescriptor_t v_desc,
17+
infiniopTensorDescriptor_t g_gamma_desc);
18+
19+
__INFINI_C __export infiniStatus_t infiniopGetSimpleGLAAttentionWorkspaceSize(
20+
infiniopSimpleGLAAttentionDescriptor_t desc,
21+
size_t *size);
22+
23+
__INFINI_C __export infiniStatus_t infiniopSimpleGLAAttention(
24+
infiniopSimpleGLAAttentionDescriptor_t desc,
25+
void *workspace,
26+
size_t workspace_size,
27+
void *out,
28+
void const *q,
29+
void const *k,
30+
void const *v,
31+
void const *g_gamma,
32+
float scale,
33+
void *stream);
34+
35+
__INFINI_C __export infiniStatus_t infiniopDestroySimpleGLAAttentionDescriptor(
36+
infiniopSimpleGLAAttentionDescriptor_t desc);
37+
38+
#endif

python/infinicore/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@
108108
from infinicore.ops.rearrange import rearrange
109109
from infinicore.ops.reciprocal import reciprocal
110110
from infinicore.ops.scatter import scatter
111-
from infinicore.ops.sinh import sinh
112111
from infinicore.ops.simple_gla_attention import simple_gla_attention
113112
from infinicore.ops.simple_gla_decode_step import simple_gla_decode_step
114113
from infinicore.ops.simple_gla_prefill import simple_gla_prefill
114+
from infinicore.ops.sinh import sinh
115115
from infinicore.ops.squeeze import squeeze
116116
from infinicore.ops.sum import sum
117117
from infinicore.ops.take import take

src/infinicore/ops/simple_gla_attention/simple_gla_attention.cc

Lines changed: 19 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,166 +1,27 @@
11
#include "infinicore/ops/simple_gla_attention.hpp"
2-
3-
#include "../../../utils.h"
42
#include "../../utils.hpp"
5-
#include "infinicore/context/context.hpp"
6-
#include <cmath>
7-
#include <cstring>
8-
#include <stdexcept>
9-
#include <vector>
103

114
namespace infinicore::op {
125

13-
namespace {
14-
15-
// Read one element from tensor at flat index, convert to float.
16-
template <typename T>
17-
inline float read_float(const std::byte *ptr, size_t idx) {
18-
return static_cast<float>(*reinterpret_cast<const T *>(ptr + idx * sizeof(T)));
19-
}
20-
21-
inline float read_float_at(const std::byte *ptr, size_t idx, DataType dtype) {
22-
switch (dtype) {
23-
case DataType::F32:
24-
return read_float<float>(ptr, idx);
25-
case DataType::F16:
26-
return _f16_to_f32(*reinterpret_cast<const fp16_t *>(ptr + idx * 2));
27-
case DataType::BF16:
28-
return _bf16_to_f32(*reinterpret_cast<const bf16_t *>(ptr + idx * 2));
29-
default:
30-
throw std::runtime_error("simple_gla_attention: unsupported dtype (need F32, F16, or BF16)");
31-
}
32-
}
33-
34-
// Write one float to tensor at flat index.
35-
inline void write_float_at(std::byte *ptr, size_t idx, DataType dtype, float val) {
36-
switch (dtype) {
37-
case DataType::F32:
38-
*reinterpret_cast<float *>(ptr + idx * 4) = val;
39-
break;
40-
case DataType::F16:
41-
*reinterpret_cast<fp16_t *>(ptr + idx * 2) = _f32_to_f16(val);
42-
break;
43-
case DataType::BF16:
44-
*reinterpret_cast<bf16_t *>(ptr + idx * 2) = _f32_to_bf16(val);
45-
break;
46-
default:
47-
throw std::runtime_error("simple_gla_attention: unsupported dtype (need F32, F16, or BF16)");
48-
}
49-
}
50-
51-
void simple_gla_attention_cpu_impl(Tensor &out,
52-
const Tensor &q,
53-
const Tensor &k,
54-
const Tensor &v,
55-
const Tensor &g_gamma,
56-
float scale) {
57-
const auto &q_shape = q->shape();
58-
const size_t B = q_shape[0];
59-
const size_t T = q_shape[1];
60-
const size_t H = q_shape[2];
61-
const size_t D = q_shape[3];
62-
63-
INFINICORE_ASSERT(k->shape() == q_shape && v->shape() == q_shape);
64-
INFINICORE_ASSERT(g_gamma->shape().size() == 1 && g_gamma->shape()[0] == H);
65-
66-
const DataType dtype = q->dtype();
67-
const std::byte *q_ptr = q->data();
68-
const std::byte *k_ptr = k->data();
69-
const std::byte *v_ptr = v->data();
70-
const std::byte *g_ptr = g_gamma->data();
71-
std::byte *out_ptr = out->data();
72-
73-
// Contiguous layout (B, T, H, D): index (b,t,h,d) = b*T*H*D + t*H*D + h*D + d
74-
const size_t stride_b = T * H * D;
75-
const size_t stride_t = H * D;
76-
const size_t stride_h = D;
77-
78-
// Gate (H,) in float
79-
std::vector<float> gate(H);
80-
for (size_t h = 0; h < H; ++h) {
81-
gate[h] = std::exp(read_float_at(g_ptr, h, g_gamma->dtype()));
82-
}
83-
84-
// State S: (B, H, D, D) in float, row-major
85-
std::vector<float> S(B * H * D * D, 0.f);
86-
87-
for (size_t t = 0; t < T; ++t) {
88-
const size_t t_offset = t * stride_t;
89-
90-
// 1. S = S * gate + outer(k_t, v_t)
91-
// k_t (b,h,d_k), v_t (b,h,d_v) -> kv(b,h,d_k,d_v) = k_t(b,h,d_k) * v_t(b,h,d_v)
92-
for (size_t b = 0; b < B; ++b) {
93-
const size_t b_offset = b * stride_b + t_offset;
94-
for (size_t h = 0; h < H; ++h) {
95-
const float g = gate[h];
96-
float *S_bh = S.data() + (b * H + h) * (D * D);
97-
98-
// Scale S by gate
99-
for (size_t i = 0; i < D * D; ++i) {
100-
S_bh[i] *= g;
101-
}
102-
103-
// Add outer(k_t, v_t)
104-
for (size_t dk = 0; dk < D; ++dk) {
105-
size_t qk_idx = b_offset + h * stride_h + dk;
106-
float k_val = read_float_at(k_ptr, qk_idx, dtype);
107-
for (size_t dv = 0; dv < D; ++dv) {
108-
size_t qv_idx = b_offset + h * stride_h + dv;
109-
float v_val = read_float_at(v_ptr, qv_idx, dtype);
110-
S_bh[dk * D + dv] += k_val * v_val;
111-
}
112-
}
113-
}
114-
}
115-
116-
// 2. o_t = (q_t * scale) @ S -> (B, H, D) for each (b,h): o[b,h,:] = scale * (q_t[b,h,:] @ S[b,h,:,:])
117-
for (size_t b = 0; b < B; ++b) {
118-
const size_t b_offset = b * stride_b + t_offset;
119-
for (size_t h = 0; h < H; ++h) {
120-
const float *S_bh = S.data() + (b * H + h) * (D * D);
121-
for (size_t dv = 0; dv < D; ++dv) {
122-
float acc = 0.f;
123-
for (size_t dk = 0; dk < D; ++dk) {
124-
size_t q_idx = b_offset + h * stride_h + dk;
125-
float q_val = read_float_at(q_ptr, q_idx, dtype) * scale;
126-
acc += q_val * S_bh[dk * D + dv];
127-
}
128-
size_t out_idx = b_offset + h * stride_h + dv;
129-
write_float_at(out_ptr, out_idx, dtype, acc);
130-
}
131-
}
132-
}
133-
}
134-
}
135-
136-
void simple_gla_attention_cpu_calculate(Tensor &out, const Tensor &q, const Tensor &k,
137-
const Tensor &v, const Tensor &g_gamma, float scale) {
138-
simple_gla_attention_cpu_impl(out, q, k, v, g_gamma, scale);
139-
}
140-
141-
static bool register_cpu = []() {
142-
SimpleGlaAttention::dispatcher().registerDevice(Device::Type::CPU, &simple_gla_attention_cpu_calculate,
143-
false);
144-
return true;
145-
}();
146-
147-
} // namespace
148-
149-
common::OpDispatcher<SimpleGlaAttention::schema> &SimpleGlaAttention::dispatcher() {
150-
static common::OpDispatcher<SimpleGlaAttention::schema> dispatcher_;
151-
return dispatcher_;
152-
}
153-
154-
void SimpleGlaAttention::execute(Tensor &out, const Tensor &q, const Tensor &k, const Tensor &v,
155-
const Tensor &g_gamma, float scale) {
156-
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(q, k, v, g_gamma);
157-
infinicore::context::setDevice(q->device());
158-
auto device_type = infinicore::context::getDevice().getType();
159-
auto func = dispatcher().lookup(device_type);
160-
if (func == nullptr) {
161-
throw std::runtime_error("simple_gla_attention: no implementation for device type " + std::to_string(static_cast<int>(device_type)));
162-
}
163-
func(out, q, k, v, g_gamma, scale);
6+
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SimpleGlaAttention);
7+
8+
SimpleGlaAttention::SimpleGlaAttention(Tensor out,
9+
const Tensor &q,
10+
const Tensor &k,
11+
const Tensor &v,
12+
const Tensor &g_gamma,
13+
float scale) {
14+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, g_gamma);
15+
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, q, k, v, g_gamma, scale);
16+
}
17+
18+
void SimpleGlaAttention::execute(Tensor out,
19+
const Tensor &q,
20+
const Tensor &k,
21+
const Tensor &v,
22+
const Tensor &g_gamma,
23+
float scale) {
24+
INFINICORE_GRAPH_OP_RECORD_OR_RUN(SimpleGlaAttention, out, q, k, v, g_gamma, scale);
16425
}
16526

16627
Tensor simple_gla_attention(const Tensor &q,
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#include "infinicore/ops/simple_gla_attention.hpp"
2+
3+
#include "../infiniop_impl.hpp"
4+
#include "infinicore/context/context.hpp"
5+
6+
#include <infiniop/ops/simple_gla_attention.h>
7+
8+
namespace infinicore::op::simple_gla_attention_impl::infiniop {
9+
10+
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SimpleGLAAttention, 64);
11+
12+
struct PlannedMeta {
13+
std::shared_ptr<Descriptor> descriptor;
14+
graph::GraphTensor workspace;
15+
graph::GraphTensor out;
16+
graph::GraphTensor q;
17+
graph::GraphTensor k;
18+
graph::GraphTensor v;
19+
graph::GraphTensor g;
20+
float scale;
21+
};
22+
23+
static void *plan(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &g, float scale) {
24+
size_t key = hash_combine(out, q, k, v, g, static_cast<size_t>(scale * 1000000.0f));
25+
26+
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
27+
Descriptor, descriptor, SimpleGLAAttention,
28+
key, out->desc(), q->desc(), k->desc(), v->desc(), g->desc());
29+
30+
size_t workspace_size = 0;
31+
INFINICORE_CHECK_ERROR(
32+
infiniopGetSimpleGLAAttentionWorkspaceSize(descriptor->desc, &workspace_size));
33+
34+
thread_local common::OpCache<size_t, Tensor> workspace_caches(8 /*capacity*/);
35+
auto device__ = context::getDevice();
36+
auto &cache__ = workspace_caches.getCache(device__);
37+
38+
Tensor workspace;
39+
if (auto cached = cache__.get(workspace_size); cached.has_value()) {
40+
workspace = *cached;
41+
} else {
42+
workspace = Tensor::empty({workspace_size}, DataType::U8, device__);
43+
cache__.put(workspace_size, workspace);
44+
}
45+
46+
return new PlannedMeta{
47+
descriptor,
48+
graph::GraphTensor(workspace),
49+
graph::GraphTensor(out),
50+
graph::GraphTensor(q),
51+
graph::GraphTensor(k),
52+
graph::GraphTensor(v),
53+
graph::GraphTensor(g),
54+
scale,
55+
};
56+
}
57+
58+
static void run(void *planned_meta) {
59+
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
60+
INFINICORE_CHECK_ERROR(
61+
infiniopSimpleGLAAttention(
62+
p->descriptor->desc,
63+
p->workspace->data(),
64+
p->workspace->numel(),
65+
p->out->data(),
66+
p->q->data(),
67+
p->k->data(),
68+
p->v->data(),
69+
p->g->data(),
70+
p->scale,
71+
context::getStream()));
72+
}
73+
74+
static void cleanup(void **planned_meta_ptr) {
75+
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
76+
*planned_meta_ptr = nullptr;
77+
}
78+
79+
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SimpleGlaAttention, &plan, &run, &cleanup);
80+
81+
} // namespace infinicore::op::simple_gla_attention_impl::infiniop

src/infinicore/ops/simple_gla_attention/simple_gla_attention_nvidia.cc

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)