Skip to content

Commit 8bf7b50

Browse files
committed
Add: distributed comm harness and treduce examples
- add backend-agnostic `comm_*` host APIs plus a2a3/a5 hardware and sim implementations so distributed runs share one communication abstraction - add Python bindings, distributed runner orchestration, and per-rank worker support to drive multi-rank examples through `run_example.py` - add distributed treduce examples for all three runtimes and fold in the PR hw-native-sys#307 review fixes for CI-friendly rank counts, explicit device selection, and stronger validation Made-with: Cursor
1 parent 915f7b5 commit 8bf7b50

30 files changed

Lines changed: 4129 additions & 24 deletions

File tree

ci.sh

Lines changed: 134 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,17 @@ while [[ $# -gt 0 ]]; do
2424
shift 2
2525
;;
2626
-d|--device)
27-
DEVICE_RANGE="$2"
28-
shift 2
27+
shift
28+
DEVICE_VALUES=()
29+
while [[ $# -gt 0 && "$1" != -* ]]; do
30+
DEVICE_VALUES+=("$1")
31+
shift
32+
done
33+
if [[ ${#DEVICE_VALUES[@]} -eq 0 ]]; then
34+
echo "Missing value for --device"
35+
exit 1
36+
fi
37+
DEVICE_RANGE=$(IFS=,; echo "${DEVICE_VALUES[*]}")
2938
;;
3039
-r|--runtime)
3140
RUNTIME="$2"
@@ -86,15 +95,22 @@ if [[ -n "$RUNTIME" ]]; then
8695
fi
8796
fi
8897

89-
# Parse device range (e.g., "5-8" or "5")
90-
if [[ "$DEVICE_RANGE" == *-* ]]; then
91-
IFS='-' read -r DEV_START DEV_END <<< "$DEVICE_RANGE"
92-
DEVICES=()
93-
for ((i=DEV_START; i<=DEV_END; i++)); do
94-
DEVICES+=("$i")
95-
done
98+
# Parse device spec (e.g., "5-8", "5", or "0,1,3,5")
99+
DEVICES=()
100+
if [[ -z "$DEVICE_RANGE" ]]; then
101+
DEVICES=("0")
96102
else
97-
DEVICES=("${DEVICE_RANGE:-0}")
103+
IFS=',' read -r -a DEVICE_ITEMS <<< "$DEVICE_RANGE"
104+
for item in "${DEVICE_ITEMS[@]}"; do
105+
if [[ "$item" == *-* ]]; then
106+
IFS='-' read -r DEV_START DEV_END <<< "$item"
107+
for ((i=DEV_START; i<=DEV_END; i++)); do
108+
DEVICES+=("$i")
109+
done
110+
else
111+
DEVICES+=("$item")
112+
fi
113+
done
98114
fi
99115
NUM_DEVICES=${#DEVICES[@]}
100116

@@ -201,13 +217,48 @@ pin_pto_isa_on_failure() {
201217
return 0 # Pinned, caller should retry
202218
}
203219

220+
get_task_device_count() {
221+
local kernel_config="$1"
222+
python - "$kernel_config" <<'PY'
223+
import importlib.util
224+
import sys
225+
226+
path = sys.argv[1]
227+
spec = importlib.util.spec_from_file_location("kernel_config", path)
228+
mod = importlib.util.module_from_spec(spec)
229+
spec.loader.exec_module(mod)
230+
dist = getattr(mod, "DISTRIBUTED_CONFIG", None)
231+
nranks = 1
232+
if isinstance(dist, dict):
233+
try:
234+
nranks = int(dist.get("nranks", 1))
235+
except (TypeError, ValueError):
236+
nranks = 1
237+
print(max(nranks, 1))
238+
PY
239+
}
240+
241+
format_device_spec() {
242+
local count="$1"
243+
if [[ "$count" -le 1 ]]; then
244+
echo "${DEVICES[0]}"
245+
return 0
246+
fi
247+
248+
local selected=("${DEVICES[@]:0:$count}")
249+
local joined
250+
joined=$(IFS=,; echo "${selected[*]}")
251+
echo "$joined"
252+
}
253+
204254
# ---- Discover all tasks ----
205255
EXAMPLES_DIR="examples"
206256
DEVICE_TESTS_DIR="tests/st"
207257

208258
declare -a HW_TASK_NAMES=()
209259
declare -a HW_TASK_DIRS=()
210260
declare -a HW_TASK_PLATS=()
261+
declare -a HW_TASK_DEVICE_COUNTS=()
211262
declare -a SIM_TASK_NAMES=()
212263
declare -a SIM_TASK_DIRS=()
213264
declare -a SIM_TASK_PLATS=()
@@ -247,18 +298,22 @@ while IFS= read -r -d '' example_dir; do
247298
SIM_TASK_DIRS+=("${example_dir}")
248299
SIM_TASK_PLATS+=("${PLATFORM}")
249300
else
301+
required_devices="$(get_task_device_count "$kernel_config")"
250302
HW_TASK_NAMES+=("example:${example_name}")
251303
HW_TASK_DIRS+=("${example_dir}")
252304
HW_TASK_PLATS+=("${PLATFORM}")
305+
HW_TASK_DEVICE_COUNTS+=("${required_devices}")
253306
fi
254307
elif [[ "$OS" == "Darwin" ]]; then
255308
SIM_TASK_NAMES+=("example:${example_name}")
256309
SIM_TASK_DIRS+=("${example_dir}")
257310
SIM_TASK_PLATS+=("${example_arch}sim")
258311
else
312+
required_devices="$(get_task_device_count "$kernel_config")"
259313
HW_TASK_NAMES+=("example:${example_name}")
260314
HW_TASK_DIRS+=("${example_dir}")
261315
HW_TASK_PLATS+=("${example_arch}")
316+
HW_TASK_DEVICE_COUNTS+=("${required_devices}")
262317
SIM_TASK_NAMES+=("example:${example_name}")
263318
SIM_TASK_DIRS+=("${example_dir}")
264319
SIM_TASK_PLATS+=("${example_arch}sim")
@@ -301,6 +356,7 @@ if [[ -d "$DEVICE_TESTS_DIR" ]]; then
301356
HW_TASK_NAMES+=("device_test:${test_name}")
302357
HW_TASK_DIRS+=("${test_dir}")
303358
HW_TASK_PLATS+=("${PLATFORM:-${test_arch}}")
359+
HW_TASK_DEVICE_COUNTS+=("$(get_task_device_count "$kernel_config")")
304360
done < <(find "$DEVICE_TESTS_DIR" -mindepth 1 -type d -print0 | sort -z)
305361
else
306362
echo "Skipping device tests (hardware platforms only)"
@@ -317,7 +373,7 @@ MAX_RETRIES=3
317373
# Log naming: ${safe_name}_${platform}_attempt${attempt}.log
318374
# Result format: name|platform|PASS/FAIL|device:X|attempt:N|Xs
319375
run_task() {
320-
local name="$1" dir="$2" platform="$3" attempt="$4" device_id="$5" print_log_on_fail="${6:-true}"
376+
local name="$1" dir="$2" platform="$3" attempt="$4" device_spec="$5" print_log_on_fail="${6:-true}" required_devices="${7:-1}"
321377
local safe_name="${name//[:\/]/_}"
322378
local task_log="${LOG_DIR}/${safe_name}_${platform}_attempt${attempt}.log"
323379
local start_time=$SECONDS
@@ -326,10 +382,16 @@ run_task() {
326382
cmd=(env PYTHONDONTWRITEBYTECODE=1 python examples/scripts/run_example.py
327383
-k "${dir}/kernels" -g "${dir}/golden.py"
328384
-p "$platform" --clone-protocol "$CLONE_PROTOCOL" "${commit_flag[@]}")
329-
[[ -n "$device_id" ]] && cmd+=(-d "$device_id")
385+
if [[ -n "$device_spec" ]]; then
386+
if [[ "$required_devices" -gt 1 ]]; then
387+
cmd+=(--devices "$device_spec" --nranks "$required_devices")
388+
else
389+
cmd+=(-d "$device_spec")
390+
fi
391+
fi
330392

331393
# Progress to stdout (not captured in log)
332-
echo "[${platform}${device_id:+:dev${device_id}}] Running: $name (attempt $attempt)"
394+
echo "[${platform}${device_spec:+:dev${device_spec}}] Running: $name (attempt $attempt)"
333395

334396
# Command output to log file only
335397
if [[ "$platform" == "a5" && -n "$device_id" ]]; then
@@ -352,21 +414,46 @@ run_task() {
352414
local status
353415
if [[ $rc -eq 0 ]]; then
354416
status="PASS"
355-
echo "[${platform}${device_id:+:dev${device_id}}] PASS: $name (${elapsed}s)"
417+
echo "[${platform}${device_spec:+:dev${device_spec}}] PASS: $name (${elapsed}s)"
356418
else
357419
status="FAIL"
358-
echo "[${platform}${device_id:+:dev${device_id}}] FAIL: $name (${elapsed}s)"
420+
echo "[${platform}${device_spec:+:dev${device_spec}}] FAIL: $name (${elapsed}s)"
359421
if [[ "$print_log_on_fail" == "true" ]]; then
360422
echo "--- LOG: $name (attempt $attempt) ---"
361423
cat "$task_log"
362424
echo "--- END ---"
363425
fi
364426
fi
365-
echo "${name}|${platform}|${status}|device:${device_id:-sim}|attempt:${attempt}|${elapsed}s" \
427+
echo "${name}|${platform}|${status}|device:${device_spec:-sim}|attempt:${attempt}|${elapsed}s" \
366428
>> "$RESULTS_FILE"
367429
return $rc
368430
}
369431

432+
run_hw_multidevice_tasks() {
433+
local attempt="$1"; shift
434+
local indices=("$@")
435+
HW_MULTI_FAILURES=()
436+
437+
for idx in "${indices[@]}"; do
438+
local required_devices="${HW_TASK_DEVICE_COUNTS[$idx]}"
439+
local platform="${HW_TASK_PLATS[$idx]}"
440+
local name="${HW_TASK_NAMES[$idx]}"
441+
442+
if [[ "$required_devices" -gt "$NUM_DEVICES" ]]; then
443+
echo "[${platform}] FAIL: $name requires ${required_devices} devices, only ${NUM_DEVICES} available"
444+
echo "${name}|${platform}|FAIL|device:insufficient|attempt:${attempt}|0s" >> "$RESULTS_FILE"
445+
HW_MULTI_FAILURES+=("$idx")
446+
continue
447+
fi
448+
449+
local device_spec
450+
device_spec="$(format_device_spec "$required_devices")"
451+
if ! run_task "$name" "${HW_TASK_DIRS[$idx]}" "$platform" "$attempt" "$device_spec" "true" "$required_devices"; then
452+
HW_MULTI_FAILURES+=("$idx")
453+
fi
454+
done
455+
}
456+
370457
# ---- SIM executor ----
371458
# run_sim_tasks <attempt> <idx1> <idx2> ...
372459
# Sets SIM_FAILURES to array of failed indices after return.
@@ -445,7 +532,7 @@ run_hw_tasks() {
445532

446533
IFS=':' read -r idx attempt <<< "$entry"
447534

448-
if run_task "${HW_TASK_NAMES[$idx]}" "${HW_TASK_DIRS[$idx]}" "${HW_TASK_PLATS[$idx]}" "$attempt" "$device_id" "false"; then
535+
if run_task "${HW_TASK_NAMES[$idx]}" "${HW_TASK_DIRS[$idx]}" "${HW_TASK_PLATS[$idx]}" "$attempt" "$device_id" "false" "1"; then
449536
echo "${idx}|PASS" >> "$hw_marker"
450537
else
451538
next=$((attempt + 1))
@@ -495,12 +582,36 @@ fi
495582

496583
# HW phase
497584
if [[ ${#HW_TASK_NAMES[@]} -gt 0 ]]; then
498-
ALL_HW=($(seq 0 $((${#HW_TASK_NAMES[@]} - 1))))
499-
echo "---- HW: ${#ALL_HW[@]} tasks on ${NUM_DEVICES} devices ----"
500-
run_hw_tasks "${ALL_HW[@]}"
501-
if [[ ${#HW_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
502-
echo "[CI] Retrying ${#HW_FAILURES[@]} HW failures with pinned PTO-ISA"
503-
run_hw_tasks "${HW_FAILURES[@]}"
585+
ALL_HW_SINGLE=()
586+
ALL_HW_MULTI=()
587+
for idx in $(seq 0 $((${#HW_TASK_NAMES[@]} - 1))); do
588+
if [[ "${HW_TASK_DEVICE_COUNTS[$idx]}" -gt 1 ]]; then
589+
ALL_HW_MULTI+=("$idx")
590+
else
591+
ALL_HW_SINGLE+=("$idx")
592+
fi
593+
done
594+
595+
echo "---- HW: ${#ALL_HW_SINGLE[@]} single-device tasks, ${#ALL_HW_MULTI[@]} multi-device tasks on ${NUM_DEVICES} devices ----"
596+
597+
HW_MULTI_FAILURES=()
598+
if [[ ${#ALL_HW_MULTI[@]} -gt 0 ]]; then
599+
run_hw_multidevice_tasks 0 "${ALL_HW_MULTI[@]}"
600+
if [[ ${#HW_MULTI_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
601+
echo "[CI] Retrying ${#HW_MULTI_FAILURES[@]} multi-device HW failures with pinned PTO-ISA"
602+
run_hw_multidevice_tasks 1 "${HW_MULTI_FAILURES[@]}"
603+
fi
604+
fi
605+
606+
HW_SINGLE_FAILURES=()
607+
if [[ ${#ALL_HW_SINGLE[@]} -gt 0 ]]; then
608+
run_hw_tasks "${ALL_HW_SINGLE[@]}"
609+
HW_SINGLE_FAILURES=("${HW_FAILURES[@]}")
610+
if [[ ${#HW_SINGLE_FAILURES[@]} -gt 0 ]] && pin_pto_isa_on_failure; then
611+
echo "[CI] Retrying ${#HW_SINGLE_FAILURES[@]} HW failures with pinned PTO-ISA"
612+
run_hw_tasks "${HW_SINGLE_FAILURES[@]}"
613+
HW_SINGLE_FAILURES=("${HW_FAILURES[@]}")
614+
fi
504615
fi
505616
fi
506617

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
Golden script for distributed TREDUCE.
3+
4+
Each rank r contributes input[i] = i + r * 100 for i in [0, 256).
5+
Root rank reduces (Sum) all inputs.
6+
7+
Expected output on root:
8+
output[i] = sum_{r=0}^{nranks-1} (i + r * 100)
9+
= nranks * i + 100 * nranks * (nranks - 1) / 2
10+
"""
11+
12+
TREDUCE_COUNT = 256
13+
NRANKS = 4
14+
15+
__outputs__ = ["output"]
16+
17+
RTOL = 1e-5
18+
ATOL = 1e-5
19+
20+
21+
def generate_distributed_inputs(rank: int, nranks: int, root: int,
22+
comm_ctx=None) -> list:
23+
"""Each rank generates its own input; output is allocated on all ranks."""
24+
input_data = [float(i + rank * 100) for i in range(TREDUCE_COUNT)]
25+
output_data = [0.0] * TREDUCE_COUNT
26+
return [
27+
("input", input_data),
28+
("output", output_data),
29+
("nranks", nranks),
30+
("root", root),
31+
]
32+
33+
34+
def compute_golden(tensors: dict, params: dict) -> None:
35+
"""Compute expected output for the root rank."""
36+
nranks = params.get("nranks", NRANKS)
37+
output = tensors["output"]
38+
for i in range(TREDUCE_COUNT):
39+
output[i] = float(
40+
nranks * i + 100 * nranks * (nranks - 1) // 2)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/**
2+
* TREDUCE kernel for simpler's kernel_entry signature.
3+
*
4+
* Performs collective reduce (Sum) across multiple NPU ranks using PTO comm
5+
* instructions. Each rank's input data resides in an RDMA window;
6+
* the root rank gathers and sums all inputs into the output buffer.
7+
*
8+
* PTO communication instructions access remote data through GVA addresses
9+
* (windowsIn[]) via MTE2 DMA over HCCS; no bound stream is required.
10+
*
11+
* args layout (all uint64_t, cast as needed):
12+
* args[0] = __gm__ float* input (device addr in RDMA window)
13+
* args[1] = __gm__ float* output (device addr, regular allocation)
14+
* args[2] = int nranks
15+
* args[3] = int root
16+
* args[4] = __gm__ CommDeviceContext* ctx (device addr)
17+
*/
18+
19+
#include <cstdint>
20+
#include <pto/pto-inst.hpp>
21+
#include "pto/comm/comm_types.hpp"
22+
#include "pto/comm/pto_comm_inst.hpp"
23+
#include "common/comm_context.h"
24+
25+
#ifndef __gm__
26+
#define __gm__
27+
#endif
28+
29+
#ifndef __aicore__
30+
#define __aicore__ [aicore]
31+
#endif
32+
33+
static constexpr size_t TREDUCE_COUNT = 256;
34+
static constexpr int kMaxSupportedRanks = 16;
35+
36+
template <typename T>
37+
AICORE inline __gm__ T *CommRemotePtr(
38+
__gm__ CommDeviceContext *ctx, __gm__ T *localPtr, int pe)
39+
{
40+
uint64_t localBase = ctx->windowsIn[ctx->rankId];
41+
uint64_t offset = (uint64_t)localPtr - localBase;
42+
return (__gm__ T *)(ctx->windowsIn[pe] + offset);
43+
}
44+
45+
extern "C" __aicore__ __attribute__((always_inline))
46+
void kernel_entry(__gm__ int64_t* args) {
47+
__gm__ float* input = reinterpret_cast<__gm__ float*>(args[0]);
48+
__gm__ float* output = reinterpret_cast<__gm__ float*>(args[1]);
49+
int nranks = static_cast<int>(args[2]);
50+
int root = static_cast<int>(args[3]);
51+
__gm__ CommDeviceContext* commCtx =
52+
reinterpret_cast<__gm__ CommDeviceContext*>(args[4]);
53+
54+
using ShapeDyn = pto::Shape<pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC,
55+
pto::DYNAMIC, pto::DYNAMIC>;
56+
using StrideDyn = pto::Stride<pto::DYNAMIC, pto::DYNAMIC, pto::DYNAMIC,
57+
pto::DYNAMIC, pto::DYNAMIC>;
58+
using Global = pto::GlobalTensor<float, ShapeDyn, StrideDyn,
59+
pto::Layout::ND>;
60+
using TileData = pto::Tile<pto::TileType::Vec, float, 1, TREDUCE_COUNT,
61+
pto::BLayout::RowMajor, -1, -1>;
62+
63+
int my_rank = static_cast<int>(commCtx->rankId);
64+
65+
ShapeDyn shape(1, 1, 1, 1, TREDUCE_COUNT);
66+
StrideDyn stride(TREDUCE_COUNT, TREDUCE_COUNT, TREDUCE_COUNT,
67+
TREDUCE_COUNT, 1);
68+
69+
TileData accTile(1, TREDUCE_COUNT);
70+
TileData recvTile(1, TREDUCE_COUNT);
71+
TASSIGN(accTile, 0x0);
72+
TASSIGN(recvTile, 0x10000);
73+
74+
if (nranks <= 0 || nranks > kMaxSupportedRanks || root < 0 || root >= nranks) {
75+
pipe_barrier(PIPE_ALL);
76+
return;
77+
}
78+
79+
if (my_rank == root) {
80+
Global outputG(output, shape, stride);
81+
Global tensors[kMaxSupportedRanks];
82+
for (int i = 0; i < nranks; ++i) {
83+
__gm__ float* remoteInput = CommRemotePtr(commCtx, input, i);
84+
tensors[i] = Global(remoteInput, shape, stride);
85+
}
86+
pto::comm::ParallelGroup<Global> pg(tensors, nranks, root);
87+
pto::comm::TREDUCE(pg, outputG, accTile, recvTile,
88+
pto::comm::ReduceOp::Sum);
89+
}
90+
91+
pipe_barrier(PIPE_ALL);
92+
}

0 commit comments

Comments
 (0)