Skip to content

Commit e81e69e

Browse files
author
echo_stone
committed
Fix: address distributed PR hw-native-sys#307 feedback
- validate distributed buffer metadata and simplify output verification - support explicit device selection in run_example.py and ci.sh for CI - shrink treduce examples to 4 ranks, remove stale config, and guard invalid rank/root values - rename the per-rank helper to distributed_worker.py and document buffer layout Made-with: Cursor
1 parent f4147c6 commit e81e69e

13 files changed

Lines changed: 320 additions & 70 deletions

File tree

ci.sh

Lines changed: 134 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,17 @@ while [[ $# -gt 0 ]]; do
1616
shift 2
1717
;;
1818
-d|--device)
19-
DEVICE_RANGE="$2"
20-
shift 2
19+
shift
20+
DEVICE_VALUES=()
21+
while [[ $# -gt 0 && "$1" != -* ]]; do
22+
DEVICE_VALUES+=("$1")
23+
shift
24+
done
25+
if [[ ${#DEVICE_VALUES[@]} -eq 0 ]]; then
26+
echo "Missing value for --device"
27+
exit 1
28+
fi
29+
DEVICE_RANGE=$(IFS=,; echo "${DEVICE_VALUES[*]}")
2130
;;
2231
-r|--runtime)
2332
RUNTIME="$2"
@@ -78,15 +87,22 @@ if [[ -n "$RUNTIME" ]]; then
7887
fi
7988
fi
8089

81-
# Parse device range (e.g., "5-8" or "5")
82-
if [[ "$DEVICE_RANGE" == *-* ]]; then
83-
IFS='-' read -r DEV_START DEV_END <<< "$DEVICE_RANGE"
84-
DEVICES=()
85-
for ((i=DEV_START; i<=DEV_END; i++)); do
86-
DEVICES+=("$i")
87-
done
90+
# Parse device spec (e.g., "5-8", "5", or "0,1,3,5")
91+
DEVICES=()
92+
if [[ -z "$DEVICE_RANGE" ]]; then
93+
DEVICES=("0")
8894
else
89-
DEVICES=("${DEVICE_RANGE:-0}")
95+
IFS=',' read -r -a DEVICE_ITEMS <<< "$DEVICE_RANGE"
96+
for item in "${DEVICE_ITEMS[@]}"; do
97+
if [[ "$item" == *-* ]]; then
98+
IFS='-' read -r DEV_START DEV_END <<< "$item"
99+
for ((i=DEV_START; i<=DEV_END; i++)); do
100+
DEVICES+=("$i")
101+
done
102+
else
103+
DEVICES+=("$item")
104+
fi
105+
done
90106
fi
91107
NUM_DEVICES=${#DEVICES[@]}
92108

@@ -199,13 +215,48 @@ pin_pto_isa_on_failure() {
199215
return 0 # Pinned, caller should retry
200216
}
201217

218+
get_task_device_count() {
219+
local kernel_config="$1"
220+
python - "$kernel_config" <<'PY'
221+
import importlib.util
222+
import sys
223+
224+
path = sys.argv[1]
225+
spec = importlib.util.spec_from_file_location("kernel_config", path)
226+
mod = importlib.util.module_from_spec(spec)
227+
spec.loader.exec_module(mod)
228+
dist = getattr(mod, "DISTRIBUTED_CONFIG", None)
229+
nranks = 1
230+
if isinstance(dist, dict):
231+
try:
232+
nranks = int(dist.get("nranks", 1))
233+
except (TypeError, ValueError):
234+
nranks = 1
235+
print(max(nranks, 1))
236+
PY
237+
}
238+
239+
format_device_spec() {
240+
local count="$1"
241+
if [[ "$count" -le 1 ]]; then
242+
echo "${DEVICES[0]}"
243+
return 0
244+
fi
245+
246+
local selected=("${DEVICES[@]:0:$count}")
247+
local joined
248+
joined=$(IFS=,; echo "${selected[*]}")
249+
echo "$joined"
250+
}
251+
202252
# ---- Discover all tasks ----
203253
EXAMPLES_DIR="examples"
204254
DEVICE_TESTS_DIR="tests/device_tests"
205255

206256
declare -a HW_TASK_NAMES=()
207257
declare -a HW_TASK_DIRS=()
208258
declare -a HW_TASK_PLATS=()
259+
declare -a HW_TASK_DEVICE_COUNTS=()
209260
declare -a SIM_TASK_NAMES=()
210261
declare -a SIM_TASK_DIRS=()
211262
declare -a SIM_TASK_PLATS=()
@@ -245,18 +296,22 @@ while IFS= read -r -d '' example_dir; do
245296
SIM_TASK_DIRS+=("${example_dir}")
246297
SIM_TASK_PLATS+=("${PLATFORM}")
247298
else
299+
required_devices="$(get_task_device_count "$kernel_config")"
248300
HW_TASK_NAMES+=("example:${example_name}")
249301
HW_TASK_DIRS+=("${example_dir}")
250302
HW_TASK_PLATS+=("${PLATFORM}")
303+
HW_TASK_DEVICE_COUNTS+=("${required_devices}")
251304
fi
252305
elif [[ "$OS" == "Darwin" ]]; then
253306
SIM_TASK_NAMES+=("example:${example_name}")
254307
SIM_TASK_DIRS+=("${example_dir}")
255308
SIM_TASK_PLATS+=("${example_arch}sim")
256309
else
310+
required_devices="$(get_task_device_count "$kernel_config")"
257311
HW_TASK_NAMES+=("example:${example_name}")
258312
HW_TASK_DIRS+=("${example_dir}")
259313
HW_TASK_PLATS+=("${example_arch}")
314+
HW_TASK_DEVICE_COUNTS+=("${required_devices}")
260315
SIM_TASK_NAMES+=("example:${example_name}")
261316
SIM_TASK_DIRS+=("${example_dir}")
262317
SIM_TASK_PLATS+=("${example_arch}sim")
@@ -299,6 +354,7 @@ if [[ -d "$DEVICE_TESTS_DIR" ]]; then
299354
HW_TASK_NAMES+=("device_test:${test_name}")
300355
HW_TASK_DIRS+=("${test_dir}")
301356
HW_TASK_PLATS+=("${PLATFORM:-${test_arch}}")
357+
HW_TASK_DEVICE_COUNTS+=("$(get_task_device_count "$kernel_config")")
302358
done < <(find "$DEVICE_TESTS_DIR" -mindepth 1 -type d -print0 | sort -z)
303359
else
304360
echo "Skipping device tests (hardware platforms only)"
@@ -314,7 +370,7 @@ MAX_RETRIES=3
314370
# Log naming: ${safe_name}_${platform}_attempt${attempt}.log
315371
# Result format: name|platform|PASS/FAIL|device:X|attempt:N|Xs
316372
run_task() {
317-
local name="$1" dir="$2" platform="$3" attempt="$4" device_id="$5" print_log_on_fail="${6:-true}"
373+
local name="$1" dir="$2" platform="$3" attempt="$4" device_spec="$5" print_log_on_fail="${6:-true}" required_devices="${7:-1}"
318374
local safe_name="${name//[:\/]/_}"
319375
local task_log="${LOG_DIR}/${safe_name}_${platform}_attempt${attempt}.log"
320376
local start_time=$SECONDS
@@ -323,10 +379,16 @@ run_task() {
323379
cmd=(python examples/scripts/run_example.py
324380
-k "${dir}/kernels" -g "${dir}/golden.py"
325381
-p "$platform" --clone-protocol "$CLONE_PROTOCOL" "${commit_flag[@]}")
326-
[[ -n "$device_id" ]] && cmd+=(-d "$device_id")
382+
if [[ -n "$device_spec" ]]; then
383+
if [[ "$required_devices" -gt 1 ]]; then
384+
cmd+=(--devices "$device_spec" --nranks "$required_devices")
385+
else
386+
cmd+=(-d "$device_spec")
387+
fi
388+
fi
327389

328390
# Progress to stdout (not captured in log)
329-
echo "[${platform}${device_id:+:dev${device_id}}] Running: $name (attempt $attempt)"
391+
echo "[${platform}${device_spec:+:dev${device_spec}}] Running: $name (attempt $attempt)"
330392

331393
# Command output to log file only
332394
"${cmd[@]}" > "$task_log" 2>&1
@@ -336,21 +398,46 @@ run_task() {
336398
local status
337399
if [[ $rc -eq 0 ]]; then
338400
status="PASS"
339-
echo "[${platform}${device_id:+:dev${device_id}}] PASS: $name (${elapsed}s)"
401+
echo "[${platform}${device_spec:+:dev${device_spec}}] PASS: $name (${elapsed}s)"
340402
else
341403
status="FAIL"
342-
echo "[${platform}${device_id:+:dev${device_id}}] FAIL: $name (${elapsed}s)"
404+
echo "[${platform}${device_spec:+:dev${device_spec}}] FAIL: $name (${elapsed}s)"
343405
if [[ "$print_log_on_fail" == "true" ]]; then
344406
echo "--- LOG: $name (attempt $attempt) ---"
345407
cat "$task_log"
346408
echo "--- END ---"
347409
fi
348410
fi
349-
echo "${name}|${platform}|${status}|device:${device_id:-sim}|attempt:${attempt}|${elapsed}s" \
411+
echo "${name}|${platform}|${status}|device:${device_spec:-sim}|attempt:${attempt}|${elapsed}s" \
350412
>> "$RESULTS_FILE"
351413
return $rc
352414
}
353415

416+
run_hw_multidevice_tasks() {
417+
local attempt="$1"; shift
418+
local indices=("$@")
419+
HW_MULTI_FAILURES=()
420+
421+
for idx in "${indices[@]}"; do
422+
local required_devices="${HW_TASK_DEVICE_COUNTS[$idx]}"
423+
local platform="${HW_TASK_PLATS[$idx]}"
424+
local name="${HW_TASK_NAMES[$idx]}"
425+
426+
if [[ "$required_devices" -gt "$NUM_DEVICES" ]]; then
427+
echo "[${platform}] FAIL: $name requires ${required_devices} devices, only ${NUM_DEVICES} available"
428+
echo "${name}|${platform}|FAIL|device:insufficient|attempt:${attempt}|0s" >> "$RESULTS_FILE"
429+
HW_MULTI_FAILURES+=("$idx")
430+
continue
431+
fi
432+
433+
local device_spec
434+
device_spec="$(format_device_spec "$required_devices")"
435+
if ! run_task "$name" "${HW_TASK_DIRS[$idx]}" "$platform" "$attempt" "$device_spec" "true" "$required_devices"; then
436+
HW_MULTI_FAILURES+=("$idx")
437+
fi
438+
done
439+
}
440+
354441
# ---- SIM executor ----
355442
# run_sim_tasks <attempt> <idx1> <idx2> ...
356443
# Sets SIM_FAILURES to array of failed indices after return.
@@ -429,7 +516,7 @@ run_hw_tasks() {
429516

430517
IFS=':' read -r idx attempt <<< "$entry"
431518

432-
if run_task "${HW_TASK_NAMES[$idx]}" "${HW_TASK_DIRS[$idx]}" "${HW_TASK_PLATS[$idx]}" "$attempt" "$device_id" "false"; then
519+
if run_task "${HW_TASK_NAMES[$idx]}" "${HW_TASK_DIRS[$idx]}" "${HW_TASK_PLATS[$idx]}" "$attempt" "$device_id" "false" "1"; then
433520
echo "${idx}|PASS" >> "$hw_marker"
434521
else
435522
next=$((attempt + 1))
@@ -473,12 +560,36 @@ fi
473560

474561
# HW phase
475562
if [[ ${#HW_TASK_NAMES[@]} -gt 0 ]]; then
476-
ALL_HW=($(seq 0 $((${#HW_TASK_NAMES[@]} - 1))))
477-
echo "---- HW: ${#ALL_HW[@]} tasks on ${NUM_DEVICES} devices ----"
478-
run_hw_tasks "${ALL_HW[@]}"
479-
if [[ ${#HW_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
480-
echo "[CI] Retrying ${#HW_FAILURES[@]} HW failures with pinned PTO-ISA"
481-
run_hw_tasks "${HW_FAILURES[@]}"
563+
ALL_HW_SINGLE=()
564+
ALL_HW_MULTI=()
565+
for idx in $(seq 0 $((${#HW_TASK_NAMES[@]} - 1))); do
566+
if [[ "${HW_TASK_DEVICE_COUNTS[$idx]}" -gt 1 ]]; then
567+
ALL_HW_MULTI+=("$idx")
568+
else
569+
ALL_HW_SINGLE+=("$idx")
570+
fi
571+
done
572+
573+
echo "---- HW: ${#ALL_HW_SINGLE[@]} single-device tasks, ${#ALL_HW_MULTI[@]} multi-device tasks on ${NUM_DEVICES} devices ----"
574+
575+
HW_MULTI_FAILURES=()
576+
if [[ ${#ALL_HW_MULTI[@]} -gt 0 ]]; then
577+
run_hw_multidevice_tasks 0 "${ALL_HW_MULTI[@]}"
578+
if [[ ${#HW_MULTI_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
579+
echo "[CI] Retrying ${#HW_MULTI_FAILURES[@]} multi-device HW failures with pinned PTO-ISA"
580+
run_hw_multidevice_tasks 1 "${HW_MULTI_FAILURES[@]}"
581+
fi
582+
fi
583+
584+
HW_SINGLE_FAILURES=()
585+
if [[ ${#ALL_HW_SINGLE[@]} -gt 0 ]]; then
586+
run_hw_tasks "${ALL_HW_SINGLE[@]}"
587+
HW_SINGLE_FAILURES=("${HW_FAILURES[@]}")
588+
if [[ ${#HW_SINGLE_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
589+
echo "[CI] Retrying ${#HW_SINGLE_FAILURES[@]} HW failures with pinned PTO-ISA"
590+
run_hw_tasks "${HW_SINGLE_FAILURES[@]}"
591+
HW_SINGLE_FAILURES=("${HW_FAILURES[@]}")
592+
fi
482593
fi
483594
fi
484595

examples/a2a3/aicpu_build_graph/treduce_distributed/golden.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
TREDUCE_COUNT = 256
13-
NRANKS = 8
13+
NRANKS = 4
1414

1515
__outputs__ = ["output"]
1616

examples/a2a3/aicpu_build_graph/treduce_distributed/kernels/aiv/treduce_kernel.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#endif
3232

3333
static constexpr size_t TREDUCE_COUNT = 256;
34+
static constexpr int kMaxSupportedRanks = 16;
3435

3536
template <typename T>
3637
AICORE inline __gm__ T *CommRemotePtr(
@@ -70,15 +71,19 @@ void kernel_entry(__gm__ int64_t* args) {
7071
TASSIGN(accTile, 0x0);
7172
TASSIGN(recvTile, 0x10000);
7273

74+
if (nranks <= 0 || nranks > kMaxSupportedRanks || root < 0 || root >= nranks) {
75+
pipe_barrier(PIPE_ALL);
76+
return;
77+
}
78+
7379
if (my_rank == root) {
7480
Global outputG(output, shape, stride);
75-
Global tensors[16];
76-
int actual_nranks = (nranks > 16) ? 16 : nranks;
77-
for (int i = 0; i < actual_nranks; ++i) {
81+
Global tensors[kMaxSupportedRanks];
82+
for (int i = 0; i < nranks; ++i) {
7883
__gm__ float* remoteInput = CommRemotePtr(commCtx, input, i);
7984
tensors[i] = Global(remoteInput, shape, stride);
8085
}
81-
pto::comm::ParallelGroup<Global> pg(tensors, actual_nranks, root);
86+
pto::comm::ParallelGroup<Global> pg(tensors, nranks, root);
8287
pto::comm::TREDUCE(pg, outputG, accTile, recvTile,
8388
pto::comm::ReduceOp::Sum);
8489
}

examples/a2a3/aicpu_build_graph/treduce_distributed/kernels/kernel_config.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,26 @@
3333
"PTO_AICPU_BUILD_GRAPH_BUILD_MODE": "1",
3434
}
3535

36+
# Distributed layout contract consumed by DistributedCodeRunner/worker:
37+
# - win_sync_prefix reserves a small header at the front of each rank's RDMA
38+
# window before any placement="window" buffers are laid out.
39+
# - buffers declares runtime allocation metadata:
40+
# * count is the element count, not byte size.
41+
# * placement="window": buffer lives in the shared RDMA window and may be
42+
# accessed by remote ranks.
43+
# * placement="device": buffer uses regular device_malloc and is local-only.
44+
# - inputs/outputs control which buffers are loaded from .bin files and which
45+
# are copied back after execution.
46+
# - args defines the orchestration/kernel uint64_t* args order.
3647
DISTRIBUTED_CONFIG = {
37-
"nranks": 8,
48+
"nranks": 4,
3849
"root": 0,
39-
"comm_include_dirs": ["tests/npu/a2a3/comm/st/testcase"],
4050
"win_sync_prefix": 256,
4151
"buffers": [
52+
# Root rank reads every rank's input through CommRemotePtr(...), so the
53+
# input buffer must be placed in the shared RDMA window.
4254
{"name": "input", "dtype": "float32", "count": 256, "placement": "window"},
55+
# The output is produced and consumed locally on the root rank only.
4356
{"name": "output", "dtype": "float32", "count": 256, "placement": "device"},
4457
],
4558
"inputs": ["input"],

examples/a2a3/host_build_graph/treduce_distributed/golden.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
TREDUCE_COUNT = 256
13-
NRANKS = 8
13+
NRANKS = 4
1414

1515
__outputs__ = ["output"]
1616

examples/a2a3/host_build_graph/treduce_distributed/kernels/aiv/treduce_kernel.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#endif
3232

3333
static constexpr size_t TREDUCE_COUNT = 256;
34+
static constexpr int kMaxSupportedRanks = 16;
3435

3536
template <typename T>
3637
AICORE inline __gm__ T *CommRemotePtr(
@@ -41,6 +42,7 @@ AICORE inline __gm__ T *CommRemotePtr(
4142
return (__gm__ T *)(ctx->windowsIn[pe] + offset);
4243
}
4344

45+
4446
extern "C" __aicore__ __attribute__((always_inline))
4547
void kernel_entry(__gm__ int64_t* args) {
4648
__gm__ float* input = reinterpret_cast<__gm__ float*>(args[0]);
@@ -70,15 +72,19 @@ void kernel_entry(__gm__ int64_t* args) {
7072
TASSIGN(accTile, 0x0);
7173
TASSIGN(recvTile, 0x10000);
7274

75+
if (nranks <= 0 || nranks > kMaxSupportedRanks || root < 0 || root >= nranks) {
76+
pipe_barrier(PIPE_ALL);
77+
return;
78+
}
79+
7380
if (my_rank == root) {
7481
Global outputG(output, shape, stride);
75-
Global tensors[16];
76-
int actual_nranks = (nranks > 16) ? 16 : nranks;
77-
for (int i = 0; i < actual_nranks; ++i) {
82+
Global tensors[kMaxSupportedRanks];
83+
for (int i = 0; i < nranks; ++i) {
7884
__gm__ float* remoteInput = CommRemotePtr(commCtx, input, i);
7985
tensors[i] = Global(remoteInput, shape, stride);
8086
}
81-
pto::comm::ParallelGroup<Global> pg(tensors, actual_nranks, root);
87+
pto::comm::ParallelGroup<Global> pg(tensors, nranks, root);
8288
pto::comm::TREDUCE(pg, outputG, accTile, recvTile,
8389
pto::comm::ReduceOp::Sum);
8490
}

0 commit comments

Comments
 (0)