Skip to content

Commit c25096a

Browse files
committed
feat: add cambricon causal_softmax op
1 parent a334495 commit c25096a

4 files changed

Lines changed: 386 additions & 3 deletions

File tree

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#ifndef INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H
2+
#define INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H
3+
4+
#include "base/causal_softmax.h"
5+
#include "cambricon/common.h"
6+
#include "cambricon/data_type_.h"
7+
8+
namespace infini::ops {
9+
10+
// TODO: Remove forward declaration.
11+
template <typename T>
12+
void CausalSoftmaxUnion(void *workspace, int core_per_cluster,
13+
int cluster_count, cnrtQueue_t queue, void *y,
14+
const void *x, size_t batch_size_, size_t seq_len_,
15+
size_t total_seq_len_, ptrdiff_t y_stride_b,
16+
ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
17+
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
18+
ptrdiff_t x_stride_j);
19+
20+
template <>
21+
class Operator<CausalSoftmax, Device::Type::kCambricon> : public CausalSoftmax {
22+
public:
23+
Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} {
24+
cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster,
25+
&cluster_count);
26+
}
27+
void operator()(const Tensor input, Tensor out) const override {
28+
auto queue = static_cast<cnrtQueue_t>(stream_ ? stream_ : 0);
29+
auto workspace{workspace_ ? workspace_ : default_workspace_};
30+
ptrdiff_t y_stride_b = ndim_ == 3 ? out_strides_[0] : 1;
31+
ptrdiff_t y_stride_i = ndim_ == 3 ? out_strides_[1] : out_strides_[0];
32+
ptrdiff_t y_stride_j = ndim_ == 3 ? out_strides_[2] : out_strides_[1];
33+
ptrdiff_t x_stride_b = ndim_ == 3 ? input_strides_[0] : 1;
34+
ptrdiff_t x_stride_i = ndim_ == 3 ? input_strides_[1] : input_strides_[0];
35+
ptrdiff_t x_stride_j = ndim_ == 3 ? input_strides_[2] : input_strides_[1];
36+
37+
DispatchFunc<
38+
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>>(
39+
{static_cast<int64_t>(input.dtype())},
40+
[&](auto input_tag) {
41+
using InputT = infini::ops::TypeMapType<Device::Type::kCambricon, ListGet<0>(input_tag)>;
42+
CausalSoftmaxUnion<InputT>(
43+
workspace, core_per_cluster, cluster_count, queue, out.data(),
44+
input.data(), batch_size_, seq_len_, total_seq_len_, y_stride_b,
45+
y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j);
46+
},
47+
"CambriconCausalSoftmax::operator() - output dispatch");
48+
}
49+
50+
std::size_t workspace_size_in_bytes() const override { return 0; }
51+
52+
~Operator() {}
53+
54+
void *default_workspace_{nullptr};
55+
int core_per_cluster = 0;
56+
int cluster_count = 0;
57+
};
58+
59+
} // namespace infini::ops
60+
61+
#endif
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#include "causal_softmax.h"
2+
3+
__nram__ char nram_buffer[NRAM_MAX_SIZE];
4+
const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 4;
5+
6+
namespace infini::ops {
7+
8+
template <typename T>
9+
__mlu_func__ void ProcessSoftmaxStep(const T *input, T *output, float scalar,
10+
int num_elements, int stride,
11+
bool is_exp_phase) {
12+
constexpr bool is_half = std::is_same_v<T, __half>;
13+
constexpr bool is_bfloat16 = std::is_same_v<T, __bang_bfloat16>;
14+
constexpr bool is_float = !is_half && !is_bfloat16;
15+
16+
const int chunk_size =
17+
SRC_MAX_SIZE /
18+
((is_half || is_bfloat16) ? (2 * sizeof(float)) : sizeof(float));
19+
float *float_buffer = (float *)nram_buffer;
20+
T *temp_buffer =
21+
is_float ? nullptr : (T *)(nram_buffer + chunk_size * sizeof(float));
22+
23+
// Common stride configurations.
24+
const int src_stride = stride * sizeof(T);
25+
const int dst_stride = stride * sizeof(T);
26+
27+
int processed = 0;
28+
while (processed < num_elements) {
29+
int curr_batch = std::min(chunk_size, num_elements - processed);
30+
31+
if constexpr (is_float) {
32+
__memcpy(
33+
float_buffer, (is_exp_phase ? input : output) + processed * stride,
34+
sizeof(float), GDRAM2NRAM, sizeof(float), src_stride, curr_batch - 1);
35+
} else {
36+
__memcpy(temp_buffer,
37+
(is_exp_phase ? input : output) + processed * stride, sizeof(T),
38+
GDRAM2NRAM, sizeof(T), src_stride, curr_batch - 1);
39+
40+
if constexpr (is_half) {
41+
__bang_half2float(float_buffer, reinterpret_cast<half *>(temp_buffer),
42+
curr_batch);
43+
} else if constexpr (is_bfloat16) {
44+
__bang_bfloat162float(float_buffer, temp_buffer, curr_batch);
45+
}
46+
}
47+
48+
// Common processing for all types.
49+
if (is_exp_phase) {
50+
__bang_sub_scalar(float_buffer, float_buffer, scalar,
51+
curr_batch); // scalar is max_val
52+
__bang_active_exphp(float_buffer, float_buffer, curr_batch);
53+
} else {
54+
__bang_mul_scalar(float_buffer, float_buffer, scalar,
55+
curr_batch); // scalar is 1.0f/sum_val
56+
}
57+
58+
if constexpr (is_float) {
59+
__memcpy(output + processed * stride, float_buffer, sizeof(float),
60+
NRAM2GDRAM, dst_stride, sizeof(float), curr_batch - 1);
61+
} else {
62+
if constexpr (is_half) {
63+
__bang_float2half(reinterpret_cast<half *>(temp_buffer), float_buffer,
64+
curr_batch);
65+
} else if constexpr (is_bfloat16) {
66+
__bang_float2bfloat16(temp_buffer, float_buffer, curr_batch);
67+
}
68+
69+
__memcpy(output + processed * stride, temp_buffer, sizeof(T), NRAM2GDRAM,
70+
dst_stride, sizeof(T), curr_batch - 1);
71+
}
72+
73+
processed += curr_batch;
74+
}
75+
}
76+
77+
template <typename T>
78+
__mlu_global__ void CausalSoftmax(T *y, const T *x, size_t batch_size,
79+
size_t seq_len, size_t total_seq_len,
80+
ptrdiff_t y_stride_b, ptrdiff_t y_stride_i,
81+
ptrdiff_t y_stride_j, ptrdiff_t x_stride_b,
82+
ptrdiff_t x_stride_i, ptrdiff_t x_stride_j) {
83+
size_t task_id = taskId;
84+
size_t task_num = taskDimX * taskDimY;
85+
86+
size_t total_tasks = batch_size * seq_len;
87+
size_t tasks_per_core = (total_tasks + task_num - 1) / task_num;
88+
size_t start = task_id * tasks_per_core;
89+
size_t end = std::min(start + tasks_per_core, total_tasks);
90+
91+
const int max_batch = SRC_MAX_SIZE / sizeof(T);
92+
T *src = (T *)nram_buffer;
93+
float *dst = (float *)(nram_buffer + max_batch * sizeof(T));
94+
95+
for (size_t index = start; index < end; index++) {
96+
size_t batch = index / seq_len;
97+
size_t i = (index % seq_len);
98+
ptrdiff_t y_offset = batch * y_stride_b + i * y_stride_i;
99+
ptrdiff_t x_offset = batch * x_stride_b + i * x_stride_i;
100+
T *y_ = y + y_offset;
101+
const T *x_ = x + x_offset;
102+
103+
// Calculate the valid sequence length for this position.
104+
size_t valid_len = total_seq_len - seq_len + i + 1;
105+
106+
// Zero out future positions.
107+
for (size_t j = valid_len; j < total_seq_len; j++) {
108+
y_[j * y_stride_j] = (T)0.0f;
109+
}
110+
111+
// Calculate max value using optimized reduction.
112+
float max_val =
113+
infini::ops::reduce::MaxBatched(x_, src, dst, valid_len, max_batch);
114+
115+
// Compute `exp(x - max)`.
116+
ProcessSoftmaxStep(x_, y_, max_val, valid_len, x_stride_j, true);
117+
118+
// Calculate sum of exponentials.
119+
float sum_val =
120+
infini::ops::reduce::SumBatched(y_, src, dst, valid_len, max_batch);
121+
122+
// Normalize by sum.
123+
ProcessSoftmaxStep(y_, y_, 1.0f / sum_val, valid_len, y_stride_j, false);
124+
}
125+
}
126+
127+
template <typename T>
128+
void CausalSoftmaxUnion(void *workspace, int core_per_cluster,
129+
int cluster_count, cnrtQueue_t queue, void *y,
130+
const void *x, size_t batch_size_, size_t seq_len_,
131+
size_t total_seq_len_, ptrdiff_t y_stride_b,
132+
ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
133+
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
134+
ptrdiff_t x_stride_j) {
135+
cnrtDim3_t kernel_dim;
136+
cnrtFunctionType_t kernel_type;
137+
138+
kernel_dim.x = core_per_cluster;
139+
kernel_dim.y = cluster_count;
140+
kernel_dim.z = 1;
141+
kernel_type = cnrtFuncTypeUnion1;
142+
143+
CausalSoftmax<T><<<kernel_dim, kernel_type, queue>>>(
144+
(T *)y, (const T *)x, batch_size_, seq_len_, total_seq_len_, y_stride_b,
145+
y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j);
146+
147+
cnrtQueueSync(queue);
148+
}
149+
150+
template void CausalSoftmaxUnion<__half>(void *, int, int, cnrtQueue_t, void *,
151+
const void *, size_t, size_t, size_t,
152+
ptrdiff_t, ptrdiff_t, ptrdiff_t,
153+
ptrdiff_t, ptrdiff_t, ptrdiff_t);
154+
155+
template void CausalSoftmaxUnion<__bang_bfloat16>(
156+
void *, int, int, cnrtQueue_t, void *, const void *, size_t, size_t, size_t,
157+
ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t);
158+
159+
template void CausalSoftmaxUnion<float>(void *, int, int, cnrtQueue_t, void *,
160+
const void *, size_t, size_t, size_t,
161+
ptrdiff_t, ptrdiff_t, ptrdiff_t,
162+
ptrdiff_t, ptrdiff_t, ptrdiff_t);
163+
164+
} // namespace infini::ops

src/cambricon/common.h

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,164 @@ __mlu_func__ void SumInternal(float* dst, float* src, int max_batch) {
3030
}
3131
}
3232

33+
template <typename T>
34+
__mlu_func__ void SumTyped(float *result, T *data, size_t len) {
35+
if constexpr (std::is_same_v<T, __half>) {
36+
__bang_half2float((float *)data, reinterpret_cast<half *>(data) + len, len);
37+
SumInternal(result, (float *)data, len);
38+
} else if constexpr (std::is_same_v<T, __bang_bfloat16>) {
39+
__bang_bfloat162float((float *)data, data + len, len);
40+
SumInternal(result, (float *)data, len);
41+
} else {
42+
SumInternal(result, data, len);
43+
}
44+
}
45+
46+
template <typename T>
47+
__mlu_func__ float Sum(const T *source, T *src, float *dst, int num_elements,
48+
int max_batch) {
49+
float res = 0.0f;
50+
int offset = (sizeof(T) == 2 ? max_batch : 0);
51+
52+
size_t processed = 0;
53+
while (processed < num_elements) {
54+
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
55+
56+
if (curr_batch < max_batch) {
57+
__bang_write_value(src, max_batch + offset, 0);
58+
}
59+
60+
__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
61+
GDRAM2NRAM);
62+
SumTyped(dst, src, max_batch);
63+
res += dst[0];
64+
processed += curr_batch;
65+
}
66+
67+
return res;
68+
}
69+
70+
template <typename T>
71+
__mlu_func__ float SumBatched(const T *source, T *src, float *dst,
72+
int num_elements, int max_batch) {
73+
constexpr int min_vector_size = 32;
74+
75+
if (num_elements < min_vector_size) {
76+
return Sum(source, src, dst, num_elements, max_batch);
77+
}
78+
79+
float res = 0.0f;
80+
int offset = (sizeof(T) == 2 ? max_batch : 0);
81+
82+
size_t processed = 0;
83+
while (processed < num_elements) {
84+
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
85+
size_t aligned_batch = (curr_batch / batch_size) * batch_size;
86+
size_t remainder = curr_batch % batch_size;
87+
88+
// Ensure NRAM buffer is zeroed.
89+
__bang_write_value(src, max_batch + offset, 0);
90+
91+
// Copy data to NRAM.
92+
__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
93+
GDRAM2NRAM);
94+
95+
if constexpr (std::is_same_v<T, __half>) {
96+
__bang_half2float((float *)(src + offset),
97+
reinterpret_cast<half *>(src) + offset, curr_batch);
98+
} else if constexpr (std::is_same_v<T, __bang_bfloat16>) {
99+
__bang_bfloat162float((float *)(src + offset), src + offset, curr_batch);
100+
}
101+
102+
if (aligned_batch > 0) {
103+
SumInternal(dst, (float *)(src + offset), aligned_batch);
104+
res += dst[0];
105+
}
106+
if (remainder > 0) {
107+
for (size_t i = aligned_batch; i < curr_batch; ++i) {
108+
res += ((float *)(src + offset))[i];
109+
}
110+
}
111+
112+
processed += curr_batch;
113+
}
114+
115+
return res;
116+
}
117+
118+
__mlu_func__ void MaxInternal(float *dst, float *src, int max_batch) {
119+
__bang_maxpool(dst, src, batch_size, 1, max_batch / batch_size, 1,
120+
max_batch / batch_size, 1, 1);
121+
__bang_argmax(dst, dst, batch_size);
122+
}
123+
124+
template <typename T>
125+
__mlu_func__ void MaxTyped(float *result, T *data, size_t len) {
126+
if constexpr (std::is_same_v<T, __half>) {
127+
__bang_half2float((float *)data, reinterpret_cast<half *>(data) + len, len);
128+
MaxInternal(result, (float *)data, len);
129+
} else if constexpr (std::is_same_v<T, __bang_bfloat16>) {
130+
__bang_bfloat162float((float *)data, data + len, len);
131+
MaxInternal(result, (float *)data, len);
132+
} else {
133+
MaxInternal(result, data, len);
134+
}
135+
}
136+
137+
template <typename T>
138+
__mlu_func__ float Max(const T *source, T *src, float *dst, int num_elements,
139+
int max_batch) {
140+
float max_val = -INFINITY;
141+
int offset = (sizeof(T) == 2 ? max_batch : 0);
142+
143+
size_t processed = 0;
144+
while (processed < num_elements) {
145+
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
146+
147+
if (curr_batch < max_batch) {
148+
__bang_write_value(src, max_batch + offset, 0);
149+
}
150+
151+
__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
152+
GDRAM2NRAM);
153+
MaxTyped(dst, src, max_batch);
154+
max_val = std::max(max_val, dst[0]);
155+
processed += curr_batch;
156+
}
157+
158+
return max_val;
159+
}
160+
161+
template <typename T>
162+
__mlu_func__ float MaxBatched(const T *source, T *src, float *dst,
163+
int num_elements, int max_batch) {
164+
constexpr int min_vector_size = 32;
165+
166+
if (num_elements < min_vector_size) {
167+
return Max(source, src, dst, num_elements, max_batch);
168+
}
169+
170+
float max_val = -INFINITY;
171+
int offset = (sizeof(T) == 2 ? max_batch : 0);
172+
173+
size_t processed = 0;
174+
while (processed < num_elements) {
175+
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
176+
177+
if (curr_batch < max_batch) {
178+
__bang_write_value(src, max_batch + offset, 0);
179+
}
180+
181+
__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
182+
GDRAM2NRAM);
183+
MaxTyped(dst, src, max_batch);
184+
max_val = std::max(max_val, dst[0]);
185+
processed += curr_batch;
186+
}
187+
188+
return max_val;
189+
}
190+
33191
} // namespace infini::ops::reduce
34192

35193
#endif // __BANG__

0 commit comments

Comments
 (0)