@@ -5,9 +5,9 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE];
55namespace infini::ops {
66
77template <typename T, typename TW>
8- __mlu_global__ void RmsNorm(const T * input, const TW * weight, T * output,
9- size_t * shape, ptrdiff_t * output_strides,
10- ptrdiff_t * input_strides, float epsilon,
8+ __mlu_global__ void RmsNorm(const T* input, const TW* weight, T* output,
9+ size_t* shape, ptrdiff_t* output_strides,
10+ ptrdiff_t* input_strides, float epsilon,
1111 int num_dims, int norm_dim_size) {
1212 // Calculate problem dimensions.
1313 int batch_volume = 1;
@@ -40,11 +40,11 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
4040 constexpr int reduce_buffer_size = 128 / sizeof(float);
4141
4242 // NRAM buffer allocation with dynamic sizing.
43- float * reduction_buffer = (float *)nram_buffer;
44- T * input_cache = (T *)(reduction_buffer + reduce_buffer_size);
45- TW * weight_cache = (TW *)(input_cache + max_batch_size);
46- float * float_buffer = (float *)(weight_cache + max_batch_size);
47- float * weight_float_buffer = (float *)(float_buffer + max_batch_size);
43+ float* reduction_buffer = (float*)nram_buffer;
44+ T* input_cache = (T*)(reduction_buffer + reduce_buffer_size);
45+ TW* weight_cache = (TW*)(input_cache + max_batch_size);
46+ float* float_buffer = (float*)(weight_cache + max_batch_size);
47+ float* weight_float_buffer = (float*)(float_buffer + max_batch_size);
4848
4949 // Process vectors assigned to current core.
5050 for (int task_idx = 0; task_idx < actual_tasks; ++task_idx) {
@@ -69,7 +69,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
6969 __memcpy(input_cache, input + input_offset, vector_size * sizeof(T),
7070 GDRAM2NRAM);
7171 if constexpr (std::is_same<T, __half>::value) {
72- __bang_half2float(float_buffer, reinterpret_cast<half *>(input_cache),
72+ __bang_half2float(float_buffer, reinterpret_cast<half*>(input_cache),
7373 vector_size);
7474 } else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
7575 __bang_bfloat162float(float_buffer, input_cache, vector_size);
@@ -99,7 +99,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
9999 current_batch * sizeof(T), GDRAM2NRAM);
100100
101101 if constexpr (std::is_same<T, __half>::value) {
102- __bang_half2float(float_buffer, reinterpret_cast<half *>(input_cache),
102+ __bang_half2float(float_buffer, reinterpret_cast<half*>(input_cache),
103103 current_batch);
104104 } else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
105105 __bang_bfloat162float(float_buffer, input_cache, current_batch);
@@ -137,7 +137,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
137137 __memcpy(weight_cache, weight, vector_size * sizeof(TW), GDRAM2NRAM);
138138
139139 if constexpr (std::is_same<T, __half>::value) {
140- __bang_half2float(float_buffer, reinterpret_cast<half *>(input_cache),
140+ __bang_half2float(float_buffer, reinterpret_cast<half*>(input_cache),
141141 vector_size);
142142 } else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
143143 __bang_bfloat162float(float_buffer, input_cache, vector_size);
@@ -148,7 +148,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
148148
149149 if constexpr (std::is_same<TW, __half>::value) {
150150 __bang_half2float(weight_float_buffer,
151- reinterpret_cast<half *>(weight_cache), vector_size);
151+ reinterpret_cast<half*>(weight_cache), vector_size);
152152 } else if constexpr (std::is_same<TW, __bang_bfloat16>::value) {
153153 __bang_bfloat162float(weight_float_buffer, weight_cache, vector_size);
154154 } else {
@@ -161,7 +161,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
161161 __bang_mul_scalar(float_buffer, float_buffer, inv_rms, vector_size);
162162
163163 if constexpr (std::is_same<T, __half>::value) {
164- __bang_float2half(reinterpret_cast<half *>(input_cache), float_buffer,
164+ __bang_float2half(reinterpret_cast<half*>(input_cache), float_buffer,
165165 vector_size);
166166 } else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
167167 __bang_float2bfloat16(input_cache, float_buffer, vector_size);
@@ -188,7 +188,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
188188 current_batch * sizeof(TW), GDRAM2NRAM);
189189
190190 if constexpr (std::is_same<T, __half>::value) {
191- __bang_half2float(float_buffer, reinterpret_cast<half *>(input_cache),
191+ __bang_half2float(float_buffer, reinterpret_cast<half*>(input_cache),
192192 current_batch);
193193 } else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
194194 __bang_bfloat162float(float_buffer, input_cache, current_batch);
@@ -199,7 +199,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
199199
200200 if constexpr (std::is_same<TW, __half>::value) {
201201 __bang_half2float(weight_float_buffer,
202- reinterpret_cast<half *>(weight_cache),
202+ reinterpret_cast<half*>(weight_cache),
203203 current_batch);
204204 } else if constexpr (std::is_same<TW, __bang_bfloat16>::value) {
205205 __bang_bfloat162float(weight_float_buffer, weight_cache,
@@ -214,7 +214,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
214214 __bang_mul_scalar(float_buffer, float_buffer, inv_rms, current_batch);
215215
216216 if constexpr (std::is_same<T, __half>::value) {
217- __bang_float2half(reinterpret_cast<half *>(input_cache), float_buffer,
217+ __bang_float2half(reinterpret_cast<half*>(input_cache), float_buffer,
218218 current_batch);
219219 } else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
220220 __bang_float2bfloat16(input_cache, float_buffer, current_batch);
@@ -234,10 +234,10 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
234234}
235235
236236template <typename T, typename TW>
237- void RmsNormUnion(void * workspace, int core_per_cluster, int cluster_count,
238- cnrtQueue_t queue, void * y, const void * x, const void * w,
239- const size_t * shape, const ptrdiff_t * y_strides,
240- const ptrdiff_t * x_strides, float eps, int ndim) {
237+ void RmsNormUnion(void* workspace, int core_per_cluster, int cluster_count,
238+ cnrtQueue_t queue, void* y, const void* x, const void* w,
239+ const size_t* shape, const ptrdiff_t* y_strides,
240+ const ptrdiff_t* x_strides, float eps, int ndim) {
241241 cnrtDim3_t kernel_dim;
242242 cnrtFunctionType_t kernel_type;
243243
@@ -263,23 +263,23 @@ void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count,
263263 }
264264
265265 // Prepare device pointers.
266- auto y_ = reinterpret_cast<T *>(y);
267- auto x_ = reinterpret_cast<const T *>(x);
268- auto w_ = reinterpret_cast<const TW *>(w);
269- char * tmp_device = reinterpret_cast<char *>(workspace);
270- char * tmp_stride = tmp_device + ndim * sizeof(size_t);
271- size_t * mlu_shape = (size_t *)tmp_device;
272- ptrdiff_t * mlu_x_strides = (ptrdiff_t *)tmp_stride;
273- ptrdiff_t * mlu_y_strides = mlu_x_strides + ndim;
266+ auto y_ = reinterpret_cast<T*>(y);
267+ auto x_ = reinterpret_cast<const T*>(x);
268+ auto w_ = reinterpret_cast<const TW*>(w);
269+ char* tmp_device = reinterpret_cast<char*>(workspace);
270+ char* tmp_stride = tmp_device + ndim * sizeof(size_t);
271+ size_t* mlu_shape = (size_t*)tmp_device;
272+ ptrdiff_t* mlu_x_strides = (ptrdiff_t*)tmp_stride;
273+ ptrdiff_t* mlu_y_strides = mlu_x_strides + ndim;
274274
275275 // Copy shape and stride information to device.
276- CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast<size_t *>(shape),
276+ CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast<size_t*>(shape),
277277 ndim * sizeof(size_t), queue,
278278 cnrtMemcpyHostToDev)); // const not supported
279- CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast<ptrdiff_t *>(x_strides),
279+ CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast<ptrdiff_t*>(x_strides),
280280 ndim * sizeof(ptrdiff_t), queue,
281281 cnrtMemcpyHostToDev));
282- CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast<ptrdiff_t *>(y_strides),
282+ CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast<ptrdiff_t*>(y_strides),
283283 ndim * sizeof(ptrdiff_t), queue,
284284 cnrtMemcpyHostToDev));
285285
@@ -289,44 +289,44 @@ void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count,
289289 cnrtQueueSync(queue);
290290}
291291
292- template void RmsNormUnion<__half, __half>(void *, int, int, cnrtQueue_t,
293- void *, const void *, const void *,
294- const size_t *, const ptrdiff_t *,
295- const ptrdiff_t *, float, int);
292+ template void RmsNormUnion<__half, __half>(void*, int, int, cnrtQueue_t, void* ,
293+ const void*, const void*,
294+ const size_t*, const ptrdiff_t*,
295+ const ptrdiff_t*, float, int);
296296
297297template void RmsNormUnion<__half, __bang_bfloat16>(
298- void *, int, int, cnrtQueue_t, void *, const void *, const void *,
299- const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
298+ void*, int, int, cnrtQueue_t, void*, const void*, const void*,
299+ const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int);
300300
301- template void RmsNormUnion<__half, float>(void *, int, int, cnrtQueue_t, void *,
302- const void *, const void *,
303- const size_t *, const ptrdiff_t *,
304- const ptrdiff_t *, float, int);
301+ template void RmsNormUnion<__half, float>(void*, int, int, cnrtQueue_t, void*,
302+ const void*, const void*,
303+ const size_t*, const ptrdiff_t*,
304+ const ptrdiff_t*, float, int);
305305
306306template void RmsNormUnion<__bang_bfloat16, __half>(
307- void *, int, int, cnrtQueue_t, void *, const void *, const void *,
308- const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
307+ void*, int, int, cnrtQueue_t, void*, const void*, const void*,
308+ const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int);
309309
310310template void RmsNormUnion<__bang_bfloat16, __bang_bfloat16>(
311- void *, int, int, cnrtQueue_t, void *, const void *, const void *,
312- const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
311+ void*, int, int, cnrtQueue_t, void*, const void*, const void*,
312+ const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int);
313313
314314template void RmsNormUnion<__bang_bfloat16, float>(
315- void *, int, int, cnrtQueue_t, void *, const void *, const void *,
316- const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
315+ void*, int, int, cnrtQueue_t, void*, const void*, const void*,
316+ const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int);
317317
318- template void RmsNormUnion<float, __half>(void *, int, int, cnrtQueue_t, void *,
319- const void *, const void *,
320- const size_t *, const ptrdiff_t *,
321- const ptrdiff_t *, float, int);
318+ template void RmsNormUnion<float, __half>(void*, int, int, cnrtQueue_t, void*,
319+ const void*, const void*,
320+ const size_t*, const ptrdiff_t*,
321+ const ptrdiff_t*, float, int);
322322
323323template void RmsNormUnion<float, __bang_bfloat16>(
324- void *, int, int, cnrtQueue_t, void *, const void *, const void *,
325- const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
324+ void*, int, int, cnrtQueue_t, void*, const void*, const void*,
325+ const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int);
326326
327- template void RmsNormUnion<float, float>(void *, int, int, cnrtQueue_t, void *,
328- const void *, const void *,
329- const size_t *, const ptrdiff_t *,
330- const ptrdiff_t *, float, int);
327+ template void RmsNormUnion<float, float>(void*, int, int, cnrtQueue_t, void*,
328+ const void*, const void*,
329+ const size_t*, const ptrdiff_t*,
330+ const ptrdiff_t*, float, int);
331331
332332} // namespace infini::ops
0 commit comments