Skip to content

Commit 8c159ad

Browse files
author
zhangyue
committed
refactor(pr66-simplify): correct rstd_out semantic name + clarity fixes
Post-merge /simplify review findings applied: - **`AddRmsNorm` param rename** (`src/base/add_rms_norm.h` + 3 Ascend kernels + test): `rstd_out` → `residual_out`. The slot actually holds `xOut` (the `input + other` residual sum) per `aclnnAddRmsNorm`'s API — the internal `rstd_tensor_` reciprocal-std buffer is private. Prior name was misleading. - **Generator shim for `apply_rotary_pos_emb`** (`scripts/generate_wrappers.py`): rename the `head_size`-as-`rotary_dim` positional forward to a named local `rotary_dim_shim` + comment noting the legacy shim assumes full rotary (`rotary_dim == head_size`). - **`kernel_sincos_cache.h` leak comment**: TODO → FIXME with persistent-worker impact call-out. Actual fix still blocked on undocumented input-address index layout for `aclnnRopeWithSinCosCache`. Skipped findings: reviewer false positives on `src/base/rotary_embedding.h` members (all consumed by kernels) and `max_seq_len_` (used in constructor body). Larger refactors (UploadCosSinCache + IndexSelect helpers, ~100 lines copy-paste) deferred to a follow-up PR.
1 parent 3fc0e8d commit 8c159ad

7 files changed

Lines changed: 68 additions & 60 deletions

File tree

scripts/generate_wrappers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,14 @@ def _generate_apply_rotary_pos_emb_shim():
342342
py::object positions = torch.attr("arange")(
343343
num_tokens, py::arg("dtype") = torch.attr("int64"),
344344
py::arg("device") = cos.attr("device"));
345+
// Legacy `apply_rotary_pos_emb` has no `rotary_dim` param; it assumes
346+
// full rotation (`rotary_dim == head_size`) — partial rotary is not
347+
// supported through this shim. Callers needing partial rotary must
348+
// invoke `rotary_embedding` directly with the correct `rotary_dim`.
349+
const int64_t rotary_dim_shim = head_size;
345350
self_module.attr("rotary_embedding")(
346-
positions, query, key, cos_sin_cache, head_size,
347-
py::int_(head_size), is_neox_style, query_out, key_out,
351+
positions, query, key, cos_sin_cache, head_size, rotary_dim_shim,
352+
is_neox_style, query_out, key_out,
348353
/*pre_gathered=*/true,
349354
py::arg("implementation_index") = implementation_index,
350355
py::arg("stream") = stream);

src/ascend/add_rms_norm/kernel.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ template <>
2424
class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
2525
public:
2626
Operator(const Tensor input, const Tensor other, const Tensor weight,
27-
float eps, Tensor out, Tensor rstd_out)
28-
: AddRmsNorm(input, other, weight, eps, out, rstd_out),
27+
float eps, Tensor out, Tensor residual_out)
28+
: AddRmsNorm(input, other, weight, eps, out, residual_out),
2929
input_cache_(input),
3030
other_cache_(other),
3131
weight_cache_(weight),
3232
out_cache_(out),
33-
rstd_out_cache_(rstd_out) {
34-
// Alpha scalar for `aclnnAdd` (`rstd_out = input + 1.0 * other`).
33+
residual_out_cache_(residual_out) {
34+
// Alpha scalar for `aclnnAdd` (`residual_out = input + 1.0 * other`).
3535
alpha_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT);
3636

3737
// `aclnnRmsNorm` writes `rstd` as a required side output. Size is
@@ -49,32 +49,32 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
4949
other_cache_.release();
5050
weight_cache_.release();
5151
out_cache_.release();
52-
rstd_out_cache_.release();
52+
residual_out_cache_.release();
5353

5454
// `rstd_tensor_` leaks with `norm_exec_` at shutdown (see `64c367c`).
5555
if (alpha_) aclDestroyScalar(alpha_);
5656
}
5757

5858
void operator()(const Tensor input, const Tensor other, const Tensor weight,
59-
float eps, Tensor out, Tensor rstd_out) const override {
59+
float eps, Tensor out, Tensor residual_out) const override {
6060
auto t_input = input_cache_.get(const_cast<void*>(input.data()));
6161
auto t_other = other_cache_.get(const_cast<void*>(other.data()));
6262
auto t_weight = weight_cache_.get(const_cast<void*>(weight.data()));
6363
auto t_out = out_cache_.get(out.data());
64-
auto t_rstd_out = rstd_out_cache_.get(rstd_out.data());
64+
auto t_residual_out = residual_out_cache_.get(residual_out.data());
6565
auto stream = static_cast<aclrtStream>(stream_);
6666

67-
// Step 1: `rstd_out = input + other`.
67+
// Step 1: `residual_out = input + other`.
6868
if (!add_exec_) {
69-
aclnnAddGetWorkspaceSize(t_input, t_other, alpha_, t_rstd_out, &add_ws_,
70-
&add_exec_);
69+
aclnnAddGetWorkspaceSize(t_input, t_other, alpha_, t_residual_out,
70+
&add_ws_, &add_exec_);
7171
aclSetAclOpExecutorRepeatable(add_exec_);
7272
} else {
7373
aclSetInputTensorAddr(add_exec_, 0, t_input,
7474
const_cast<void*>(input.data()));
7575
aclSetInputTensorAddr(add_exec_, 1, t_other,
7676
const_cast<void*>(other.data()));
77-
aclSetOutputTensorAddr(add_exec_, 0, t_rstd_out, rstd_out.data());
77+
aclSetOutputTensorAddr(add_exec_, 0, t_residual_out, residual_out.data());
7878
}
7979
auto& add_arena = ascend::GetWorkspacePool().Ensure(stream, add_ws_);
8080
aclnnAdd(add_arena.buf, add_ws_, add_exec_, stream);
@@ -92,13 +92,13 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
9292
aclSetRawTensorAddr(rstd_tensor_, rstd_arena.buf);
9393
}
9494

95-
// Step 2: `out = rms_norm(rstd_out, weight, eps)`.
95+
// Step 2: `out = rms_norm(residual_out, weight, eps)`.
9696
if (!norm_exec_) {
97-
aclnnRmsNormGetWorkspaceSize(t_rstd_out, t_weight, eps, t_out,
97+
aclnnRmsNormGetWorkspaceSize(t_residual_out, t_weight, eps, t_out,
9898
rstd_tensor_, &norm_ws_, &norm_exec_);
9999
aclSetAclOpExecutorRepeatable(norm_exec_);
100100
} else {
101-
aclSetInputTensorAddr(norm_exec_, 0, t_rstd_out, rstd_out.data());
101+
aclSetInputTensorAddr(norm_exec_, 0, t_residual_out, residual_out.data());
102102
aclSetInputTensorAddr(norm_exec_, 1, t_weight,
103103
const_cast<void*>(weight.data()));
104104
aclSetOutputTensorAddr(norm_exec_, 0, t_out, out.data());
@@ -117,7 +117,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
117117

118118
mutable ascend::AclTensorCache out_cache_;
119119

120-
mutable ascend::AclTensorCache rstd_out_cache_;
120+
mutable ascend::AclTensorCache residual_out_cache_;
121121

122122
float alpha_storage_ = 1.0f;
123123

src/ascend/add_rms_norm/kernel_custom.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ namespace infini::ops {
2929

3030
// Custom AscendC fused `AddRmsNorm` kernel (implementation index 2).
3131
//
32-
// A single-kernel implementation that computes `rstd_out = input + other`
33-
// followed by `out = rms_norm(rstd_out, weight, eps)` in one launch,
32+
// A single-kernel implementation that computes `residual_out = input + other`
33+
// followed by `out = rms_norm(residual_out, weight, eps)` in one launch,
3434
// avoiding the decomposed `aclnnAdd` + `aclnnRmsNorm` calls (index 0) or
3535
// the fused `aclnnAddRmsNorm` call (index 1). Migrated from the custom
3636
// `RmsNorm` kernel (index 1 of `RmsNorm`).
3737
//
3838
// Select via `implementation_index=2` in Python:
39-
// `infini.ops.add_rms_norm(input, other, weight, eps, out, rstd_out,
39+
// `infini.ops.add_rms_norm(input, other, weight, eps, out, residual_out,
4040
// implementation_index=2, stream=s)`.
4141
//
4242
// Requirements:
@@ -49,8 +49,8 @@ template <>
4949
class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
5050
public:
5151
Operator(const Tensor input, const Tensor other, const Tensor weight,
52-
float eps, Tensor out, Tensor rstd_out)
53-
: AddRmsNorm(input, other, weight, eps, out, rstd_out) {
52+
float eps, Tensor out, Tensor residual_out)
53+
: AddRmsNorm(input, other, weight, eps, out, residual_out) {
5454
// Dtype size in bytes.
5555
dtype_size_ = (input.dtype() == DataType::kFloat16) ? 2 : 4;
5656

@@ -96,7 +96,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
9696
}
9797

9898
void operator()(const Tensor input, const Tensor other, const Tensor weight,
99-
float eps, Tensor out, Tensor rstd_out) const override {
99+
float eps, Tensor out, Tensor residual_out) const override {
100100
auto stream = static_cast<aclrtStream>(stream_);
101101

102102
// Determine `float32` `weight` pointer.
@@ -144,7 +144,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
144144
// Launch custom AscendC kernel.
145145
aclrtlaunch_add_rms_norm(block_dim, stream, const_cast<void*>(input.data()),
146146
const_cast<void*>(other.data()), weight_fp32,
147-
out.data(), rstd_out.data(), total_rows_,
147+
out.data(), residual_out.data(), total_rows_,
148148
static_cast<int64_t>(dim_), dim_length_align_,
149149
former_num, former_length, tail_length, eps,
150150
dtype_size_);

src/ascend/add_rms_norm/kernel_fused.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,25 @@ namespace infini::ops {
1515

1616
// Fused implementation via `aclnnAddRmsNorm` (implementation index 1).
1717
//
18-
// Computes `rstd_out = input + other` and `out = rms_norm(rstd_out, weight,
19-
// eps)` in a single CANN launch. The fused API has higher host-side launch
20-
// overhead (~200 us) compared to the decomposed `aclnnAdd` + `aclnnRmsNorm`
21-
// path (~39 us), but may offer better NPU-side efficiency for large tensors
22-
// where kernel fusion reduces memory traffic.
18+
// Computes `residual_out = input + other` and `out = rms_norm(residual_out,
19+
// weight, eps)` in a single CANN launch. The fused API has higher host-side
20+
// launch overhead (~200 us) compared to the decomposed `aclnnAdd` +
21+
// `aclnnRmsNorm` path (~39 us), but may offer better NPU-side efficiency for
22+
// large tensors where kernel fusion reduces memory traffic.
2323
//
2424
// Select via `implementation_index=1` in Python:
2525
// infini.ops.add_rms_norm(..., implementation_index=1, stream=s)
2626
template <>
2727
class Operator<AddRmsNorm, Device::Type::kAscend, 1> : public AddRmsNorm {
2828
public:
2929
Operator(const Tensor input, const Tensor other, const Tensor weight,
30-
float eps, Tensor out, Tensor rstd_out)
31-
: AddRmsNorm(input, other, weight, eps, out, rstd_out),
30+
float eps, Tensor out, Tensor residual_out)
31+
: AddRmsNorm(input, other, weight, eps, out, residual_out),
3232
input_cache_(input),
3333
other_cache_(other),
3434
weight_cache_(weight),
3535
out_cache_(out),
36-
rstd_out_cache_(rstd_out) {
36+
residual_out_cache_(residual_out) {
3737
// `aclnnAddRmsNorm` requires `rstdOut` to have the same ndim as `input`,
3838
// with the last `weight.ndim()` dimensions set to 1. For example:
3939
// `input` (2, 32, 128), `weight` (128) -> `rstdOut` (2, 32, 1).
@@ -68,25 +68,25 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 1> : public AddRmsNorm {
6868
other_cache_.release();
6969
weight_cache_.release();
7070
out_cache_.release();
71-
rstd_out_cache_.release();
71+
residual_out_cache_.release();
7272

7373
// `rstd_tensor_` leaks with the executor at shutdown (see `64c367c`).
7474
if (rstd_data_) aclrtFree(rstd_data_);
7575
}
7676

7777
void operator()(const Tensor input, const Tensor other, const Tensor weight,
78-
float eps, Tensor out, Tensor rstd_out) const override {
78+
float eps, Tensor out, Tensor residual_out) const override {
7979
auto t_input = input_cache_.get(const_cast<void*>(input.data()));
8080
auto t_other = other_cache_.get(const_cast<void*>(other.data()));
8181
auto t_weight = weight_cache_.get(const_cast<void*>(weight.data()));
8282
auto t_out = out_cache_.get(out.data());
83-
auto t_rstd_out = rstd_out_cache_.get(rstd_out.data());
83+
auto t_residual_out = residual_out_cache_.get(residual_out.data());
8484
auto stream = static_cast<aclrtStream>(stream_);
8585

8686
if (!executor_) {
8787
aclnnAddRmsNormGetWorkspaceSize(
8888
t_input, t_other, t_weight, static_cast<double>(eps), t_out,
89-
rstd_tensor_, t_rstd_out, &ws_size_, &executor_);
89+
rstd_tensor_, t_residual_out, &ws_size_, &executor_);
9090
aclSetAclOpExecutorRepeatable(executor_);
9191
} else {
9292
aclSetInputTensorAddr(executor_, 0, t_input,
@@ -97,7 +97,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 1> : public AddRmsNorm {
9797
const_cast<void*>(weight.data()));
9898
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
9999
// `rstd` at output index 1 has a stable address — no update needed.
100-
aclSetOutputTensorAddr(executor_, 2, t_rstd_out, rstd_out.data());
100+
aclSetOutputTensorAddr(executor_, 2, t_residual_out, residual_out.data());
101101
}
102102

103103
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
@@ -113,7 +113,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 1> : public AddRmsNorm {
113113

114114
mutable ascend::AclTensorCache out_cache_;
115115

116-
mutable ascend::AclTensorCache rstd_out_cache_;
116+
mutable ascend::AclTensorCache residual_out_cache_;
117117

118118
std::vector<int64_t> fused_rstd_shape_;
119119

src/ascend/rotary_embedding/kernel_sincos_cache.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,19 @@ class Operator<RotaryEmbedding, Device::Type::kAscend, 2>
124124
auto t_q_out = q_out_cache_.get(const_cast<void*>(q_out.data()));
125125
auto t_k_out = k_out_cache_.get(const_cast<void*>(k_out.data()));
126126

127-
// Fresh executor each call: `aclnnRopeWithSinCosCache`'s public header
128-
// hides four `REG_OP` attrs (see
129-
// `aclnn_rope_with_sin_cos_cache_hidden_attrs` memory). The official
127+
// FIXME: per-call unbounded executor leak. `aclnnRopeWithSinCosCache`'s
128+
// public header hides four `REG_OP` attrs (see
129+
// `aclnn_rope_with_sin_cos_cache_hidden_attrs` memory), so the official
130130
// `aclSetInputTensorAddr` index numbering for this kernel is not
131-
// documented, so we cannot safely reuse a Repeatable executor across
132-
// calls. The async stream consumes the executor after enqueue, so
133-
// destroying it synchronously here would race with the launch — we
134-
// leak for now.
131+
// documented — we cannot safely reuse a Repeatable executor across calls.
132+
// The async stream consumes the executor after enqueue, so destroying it
133+
// synchronously here races with the launch (SIGABRT). Long-running
134+
// persistent workers (e.g. vLLM decode) accumulate one executor per
135+
// forward step until the runtime tears down.
135136
//
136-
// TODO: cache + set Repeatable once the input-address index layout is
137-
// confirmed for this kernel.
137+
// Resolve by obtaining the input-address index layout from the CANN team
138+
// (or deriving it from the binary) and switching to the cached-executor
139+
// pattern used in `kernel.h` / `kernel_atb.h`.
138140
uint64_t ws_size = 0;
139141
aclOpExecutor* executor = nullptr;
140142

src/base/add_rms_norm.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace infini::ops {
1111
class AddRmsNorm : public Operator<AddRmsNorm> {
1212
public:
1313
AddRmsNorm(const Tensor input, const Tensor other, const Tensor weight,
14-
float eps, Tensor out, Tensor rstd_out)
14+
float eps, Tensor out, Tensor residual_out)
1515
: input_shape_{input.shape()},
1616
eps_{eps},
1717
dim_{input.size(-1)},
@@ -22,13 +22,14 @@ class AddRmsNorm : public Operator<AddRmsNorm> {
2222
"`AddRmsNorm`: `input` and `other` must have the same dtype.");
2323
assert(input.dtype() == out.dtype() &&
2424
"`AddRmsNorm`: `input` and `out` must have the same dtype.");
25-
assert(input.dtype() == rstd_out.dtype() &&
26-
"`AddRmsNorm`: `input` and `rstd_out` must have the same dtype.");
25+
assert(
26+
input.dtype() == residual_out.dtype() &&
27+
"`AddRmsNorm`: `input` and `residual_out` must have the same dtype.");
2728
}
2829

2930
virtual void operator()(const Tensor input, const Tensor other,
3031
const Tensor weight, float eps, Tensor out,
31-
Tensor rstd_out) const = 0;
32+
Tensor residual_out) const = 0;
3233

3334
protected:
3435
Tensor::Shape input_shape_;

tests/test_add_rms_norm.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,43 +47,43 @@ def test_add_rms_norm(
4747
other = randn_strided(shape, strides, dtype=dtype, device=device)
4848
weight = randn_strided(weight_shape, None, dtype=dtype, device=device)
4949
out = empty_strided(shape, strides, dtype=dtype, device=device)
50-
rstd_out = empty_strided(shape, strides, dtype=dtype, device=device)
50+
residual_out = empty_strided(shape, strides, dtype=dtype, device=device)
5151

5252
return Payload(
5353
lambda *args, **kwargs: _add_rms_norm(
5454
*args, **kwargs, implementation_index=implementation_index
5555
),
5656
_torch_add_rms_norm,
5757
(input, other, weight),
58-
{"eps": eps, "out": out, "rstd_out": rstd_out},
58+
{"eps": eps, "out": out, "residual_out": residual_out},
5959
rtol=rtol,
6060
atol=atol,
6161
)
6262

6363

6464
def _add_rms_norm(
65-
input, other, weight, *, eps=1e-6, out=None, rstd_out=None, implementation_index=0
65+
input, other, weight, *, eps=1e-6, out=None, residual_out=None, implementation_index=0
6666
):
6767
infini.ops.add_rms_norm(
6868
input,
6969
other,
7070
weight,
7171
eps,
7272
out,
73-
rstd_out,
73+
residual_out,
7474
implementation_index=implementation_index,
7575
stream=get_stream(input.device),
7676
)
7777

7878
# Concatenate both outputs into a single flat tensor for `allclose` comparison.
79-
return torch.cat([out.contiguous().flatten(), rstd_out.contiguous().flatten()])
79+
return torch.cat([out.contiguous().flatten(), residual_out.contiguous().flatten()])
8080

8181

82-
def _torch_add_rms_norm(input, other, weight, *, eps=1e-6, out=None, rstd_out=None):
82+
def _torch_add_rms_norm(input, other, weight, *, eps=1e-6, out=None, residual_out=None):
8383
x_sum = input + other
8484

85-
if rstd_out is not None:
86-
rstd_out.copy_(x_sum)
85+
if residual_out is not None:
86+
residual_out.copy_(x_sum)
8787

8888
rms = torch.sqrt(
8989
torch.mean(x_sum.float() * x_sum.float(), dim=-1, keepdim=True) + eps
@@ -93,4 +93,4 @@ def _torch_add_rms_norm(input, other, weight, *, eps=1e-6, out=None, rstd_out=No
9393
if out is not None:
9494
out.copy_(y)
9595

96-
return torch.cat([out.contiguous().flatten(), rstd_out.contiguous().flatten()])
96+
return torch.cat([out.contiguous().flatten(), residual_out.contiguous().flatten()])

0 commit comments

Comments
 (0)