Skip to content

Commit 2816b58

Browse files
authored
refactor: make data type mappings and shared CUDA headers device-aware (#38)
* refactor: make `TypeMap`, `IsFP16`, `IsBFloat16`, and `DispatchFunc` device-aware * refactor: make `cuda/` shared headers self-contained and include-order-independent * fix: update call sites to device-aware `TypeMap`, `IsFP16`/`IsBFloat16`, and `DispatchFunc` * chore: format files with `clang-format` * fix: update `cuda/swiglu` kernels to use device-aware type predicates * fix: replace per-instance `blasHandle_t` with a static singleton in `Blas` * fix: restore kernel headers for `moore/add` and `moore/swiglu` to use `clang-format off` and `clang-format on` * fix: use absolute includes, consistent include guards, and formatted comments * refactor: extract `GetOptimalBlockSize` logic into shared `ComputeOptimalBlockSize` * fix: include `<musa_fp16.h>` in `polyfills.cuh` before `hrcp` macro to prevent collision * chore: add blank lines between `using` type alias declarations in `device_.h` * chore: add TODO comments for potential performance and concurrency issues * fix: move `clang-format` guards to wrap only CUDA headers in `iluvatar/device_.h`
1 parent f17e37c commit 2816b58

45 files changed

Lines changed: 683 additions & 524 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/cambricon/device_.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef INFINI_OPS_CAMBRICON_DEVICE__H_
2+
#define INFINI_OPS_CAMBRICON_DEVICE__H_
3+
4+
#include "bang_bf16.h"
5+
#include "bang_fp16.h"
6+
#include "data_type.h"
7+
#include "device.h"
8+
9+
namespace infini::ops {
10+
11+
template <>
12+
struct TypeMap<Device::Type::kCambricon, DataType::kFloat16> {
13+
using type = __half;
14+
};
15+
16+
template <>
17+
struct TypeMap<Device::Type::kCambricon, DataType::kBFloat16> {
18+
using type = __bang_bfloat16;
19+
};
20+
21+
} // namespace infini::ops
22+
23+
#endif

src/cambricon/rms_norm/kernel.mlu

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ __nram__ char nram_buffer[NRAM_MAX_SIZE];
55
namespace infini::ops {
66

77
template <typename T, typename TW>
8-
__mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
9-
size_t *shape, ptrdiff_t *output_strides,
10-
ptrdiff_t *input_strides, float epsilon,
8+
__mlu_global__ void RmsNorm(const T* input, const TW* weight, T* output,
9+
size_t* shape, ptrdiff_t* output_strides,
10+
ptrdiff_t* input_strides, float epsilon,
1111
int num_dims, int norm_dim_size) {
1212
// Calculate problem dimensions.
1313
int batch_volume = 1;
@@ -40,11 +40,11 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
4040
constexpr int reduce_buffer_size = 128 / sizeof(float);
4141

4242
// NRAM buffer allocation with dynamic sizing.
43-
float *reduction_buffer = (float *)nram_buffer;
44-
T *input_cache = (T *)(reduction_buffer + reduce_buffer_size);
45-
TW *weight_cache = (TW *)(input_cache + max_batch_size);
46-
float *float_buffer = (float *)(weight_cache + max_batch_size);
47-
float *weight_float_buffer = (float *)(float_buffer + max_batch_size);
43+
float* reduction_buffer = (float*)nram_buffer;
44+
T* input_cache = (T*)(reduction_buffer + reduce_buffer_size);
45+
TW* weight_cache = (TW*)(input_cache + max_batch_size);
46+
float* float_buffer = (float*)(weight_cache + max_batch_size);
47+
float* weight_float_buffer = (float*)(float_buffer + max_batch_size);
4848

4949
// Process vectors assigned to current core.
5050
for (int task_idx = 0; task_idx < actual_tasks; ++task_idx) {
@@ -69,7 +69,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
6969
__memcpy(input_cache, input + input_offset, vector_size * sizeof(T),
7070
GDRAM2NRAM);
7171
if constexpr (std::is_same<T, __half>::value) {
72-
__bang_half2float(float_buffer, reinterpret_cast<half *>(input_cache),
72+
__bang_half2float(float_buffer, reinterpret_cast<half*>(input_cache),
7373
vector_size);
7474
} else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
7575
__bang_bfloat162float(float_buffer, input_cache, vector_size);
@@ -99,7 +99,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
9999
current_batch * sizeof(T), GDRAM2NRAM);
100100

101101
if constexpr (std::is_same<T, __half>::value) {
102-
__bang_half2float(float_buffer, reinterpret_cast<half *>(input_cache),
102+
__bang_half2float(float_buffer, reinterpret_cast<half*>(input_cache),
103103
current_batch);
104104
} else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
105105
__bang_bfloat162float(float_buffer, input_cache, current_batch);
@@ -137,7 +137,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
137137
__memcpy(weight_cache, weight, vector_size * sizeof(TW), GDRAM2NRAM);
138138

139139
if constexpr (std::is_same<T, __half>::value) {
140-
__bang_half2float(float_buffer, reinterpret_cast<half *>(input_cache),
140+
__bang_half2float(float_buffer, reinterpret_cast<half*>(input_cache),
141141
vector_size);
142142
} else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
143143
__bang_bfloat162float(float_buffer, input_cache, vector_size);
@@ -148,7 +148,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
148148

149149
if constexpr (std::is_same<TW, __half>::value) {
150150
__bang_half2float(weight_float_buffer,
151-
reinterpret_cast<half *>(weight_cache), vector_size);
151+
reinterpret_cast<half*>(weight_cache), vector_size);
152152
} else if constexpr (std::is_same<TW, __bang_bfloat16>::value) {
153153
__bang_bfloat162float(weight_float_buffer, weight_cache, vector_size);
154154
} else {
@@ -161,7 +161,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
161161
__bang_mul_scalar(float_buffer, float_buffer, inv_rms, vector_size);
162162

163163
if constexpr (std::is_same<T, __half>::value) {
164-
__bang_float2half(reinterpret_cast<half *>(input_cache), float_buffer,
164+
__bang_float2half(reinterpret_cast<half*>(input_cache), float_buffer,
165165
vector_size);
166166
} else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
167167
__bang_float2bfloat16(input_cache, float_buffer, vector_size);
@@ -188,7 +188,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
188188
current_batch * sizeof(TW), GDRAM2NRAM);
189189

190190
if constexpr (std::is_same<T, __half>::value) {
191-
__bang_half2float(float_buffer, reinterpret_cast<half *>(input_cache),
191+
__bang_half2float(float_buffer, reinterpret_cast<half*>(input_cache),
192192
current_batch);
193193
} else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
194194
__bang_bfloat162float(float_buffer, input_cache, current_batch);
@@ -199,7 +199,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
199199

200200
if constexpr (std::is_same<TW, __half>::value) {
201201
__bang_half2float(weight_float_buffer,
202-
reinterpret_cast<half *>(weight_cache),
202+
reinterpret_cast<half*>(weight_cache),
203203
current_batch);
204204
} else if constexpr (std::is_same<TW, __bang_bfloat16>::value) {
205205
__bang_bfloat162float(weight_float_buffer, weight_cache,
@@ -214,7 +214,7 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
214214
__bang_mul_scalar(float_buffer, float_buffer, inv_rms, current_batch);
215215

216216
if constexpr (std::is_same<T, __half>::value) {
217-
__bang_float2half(reinterpret_cast<half *>(input_cache), float_buffer,
217+
__bang_float2half(reinterpret_cast<half*>(input_cache), float_buffer,
218218
current_batch);
219219
} else if constexpr (std::is_same<T, __bang_bfloat16>::value) {
220220
__bang_float2bfloat16(input_cache, float_buffer, current_batch);
@@ -234,10 +234,10 @@ __mlu_global__ void RmsNorm(const T *input, const TW *weight, T *output,
234234
}
235235

236236
template <typename T, typename TW>
237-
void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count,
238-
cnrtQueue_t queue, void *y, const void *x, const void *w,
239-
const size_t *shape, const ptrdiff_t *y_strides,
240-
const ptrdiff_t *x_strides, float eps, int ndim) {
237+
void RmsNormUnion(void* workspace, int core_per_cluster, int cluster_count,
238+
cnrtQueue_t queue, void* y, const void* x, const void* w,
239+
const size_t* shape, const ptrdiff_t* y_strides,
240+
const ptrdiff_t* x_strides, float eps, int ndim) {
241241
cnrtDim3_t kernel_dim;
242242
cnrtFunctionType_t kernel_type;
243243

@@ -263,23 +263,23 @@ void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count,
263263
}
264264

265265
// Prepare device pointers.
266-
auto y_ = reinterpret_cast<T *>(y);
267-
auto x_ = reinterpret_cast<const T *>(x);
268-
auto w_ = reinterpret_cast<const TW *>(w);
269-
char *tmp_device = reinterpret_cast<char *>(workspace);
270-
char *tmp_stride = tmp_device + ndim * sizeof(size_t);
271-
size_t *mlu_shape = (size_t *)tmp_device;
272-
ptrdiff_t *mlu_x_strides = (ptrdiff_t *)tmp_stride;
273-
ptrdiff_t *mlu_y_strides = mlu_x_strides + ndim;
266+
auto y_ = reinterpret_cast<T*>(y);
267+
auto x_ = reinterpret_cast<const T*>(x);
268+
auto w_ = reinterpret_cast<const TW*>(w);
269+
char* tmp_device = reinterpret_cast<char*>(workspace);
270+
char* tmp_stride = tmp_device + ndim * sizeof(size_t);
271+
size_t* mlu_shape = (size_t*)tmp_device;
272+
ptrdiff_t* mlu_x_strides = (ptrdiff_t*)tmp_stride;
273+
ptrdiff_t* mlu_y_strides = mlu_x_strides + ndim;
274274

275275
// Copy shape and stride information to device.
276-
CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast<size_t *>(shape),
276+
CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast<size_t*>(shape),
277277
ndim * sizeof(size_t), queue,
278278
cnrtMemcpyHostToDev)); // const not supported
279-
CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast<ptrdiff_t *>(x_strides),
279+
CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast<ptrdiff_t*>(x_strides),
280280
ndim * sizeof(ptrdiff_t), queue,
281281
cnrtMemcpyHostToDev));
282-
CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast<ptrdiff_t *>(y_strides),
282+
CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast<ptrdiff_t*>(y_strides),
283283
ndim * sizeof(ptrdiff_t), queue,
284284
cnrtMemcpyHostToDev));
285285

@@ -289,44 +289,44 @@ void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count,
289289
cnrtQueueSync(queue);
290290
}
291291

292-
template void RmsNormUnion<__half, __half>(void *, int, int, cnrtQueue_t,
293-
void *, const void *, const void *,
294-
const size_t *, const ptrdiff_t *,
295-
const ptrdiff_t *, float, int);
292+
template void RmsNormUnion<__half, __half>(void*, int, int, cnrtQueue_t, void*,
293+
const void*, const void*,
294+
const size_t*, const ptrdiff_t*,
295+
const ptrdiff_t*, float, int);
296296

297297
template void RmsNormUnion<__half, __bang_bfloat16>(
298-
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
299-
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
298+
void*, int, int, cnrtQueue_t, void*, const void*, const void*,
299+
const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int);
300300

301-
template void RmsNormUnion<__half, float>(void *, int, int, cnrtQueue_t, void *,
302-
const void *, const void *,
303-
const size_t *, const ptrdiff_t *,
304-
const ptrdiff_t *, float, int);
301+
template void RmsNormUnion<__half, float>(void*, int, int, cnrtQueue_t, void*,
302+
const void*, const void*,
303+
const size_t*, const ptrdiff_t*,
304+
const ptrdiff_t*, float, int);
305305

306306
template void RmsNormUnion<__bang_bfloat16, __half>(
307-
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
308-
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
307+
void*, int, int, cnrtQueue_t, void*, const void*, const void*,
308+
const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int);
309309

310310
template void RmsNormUnion<__bang_bfloat16, __bang_bfloat16>(
311-
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
312-
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
311+
void*, int, int, cnrtQueue_t, void*, const void*, const void*,
312+
const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int);
313313

314314
template void RmsNormUnion<__bang_bfloat16, float>(
315-
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
316-
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
315+
void*, int, int, cnrtQueue_t, void*, const void*, const void*,
316+
const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int);
317317

318-
template void RmsNormUnion<float, __half>(void *, int, int, cnrtQueue_t, void *,
319-
const void *, const void *,
320-
const size_t *, const ptrdiff_t *,
321-
const ptrdiff_t *, float, int);
318+
template void RmsNormUnion<float, __half>(void*, int, int, cnrtQueue_t, void*,
319+
const void*, const void*,
320+
const size_t*, const ptrdiff_t*,
321+
const ptrdiff_t*, float, int);
322322

323323
template void RmsNormUnion<float, __bang_bfloat16>(
324-
void *, int, int, cnrtQueue_t, void *, const void *, const void *,
325-
const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int);
324+
void*, int, int, cnrtQueue_t, void*, const void*, const void*,
325+
const size_t*, const ptrdiff_t*, const ptrdiff_t*, float, int);
326326

327-
template void RmsNormUnion<float, float>(void *, int, int, cnrtQueue_t, void *,
328-
const void *, const void *,
329-
const size_t *, const ptrdiff_t *,
330-
const ptrdiff_t *, float, int);
327+
template void RmsNormUnion<float, float>(void*, int, int, cnrtQueue_t, void*,
328+
const void*, const void*,
329+
const size_t*, const ptrdiff_t*,
330+
const ptrdiff_t*, float, int);
331331

332332
} // namespace infini::ops

src/cambricon/rms_norm/rms_norm.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
#include <cstdint>
66
#include <vector>
77

8-
#include "../common.h"
8+
#include "cambricon/common.h"
9+
#include "cambricon/device_.h"
910
#include "base/rms_norm.h"
1011

1112
namespace infini::ops {
1213

1314
// TODO: Remove forward declaration.
1415
template <typename T, typename Tw>
15-
void RmsNormUnion(void *workspace, int core_per_cluster, int cluster_count,
16-
cnrtQueue_t queue, void *y, const void *x, const void *w,
17-
const size_t *shape, const ptrdiff_t *y_strides,
18-
const ptrdiff_t *x_strides, float eps, int ndim);
16+
void RmsNormUnion(void* workspace, int core_per_cluster, int cluster_count,
17+
cnrtQueue_t queue, void* y, const void* x, const void* w,
18+
const size_t* shape, const ptrdiff_t* y_strides,
19+
const ptrdiff_t* x_strides, float eps, int ndim);
1920

2021
template <>
2122
class Operator<RmsNorm, Device::Type::kCambricon> : public RmsNorm {
@@ -33,6 +34,7 @@ class Operator<RmsNorm, Device::Type::kCambricon> : public RmsNorm {
3334
auto workspace{workspace_ ? workspace_ : default_workspace_};
3435

3536
DispatchFunc<
37+
Device::Type::kCambricon,
3638
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>,
3739
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>>(
3840
{input.dtype(), weight.dtype()},
@@ -41,8 +43,8 @@ class Operator<RmsNorm, Device::Type::kCambricon> : public RmsNorm {
4143
using WeightT = typename decltype(weight_tag)::type;
4244

4345
RmsNormUnion<InputT, WeightT>(
44-
workspace, core_per_cluster, cluster_count, queue,
45-
out.data(), input.data(), weight.data(), out_shape_.data(),
46+
workspace, core_per_cluster, cluster_count, queue, out.data(),
47+
input.data(), weight.data(), out_shape_.data(),
4648
out_strides_.data(), input_strides_.data(), eps, ndim_);
4749
},
4850
"CambriconRmsNorm::operator() - output dispatch");
@@ -54,7 +56,7 @@ class Operator<RmsNorm, Device::Type::kCambricon> : public RmsNorm {
5456
return ndim_ * (sizeof(size_t) + 2 * sizeof(ptrdiff_t));
5557
}
5658

57-
void *default_workspace_{nullptr};
59+
void* default_workspace_{nullptr};
5860
int core_per_cluster = 0;
5961
int cluster_count = 0;
6062
};

src/common/constexpr_map.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct ConstexprMap {
1414
: data_(data) {}
1515

1616
constexpr Value at(Key key) const {
17-
for (const auto &pr : data_) {
17+
for (const auto& pr : data_) {
1818
if (pr.first == key) return pr.second;
1919
}
2020
// TODO(lzm): change to logging.

src/cpu/add/add.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Operator<Add, Device::Type::kCpu> : public Add,
2020

2121
void operator()(const Tensor input, const Tensor other,
2222
Tensor out) const override {
23-
DispatchFunc<AllTypes>(
23+
DispatchFunc<Device::Type::kCpu, AllTypes>(
2424
out_type_,
2525
[&](auto tag) {
2626
using T = typename decltype(tag)::type;
@@ -32,8 +32,9 @@ class Operator<Add, Device::Type::kCpu> : public Add,
3232
private:
3333
template <typename T>
3434
void Compute(const Tensor input, const Tensor other, Tensor out) const {
35-
using ComputeType =
36-
std::conditional_t<IsBFloat16<T> || IsFP16<T>, float, T>;
35+
using ComputeType = std::conditional_t<IsBFloat16<Device::Type::kCpu, T> ||
36+
IsFP16<Device::Type::kCpu, T>,
37+
float, T>;
3738

3839
const auto* input_ptr = static_cast<const T*>(input.data());
3940
const auto* other_ptr = static_cast<const T*>(other.data());

src/cpu/caster_.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <type_traits>
55

66
#include "caster.h"
7+
#include "cpu/device_.h"
78

89
namespace infini::ops {
910

@@ -21,8 +22,10 @@ struct Caster<Device::Type::kCpu> {
2122
return std::forward<Src>(x);
2223
}
2324

24-
constexpr bool src_is_custom = IsBFloat16<PureSrc> || IsFP16<PureSrc>;
25-
constexpr bool dst_is_custom = IsBFloat16<PureDst> || IsFP16<PureDst>;
25+
constexpr bool src_is_custom = IsBFloat16<Device::Type::kCpu, PureSrc> ||
26+
IsFP16<Device::Type::kCpu, PureSrc>;
27+
constexpr bool dst_is_custom = IsBFloat16<Device::Type::kCpu, PureDst> ||
28+
IsFP16<Device::Type::kCpu, PureDst>;
2629

2730
if constexpr (!src_is_custom && !dst_is_custom) {
2831
return static_cast<PureDst>(std::forward<Src>(x));

src/cpu/causal_softmax/causal_softmax.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Operator<CausalSoftmax, Device::Type::kCpu> : public CausalSoftmax,
1818
Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} {}
1919

2020
void operator()(const Tensor input, Tensor out) const override {
21-
DispatchFunc<AllFloatTypes>(
21+
DispatchFunc<Device::Type::kCpu, AllFloatTypes>(
2222
out.dtype(),
2323
[&](auto tag) {
2424
using T = typename decltype(tag)::type;

src/cpu/device_.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef INFINI_OPS_CPU_DEVICE__H_
2+
#define INFINI_OPS_CPU_DEVICE__H_
3+
4+
#include "data_type.h"
5+
#include "device.h"
6+
7+
namespace infini::ops {
8+
9+
template <>
10+
struct TypeMap<Device::Type::kCpu, DataType::kFloat16> {
11+
using type = Float16;
12+
};
13+
14+
template <>
15+
struct TypeMap<Device::Type::kCpu, DataType::kBFloat16> {
16+
using type = BFloat16;
17+
};
18+
19+
} // namespace infini::ops
20+
21+
#endif

src/cpu/gemm/gemm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class Operator<Gemm, Device::Type::kCpu> : public Gemm,
3131
void operator()(const Tensor a, const Tensor b, std::optional<float> alpha,
3232
std::optional<float> beta, std::optional<int> trans_a,
3333
std::optional<int> trans_b, Tensor c) const override {
34-
DispatchFunc<AllFloatTypes>(
34+
DispatchFunc<Device::Type::kCpu, AllFloatTypes>(
3535
c.dtype(),
3636
[&](auto tag) {
3737
using T = typename decltype(tag)::type;

src/cpu/rms_norm/rms_norm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Operator<RmsNorm, Device::Type::kCpu> : public RmsNorm,
1919

2020
void operator()(const Tensor input, const Tensor weight, float eps,
2121
Tensor out) const override {
22-
DispatchFunc<AllFloatTypes>(
22+
DispatchFunc<Device::Type::kCpu, AllFloatTypes>(
2323
out.dtype(),
2424
[&](auto tag) {
2525
using T = typename decltype(tag)::type;

0 commit comments

Comments
 (0)