Skip to content

Commit 7a8bafd

Browse files
author
echo_stone
committed
Add: distributed comm harness and treduce examples
- add backend-agnostic `comm_*` host APIs plus a2a3/a5 hardware and sim implementations so distributed runs share one communication abstraction - add Python bindings, distributed runner orchestration, and per-rank worker support to drive multi-rank examples through `run_example.py` - add distributed treduce examples for all three runtimes and fold in the PR #307 review fixes for CI-friendly rank counts, explicit device selection, and stronger validation Made-with: Cursor
1 parent e4348eb commit 7a8bafd

30 files changed

Lines changed: 3505 additions & 26 deletions

File tree

ci.sh

Lines changed: 134 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,17 @@ while [[ $# -gt 0 ]]; do
1616
shift 2
1717
;;
1818
-d|--device)
19-
DEVICE_RANGE="$2"
20-
shift 2
19+
shift
20+
DEVICE_VALUES=()
21+
while [[ $# -gt 0 && "$1" != -* ]]; do
22+
DEVICE_VALUES+=("$1")
23+
shift
24+
done
25+
if [[ ${#DEVICE_VALUES[@]} -eq 0 ]]; then
26+
echo "Missing value for --device"
27+
exit 1
28+
fi
29+
DEVICE_RANGE=$(IFS=,; echo "${DEVICE_VALUES[*]}")
2130
;;
2231
-r|--runtime)
2332
RUNTIME="$2"
@@ -78,15 +87,22 @@ if [[ -n "$RUNTIME" ]]; then
7887
fi
7988
fi
8089

81-
# Parse device range (e.g., "5-8" or "5")
82-
if [[ "$DEVICE_RANGE" == *-* ]]; then
83-
IFS='-' read -r DEV_START DEV_END <<< "$DEVICE_RANGE"
84-
DEVICES=()
85-
for ((i=DEV_START; i<=DEV_END; i++)); do
86-
DEVICES+=("$i")
87-
done
90+
# Parse device spec (e.g., "5-8", "5", or "0,1,3,5")
91+
DEVICES=()
92+
if [[ -z "$DEVICE_RANGE" ]]; then
93+
DEVICES=("0")
8894
else
89-
DEVICES=("${DEVICE_RANGE:-0}")
95+
IFS=',' read -r -a DEVICE_ITEMS <<< "$DEVICE_RANGE"
96+
for item in "${DEVICE_ITEMS[@]}"; do
97+
if [[ "$item" == *-* ]]; then
98+
IFS='-' read -r DEV_START DEV_END <<< "$item"
99+
for ((i=DEV_START; i<=DEV_END; i++)); do
100+
DEVICES+=("$i")
101+
done
102+
else
103+
DEVICES+=("$item")
104+
fi
105+
done
90106
fi
91107
NUM_DEVICES=${#DEVICES[@]}
92108

@@ -199,13 +215,48 @@ pin_pto_isa_on_failure() {
199215
return 0 # Pinned, caller should retry
200216
}
201217

218+
get_task_device_count() {
219+
local kernel_config="$1"
220+
python - "$kernel_config" <<'PY'
221+
import importlib.util
222+
import sys
223+
224+
path = sys.argv[1]
225+
spec = importlib.util.spec_from_file_location("kernel_config", path)
226+
mod = importlib.util.module_from_spec(spec)
227+
spec.loader.exec_module(mod)
228+
dist = getattr(mod, "DISTRIBUTED_CONFIG", None)
229+
nranks = 1
230+
if isinstance(dist, dict):
231+
try:
232+
nranks = int(dist.get("nranks", 1))
233+
except (TypeError, ValueError):
234+
nranks = 1
235+
print(max(nranks, 1))
236+
PY
237+
}
238+
239+
format_device_spec() {
240+
local count="$1"
241+
if [[ "$count" -le 1 ]]; then
242+
echo "${DEVICES[0]}"
243+
return 0
244+
fi
245+
246+
local selected=("${DEVICES[@]:0:$count}")
247+
local joined
248+
joined=$(IFS=,; echo "${selected[*]}")
249+
echo "$joined"
250+
}
251+
202252
# ---- Discover all tasks ----
203253
EXAMPLES_DIR="examples"
204254
DEVICE_TESTS_DIR="tests/device_tests"
205255

206256
declare -a HW_TASK_NAMES=()
207257
declare -a HW_TASK_DIRS=()
208258
declare -a HW_TASK_PLATS=()
259+
declare -a HW_TASK_DEVICE_COUNTS=()
209260
declare -a SIM_TASK_NAMES=()
210261
declare -a SIM_TASK_DIRS=()
211262
declare -a SIM_TASK_PLATS=()
@@ -245,18 +296,22 @@ while IFS= read -r -d '' example_dir; do
245296
SIM_TASK_DIRS+=("${example_dir}")
246297
SIM_TASK_PLATS+=("${PLATFORM}")
247298
else
299+
required_devices="$(get_task_device_count "$kernel_config")"
248300
HW_TASK_NAMES+=("example:${example_name}")
249301
HW_TASK_DIRS+=("${example_dir}")
250302
HW_TASK_PLATS+=("${PLATFORM}")
303+
HW_TASK_DEVICE_COUNTS+=("${required_devices}")
251304
fi
252305
elif [[ "$OS" == "Darwin" ]]; then
253306
SIM_TASK_NAMES+=("example:${example_name}")
254307
SIM_TASK_DIRS+=("${example_dir}")
255308
SIM_TASK_PLATS+=("${example_arch}sim")
256309
else
310+
required_devices="$(get_task_device_count "$kernel_config")"
257311
HW_TASK_NAMES+=("example:${example_name}")
258312
HW_TASK_DIRS+=("${example_dir}")
259313
HW_TASK_PLATS+=("${example_arch}")
314+
HW_TASK_DEVICE_COUNTS+=("${required_devices}")
260315
SIM_TASK_NAMES+=("example:${example_name}")
261316
SIM_TASK_DIRS+=("${example_dir}")
262317
SIM_TASK_PLATS+=("${example_arch}sim")
@@ -299,6 +354,7 @@ if [[ -d "$DEVICE_TESTS_DIR" ]]; then
299354
HW_TASK_NAMES+=("device_test:${test_name}")
300355
HW_TASK_DIRS+=("${test_dir}")
301356
HW_TASK_PLATS+=("${PLATFORM:-${test_arch}}")
357+
HW_TASK_DEVICE_COUNTS+=("$(get_task_device_count "$kernel_config")")
302358
done < <(find "$DEVICE_TESTS_DIR" -mindepth 1 -type d -print0 | sort -z)
303359
else
304360
echo "Skipping device tests (hardware platforms only)"
@@ -314,7 +370,7 @@ MAX_RETRIES=3
314370
# Log naming: ${safe_name}_${platform}_attempt${attempt}.log
315371
# Result format: name|platform|PASS/FAIL|device:X|attempt:N|Xs
316372
run_task() {
317-
local name="$1" dir="$2" platform="$3" attempt="$4" device_id="$5" print_log_on_fail="${6:-true}"
373+
local name="$1" dir="$2" platform="$3" attempt="$4" device_spec="$5" print_log_on_fail="${6:-true}" required_devices="${7:-1}"
318374
local safe_name="${name//[:\/]/_}"
319375
local task_log="${LOG_DIR}/${safe_name}_${platform}_attempt${attempt}.log"
320376
local start_time=$SECONDS
@@ -323,10 +379,16 @@ run_task() {
323379
cmd=(python examples/scripts/run_example.py
324380
-k "${dir}/kernels" -g "${dir}/golden.py"
325381
-p "$platform" --clone-protocol "$CLONE_PROTOCOL" "${commit_flag[@]}")
326-
[[ -n "$device_id" ]] && cmd+=(-d "$device_id")
382+
if [[ -n "$device_spec" ]]; then
383+
if [[ "$required_devices" -gt 1 ]]; then
384+
cmd+=(--devices "$device_spec" --nranks "$required_devices")
385+
else
386+
cmd+=(-d "$device_spec")
387+
fi
388+
fi
327389

328390
# Progress to stdout (not captured in log)
329-
echo "[${platform}${device_id:+:dev${device_id}}] Running: $name (attempt $attempt)"
391+
echo "[${platform}${device_spec:+:dev${device_spec}}] Running: $name (attempt $attempt)"
330392

331393
# Command output to log file only
332394
"${cmd[@]}" > "$task_log" 2>&1
@@ -336,21 +398,46 @@ run_task() {
336398
local status
337399
if [[ $rc -eq 0 ]]; then
338400
status="PASS"
339-
echo "[${platform}${device_id:+:dev${device_id}}] PASS: $name (${elapsed}s)"
401+
echo "[${platform}${device_spec:+:dev${device_spec}}] PASS: $name (${elapsed}s)"
340402
else
341403
status="FAIL"
342-
echo "[${platform}${device_id:+:dev${device_id}}] FAIL: $name (${elapsed}s)"
404+
echo "[${platform}${device_spec:+:dev${device_spec}}] FAIL: $name (${elapsed}s)"
343405
if [[ "$print_log_on_fail" == "true" ]]; then
344406
echo "--- LOG: $name (attempt $attempt) ---"
345407
cat "$task_log"
346408
echo "--- END ---"
347409
fi
348410
fi
349-
echo "${name}|${platform}|${status}|device:${device_id:-sim}|attempt:${attempt}|${elapsed}s" \
411+
echo "${name}|${platform}|${status}|device:${device_spec:-sim}|attempt:${attempt}|${elapsed}s" \
350412
>> "$RESULTS_FILE"
351413
return $rc
352414
}
353415

416+
run_hw_multidevice_tasks() {
417+
local attempt="$1"; shift
418+
local indices=("$@")
419+
HW_MULTI_FAILURES=()
420+
421+
for idx in "${indices[@]}"; do
422+
local required_devices="${HW_TASK_DEVICE_COUNTS[$idx]}"
423+
local platform="${HW_TASK_PLATS[$idx]}"
424+
local name="${HW_TASK_NAMES[$idx]}"
425+
426+
if [[ "$required_devices" -gt "$NUM_DEVICES" ]]; then
427+
echo "[${platform}] FAIL: $name requires ${required_devices} devices, only ${NUM_DEVICES} available"
428+
echo "${name}|${platform}|FAIL|device:insufficient|attempt:${attempt}|0s" >> "$RESULTS_FILE"
429+
HW_MULTI_FAILURES+=("$idx")
430+
continue
431+
fi
432+
433+
local device_spec
434+
device_spec="$(format_device_spec "$required_devices")"
435+
if ! run_task "$name" "${HW_TASK_DIRS[$idx]}" "$platform" "$attempt" "$device_spec" "true" "$required_devices"; then
436+
HW_MULTI_FAILURES+=("$idx")
437+
fi
438+
done
439+
}
440+
354441
# ---- SIM executor ----
355442
# run_sim_tasks <attempt> <idx1> <idx2> ...
356443
# Sets SIM_FAILURES to array of failed indices after return.
@@ -429,7 +516,7 @@ run_hw_tasks() {
429516

430517
IFS=':' read -r idx attempt <<< "$entry"
431518

432-
if run_task "${HW_TASK_NAMES[$idx]}" "${HW_TASK_DIRS[$idx]}" "${HW_TASK_PLATS[$idx]}" "$attempt" "$device_id" "false"; then
519+
if run_task "${HW_TASK_NAMES[$idx]}" "${HW_TASK_DIRS[$idx]}" "${HW_TASK_PLATS[$idx]}" "$attempt" "$device_id" "false" "1"; then
433520
echo "${idx}|PASS" >> "$hw_marker"
434521
else
435522
next=$((attempt + 1))
@@ -473,12 +560,36 @@ fi
473560

474561
# HW phase
475562
if [[ ${#HW_TASK_NAMES[@]} -gt 0 ]]; then
476-
ALL_HW=($(seq 0 $((${#HW_TASK_NAMES[@]} - 1))))
477-
echo "---- HW: ${#ALL_HW[@]} tasks on ${NUM_DEVICES} devices ----"
478-
run_hw_tasks "${ALL_HW[@]}"
479-
if [[ ${#HW_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
480-
echo "[CI] Retrying ${#HW_FAILURES[@]} HW failures with pinned PTO-ISA"
481-
run_hw_tasks "${HW_FAILURES[@]}"
563+
ALL_HW_SINGLE=()
564+
ALL_HW_MULTI=()
565+
for idx in $(seq 0 $((${#HW_TASK_NAMES[@]} - 1))); do
566+
if [[ "${HW_TASK_DEVICE_COUNTS[$idx]}" -gt 1 ]]; then
567+
ALL_HW_MULTI+=("$idx")
568+
else
569+
ALL_HW_SINGLE+=("$idx")
570+
fi
571+
done
572+
573+
echo "---- HW: ${#ALL_HW_SINGLE[@]} single-device tasks, ${#ALL_HW_MULTI[@]} multi-device tasks on ${NUM_DEVICES} devices ----"
574+
575+
HW_MULTI_FAILURES=()
576+
if [[ ${#ALL_HW_MULTI[@]} -gt 0 ]]; then
577+
run_hw_multidevice_tasks 0 "${ALL_HW_MULTI[@]}"
578+
if [[ ${#HW_MULTI_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
579+
echo "[CI] Retrying ${#HW_MULTI_FAILURES[@]} multi-device HW failures with pinned PTO-ISA"
580+
run_hw_multidevice_tasks 1 "${HW_MULTI_FAILURES[@]}"
581+
fi
582+
fi
583+
584+
HW_SINGLE_FAILURES=()
585+
if [[ ${#ALL_HW_SINGLE[@]} -gt 0 ]]; then
586+
run_hw_tasks "${ALL_HW_SINGLE[@]}"
587+
HW_SINGLE_FAILURES=("${HW_FAILURES[@]}")
588+
if [[ ${#HW_SINGLE_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
589+
echo "[CI] Retrying ${#HW_SINGLE_FAILURES[@]} HW failures with pinned PTO-ISA"
590+
run_hw_tasks "${HW_SINGLE_FAILURES[@]}"
591+
HW_SINGLE_FAILURES=("${HW_FAILURES[@]}")
592+
fi
482593
fi
483594
fi
484595

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
Golden script for distributed TREDUCE.
3+
4+
Each rank r contributes input[i] = i + r * 100 for i in [0, 256).
5+
Root rank reduces (Sum) all inputs.
6+
7+
Expected output on root:
8+
output[i] = sum_{r=0}^{nranks-1} (i + r * 100)
9+
= nranks * i + 100 * nranks * (nranks - 1) / 2
10+
"""
11+
12+
TREDUCE_COUNT = 256
13+
NRANKS = 4
14+
15+
__outputs__ = ["output"]
16+
17+
RTOL = 1e-5
18+
ATOL = 1e-5
19+
20+
21+
def generate_distributed_inputs(rank: int, nranks: int, root: int,
22+
comm_ctx=None) -> list:
23+
"""Each rank generates its own input; output is allocated on all ranks."""
24+
input_data = [float(i + rank * 100) for i in range(TREDUCE_COUNT)]
25+
output_data = [0.0] * TREDUCE_COUNT
26+
return [
27+
("input", input_data),
28+
("output", output_data),
29+
("nranks", nranks),
30+
("root", root),
31+
]
32+
33+
34+
def compute_golden(tensors: dict, params: dict) -> None:
35+
"""Compute expected output for the root rank."""
36+
nranks = params.get("nranks", NRANKS)
37+
output = tensors["output"]
38+
for i in range(TREDUCE_COUNT):
39+
output[i] = float(
40+
nranks * i + 100 * nranks * (nranks - 1) // 2)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/**
2+
* TREDUCE kernel for simpler's kernel_entry signature.
3+
*
4+
* Performs collective reduce (Sum) across multiple NPU ranks using PTO comm
5+
* instructions. Each rank's input data resides in an RDMA window;
6+
* the root rank gathers and sums all inputs into the output buffer.
7+
*
8+
* PTO communication instructions access remote data through GVA addresses
9+
* (windowsIn[]) via MTE2 DMA over HCCS; no bound stream is required.
10+
*
11+
* args layout (all uint64_t, cast as needed):
12+
* args[0] = __gm__ float* input (device addr in RDMA window)
13+
* args[1] = __gm__ float* output (device addr, regular allocation)
14+
* args[2] = int nranks
15+
* args[3] = int root
16+
* args[4] = __gm__ CommDeviceContext* ctx (device addr)
17+
*/
18+
19+
#include <cstdint>
20+
#include <pto/pto-inst.hpp>
21+
#include "pto/comm/comm_types.hpp"
22+
#include "pto/comm/pto_comm_inst.hpp"
23+
#include "common/comm_context.h"
24+
25+
#ifndef __gm__
26+
#define __gm__
27+
#endif
28+
29+
#ifndef __aicore__
30+
#define __aicore__ [aicore]
31+
#endif
32+
33+
static constexpr size_t TREDUCE_COUNT = 256;
34+
static constexpr int kMaxSupportedRanks = 16;
35+
36+
template <typename T>
37+
AICORE inline __gm__ T *CommRemotePtr(
38+
__gm__ CommDeviceContext *ctx, __gm__ T *localPtr, int pe)
39+
{
40+
uint64_t localBase = ctx->windowsIn[ctx->rankId];
41+
uint64_t offset = (uint64_t)localPtr - localBase;
42+
return (__gm__ T *)(ctx->windowsIn[pe] + offset);
43+
}
44+
45+
extern "C" __aicore__ __attribute__((always_inline))
46+
void kernel_entry(__gm__ int64_t* args) {
47+
__gm__ float* input = reinterpret_cast<__gm__ float*>(args[0]);
48+
__gm__ float* output = reinterpret_cast<__gm__ float*>(args[1]);
49+
int nranks = static_cast<int>(args[2]);
50+
int root = static_cast<int>(args[3]);
51+
__gm__ CommDeviceContext* commCtx =
52+
reinterpret_cast<__gm__ CommDeviceContext*>(args[4]);
53+
54+
using ShapeDyn = pto::Shape<pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC,
55+
pto::DYNAMIC, pto::DYNAMIC>;
56+
using StrideDyn = pto::Stride<pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC,
57+
pto::DYNAMIC, pto::DYNAMIC>;
58+
using Global = pto::GlobalTensor<float, ShapeDyn, StrideDyn,
59+
pto::Layout::ND>;
60+
using TileData = pto::Tile<pto::TileType::Vec, float, 1, TREDUCE_COUNT,
61+
pto::BLayout::RowMajor, -1, -1>;
62+
63+
int my_rank = static_cast<int>(commCtx->rankId);
64+
65+
ShapeDyn shape(1, 1, 1, 1, TREDUCE_COUNT);
66+
StrideDyn stride(TREDUCE_COUNT, TREDUCE_COUNT, TREDUCE_COUNT,
67+
TREDUCE_COUNT, 1);
68+
69+
TileData accTile(1, TREDUCE_COUNT);
70+
TileData recvTile(1, TREDUCE_COUNT);
71+
TASSIGN(accTile, 0x0);
72+
TASSIGN(recvTile, 0x10000);
73+
74+
if (nranks <= 0 || nranks > kMaxSupportedRanks || root < 0 || root >= nranks) {
75+
pipe_barrier(PIPE_ALL);
76+
return;
77+
}
78+
79+
if (my_rank == root) {
80+
Global outputG(output, shape, stride);
81+
Global tensors[kMaxSupportedRanks];
82+
for (int i = 0; i < nranks; ++i) {
83+
__gm__ float* remoteInput = CommRemotePtr(commCtx, input, i);
84+
tensors[i] = Global(remoteInput, shape, stride);
85+
}
86+
pto::comm::ParallelGroup<Global> pg(tensors, nranks, root);
87+
pto::comm::TREDUCE(pg, outputG, accTile, recvTile,
88+
pto::comm::ReduceOp::Sum);
89+
}
90+
91+
pipe_barrier(PIPE_ALL);
92+
}

0 commit comments

Comments
 (0)