From 7094dd0be9165af2b9d8b1ede45a88225fa30ea0 Mon Sep 17 00:00:00 2001 From: PKUZHOU <751722308@qq.com> Date: Sun, 12 Apr 2026 16:33:00 +0800 Subject: [PATCH 1/2] add distributed (sync+async) execution --- .../async_completion_demo/golden.py | 73 +++ .../kernels/aiv/kernel_consumer.cpp | 66 ++ .../kernels/aiv/kernel_producer.cpp | 90 +++ .../kernels/aiv/kernel_producer_async.cpp | 89 +++ .../kernels/kernel_config.py | 56 ++ .../async_demo_orchestration.cpp | 142 +++++ .../async_notify_demo/golden.py | 70 +++ .../kernels/aiv/kernel_consumer.cpp | 72 +++ .../kernels/aiv/kernel_notify_wait.cpp | 46 ++ .../kernels/aiv/kernel_producer_notify.cpp | 103 ++++ .../kernels/kernel_config.py | 50 ++ .../async_notify_orchestration.cpp | 88 +++ examples/scripts/README.md | 117 +++- examples/scripts/code_runner.py | 499 ++++++++++++++- examples/scripts/run_example.py | 65 +- python/bindings/CMakeLists.txt | 1 + python/bindings/dist_worker_bind.h | 122 +++- python/bindings/task_interface.cpp | 35 +- python/task_interface.py | 268 +++++++++ python/worker.py | 566 +++++++++++++++--- .../include/aicore/pto_async_backend_kernel.h | 90 +++ .../platform/include/common/comm_context.h | 30 + src/a2a3/platform/include/host/comm.h | 102 ++++ src/a2a3/platform/onboard/host/CMakeLists.txt | 34 ++ src/a2a3/platform/onboard/host/comm_hccl.cpp | 525 ++++++++++++++++ .../platform/onboard/host/device_runner.cpp | 10 + .../onboard/host/pto_runtime_c_api.cpp | 101 ++++ src/a2a3/platform/sim/host/CMakeLists.txt | 2 + src/a2a3/platform/sim/host/comm_sim.cpp | 205 +++++++ .../platform/sim/host/pto_runtime_c_api.cpp | 53 ++ .../aicpu/aicpu_executor.cpp | 93 ++- .../host/runtime_maker.cpp | 53 ++ .../orchestration/pto_orchestration_api.h | 126 ++++ .../runtime/pto_async_kernel_api.h | 22 + .../runtime/pto_async_wait.h | 359 +++++++++++ .../runtime/pto_cq_kernel_api.h | 136 +++++ .../runtime/pto_cq_types.h | 47 ++ .../runtime/pto_notify_kernel_api.h | 41 ++ .../runtime/pto_runtime2.cpp | 27 + .../runtime/pto_runtime2.h | 10 + .../runtime/pto_runtime2_types.h | 9 + .../runtime/pto_scheduler.h | 4 +- .../runtime/pto_shared_memory.h | 9 + .../runtime/pto_sq_kernel_api.h | 201 +++++++ .../runtime/pto_types.h | 64 +- .../runtime/runtime.cpp | 12 + .../runtime/runtime.h | 8 + .../include/aicore/pto_async_backend_kernel.h | 92 +++ .../onboard/host/pto_runtime_c_api.cpp | 34 ++ .../platform/sim/host/pto_runtime_c_api.cpp | 34 ++ .../host/runtime_maker.cpp | 8 + .../dist_chip_bootstrap_channel.cpp | 131 ++++ .../distributed/dist_chip_bootstrap_channel.h | 69 +++ src/common/distributed/dist_orchestrator.cpp | 20 +- src/common/distributed/dist_orchestrator.h | 12 +- src/common/distributed/dist_tensormap.cpp | 10 +- src/common/distributed/dist_tensormap.h | 17 +- src/common/distributed/dist_types.cpp | 10 +- src/common/distributed/dist_types.h | 16 +- src/common/task_interface/tensor_arg.h | 12 + src/common/worker/chip_worker.cpp | 134 +++++ src/common/worker/chip_worker.h | 32 + src/common/worker/pto_runtime_c_api.h | 15 +- .../test_chip_bootstrap_channel.py | 93 +++ 64 files changed, 5596 insertions(+), 134 deletions(-) create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/golden.py create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_consumer.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer_async.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/kernel_config.py create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/orchestration/async_demo_orchestration.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/golden.py create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_consumer.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_notify_wait.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_producer_notify.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/kernel_config.py create mode 100644 examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/orchestration/async_notify_orchestration.cpp create mode 100644 src/a2a3/platform/include/aicore/pto_async_backend_kernel.h create mode 100644 src/a2a3/platform/include/common/comm_context.h create mode 100644 src/a2a3/platform/include/host/comm.h create mode 100644 src/a2a3/platform/onboard/host/comm_hccl.cpp create mode 100644 src/a2a3/platform/sim/host/comm_sim.cpp create mode 100644 src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_kernel_api.h create mode 100644 src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_wait.h create mode 100644 src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_kernel_api.h create mode 100644 src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_types.h create mode 100644 src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_notify_kernel_api.h create mode 100644 src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_sq_kernel_api.h create mode 100644 src/a5/platform/include/aicore/pto_async_backend_kernel.h create mode 100644 src/common/distributed/dist_chip_bootstrap_channel.cpp create mode 100644 src/common/distributed/dist_chip_bootstrap_channel.h create mode 100644 tests/ut/py/test_dist_worker/test_chip_bootstrap_channel.py diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/golden.py b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/golden.py new file mode 100644 index 000000000..4d6dad803 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/golden.py @@ -0,0 +1,73 @@ +""" +Golden script for async_completion_demo. + +Single-card / sim path keeps the original producer-consumer pipeline: + producer: out[i] = in[i] * 2.0 + consumer: result[i] = out[i] + 1.0 + +Hardware 2-card path validates `out` and `result`: + each rank TGET_ASYNCs the peer rank's `in` into local `out`, then the + normal consumer computes `result = out + 1`. +""" + +import ctypes +import torch + +__outputs__ = ["result", "out"] + +RTOL = 1e-5 +ATOL = 1e-5 + + +def generate_inputs(params: dict) -> list: + SIZE = 128 * 128 + + inp = torch.full((SIZE,), 3.0, dtype=torch.float32) + out = torch.zeros(SIZE, dtype=torch.float32) + result = torch.zeros(SIZE, dtype=torch.float32) + event_handle_output = torch.zeros(4, dtype=torch.int32) + + return [ + ("in", inp), + ("out", out), + ("result", result), + ("event_handle_output", event_handle_output), + ("size_in", ctypes.c_int64(inp.nbytes)), + ("size_out", ctypes.c_int64(out.nbytes)), + ("size_result", ctypes.c_int64(result.nbytes)), + ("size_event_handle_output", ctypes.c_int64(event_handle_output.nbytes)), + ("SIZE", ctypes.c_int64(SIZE)), + ] + + +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + del comm_ctx + del nranks + del root + + size = 128 * 128 + inp = [float(i % 251) / 10.0 for i in range(size)] + out = [0.0] * size + result = [0.0] * size + + return [ + ("in", inp), + ("out", out), + ("result", result), + ] + + +def compute_golden(tensors: dict, params: dict) -> None: + if "in" in tensors: + inp = torch.as_tensor(tensors["in"]) + tensors["result"][:] = inp * 2.0 + 1.0 + tensors["out"][:] = inp * 2.0 + return + + out = tensors["out"] + result = tensors["result"] + for i in range(len(out)): + value = float(i % 251) / 10.0 + out[i] = value + result[i] = value + 1.0 diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_consumer.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_consumer.cpp new file mode 100644 index 000000000..f206bf3a3 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_consumer.cpp @@ -0,0 +1,66 @@ +/** + * Async Completion Demo - Consumer Kernel (func_id=1) + * + * Implements: result[i] = src[i] + 1.0 + * + * This kernel executes as a normal run-to-completion task. It depends on the + * producer's output tensor; the scheduler only dispatches it after the + * producer's deferred completion (event flag) is resolved. + * + * Kernel args layout (packed by scheduler): + * args[0] = &Tensor(src) — input tensor struct pointer (producer's output) + * args[1] = &Tensor(result) — output tensor struct pointer + */ + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* src_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ Tensor* result_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + + __gm__ float* src = reinterpret_cast<__gm__ float*>(src_tensor->buffer.addr) + src_tensor->start_offset; + __gm__ float* result = reinterpret_cast<__gm__ float*>(result_tensor->buffer.addr) + result_tensor->start_offset; + + constexpr int kTRows_ = 128; + constexpr int kTCols_ = 128; + constexpr int vRows = 128; + constexpr int vCols = 128; + + using DynShapeDim5 = Shape<1, 1, 1, vRows, vCols>; + using DynStridDim5 = Stride<1, 1, 1, kTCols_, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData srcTile(vRows, vCols); + TileData dstTile(vRows, vCols); + TASSIGN(srcTile, 0x0); + TASSIGN(dstTile, 0x10000); + + GlobalData srcGlobal(src); + GlobalData dstGlobal(result); + + TLOAD(srcTile, srcGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TADDS(dstTile, srcTile, 1.0f); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + TSTORE(dstGlobal, dstTile); + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer.cpp new file mode 100644 index 000000000..d97f3133f --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer.cpp @@ -0,0 +1,90 @@ +/** + * Async Completion Demo - Simulated Producer Kernel (func_id=0) + * + * Implements: out[i] = in[i] * 2.0 + * + * After storing the output, writes 1 to a GM completion flag, then registers + * the completion via the CQ. The scheduler reads the CQ after this kernel + * returns and polls the flag address. + * + * Kernel args layout (packed by scheduler): + * args[0] = &Tensor(in) — input tensor struct pointer + * args[1] = &Tensor(out) — output tensor struct pointer + * args[2] = event_flag_gm_addr — GM flag addr (pre-allocated by golden.py) + * args[3] = cq_addr — completion queue (appended by submit_deferred) + */ + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include "pto_async_kernel_api.h" + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* in_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ Tensor* out_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + uint64_t event_flag_addr = static_cast(args[2]); + uint64_t cq_addr = static_cast(args[3]); + + __gm__ float* in_data = reinterpret_cast<__gm__ float*>(in_tensor->buffer.addr) + in_tensor->start_offset; + __gm__ float* out_data = reinterpret_cast<__gm__ float*>(out_tensor->buffer.addr) + out_tensor->start_offset; + + constexpr int kTRows_ = 128; + constexpr int kTCols_ = 128; + constexpr int vRows = 128; + constexpr int vCols = 128; + + using DynShapeDim5 = Shape<1, 1, 1, vRows, vCols>; + using DynStridDim5 = Stride<1, 1, 1, kTCols_, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData inTile(vRows, vCols); + TileData outTile(vRows, vCols); + TASSIGN(inTile, 0x0); + TASSIGN(outTile, 0x10000); + + GlobalData inGlobal(in_data); + GlobalData outGlobal(out_data); + + TLOAD(inTile, inGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // out = in + in = in * 2.0 + TADD(outTile, inTile, inTile); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + TSTORE(outGlobal, outTile); + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + + // Signal async completion: write non-zero flag to GM + volatile __gm__ int32_t* flag = reinterpret_cast( + static_cast(event_flag_addr)); +#if defined(SINGLE_CACHE_LINE) && defined(DSB_DDR) + dcci((__gm__ int32_t*)flag, SINGLE_CACHE_LINE); + *flag = 1; + dcci((__gm__ int32_t*)flag, SINGLE_CACHE_LINE); + dsb(DSB_DDR); +#else + *flag = 1; +#endif + + volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr); + pto2_cq_reset(cq); + pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, event_flag_addr); + pto2_cq_flush(cq); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer_async.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer_async.cpp new file mode 100644 index 000000000..2c7110597 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer_async.cpp @@ -0,0 +1,89 @@ +/** + * Async Completion Demo - Hardware 2P SDMA TGET Producer Kernel (func_id=2) + * + * Implements: + * 1. Read peer rank's input buffer via TGET_ASYNC into local out + * 2. Register the async event in the CQ + * 3. Return immediately so the runtime completes the task asynchronously + * + * This kernel is only compiled for real hardware (a2a3), not for simulation. + * + * Kernel args layout (packed by scheduler): + * args[0] = &Tensor(in) — input tensor struct pointer + * args[1] = &Tensor(out) — output tensor struct pointer + * args[2] = CommDeviceContext* — distributed communication context + * args[3] = sdma_context_addr — SDMA async context + * args[4] = cq_addr — completion queue (appended by submit_deferred) + */ + +#include +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include +#include "pto/comm/pto_comm_inst.hpp" +#include "pto/npu/comm/async/sdma/sdma_types.hpp" +#include "pto/common/pto_tile.hpp" + +#include "common/comm_context.h" +#include "tensor.h" + +using namespace pto; + +#include "pto_async_kernel_api.h" + +template +AICORE inline __gm__ T* CommRemotePtr(__gm__ CommDeviceContext* ctx, __gm__ T* local_ptr, + int peer_rank) { + uint64_t local_base = ctx->windowsIn[ctx->rankId]; + uint64_t offset = (uint64_t)local_ptr - local_base; + return (__gm__ T*)(ctx->windowsIn[peer_rank] + offset); +} + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* in_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ Tensor* out_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ CommDeviceContext* comm_ctx = + reinterpret_cast<__gm__ CommDeviceContext*>(args[2]); + uint64_t sdma_context = static_cast(args[3]); + uint64_t cq_addr = static_cast(args[4]); + + __gm__ float* in_data = reinterpret_cast<__gm__ float*>(in_tensor->buffer.addr) + in_tensor->start_offset; + __gm__ float* out_data = reinterpret_cast<__gm__ float*>(out_tensor->buffer.addr) + out_tensor->start_offset; + volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr); + pto2_cq_reset(cq); + + int my_rank = static_cast(comm_ctx->rankId); + int nranks = static_cast(comm_ctx->rankNum); + if (nranks != 2) { + pipe_barrier(PIPE_ALL); + return; + } + int peer_rank = 1 - my_rank; + + constexpr int kTotalElems = 128 * 128; + + using FlatShape = Shape<1, 1, 1, 1, kTotalElems>; + using FlatStride = Stride; + using FlatGlobalData = GlobalTensor; + FlatGlobalData outGlobalFlat(out_data); + __gm__ float* remote_in_data = CommRemotePtr(comm_ctx, in_data, peer_rank); + FlatGlobalData remoteInGlobalFlat(remote_in_data); + + using ScratchTile = pto::Tile; + ScratchTile scratchTile; + TASSIGN(scratchTile, 0x20000); + + __gm__ uint8_t* context = reinterpret_cast<__gm__ uint8_t*>(static_cast(sdma_context)); + + auto desc = pto2_remote_copy_tget_descriptor(outGlobalFlat, remoteInGlobalFlat, scratchTile, context); + uint64_t tag = pto2_send_request_entry(PTO2_ENGINE_SDMA, PTO2_SQ_ID_AUTO, desc); + pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, tag); + + pto2_cq_flush(cq); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/kernel_config.py b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/kernel_config.py new file mode 100644 index 000000000..8b34c541b --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/kernel_config.py @@ -0,0 +1,56 @@ +""" +Async Completion Demo - Kernel and Orchestration Configuration + +Two hardware cards use the existing deferred-completion producer API to +demonstrate a real 2P TGET_ASYNC remote read. The legacy single-card / sim +path stays available for local debugging. +""" + +import os +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "async_demo_orchestration.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +_platform = os.environ.get("PTO_PLATFORM", "a2a3sim") + +KERNELS = [ + {"func_id": 0, "source": str(_KERNELS_ROOT / "aiv" / "kernel_producer.cpp"), "core_type": "aiv"}, + {"func_id": 1, "source": str(_KERNELS_ROOT / "aiv" / "kernel_consumer.cpp"), "core_type": "aiv"}, +] + +if _platform == "a2a3": + KERNELS.append( + {"func_id": 2, "source": str(_KERNELS_ROOT / "aiv" / "kernel_producer_async.cpp"), "core_type": "aiv"}, + ) + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "orch_thread_num": 1, + "block_dim": 3, + "rounds": 1, +} + +if _platform == "a2a3": + RUNTIME_ENV = { + "PTO2_ENABLE_REMOTE_COPY_ASYNC": "1", + } + + DISTRIBUTED_CONFIG = { + "nranks": 2, + "root": 0, + "win_sync_prefix": 256, + "buffers": [ + {"name": "in", "dtype": "float32", "count": 128 * 128, "placement": "window"}, + {"name": "out", "dtype": "float32", "count": 128 * 128, "placement": "window"}, + {"name": "result", "dtype": "float32", "count": 128 * 128, "placement": "device"}, + ], + "inputs": ["in"], + "outputs": ["out", "result"], + "args": ["in", "out", "result", "deviceCtx"], + } diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/orchestration/async_demo_orchestration.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/orchestration/async_demo_orchestration.cpp new file mode 100644 index 000000000..8b06f3798 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/orchestration/async_demo_orchestration.cpp @@ -0,0 +1,142 @@ +/** + * Async Completion Demo - Device-side orchestration (CQ model) + * + * Two execution modes share this file: + * + * 1. Single-card / sim mode (legacy demo): + * t0 (producer): out = in * 2.0 [deferred completion via CQ] + * t1 (consumer): result = out + 1.0 [run-to-completion] + * + * 2. Two-card hardware mode: + * both ranks submit one deferred producer task that TGET_ASYNCs the peer + * rank's input buffer into local out, then run the normal consumer on out. + * + * CQ model: + * Orchestration marks t0 as complete_in_future and passes a CQ address. + * The producer kernel decides at runtime what completions it needs and writes + * them into the completion queue. The scheduler reads the CQ after the kernel + * returns and registers completions dynamically. + */ + +#include +#include + +#include "common/comm_context.h" +#include "pto_orchestration_api.h" + +#define ARG_PTR_IN 0 +#define ARG_PTR_OUT 1 +#define ARG_PTR_RESULT 2 +#define ARG_PTR_EVENT_HANDLE_OUTPUT 3 + +#define ARG_SIZE_IN 4 +#define ARG_SIZE_OUT 5 +#define ARG_SIZE_RESULT 6 +#define ARG_SIZE_EVENT_HANDLE_OUTPUT 7 + +#define ARG_SIZE 8 + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(const ChipStorageTaskArgs &orch_args) { + int arg_count = orch_args.tensor_count() + orch_args.scalar_count(); + return PTO2OrchestrationConfig{ + .expected_arg_count = (arg_count >= 9) ? 9 : 4, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(const ChipStorageTaskArgs &orch_args) { + int arg_count = orch_args.tensor_count() + orch_args.scalar_count(); + + if (arg_count == 4) { + void *in_ptr = nullptr; + void *out_ptr = nullptr; + void *result_ptr = nullptr; + auto *comm_ctx = reinterpret_cast(static_cast( + orch_args.tensor_count() == 0 ? orch_args.scalar(3) : orch_args.scalar(0))); + if (orch_args.tensor_count() == 0) { + in_ptr = reinterpret_cast(static_cast(orch_args.scalar(0))); + out_ptr = reinterpret_cast(static_cast(orch_args.scalar(1))); + result_ptr = reinterpret_cast(static_cast(orch_args.scalar(2))); + } else { + in_ptr = orch_args.tensor(0).data_as(); + out_ptr = orch_args.tensor(1).data_as(); + result_ptr = orch_args.tensor(2).data_as(); + } + int my_rank = (int)comm_ctx->rankId; + + uint32_t shapes[1] = {128 * 128}; + Tensor ext_in = make_tensor_external(in_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_out = make_tensor_external(out_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_result = make_tensor_external(result_ptr, shapes, 1, DataType::FLOAT32); + + uint64_t remote_copy_context = pto2_rt_get_remote_copy_context(); + uint64_t cq = pto2_rt_alloc_cq(); + if (remote_copy_context == 0 || cq == 0) { + LOG_ERROR("async_demo 2P: rank %d failed to get remote-copy context or CQ (ctx=0x%lx, cq=0x%lx)", + my_rank, remote_copy_context, cq); + return; + } + + Arg params_producer; + params_producer.add_input(ext_in); + params_producer.add_output(ext_out); + params_producer.add_scalar((uint64_t)(uintptr_t)comm_ctx); + params_producer.add_scalar(remote_copy_context); + pto2_rt_submit_aiv_task_deferred(2, params_producer, cq); + + Arg params_consumer; + params_consumer.add_input(ext_out); + params_consumer.add_output(ext_result); + pto2_rt_submit_aiv_task(1, params_consumer); + + LOG_INFO("async_demo 2P: rank %d submitted TGET_ASYNC producer with CQ", my_rank); + return; + } + + void *in_ptr = orch_args.tensor(ARG_PTR_IN).data_as(); + void *out_ptr = orch_args.tensor(ARG_PTR_OUT).data_as(); + void *result_ptr = orch_args.tensor(ARG_PTR_RESULT).data_as(); + uint64_t event_handle_output_gm = reinterpret_cast(orch_args.tensor(ARG_PTR_EVENT_HANDLE_OUTPUT).data_as()); + int SIZE = static_cast(orch_args.scalar(4) & 0x7FFFFFFF); + + uint64_t remote_copy_context = pto2_rt_get_remote_copy_context(); + uint64_t cq = pto2_rt_alloc_cq(); + + LOG_INFO("async_demo: SIZE=%d, event_handle_output=0x%lx, remote_copy_context=0x%lx, cq=0x%lx", + SIZE, event_handle_output_gm, remote_copy_context, cq); + + uint32_t shapes[1] = {(uint32_t)SIZE}; + Tensor ext_in = make_tensor_external(in_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_out = make_tensor_external(out_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_result = make_tensor_external(result_ptr, shapes, 1, DataType::FLOAT32); + + if (remote_copy_context != 0) { + // HW mode: kernel issues async SDMA request and puts event.handle directly in CQ entry. + Arg params_producer; + params_producer.add_input(ext_in); + params_producer.add_output(ext_out); + params_producer.add_scalar(remote_copy_context); + pto2_rt_submit_aiv_task_deferred(2, params_producer, cq); + + LOG_INFO("async_demo: HW mode - submitted async SDMA producer (func_id=2) with CQ"); + } else { + Arg params_producer; + params_producer.add_input(ext_in); + params_producer.add_output(ext_out); + params_producer.add_scalar(event_handle_output_gm); + pto2_rt_submit_aiv_task_deferred(0, params_producer, cq); + + LOG_INFO("async_demo: Sim mode - submitted producer (func_id=0) with CQ"); + } + + // t1 (consumer): result = out + 1.0 — normal run-to-completion + Arg params_consumer; + params_consumer.add_input(ext_out); + params_consumer.add_output(ext_result); + pto2_rt_submit_aiv_task(1, params_consumer); +} + +} // extern "C" diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/golden.py b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/golden.py new file mode 100644 index 000000000..1d8ffa872 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/golden.py @@ -0,0 +1,70 @@ +""" +Golden script for async_notify_demo. + +Two hardware ranks each produce `out = in * 2` and TNOTIFY the peer. +The consumer is launch-gated on the local notification counter >= 1. +When the consumer runs, it reads notify_counter (must be 1) and computes +`result = out + notify_counter = in*2 + 1`. +""" + +import torch + +__outputs__ = ["result", "out"] + +RTOL = 1e-5 +ATOL = 1e-5 + + +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + del rank + del nranks + del root + del comm_ctx + + size = 128 * 128 + inp = [float(i % 251) / 10.0 for i in range(size)] + out = [0.0] * size + result = [0.0] * size + notify_counter = [0] + + return [ + ("in", inp), + ("out", out), + ("result", result), + ("notify_counter", notify_counter), + ] + + +def generate_inputs(params: dict) -> list: + del params + + size = 128 * 128 + inp = torch.tensor([float(i % 251) / 10.0 for i in range(size)], dtype=torch.float32) + out = torch.zeros(size, dtype=torch.float32) + result = torch.zeros(size, dtype=torch.float32) + notify_counter = torch.zeros(1, dtype=torch.int32) + + return [ + ("in", inp), + ("out", out), + ("result", result), + ("notify_counter", notify_counter), + ] + + +def compute_golden(tensors: dict, params: dict) -> None: + del params + + if "in" in tensors: + inp = torch.as_tensor(tensors["in"]) + tensors["out"][:] = inp * 2.0 + tensors["result"][:] = tensors["out"] + 1.0 + return + + out = tensors["out"] + result = tensors["result"] + for i in range(len(out)): + value = float(i % 251) / 10.0 + out[i] = value * 2.0 + result[i] = out[i] + 1.0 diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_consumer.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_consumer.cpp new file mode 100644 index 000000000..28380969a --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_consumer.cpp @@ -0,0 +1,72 @@ +/** + * Async Notify Demo - Consumer Kernel (func_id=1) + * + * Implements: result[i] = src[i] + notify_counter[0] + * + * Depends on NotifyWait completing (via dummy tensor), guaranteeing + * the local notification counter >= 1 before this kernel runs. + * + * Kernel args layout (packed by scheduler): + * args[0] = &Tensor(dummy_notify) — input (dependency token from NotifyWait) + * args[1] = &Tensor(src) — input tensor struct pointer (producer's output) + * args[2] = &Tensor(result) — output tensor struct pointer + * args[3] = notify_counter_addr — local notify counter (window memory) + */ + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + // args[0] = dummy_notify tensor (dependency token, unused) + __gm__ Tensor* src_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ Tensor* result_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ int32_t* notify_counter = reinterpret_cast<__gm__ int32_t*>(args[3]); + + __gm__ float* src = + reinterpret_cast<__gm__ float*>(src_tensor->buffer.addr) + src_tensor->start_offset; + __gm__ float* result = + reinterpret_cast<__gm__ float*>(result_tensor->buffer.addr) + result_tensor->start_offset; + + constexpr int kTRows_ = 128; + constexpr int kTCols_ = 128; + constexpr int vRows = 128; + constexpr int vCols = 128; + + using DynShapeDim5 = Shape<1, 1, 1, vRows, vCols>; + using DynStridDim5 = Stride<1, 1, 1, kTCols_, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData srcTile(vRows, vCols); + TileData dstTile(vRows, vCols); + TASSIGN(srcTile, 0x0); + TASSIGN(dstTile, 0x10000); + + GlobalData srcGlobal(src); + GlobalData dstGlobal(result); + + TLOAD(srcTile, srcGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + float notify_value = static_cast(*notify_counter); + TADDS(dstTile, srcTile, notify_value); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + TSTORE(dstGlobal, dstTile); + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_notify_wait.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_notify_wait.cpp new file mode 100644 index 000000000..de5588f7a --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_notify_wait.cpp @@ -0,0 +1,46 @@ +/** + * NotifyWait Kernel — register notification counter as CQ condition (func_id=2) + * + * Trivial deferred-completion kernel: registers a COUNTER wait condition + * for the notification counter, then returns immediately. The scheduler + * polls the counter via the CQ mechanism and completes this task once + * *notify_counter >= expected_value. + * + * Kernel args layout: + * args[0] = &Tensor(dummy_notify) — output (dependency token for downstream) + * args[1] = notify_counter_addr — scalar (GM int32* to poll) + * args[2] = expected_value — scalar (threshold) + * args[3] = cq_addr — scalar (auto-appended by deferred submit) + */ + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include +#include "tensor.h" +#include "pto_async_kernel_api.h" + +extern "C" __aicore__ __attribute__((always_inline)) +void kernel_entry(__gm__ int64_t* args) { + uint64_t notify_counter_addr = static_cast(args[1]); + uint32_t expected_value = static_cast(args[2]); + uint64_t cq_addr = static_cast(args[3]); + + volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr); + pto2_cq_reset(cq); + pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, + notify_counter_addr, expected_value); + // Flush CQ writes from AICore data cache to GM so the AICPU scheduler + // can read them. pto2_cq_flush's #if-defined guards don't fire because + // the constants are C++ enums, not macros — call intrinsics directly. + dcci((__gm__ int32_t*)cq, cache_line_t::ENTIRE_DATA_CACHE, dcci_dst_t::CACHELINE_OUT); + dsb(DSB_DDR); + pipe_barrier(PIPE_ALL); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_producer_notify.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_producer_notify.cpp new file mode 100644 index 000000000..9ac07aa4a --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_producer_notify.cpp @@ -0,0 +1,103 @@ +/** + * Async Notify Demo - Hardware 2P Notify Producer Kernel (func_id=0) + * + * Implements: + * 1. Local compute: out[i] = in[i] * 2.0 + * 2. Notify peer rank via TNOTIFY(AtomicAdd) on the peer's window counter + * 3. Return normally (run-to-completion, no deferred completion) + * + * Rank 1 inserts a deliberate delay before notifying. This makes missing + * launch-gating on the consumer side visible in the example output. + * + * Kernel args layout (packed by scheduler): + * args[0] = &Tensor(in) — input tensor struct pointer + * args[1] = &Tensor(out) — output tensor struct pointer + * args[2] = notify_counter_addr — local notify counter (window memory) + * args[3] = CommDeviceContext* — distributed communication context + */ + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include +#include "pto/common/pto_tile.hpp" + +#include "common/comm_context.h" +#include "tensor.h" + +using namespace pto; + +#include "pto_async_kernel_api.h" + +template +AICORE inline __gm__ T* CommRemotePtr(__gm__ CommDeviceContext* ctx, __gm__ T* local_ptr, + int peer_rank) { + uint64_t local_base = ctx->windowsIn[ctx->rankId]; + uint64_t offset = (uint64_t)local_ptr - local_base; + return (__gm__ T*)(ctx->windowsIn[peer_rank] + offset); +} + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* in_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ Tensor* out_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int32_t* local_counter = reinterpret_cast<__gm__ int32_t*>(args[2]); + __gm__ CommDeviceContext* comm_ctx = reinterpret_cast<__gm__ CommDeviceContext*>(args[3]); + + __gm__ float* in_data = + reinterpret_cast<__gm__ float*>(in_tensor->buffer.addr) + in_tensor->start_offset; + __gm__ float* out_data = + reinterpret_cast<__gm__ float*>(out_tensor->buffer.addr) + out_tensor->start_offset; + + int my_rank = static_cast(comm_ctx->rankId); + int nranks = static_cast(comm_ctx->rankNum); + if (nranks != 2) { + pipe_barrier(PIPE_ALL); + return; + } + int peer_rank = 1 - my_rank; + + constexpr int kTRows_ = 128; + constexpr int kTCols_ = 128; + constexpr int vRows = 128; + constexpr int vCols = 128; + + using DynShapeDim5 = Shape<1, 1, 1, vRows, vCols>; + using DynStridDim5 = Stride<1, 1, 1, kTCols_, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData inTile(vRows, vCols); + TileData outTile(vRows, vCols); + TASSIGN(inTile, 0x0); + TASSIGN(outTile, 0x10000); + + GlobalData inGlobal(in_data); + GlobalData outGlobal(out_data); + + TLOAD(inTile, inGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TADD(outTile, inTile, inTile); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + TSTORE(outGlobal, outTile); + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + + if (my_rank == 1) { + for (volatile int i = 0; i < 2000000; ++i) { + } + } + + __gm__ int32_t* remote_counter = CommRemotePtr(comm_ctx, local_counter, peer_rank); + pto2_send_notification(remote_counter, 1, PTO2NotifyOp::AtomicAdd); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/kernel_config.py b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/kernel_config.py new file mode 100644 index 000000000..d0a3eda0a --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/kernel_config.py @@ -0,0 +1,50 @@ +""" +Async Notify Demo - Kernel and Orchestration Configuration + +Two hardware cards use TNOTIFY(AtomicAdd) for inter-rank notification. +The consumer depends on a deferred NotifyWait task that polls the +local notification counter >= 1 via the CQ mechanism. +""" + +import os +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent +_platform = os.environ.get("PTO_PLATFORM", "a2a3sim") + +if _platform != "a2a3": + raise RuntimeError("async_notify_demo currently requires PTO_PLATFORM=a2a3") + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "async_notify_orchestration.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +KERNELS = [ + {"func_id": 0, "source": str(_KERNELS_ROOT / "aiv" / "kernel_producer_notify.cpp"), "core_type": "aiv"}, + {"func_id": 1, "source": str(_KERNELS_ROOT / "aiv" / "kernel_consumer.cpp"), "core_type": "aiv"}, + {"func_id": 2, "source": str(_KERNELS_ROOT / "aiv" / "kernel_notify_wait.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "orch_thread_num": 1, + "block_dim": 3, + "rounds": 1, +} + +DISTRIBUTED_CONFIG = { + "nranks": 2, + "root": 0, + "win_sync_prefix": 256, + "buffers": [ + {"name": "in", "dtype": "float32", "count": 128 * 128, "placement": "window"}, + {"name": "out", "dtype": "float32", "count": 128 * 128, "placement": "device"}, + {"name": "result", "dtype": "float32", "count": 128 * 128, "placement": "device"}, + {"name": "notify_counter", "dtype": "int32", "count": 1, "placement": "window"}, + ], + "inputs": ["in", "notify_counter"], + "outputs": ["out", "result"], + "args": ["in", "out", "result", "notify_counter", "deviceCtx"], +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/orchestration/async_notify_orchestration.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/orchestration/async_notify_orchestration.cpp new file mode 100644 index 000000000..f128c95d2 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/orchestration/async_notify_orchestration.cpp @@ -0,0 +1,88 @@ +/** + * Async Notify Demo - Device-side orchestration + * + * Two-card hardware mode: + * t0 (producer, func_id=0): out = in * 2, then TNOTIFY(AtomicAdd) the + * peer's window counter. Completes normally (RTC). + * t1 (notify_wait, func_id=2, deferred): registers notification counter + * condition (counter >= 1) via CQ, returns immediately. + * Produces dummy_notify tensor for dependency chain. + * t2 (consumer, func_id=1): result = out + notify_counter. + * Depends on both producer (via ext_out) and notify_wait + * (via dummy_notify), ensuring counter >= 1 before reading. + * + * The notify counter is pre-zeroed by the distributed runner input loader. + */ + +#include + +#include "common/comm_context.h" +#include "pto_orchestration_api.h" + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(const ChipStorageTaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{ + .expected_arg_count = 5, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(const ChipStorageTaskArgs &orch_args) { + int arg_count = orch_args.tensor_count() + orch_args.scalar_count(); + + if (arg_count != 5) { + LOG_ERROR("async_notify_demo: expected 5 args, got %d", arg_count); + return; + } + + void *in_ptr = nullptr; + void *out_ptr = nullptr; + void *result_ptr = nullptr; + void *notify_counter_ptr = nullptr; + auto *comm_ctx = reinterpret_cast(static_cast( + orch_args.tensor_count() == 0 ? orch_args.scalar(4) : orch_args.scalar(0))); + if (orch_args.tensor_count() == 0) { + in_ptr = reinterpret_cast(static_cast(orch_args.scalar(0))); + out_ptr = reinterpret_cast(static_cast(orch_args.scalar(1))); + result_ptr = reinterpret_cast(static_cast(orch_args.scalar(2))); + notify_counter_ptr = reinterpret_cast(static_cast(orch_args.scalar(3))); + } else { + in_ptr = orch_args.tensor(0).data_as(); + out_ptr = orch_args.tensor(1).data_as(); + result_ptr = orch_args.tensor(2).data_as(); + notify_counter_ptr = orch_args.tensor(3).data_as(); + } + int my_rank = (int)comm_ctx->rankId; + + uint32_t shapes[1] = {128 * 128}; + Tensor ext_in = make_tensor_external(in_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_out = make_tensor_external(out_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_result = make_tensor_external(result_ptr, shapes, 1, DataType::FLOAT32); + + // Producer: normal run-to-completion task (sends TNOTIFY to peer) + Arg params_producer; + params_producer.add_input(ext_in); + params_producer.add_output(ext_out); + params_producer.add_scalar((uint64_t)(uintptr_t)notify_counter_ptr); + params_producer.add_scalar((uint64_t)(uintptr_t)comm_ctx); + pto2_rt_submit_aiv_task(0, params_producer); + + // Returns a dependency token tensor for downstream tasks. + Tensor notify_token = pto2_rt_submit_notification_wait_task(2, (uint64_t)(uintptr_t)notify_counter_ptr, 1); + + // Consumer: depends on producer (via ext_out) and notify_wait (via token). + Arg params_consumer; + params_consumer.add_input(notify_token); + params_consumer.add_input(ext_out); + params_consumer.add_output(ext_result); + params_consumer.add_scalar((uint64_t)(uintptr_t)notify_counter_ptr); + pto2_rt_submit_aiv_task(1, params_consumer); + + LOG_INFO("async_notify_demo: rank %d producer=RTC, notify_wait=deferred(counter=0x%lx), consumer=RTC", + my_rank, (uint64_t)(uintptr_t)notify_counter_ptr); +} + +} // extern "C" diff --git a/examples/scripts/README.md b/examples/scripts/README.md index ce5ac5c8c..fb2bb46af 100644 --- a/examples/scripts/README.md +++ b/examples/scripts/README.md @@ -42,6 +42,32 @@ python examples/scripts/run_example.py \ -p a2a3sim ``` +#### Running Distributed (Multi-Rank) Tests + +Distributed examples are auto-detected when `kernel_config.py` contains a `DISTRIBUTED_CONFIG` dictionary. No separate script is needed — `run_example.py` handles it automatically: + +```bash +# Simulation (no hardware required, 8 ranks by default from kernel_config) +python examples/scripts/run_example.py \ + -k path/to/distributed_test/kernels \ + -g path/to/distributed_test/golden.py \ + -p a2a3sim + +# Hardware platform — pick specific devices (nranks inferred from device count) +python examples/scripts/run_example.py \ + -k path/to/distributed_test/kernels \ + -g path/to/distributed_test/golden.py \ + -p a2a3 --devices 0,1,2,3,4,5,6,7 + +# Hardware platform — non-contiguous devices +python examples/scripts/run_example.py \ + -k path/to/distributed_test/kernels \ + -g path/to/distributed_test/golden.py \ + -p a2a3 --devices 2,4,5,7 +``` + +The framework runs distributed examples through the mainline L3 Python worker API. `create_code_runner(...)` always returns the unified `CodeRunner`; when `kernel_config.py` contains `DISTRIBUTED_CONFIG`, `CodeRunner` switches into distributed mode, performs rank-local `comm_*` bootstrap inside each chip child, then submits one group CHIP task with `Task(orch=...)`, `WorkerPayload`, and `args_list=[...]`. On simulation (`a2a3sim`), ranks communicate via POSIX shared memory; on hardware (`a2a3`), they use HCCL over RDMA. + ## Command Line Arguments ### `run_example.py` Parameters @@ -56,6 +82,7 @@ python examples/scripts/run_example.py \ | `--verbose` | `-v` | Enable verbose output (equivalent to `--log-level debug`) | False | | `--silent` | | Enable silent mode (equivalent to `--log-level error`) | False | | `--log-level` | | Set log level: `error`, `warn`, `info`, `debug` | `info` | +| `--nranks` | | Number of ranks for distributed tests | From `DISTRIBUTED_CONFIG` | | `--clone-protocol` | | Git protocol for cloning pto-isa: `ssh` or `https` | `ssh` | ### Platform Description @@ -162,7 +189,54 @@ ORCHESTRATION = { } ``` -### 3. `golden.py` Format +### 3. Distributed `kernel_config.py` Format + +To make a test distributed, add a `DISTRIBUTED_CONFIG` dictionary alongside the standard `KERNELS` and `ORCHESTRATION` fields: + +```python +DISTRIBUTED_CONFIG = { + "nranks": 8, # Number of ranks + "root": 0, # Root rank for collective ops + "comm_include_dirs": [...], # Extra include dirs for kernel compilation + "win_sync_prefix": 256, # Bytes reserved before window buffers + "buffers": [ + {"name": "input", "dtype": "float32", "count": 256, "placement": "window"}, + {"name": "output", "dtype": "float32", "count": 256, "placement": "device"}, + ], + "inputs": ["input"], # Buffers to load from .bin files + "outputs": ["output"], # Buffers to save after execution + "args": ["input", "output", "nranks", "root", "deviceCtx"], +} +``` + +- **`placement: "window"`** — Buffer is allocated in the RDMA window region (accessible by all ranks). +- **`placement: "device"`** — Buffer is allocated via `device_malloc` (local to each rank). +- **`args`** — Tokens passed as orchestration function arguments. Special tokens: `nranks`, `root`, `deviceCtx` (pointer to `CommDeviceContext`). + +### 4. Distributed `golden.py` Format + +The golden script for distributed tests uses `generate_distributed_inputs` instead of `generate_inputs`: + +```python +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + """Return a list of (name, data) tuples for this rank.""" + input_data = [float(i + rank * 100) for i in range(256)] + output_data = [0.0] * 256 + return [ + ("input", input_data), + ("output", output_data), + ] + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected output for the root rank (in-place).""" + nranks = params.get("nranks", 8) + output = tensors["output"] + for i in range(256): + output[i] = float(nranks * i + 100 * nranks * (nranks - 1) // 2) +``` + +### 5. Standard `golden.py` Format ```python import torch @@ -375,6 +449,25 @@ TEST PASSED ============================================================ ``` +### Distributed Test Success Example + +``` +[INFO] Detected DISTRIBUTED_CONFIG — using distributed runner +[INFO] === Phase 1: Building runtime === +... +[INFO] === Launching 8 workers === +[INFO] Rank 0: OK +[INFO] Rank 1: OK +... +[INFO] Rank 7: OK +[INFO] VERIFY PASSED: output — 256 elements correct +[INFO] Sample: [2800.0, 2808.0, 2816.0, 2824.0, 2832.0] + +============================================================ +TEST PASSED +============================================================ +``` + ### Failure Example ```text @@ -388,8 +481,9 @@ TEST FAILED: Output 'f' does not match golden ## Reference Examples -- **Hardware Example**: [examples/a2a3/host_build_graph/vector_example/](../a2a3/host_build_graph/vector_example/) -- **Simulation Example**: [examples/a2a3/host_build_graph/vector_example/](../a2a3/host_build_graph/vector_example/) +- **Single-Card Example**: [examples/a2a3/host_build_graph/vector_example/](../a2a3/host_build_graph/vector_example/) +- **Async Completion Demo** (2-card, deferred RDMA read): [examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/](../a2a3/tensormap_and_ringbuffer/async_completion_demo/) +- **Async Notify Demo** (2-card, TNOTIFY launch gating): [examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/](../a2a3/tensormap_and_ringbuffer/async_notify_demo/) ## FAQ @@ -533,6 +627,23 @@ runner = create_code_runner( runner.run() # Execute test ``` +### Distributed Programmatic Usage + +```python +from code_runner import create_code_runner + +runner = create_code_runner( + kernels_dir="path/to/distributed_test/kernels", + golden_path="path/to/distributed_test/golden.py", + platform="a2a3sim", + nranks=8, +) + +runner.run_all() # compile, prepare data, launch workers, verify +``` + +Internally this follows the same host-side shape as `tests/ut/py/test_dist_worker` and `tests/st/test_worker_api.py`: build the parent callable once, initialize `Worker(level=3)`, and submit a group CHIP task from a parent orchestration function. + ## Related Documentation - [Main Project README](../../README.md) diff --git a/examples/scripts/code_runner.py b/examples/scripts/code_runner.py index 5a8631c30..e1b103aa5 100644 --- a/examples/scripts/code_runner.py +++ b/examples/scripts/code_runner.py @@ -57,9 +57,12 @@ def compute_golden(tensors: dict, params: dict) -> None: import importlib.util import logging import os +import shutil +import struct import sys import time from contextlib import contextmanager +from multiprocessing.shared_memory import SharedMemory from pathlib import Path from typing import Any, Optional @@ -405,7 +408,9 @@ def _ensure_pto_isa_root_locked( logger.warning(f"pto-isa cloned but missing include directory: {include_dir}") return None - return str(clone_path.resolve()) + resolved_root = str(clone_path.resolve()) + os.environ["PTO_ISA_ROOT"] = resolved_root + return resolved_root def _kernel_config_runtime_env(kernel_config_module, kernels_dir: Path) -> dict[str, str]: @@ -471,6 +476,20 @@ class CodeRunner: platform: Platform name ("a2a3" for hardware, "a2a3sim" for simulation, default: "a2a3") """ + DTYPE_FORMAT = { + "float32": ("f", 4), + "float64": ("d", 8), + "int32": ("i", 4), + "int64": ("q", 8), + "uint32": ("I", 4), + "uint64": ("Q", 8), + "float16": ("e", 2), + "int16": ("h", 2), + "uint16": ("H", 2), + "int8": ("b", 1), + "uint8": ("B", 1), + } + def __init__( # noqa: PLR0913 self, kernels_dir: str, @@ -485,6 +504,8 @@ def __init__( # noqa: PLR0913 repeat_rounds: Optional[int] = None, clone_protocol: str = "ssh", skip_golden: bool = False, + nranks: Optional[int] = None, + device_ids: Optional[list[int]] = None, ): # Setup logging if not already configured (e.g., when used directly, not via run_example.py) _setup_logging_if_needed() @@ -536,6 +557,33 @@ def __init__( # noqa: PLR0913 self.runtime_name = runtime_config.get("runtime", "host_build_graph") self.repeat_rounds = repeat_rounds if repeat_rounds is not None else runtime_config.get("rounds", 1) + self._is_distributed = hasattr(self._kernel_config, "DISTRIBUTED_CONFIG") + if self._is_distributed: + dist_cfg = getattr(self._kernel_config, "DISTRIBUTED_CONFIG", {}) + self.nranks = nranks if nranks is not None else dist_cfg.get("nranks", 8) + self.root = dist_cfg.get("root", 0) + if self.nranks <= 0: + raise ValueError(f"Distributed nranks must be positive, got {self.nranks}") + if self.root < 0 or self.root >= self.nranks: + raise ValueError(f"Distributed root must be in [0, {self.nranks}), got {self.root}") + if device_ids is None: + self.device_ids = list(range(self.nranks)) + else: + if len(device_ids) != self.nranks: + raise ValueError(f"Expected {self.nranks} device ids, got {len(device_ids)}: {device_ids}") + self.device_ids = list(device_ids) + self.orch_func = self.orchestration["function_name"] + self._dist_run_dir = ( + self.project_root / "build" / "distributed" / "runs" / f"run_{os.getpid()}_{time.time_ns()}" + ) + self.build_dir = self._dist_run_dir / "cache" + self.artifact_dir = self._dist_run_dir / "artifacts" + self._dist_example_input_shms: list[SharedMemory] = [] + self._dist_example_output_shms: list[SharedMemory] = [] + self._dist_example_output_artifacts: list[dict[str, Path]] = [] + self._dist_example_inputs_by_rank = [] + self._dist_example_outputs_by_rank = [] + def _load_kernel_config(self): """Load kernel_config.py from kernels directory.""" config_path = self.kernels_dir / "kernel_config.py" @@ -706,6 +754,11 @@ def run(self) -> None: # noqa: PLR0912, PLR0915 - Run via ChipWorker - Compare with golden """ + if self._is_distributed: + if not self.run_all(): + raise RuntimeError("Distributed run failed") + return + # Import runtime modules (deferred import to avoid top-level dependency) from elf_parser import extract_text_section # noqa: PLC0415 from kernel_compiler import KernelCompiler # noqa: PLC0415 @@ -751,9 +804,10 @@ def run(self) -> None: # noqa: PLR0912, PLR0915 runtime_base_dir = os.path.join(self.project_root, "src", arch, "runtime", self.runtime_name) - # Read include_dirs from build_config.py for kernel compilation + # Start from the standard runtime + common + platform include roots used + # by orchestration compilation, then add runtime-specific aicore extras. build_config_path = os.path.join(runtime_base_dir, "build_config.py") - runtime_include_dirs = [] + runtime_include_dirs = kernel_compiler.get_orchestration_include_dirs(self.runtime_name) if os.path.isfile(build_config_path): import importlib.util # noqa: PLC0415 @@ -763,10 +817,9 @@ def run(self) -> None: # noqa: PLR0912, PLR0915 spec.loader.exec_module(bc_module) aicore_cfg = bc_module.BUILD_CONFIG.get("aicore", {}) for p in aicore_cfg.get("include_dirs", []): - runtime_include_dirs.append(os.path.join(runtime_base_dir, p)) - else: - runtime_include_dirs.append(os.path.join(runtime_base_dir, "runtime")) - runtime_include_dirs.append(os.path.join(self.project_root, "src", "common", "task_interface")) + inc_dir = os.path.join(runtime_base_dir, p) + if inc_dir not in runtime_include_dirs: + runtime_include_dirs.append(inc_dir) def _build_runtime(): return builder.get_binaries(self.runtime_name, build=self.build_runtime) @@ -905,6 +958,23 @@ def _compile_one_kernel(kernel): logger.info(f"=== All {total_cases} cases passed ===") logger.info("=" * 60) + def run_all(self, skip_compile: bool = False, skip_verify: bool = False) -> bool: + if self._is_distributed: + try: + if not skip_compile: + self._compile_dist_example_artifacts() + self._prepare_dist_example_data() + success = self._run_distributed() + if success and not skip_verify: + success = self._verify_distributed() + return success + finally: + self._cleanup_dist_example_staging() + + del skip_compile, skip_verify + self.run() + return True + def _compare_with_golden( self, outputs: dict[str, torch.Tensor], @@ -944,6 +1014,400 @@ def _compare_with_golden( matched = torch.isclose(actual, expected, rtol=self.rtol, atol=self.atol).sum().item() logger.info(f" {name}: PASS ({matched}/{actual.numel()} elements matched)") + def _dist_example_buffer_config(self, name: str): + dist = getattr(self._kernel_config, "DISTRIBUTED_CONFIG", {}) + for buf_cfg in dist.get("buffers", []): + if buf_cfg["name"] == name: + return buf_cfg + raise ValueError(f"Buffer '{name}' not found in DISTRIBUTED_CONFIG['buffers']") + + def _chip_buffer_dtype_to_task_dtype(self, dtype: str): + from task_interface import DataType # noqa: PLC0415 + + mapping = { + "float32": DataType.FLOAT32, + "float16": DataType.FLOAT16, + "int32": DataType.INT32, + "int16": DataType.INT16, + "int8": DataType.INT8, + "uint8": DataType.UINT8, + "int64": DataType.INT64, + } + if dtype not in mapping: + raise ValueError(f"Unsupported distributed buffer dtype: {dtype}") + return mapping[dtype] + + def _chip_runtime_artifact_paths(self): + return { + "host": self.artifact_dir / "libhost_runtime.so", + "aicpu": self.artifact_dir / "libaicpu_kernel.so", + "aicore": self.artifact_dir / "aicore_kernel.o", + } + + def _chip_orch_artifact_name(self): + return Path(self.orchestration["source"]).stem + ".so" + + def _chip_kernel_artifact_name(self, kernel_cfg): + return Path(kernel_cfg["source"]).stem + ".bin" + + def _build_chip_callable(self): + from task_interface import ChipCallable, CoreCallable # noqa: PLC0415 + + orch_binary = (self.artifact_dir / self._chip_orch_artifact_name()).read_bytes() + children = [] + for kernel_cfg in self.kernels: + binary = (self.artifact_dir / self._chip_kernel_artifact_name(kernel_cfg)).read_bytes() + children.append((kernel_cfg["func_id"], CoreCallable.build(kernel_cfg.get("signature", []), binary))) + return ChipCallable.build( + self.orchestration.get("signature", []), + self.orch_func, + orch_binary, + children, + ) + + def _buffer_nbytes(self, buf_cfg: dict) -> int: + fmt = self.DTYPE_FORMAT.get(buf_cfg["dtype"]) + if fmt is None: + raise ValueError(f"Unsupported dtype '{buf_cfg['dtype']}' for buffer '{buf_cfg['name']}'") + return int(buf_cfg["count"]) * fmt[1] + + def _cleanup_dist_example_staging(self) -> None: + for shm in self._dist_example_input_shms: + try: + shm.close() + finally: + try: + shm.unlink() + except FileNotFoundError: + pass + for shm in self._dist_example_output_shms: + try: + shm.close() + finally: + try: + shm.unlink() + except FileNotFoundError: + pass + self._dist_example_input_shms.clear() + self._dist_example_output_shms.clear() + self._dist_example_output_artifacts.clear() + self._dist_example_inputs_by_rank.clear() + self._dist_example_outputs_by_rank.clear() + + def _prepare_dist_example_data(self) -> None: + from worker import HostBufferStaging # noqa: PLC0415 + + self._cleanup_dist_example_staging() + golden = self._golden_module + if not hasattr(golden, "generate_distributed_inputs"): + raise AttributeError( + "Distributed examples must define generate_distributed_inputs(rank, nranks, root, comm_ctx=None)" + ) + + self._dist_example_inputs_by_rank: list[dict[str, HostBufferStaging]] = [] + self._dist_example_outputs_by_rank: list[dict[str, HostBufferStaging]] = [] + + dist_cfg = getattr(self._kernel_config, "DISTRIBUTED_CONFIG", {}) + input_names = set(dist_cfg.get("inputs", [])) + output_names = set(dist_cfg.get("outputs", [])) + + for rank in range(self.nranks): + rank_dir = self.artifact_dir / f"rank_{rank}" + rank_dir.mkdir(parents=True, exist_ok=True) + + inputs = golden.generate_distributed_inputs(rank, self.nranks, self.root) + input_map: dict[str, HostBufferStaging] = {} + output_map: dict[str, HostBufferStaging] = {} + + for name, data in inputs: + buf_cfg = self._dist_example_buffer_config(name) + fmt = self.DTYPE_FORMAT.get(buf_cfg["dtype"]) + if fmt is None: + raise ValueError(f"Unsupported dtype '{buf_cfg['dtype']}' for buffer '{name}'") + if isinstance(data, (list, tuple)): + raw = struct.pack(f"<{len(data)}{fmt[0]}", *data) + else: + raw = bytes(data) + if len(raw) != self._buffer_nbytes(buf_cfg): + raise ValueError(f"Distributed input '{name}' size mismatch: got {len(raw)}, expected {self._buffer_nbytes(buf_cfg)}") + if name in input_names: + shm = SharedMemory(create=True, size=len(raw)) + assert shm.buf is not None + shm.buf[:len(raw)] = raw + self._dist_example_input_shms.append(shm) + input_map[name] = HostBufferStaging(name=name, shm_name=shm.name, size=len(raw)) + (rank_dir / f"{name}.bin").write_bytes(raw) + + for name in output_names: + buf_cfg = self._dist_example_buffer_config(name) + size = self._buffer_nbytes(buf_cfg) + shm = SharedMemory(create=True, size=size) + assert shm.buf is not None + if size: + shm.buf[:size] = b"\0" * size + self._dist_example_output_shms.append(shm) + output_map[name] = HostBufferStaging(name=name, shm_name=shm.name, size=size) + + self._dist_example_inputs_by_rank.append(input_map) + self._dist_example_outputs_by_rank.append(output_map) + self._dist_example_output_artifacts.append( + {name: rank_dir / f"{name}.bin" for name in output_names} + ) + + logger.info(f"Prepared distributed data for {self.nranks} ranks in {self.artifact_dir}") + + def _build_chip_bootstrap_config(self, rank: int): + from worker import ChipBootstrapConfig, ChipBufferSpec, ChipCommBootstrapConfig # noqa: PLC0415 + + dist = getattr(self._kernel_config, "DISTRIBUTED_CONFIG", {}) + buffers = [] + total_window = int(dist.get("win_sync_prefix", 0)) + for buf in dist.get("buffers", []): + spec = ChipBufferSpec( + name=buf["name"], + dtype=buf["dtype"], + count=int(buf["count"]), + placement=buf["placement"], + nbytes=self._buffer_nbytes(buf), + load_from_host=buf["name"] in dist.get("inputs", []), + store_to_host=buf["name"] in dist.get("outputs", []), + ) + buffers.append(spec) + if spec.placement == "window": + total_window += spec.nbytes + return ChipBootstrapConfig( + comm=ChipCommBootstrapConfig( + rank=rank, + nranks=self.nranks, + rootinfo_path=str(self.artifact_dir / "rootinfo.bin"), + win_sync_prefix=int(dist.get("win_sync_prefix", 0)), + window_size=total_window, + ), + buffers=buffers, + host_inputs=list(self._dist_example_inputs_by_rank[rank].values()), + host_outputs=list(self._dist_example_outputs_by_rank[rank].values()), + ) + + def _make_chip_task_args(self, chip_context): + from task_interface import ChipStorageTaskArgs, scalar_to_uint64 # noqa: PLC0415 + + dist = getattr(self._kernel_config, "DISTRIBUTED_CONFIG", {}) + buf_cfg_by_name = {buf["name"]: buf for buf in dist.get("buffers", [])} + args = ChipStorageTaskArgs() + for tok in dist.get("args", []): + if tok == "nranks": + args.add_scalar(scalar_to_uint64(self.nranks)) + elif tok == "root": + args.add_scalar(scalar_to_uint64(self.root)) + elif tok == "rank": + args.add_scalar(scalar_to_uint64(chip_context.rank)) + elif tok == "deviceCtx": + args.add_scalar(scalar_to_uint64(chip_context.device_ctx)) + else: + buf_cfg = buf_cfg_by_name.get(tok) + if buf_cfg is None: + raise ValueError(f"Unknown distributed arg token: {tok}") + tensor_arg = chip_context.buffer_tensors.get(tok) + if tensor_arg is None: + from task_interface import make_device_tensor_arg # noqa: PLC0415 + + tensor_arg = make_device_tensor_arg( + chip_context.buffer_ptrs[tok], + (int(buf_cfg["count"]),), + self._chip_buffer_dtype_to_task_dtype(buf_cfg["dtype"]), + ) + args.add_tensor(tensor_arg) + return args + + def _compile_dist_example_artifacts(self): + from elf_parser import extract_text_section # noqa: PLC0415 + from kernel_compiler import KernelCompiler # noqa: PLC0415 + from runtime_builder import RuntimeBuilder # noqa: PLC0415 + + if self.build_dir.exists(): + shutil.rmtree(self.build_dir) + if self.artifact_dir.exists(): + shutil.rmtree(self.artifact_dir) + self.artifact_dir.mkdir(parents=True, exist_ok=True) + self.build_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Using distributed run directory: {self._dist_run_dir}") + + pto_isa_root = _ensure_pto_isa_root(verbose=True, commit=self.pto_isa_commit, clone_protocol=self.clone_protocol) + if pto_isa_root is None: + raise EnvironmentError("PTO_ISA_ROOT could not be resolved.") + + builder = RuntimeBuilder(platform=self.platform) + kernel_compiler = KernelCompiler(self.platform) + + logger.info("=== Phase 1: Building runtime ===") + runtime_bins = builder.get_binaries(self.runtime_name, build=self.build_runtime) + + logger.info("=== Phase 2: Compiling orchestration ===") + orch_binary = kernel_compiler.compile_orchestration( + self.runtime_name, + str(Path(self.orchestration["source"]).resolve()), + extra_include_dirs=kernel_compiler.get_orchestration_include_dirs(self.runtime_name), + build_dir=str(self.build_dir), + ) + + logger.info("=== Phase 3: Compiling kernels ===") + extra_includes = kernel_compiler.get_orchestration_include_dirs(self.runtime_name) + for d in getattr(self._kernel_config, "DISTRIBUTED_CONFIG", {}).get("comm_include_dirs", []): + p = Path(pto_isa_root) / d if not os.path.isabs(d) else Path(d) + extra_includes.append(str(p)) + kernel_bins = {} + for k in self.kernels: + src = k["source"] if os.path.isabs(k["source"]) else str(self.kernels_dir / k["source"]) + incore_o = kernel_compiler.compile_incore( + src, + core_type=k.get("core_type", "aiv"), + pto_isa_root=pto_isa_root, + extra_include_dirs=extra_includes, + build_dir=str(self.build_dir), + ) + kernel_bins[k["func_id"]] = (k, incore_o if self.platform.endswith("sim") else extract_text_section(incore_o)) + + def save(name, data): + path = self.artifact_dir / name + path.write_bytes(data) + logger.info(f" {name}: {len(data)} bytes") + + save("libhost_runtime.so", runtime_bins.host_path.read_bytes()) + save("libaicpu_kernel.so", runtime_bins.aicpu_path.read_bytes()) + save("aicore_kernel.o", runtime_bins.aicore_path.read_bytes()) + save(self._chip_orch_artifact_name(), orch_binary) + for _, (kcfg, data) in kernel_bins.items(): + save(self._chip_kernel_artifact_name(kcfg), data) + self._chip_callable = self._build_chip_callable() + logger.info(f"All artifacts saved to {self.artifact_dir}") + + def _dump_dist_example_outputs(self) -> None: + for rank, outputs in enumerate(self._dist_example_outputs_by_rank): + for name, staging in outputs.items(): + shm = SharedMemory(name=staging.shm_name) + try: + assert shm.buf is not None + self._dist_example_output_artifacts[rank][name].write_bytes(bytes(shm.buf[:staging.size])) + finally: + shm.close() + + def _run_distributed(self): + from task_interface import WorkerPayload, WorkerType # noqa: PLC0415 + from worker import Task, Worker # noqa: PLC0415 + + if not self._dist_example_inputs_by_rank or not self._dist_example_outputs_by_rank: + raise RuntimeError("Distributed data is not prepared. Call run_all() or _prepare_dist_example_data() first.") + + rootinfo_file = self.artifact_dir / "rootinfo.bin" + for f in self.artifact_dir.glob("barrier_*.ready"): + f.unlink() + if rootinfo_file.exists(): + rootinfo_file.unlink() + + if not hasattr(self, "_chip_callable"): + self._chip_callable = self._build_chip_callable() + + runtime_paths = self._chip_runtime_artifact_paths() + chip_bootstrap_configs = [self._build_chip_bootstrap_config(rank) for rank, _ in enumerate(self.device_ids)] + run_env = _kernel_config_runtime_env(self._kernel_config, self.kernels_dir) + success = True + worker = None + try: + with _temporary_env(run_env): + worker = Worker( + level=3, + device_ids=self.device_ids, + num_sub_workers=0, + platform=self.platform, + runtime=self.runtime_name, + host_path=str(runtime_paths["host"]), + aicpu_path=str(runtime_paths["aicpu"]), + aicore_path=str(runtime_paths["aicore"]), + chip_bootstrap_configs=chip_bootstrap_configs, + ) + worker.init() + rank_args = [self._make_chip_task_args(ctx) for ctx in worker.chip_contexts] + + payload = WorkerPayload() + payload.worker_type = WorkerType.CHIP + payload.callable = self._chip_callable.buffer_ptr() + payload.block_dim = int(getattr(self._kernel_config, "RUNTIME_CONFIG", {}).get("block_dim", 1)) + payload.aicpu_thread_num = int(getattr(self._kernel_config, "RUNTIME_CONFIG", {}).get("aicpu_thread_num", 1)) + + def orch_fn(w, args_list): + w.submit( + WorkerType.CHIP, + payload, + args_list=[arg.__ptr__() for arg in args_list], + outputs=[], + ) + + worker.run(Task(orch=orch_fn, args=rank_args)) + self._dump_dist_example_outputs() + except Exception: + logger.exception("Distributed worker execution failed") + success = False + finally: + if worker is not None: + worker.close() + for f in self.artifact_dir.glob("barrier_*.ready"): + f.unlink() + + print() + print(f"=== ALL {self.nranks} RANKS COMPLETED ===" if success else f"=== DISTRIBUTED RUN FAILED ({self.nranks} ranks) ===") + return success + + def _verify_distributed(self): + dist = getattr(self._kernel_config, "DISTRIBUTED_CONFIG", {}) + output_names = dist.get("outputs", []) + buf_map = {b["name"]: b for b in dist.get("buffers", [])} + + seed_dir = self.artifact_dir / f"rank_{self.root}" + seed_outputs = {} + for name in output_names: + path = seed_dir / f"{name}.bin" + if not path.exists(): + logger.error(f"Output file not found: {path}") + return False + raw = path.read_bytes() + dtype = buf_map.get(name, {}).get("dtype", "float32") + fmt_char, elem_sz = self.DTYPE_FORMAT.get(dtype, ("f", 4)) + count = len(raw) // elem_sz + seed_outputs[name] = list(struct.unpack(f"<{count}{fmt_char}", raw)) + + expected_outputs = {n: v.copy() for n, v in seed_outputs.items()} + self._golden_module.compute_golden(expected_outputs, {"nranks": self.nranks, "root": self.root}) + + rtol = getattr(self._golden_module, "RTOL", 1e-5) + atol = getattr(self._golden_module, "ATOL", 1e-5) + all_ok = True + for rank in range(self.nranks): + rank_dir = self.artifact_dir / f"rank_{rank}" + for name in output_names: + path = rank_dir / f"{name}.bin" + if not path.exists(): + logger.error(f"Output file not found: {path}") + all_ok = False + continue + raw = path.read_bytes() + dtype = buf_map.get(name, {}).get("dtype", "float32") + fmt_char, elem_sz = self.DTYPE_FORMAT.get(dtype, ("f", 4)) + count = len(raw) // elem_sz + actual = list(struct.unpack(f"<{count}{fmt_char}", raw)) + expected = expected_outputs[name] + mismatches = 0 + for i, (a, e) in enumerate(zip(actual, expected)): + if abs(a - e) > atol + rtol * abs(e): + if mismatches < 3: + logger.error(f" rank {rank} {name}[{i}]: got {a}, expected {e}") + mismatches += 1 + if mismatches: + logger.error(f"VERIFY FAILED: rank {rank} {name} — {mismatches}/{len(actual)} mismatches") + all_ok = False + else: + logger.info(f"VERIFY PASSED: rank {rank} {name} — {len(actual)} elements correct") + print("\n=== VERIFICATION PASSED ===\n" if all_ok else "\n=== VERIFICATION FAILED ===\n") + return all_ok def create_code_runner( # noqa: PLR0913 kernels_dir, @@ -958,8 +1422,25 @@ def create_code_runner( # noqa: PLR0913 repeat_rounds=None, clone_protocol="ssh", skip_golden=False, + nranks=None, + device_ids=None, ): - """Factory: creates a CodeRunner based on kernel_config.""" + """Factory: creates the example runner for the given kernel_config.""" + effective_device_ids = None if device_ids is None else list(device_ids) + effective_nranks = nranks + kernels_dir_path = Path(kernels_dir).resolve() + kernel_config = _load_module_from_path(kernels_dir_path / "kernel_config.py", f"kernel_config_factory_{os.getpid()}") + if hasattr(kernel_config, "DISTRIBUTED_CONFIG"): + dist_cfg = getattr(kernel_config, "DISTRIBUTED_CONFIG", {}) + if effective_device_ids is not None: + effective_nranks = len(effective_device_ids) + else: + effective_nranks = nranks if nranks is not None else dist_cfg.get("nranks", 8) + base_device = 0 if device_id is None else device_id + effective_device_ids = [base_device + i for i in range(effective_nranks)] + if nranks is not None and nranks != effective_nranks: + raise ValueError(f"--nranks={nranks} conflicts with device list ({effective_nranks} devices)") + return CodeRunner( kernels_dir=kernels_dir, golden_path=golden_path, @@ -973,4 +1454,6 @@ def create_code_runner( # noqa: PLR0913 repeat_rounds=repeat_rounds, clone_protocol=clone_protocol, skip_golden=skip_golden, + nranks=effective_nranks, + device_ids=effective_device_ids, ) diff --git a/examples/scripts/run_example.py b/examples/scripts/run_example.py index 89ab84199..2cc946ea6 100644 --- a/examples/scripts/run_example.py +++ b/examples/scripts/run_example.py @@ -81,6 +81,35 @@ def _wait_for_new_device_log(log_dir, pre_run_logs, timeout=15, interval=0.5): return None +def _parse_device_spec(spec): + """Expand a device spec like '4-7' or '0,1,3,5' into device ids.""" + if spec is None: + return None + + spec = spec.strip() + if not spec: + raise ValueError("Device spec must not be empty") + + device_ids = [] + for item in spec.split(","): + item = item.strip() + if not item: + continue + if "-" in item: + start_str, end_str = item.split("-", 1) + start = int(start_str) + end = int(end_str) + if end < start: + raise ValueError(f"Invalid device range '{item}': end < start") + device_ids.extend(range(start, end + 1)) + else: + device_ids.append(int(item)) + + if not device_ids: + raise ValueError("Device spec must contain at least one device") + + return device_ids + def main(): # noqa: PLR0912 parser = argparse.ArgumentParser( description="Run PTO runtime test with kernel config and golden script", @@ -201,10 +230,33 @@ def compute_golden(tensors: dict, params: dict) -> None: help="Compile runtime from source instead of using pre-built binaries", ) + parser.add_argument( + "--nranks", + type=int, + default=None, + help="Override number of ranks for distributed tests (default: from kernel_config)" + ) + + parser.add_argument( + "--device-range", + type=str, + default=None, + help="Explicit device range for distributed tests (e.g., 4-7)" + ) + + parser.add_argument( + "--devices", + type=str, + default=None, + help="Explicit distributed device list, supports comma lists/ranges (e.g., 0,1,3,5 or 4-7)" + ) + args = parser.parse_args() if args.all and args.case: parser.error("--all and --case are mutually exclusive") + if args.device_range and args.devices: + parser.error("--device-range and --devices are mutually exclusive") # Determine log level from arguments log_level_str = None @@ -253,6 +305,12 @@ def compute_golden(tensors: dict, params: dict) -> None: try: from code_runner import create_code_runner # noqa: PLC0415 + selected_device_ids = None + if args.devices is not None: + selected_device_ids = _parse_device_spec(args.devices) + elif args.device_range is not None: + selected_device_ids = _parse_device_spec(args.device_range) + runner = create_code_runner( kernels_dir=str(args.kernels), golden_path=str(args.golden), @@ -266,6 +324,8 @@ def compute_golden(tensors: dict, params: dict) -> None: repeat_rounds=args.rounds, clone_protocol=args.clone_protocol, skip_golden=args.skip_golden, + nranks=args.nranks, + device_ids=selected_device_ids, ) # Snapshot existing device logs before the run so we can identify the @@ -277,7 +337,10 @@ def compute_golden(tensors: dict, params: dict) -> None: if device_log_dir.exists(): pre_run_device_logs = set(device_log_dir.glob("*.log")) - runner.run() + success = runner.run_all() + if not success: + logger.error("TEST FAILED") + return 1 logger.info("=" * 60) logger.info("TEST PASSED") logger.info("=" * 60) diff --git a/python/bindings/CMakeLists.txt b/python/bindings/CMakeLists.txt index aee68ac64..4f6b3e3ce 100644 --- a/python/bindings/CMakeLists.txt +++ b/python/bindings/CMakeLists.txt @@ -18,6 +18,7 @@ list(TRANSFORM BINDING_SOURCES PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") set(DIST_SRC ${CMAKE_SOURCE_DIR}/src/common/distributed) set(DIST_SOURCES + ${DIST_SRC}/dist_chip_bootstrap_channel.cpp ${DIST_SRC}/dist_types.cpp ${DIST_SRC}/dist_tensormap.cpp ${DIST_SRC}/dist_ring.cpp diff --git a/python/bindings/dist_worker_bind.h b/python/bindings/dist_worker_bind.h index 0bf961343..66e9db7c8 100644 --- a/python/bindings/dist_worker_bind.h +++ b/python/bindings/dist_worker_bind.h @@ -26,6 +26,7 @@ #include #include "dist_chip_process.h" +#include "dist_chip_bootstrap_channel.h" #include "dist_orchestrator.h" #include "dist_sub_worker.h" #include "dist_types.h" @@ -50,6 +51,19 @@ inline void bind_dist_worker(nb::module_ &m) { .value("COMPLETED", TaskState::COMPLETED) .value("CONSUMED", TaskState::CONSUMED); + nb::class_(m, "DistTensorKey") + .def(nb::init<>()) + .def( + "__init__", + [](DistTensorKey *self, uint64_t base_ptr, int32_t worker_index) { + new (self) DistTensorKey{worker_index, base_ptr}; + }, + nb::arg("base_ptr"), + nb::arg("worker_index") = -1 + ) + .def_rw("worker_index", &DistTensorKey::worker_index) + .def_rw("base_ptr", &DistTensorKey::base_ptr); + // --- WorkerPayload --- nb::class_(m, "WorkerPayload") .def(nb::init<>()) @@ -85,12 +99,35 @@ inline void bind_dist_worker(nb::module_ &m) { .def(nb::init<>()) .def( "__init__", - [](DistInputSpec *self, uint64_t base_ptr) { - new (self) DistInputSpec{base_ptr}; + [](DistInputSpec *self, uint64_t base_ptr, int32_t worker_index) { + new (self) DistInputSpec{DistTensorKey{worker_index, base_ptr}}; }, - nb::arg("base_ptr") + nb::arg("base_ptr"), + nb::arg("worker_index") = -1 + ) + .def_rw("key", &DistInputSpec::key) + .def_prop_rw( + "base_ptr", + [](const DistInputSpec &self) { + return self.key.base_ptr; + }, + [](DistInputSpec &self, uint64_t base_ptr) { + self.key.base_ptr = base_ptr; + } ) - .def_rw("base_ptr", &DistInputSpec::base_ptr); + .def_prop_rw( + "worker_index", + [](const DistInputSpec &self) { + return self.key.worker_index; + }, + [](DistInputSpec &self, int32_t worker_index) { + self.key.worker_index = worker_index; + } + ); + + nb::enum_(m, "DistOutputOwnership") + .value("ALLOCATED", DistOutputOwnership::ALLOCATED) + .value("EXTERNAL", DistOutputOwnership::EXTERNAL); // --- DistOutputSpec --- nb::class_(m, "DistOutputSpec") @@ -98,11 +135,54 @@ inline void bind_dist_worker(nb::module_ &m) { .def( "__init__", [](DistOutputSpec *self, size_t size) { - new (self) DistOutputSpec{size}; + new (self) DistOutputSpec{DistOutputOwnership::ALLOCATED, size, DistTensorKey{}, nullptr}; }, nb::arg("size") ) - .def_rw("size", &DistOutputSpec::size); + .def_static( + "external", + [](uint64_t ptr, size_t size, int32_t worker_index) { + DistOutputSpec spec; + spec.ownership = DistOutputOwnership::EXTERNAL; + spec.size = size; + spec.key = DistTensorKey{worker_index, ptr}; + spec.external_ptr = reinterpret_cast(ptr); + return spec; + }, + nb::arg("ptr"), + nb::arg("size"), + nb::arg("worker_index") = -1 + ) + .def_rw("ownership", &DistOutputSpec::ownership) + .def_rw("size", &DistOutputSpec::size) + .def_rw("key", &DistOutputSpec::key) + .def_prop_rw( + "base_ptr", + [](const DistOutputSpec &self) { + return self.key.base_ptr; + }, + [](DistOutputSpec &self, uint64_t base_ptr) { + self.key.base_ptr = base_ptr; + } + ) + .def_prop_rw( + "worker_index", + [](const DistOutputSpec &self) { + return self.key.worker_index; + }, + [](DistOutputSpec &self, int32_t worker_index) { + self.key.worker_index = worker_index; + } + ) + .def_prop_rw( + "ptr", + [](const DistOutputSpec &self) { + return reinterpret_cast(self.external_ptr); + }, + [](DistOutputSpec &self, uint64_t ptr) { + self.external_ptr = reinterpret_cast(ptr); + } + ); // --- DistSubmitOutput --- nb::class_(m, "DistSubmitOutput") @@ -144,6 +224,36 @@ inline void bind_dist_worker(nb::module_ &m) { // Python can use this constant to allocate mailboxes of the right size. m.attr("DIST_SUB_MAILBOX_SIZE") = static_cast(DIST_SUB_MAILBOX_SIZE); + nb::enum_(m, "ChipBootstrapMailboxState") + .value("IDLE", ChipBootstrapMailboxState::IDLE) + .value("SUCCESS", ChipBootstrapMailboxState::SUCCESS) + .value("ERROR", ChipBootstrapMailboxState::ERROR); + + nb::class_(m, "DistChipBootstrapChannel") + .def( + "__init__", + [](DistChipBootstrapChannel *self, uint64_t mailbox_ptr, size_t max_buffer_count) { + new (self) DistChipBootstrapChannel(reinterpret_cast(mailbox_ptr), max_buffer_count); + }, + nb::arg("mailbox_ptr"), nb::arg("max_buffer_count"), + "Wrap a chip-bootstrap mailbox pointer. max_buffer_count must match the bootstrap buffer list length." + ) + .def("reset", &DistChipBootstrapChannel::reset) + .def( + "write_success", &DistChipBootstrapChannel::write_success, nb::arg("device_ctx"), nb::arg("local_window_base"), + nb::arg("actual_window_size"), nb::arg("buffer_ptrs") + ) + .def("write_error", &DistChipBootstrapChannel::write_error, nb::arg("error_code"), nb::arg("message")) + .def_prop_ro("state", &DistChipBootstrapChannel::state) + .def_prop_ro("error_code", &DistChipBootstrapChannel::error_code) + .def_prop_ro("device_ctx", &DistChipBootstrapChannel::device_ctx) + .def_prop_ro("local_window_base", &DistChipBootstrapChannel::local_window_base) + .def_prop_ro("actual_window_size", &DistChipBootstrapChannel::actual_window_size) + .def_prop_ro("buffer_ptrs", &DistChipBootstrapChannel::buffer_ptrs) + .def_prop_ro("error_message", &DistChipBootstrapChannel::error_message); + + m.attr("DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE") = static_cast(DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE); + // --- DistChipProcess --- // Fork + host_runtime.so init are managed from Python (Worker.__init__). // This class handles dispatch/poll via the chip mailbox (4096 bytes). diff --git a/python/bindings/task_interface.cpp b/python/bindings/task_interface.cpp index 50cb2bb07..f9aa76c7f 100644 --- a/python/bindings/task_interface.cpp +++ b/python/bindings/task_interface.cpp @@ -74,6 +74,11 @@ NB_MODULE(_task_interface, m) { // --- Constants --- m.attr("CONTINUOUS_TENSOR_MAX_DIMS") = CONTINUOUS_TENSOR_MAX_DIMS; + m.attr("CHIP_STORAGE_TASK_ARGS_SIZE") = static_cast(sizeof(ChipStorageTaskArgs)); + + nb::enum_(m, "TensorStorageType") + .value("HOST", TensorStorageType::HOST) + .value("DEVICE", TensorStorageType::DEVICE); // --- ContinuousTensor --- nb::class_(m, "ContinuousTensor") @@ -81,7 +86,7 @@ NB_MODULE(_task_interface, m) { .def_static( "make", - [](uint64_t data, nb::tuple shapes, DataType dtype) -> ContinuousTensor { + [](uint64_t data, nb::tuple shapes, DataType dtype, bool device_resident) -> ContinuousTensor { size_t n = nb::len(shapes); if (n > CONTINUOUS_TENSOR_MAX_DIMS) throw std::invalid_argument("shapes length exceeds CONTINUOUS_TENSOR_MAX_DIMS"); @@ -89,11 +94,12 @@ NB_MODULE(_task_interface, m) { arg.data = data; arg.dtype = dtype; arg.ndims = static_cast(n); + arg.storage = device_resident ? TensorStorageType::DEVICE : TensorStorageType::HOST; for (size_t i = 0; i < n; ++i) arg.shapes[i] = nb::cast(shapes[i]); return arg; }, - nb::arg("data"), nb::arg("shapes"), nb::arg("dtype"), + nb::arg("data"), nb::arg("shapes"), nb::arg("dtype"), nb::arg("device_resident") = false, "Create a ContinuousTensor from a data pointer, shape tuple, and dtype." ) @@ -150,6 +156,16 @@ NB_MODULE(_task_interface, m) { } ) + .def_prop_rw( + "device_resident", + [](const ContinuousTensor &self) -> bool { + return self.is_device_resident(); + }, + [](ContinuousTensor &self, bool device_resident) { + self.storage = device_resident ? TensorStorageType::DEVICE : TensorStorageType::HOST; + } + ) + .def( "nbytes", [](const ContinuousTensor &self) -> uint64_t { @@ -165,7 +181,8 @@ NB_MODULE(_task_interface, m) { if (i) os << ", "; os << self.shapes[i]; } - os << "), dtype=" << get_dtype_name(self.dtype) << ")"; + os << "), dtype=" << get_dtype_name(self.dtype) + << ", device_resident=" << (self.is_device_resident() ? "True" : "False") << ")"; return os.str(); }); @@ -611,7 +628,17 @@ NB_MODULE(_task_interface, m) { ) .def_prop_ro("device_id", &ChipWorker::device_id) .def_prop_ro("initialized", &ChipWorker::initialized) - .def_prop_ro("device_set", &ChipWorker::device_set); + .def_prop_ro("device_set", &ChipWorker::device_set) + .def("device_malloc", &ChipWorker::device_malloc, nb::arg("size")) + .def("device_free", &ChipWorker::device_free, nb::arg("dev_ptr")) + .def("copy_to_device", &ChipWorker::copy_to_device, nb::arg("dev_ptr"), nb::arg("host_ptr"), nb::arg("size")) + .def("copy_from_device", &ChipWorker::copy_from_device, nb::arg("host_ptr"), nb::arg("dev_ptr"), nb::arg("size")) + .def("comm_init", &ChipWorker::comm_init, nb::arg("rank"), nb::arg("nranks"), nb::arg("device_id"), nb::arg("rootinfo_path")) + .def("comm_alloc_windows", &ChipWorker::comm_alloc_windows, nb::arg("comm_handle"), nb::arg("win_size")) + .def("comm_get_local_window_base", &ChipWorker::comm_get_local_window_base, nb::arg("comm_handle")) + .def("comm_get_window_size", &ChipWorker::comm_get_window_size, nb::arg("comm_handle")) + .def("comm_barrier", &ChipWorker::comm_barrier, nb::arg("comm_handle")) + .def("comm_destroy", &ChipWorker::comm_destroy, nb::arg("comm_handle")); bind_dist_worker(m); } diff --git a/python/task_interface.py b/python/task_interface.py index 1dfd327e5..0b3cf59a9 100644 --- a/python/task_interface.py +++ b/python/task_interface.py @@ -16,27 +16,39 @@ from task_interface import DataType, ContinuousTensor, ChipStorageTaskArgs, make_tensor_arg """ +from dataclasses import dataclass +from multiprocessing.shared_memory import SharedMemory + +# CHIP_STORAGE_TASK_ARGS_SIZE is the authoritative C++ sizeof(ChipStorageTaskArgs), +# exported for mailbox memcpy paths in the L3 chip-process worker flow. from _task_interface import ( # pyright: ignore[reportMissingImports] + CHIP_STORAGE_TASK_ARGS_SIZE, CONTINUOUS_TENSOR_MAX_DIMS, + DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE, DIST_CHIP_MAILBOX_SIZE, DIST_SUB_MAILBOX_SIZE, ArgDirection, CallConfig, ChipCallable, ChipStorageTaskArgs, + ChipBootstrapMailboxState, ContinuousTensor, CoreCallable, DataType, + DistChipBootstrapChannel, DistChipProcess, DistInputSpec, + DistOutputOwnership, DistOutputSpec, DistSubmitOutput, DistSubmitResult, + DistTensorKey, DistSubWorker, DistWorker, DynamicTaskArgs, TaggedTaskArgs, TaskState, + TensorStorageType, TensorArgType, WorkerPayload, WorkerType, @@ -50,10 +62,12 @@ "DataType", "get_element_size", "get_dtype_name", + "CHIP_STORAGE_TASK_ARGS_SIZE", "CONTINUOUS_TENSOR_MAX_DIMS", "ContinuousTensor", "ChipStorageTaskArgs", "TensorArgType", + "TensorStorageType", "DynamicTaskArgs", "TaggedTaskArgs", "ArgDirection", @@ -64,23 +78,41 @@ "arg_direction_name", "torch_dtype_to_datatype", "make_tensor_arg", + "make_device_tensor_arg", "scalar_to_uint64", # Distributed runtime "WorkerType", "TaskState", "WorkerPayload", + "DistTensorKey", "DistInputSpec", + "DistOutputOwnership", "DistOutputSpec", "DistSubmitOutput", "DistSubmitResult", "DistSubWorker", + "DistChipBootstrapChannel", "DistChipProcess", "DistWorker", + "ChipBootstrapMailboxState", "DIST_SUB_MAILBOX_SIZE", "DIST_CHIP_MAILBOX_SIZE", + "DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE", + "ChipBootstrapResult", ] +@dataclass +class ChipBootstrapResult: + """Parent-visible reply from per-chip bootstrap.""" + + comm_handle: int + device_ctx: int + local_window_base: int + actual_window_size: int + buffer_ptrs: list[int] + + # Lazy-loaded torch dtype → DataType map (avoids importing torch at module load) _TORCH_DTYPE_MAP = None @@ -126,6 +158,23 @@ def make_tensor_arg(tensor) -> ContinuousTensor: return ContinuousTensor.make(tensor.data_ptr(), shapes, dt) +def make_device_tensor_arg(ptr: int, shape, dtype) -> ContinuousTensor: + """Create a device-resident ``ContinuousTensor`` from an external device pointer. + + Args: + ptr: Device or window pointer already valid in the target chip process. + shape: Iterable of tensor dimensions. + dtype: Either ``DataType`` or a ``torch.dtype`` supported by ``make_tensor_arg``. + """ + _ensure_torch_map() + if not isinstance(dtype, DataType): + dtype = _TORCH_DTYPE_MAP.get(dtype) # pyright: ignore[reportOptionalMemberAccess] + if dtype is None: + raise ValueError(f"Unsupported dtype for ContinuousTensor: {dtype}") + shapes = tuple(int(s) for s in shape) + return ContinuousTensor.make(int(ptr), shapes, dtype, device_resident=True) + + def scalar_to_uint64(value) -> int: """Convert a scalar value to ``uint64``. @@ -221,6 +270,225 @@ def run(self, callable, args, config=None, **kwargs): setattr(config, k, v) self._impl.run(callable, args, config) + def run_raw(self, callable, args, *, block_dim=1, aicpu_thread_num=3, enable_profiling=False): + """Run a callable using raw pointer arguments.""" + self._impl.run_raw(int(callable), int(args), int(block_dim), int(aicpu_thread_num), bool(enable_profiling)) + + def device_malloc(self, size): + """Allocate device memory in the current device context.""" + return int(self._impl.device_malloc(int(size))) + + def device_free(self, dev_ptr): + """Free device memory allocated by ``device_malloc()``.""" + self._impl.device_free(int(dev_ptr)) + + def copy_to_device(self, dev_ptr, host_ptr, size): + """Copy bytes from a host pointer into a device pointer.""" + self._impl.copy_to_device(int(dev_ptr), int(host_ptr), int(size)) + + def copy_from_device(self, host_ptr, dev_ptr, size): + """Copy bytes from a device pointer into a host pointer.""" + self._impl.copy_from_device(int(host_ptr), int(dev_ptr), int(size)) + + def comm_init(self, rank, nranks, device_id, rootinfo_path): + """Create a communicator in the current chip child.""" + return int(self._impl.comm_init(int(rank), int(nranks), int(device_id), str(rootinfo_path))) + + def comm_alloc_windows(self, comm_handle, win_size): + """Allocate the communicator-owned window and return the device context.""" + return int(self._impl.comm_alloc_windows(int(comm_handle), int(win_size))) + + def comm_get_local_window_base(self, comm_handle): + """Return the local base address of the communicator window.""" + return int(self._impl.comm_get_local_window_base(int(comm_handle))) + + def comm_get_window_size(self, comm_handle): + """Return the actual communicator window size.""" + return int(self._impl.comm_get_window_size(int(comm_handle))) + + def comm_destroy(self, comm_handle): + """Destroy a communicator previously created by ``comm_init()``.""" + self._impl.comm_destroy(int(comm_handle)) + + def bootstrap( + self, + device_id, + *, + comm_rank=-1, + comm_nranks=0, + rootinfo_path="", + window_size=0, + win_sync_prefix=0, + buffer_sizes, + buffer_placements, + input_blobs, + ): + """Bootstrap per-chip runtime state before the first task submission. + + This optional handshake extends plain ``init()`` with communicator setup, + window/device buffer allocation, initial H2D staging, and a bootstrap + reply that the parent process can use to build task arguments. + """ + buffer_sizes = [int(size) for size in buffer_sizes] + buffer_placements = [str(placement) for placement in buffer_placements] + input_blobs = list(input_blobs) + + if len(buffer_sizes) != len(buffer_placements): + raise ValueError("buffer_sizes and buffer_placements must have the same length") + if len(buffer_sizes) != len(input_blobs): + raise ValueError("input_blobs length must match buffer_sizes") + + enable_comm = int(comm_rank) >= 0 + comm_handle = 0 + device_ctx = 0 + local_window_base = 0 + actual_window_size = 0 + owned_device_ptrs: list[int] = [] + buffer_ptrs: list[int] = [] + + try: + if enable_comm: + if int(comm_nranks) <= 0: + raise ValueError("comm_nranks must be positive when comm bootstrap is enabled") + if not str(rootinfo_path): + raise ValueError("rootinfo_path is required when comm bootstrap is enabled") + comm_handle = self.comm_init(comm_rank, comm_nranks, device_id, rootinfo_path) + + if not self.device_set: + self.set_device(int(device_id)) + elif self.device_id != int(device_id): + raise ValueError("ChipWorker already bound to a different device") + + if enable_comm: + device_ctx = self.comm_alloc_windows(comm_handle, window_size) + local_window_base = self.comm_get_local_window_base(comm_handle) + actual_window_size = self.comm_get_window_size(comm_handle) + + win_offset = int(win_sync_prefix) + for size, placement, blob in zip(buffer_sizes, buffer_placements, input_blobs, strict=True): + ptr = 0 + if placement == "window": + if not enable_comm: + raise ValueError("window placement requires comm bootstrap") + ptr = local_window_base + win_offset + win_offset += size + elif placement == "device": + ptr = self.device_malloc(size) + owned_device_ptrs.append(ptr) + else: + raise ValueError(f"Unsupported buffer placement: {placement}") + + buffer_ptrs.append(ptr) + + if blob is not None: + if not isinstance(blob, bytes): + raise ValueError("input blobs must be bytes or None") + if len(blob) != size: + raise ValueError("input blob size must match buffer size") + if size > 0: + import ctypes as _ct + + host_buf = _ct.create_string_buffer(blob, size) + self.copy_to_device(ptr, _ct.addressof(host_buf), size) + + if enable_comm: + self.comm_barrier(comm_handle) + except Exception: + for ptr in owned_device_ptrs: + try: + self.device_free(ptr) + except Exception: + pass + if comm_handle != 0: + try: + self.comm_destroy(comm_handle) + except Exception: + pass + raise + + return { + "comm_handle": comm_handle, + "device_ctx": device_ctx, + "local_window_base": local_window_base, + "actual_window_size": actual_window_size, + "buffer_ptrs": buffer_ptrs, + } + + def shutdown_bootstrap(self, *, comm_handle=0, buffer_ptrs, buffer_placements): + """Release per-chip runtime state previously created by ``bootstrap()``.""" + buffer_ptrs = [int(ptr) for ptr in buffer_ptrs] + buffer_placements = [str(placement) for placement in buffer_placements] + if len(buffer_ptrs) != len(buffer_placements): + raise ValueError("buffer_ptrs and buffer_placements must have the same length") + for ptr, placement in zip(buffer_ptrs, buffer_placements, strict=True): + if placement == "device" and ptr != 0: + self.device_free(ptr) + if int(comm_handle) != 0: + self.comm_destroy(int(comm_handle)) + + @staticmethod + def _read_bootstrap_input_bytes(shm_name: str, size: int) -> bytes: + shm = SharedMemory(name=shm_name) + try: + if size == 0: + return b"" + assert shm.buf is not None + return bytes(shm.buf[:size]) + finally: + shm.close() + + def bootstrap_context(self, device_id, chip_bootstrap_config) -> ChipBootstrapResult: + """Bootstrap a chip child from a typed bootstrap config.""" + comm_cfg = getattr(chip_bootstrap_config, "comm", None) + input_blobs = [] + for buf in chip_bootstrap_config.buffers: + if buf.load_from_host: + staged = chip_bootstrap_config.input_staging(buf.name) + input_blobs.append(self._read_bootstrap_input_bytes(staged.shm_name, staged.size)) + else: + input_blobs.append(None) + reply = self.bootstrap( + device_id, + comm_rank=comm_cfg.rank if comm_cfg is not None else -1, + comm_nranks=comm_cfg.nranks if comm_cfg is not None else 0, + rootinfo_path=comm_cfg.rootinfo_path if comm_cfg is not None else "", + window_size=comm_cfg.window_size if comm_cfg is not None else 0, + win_sync_prefix=comm_cfg.win_sync_prefix if comm_cfg is not None else 0, + buffer_sizes=[buf.nbytes for buf in chip_bootstrap_config.buffers], + buffer_placements=[buf.placement for buf in chip_bootstrap_config.buffers], + input_blobs=input_blobs, + ) + return ChipBootstrapResult( + comm_handle=int(reply["comm_handle"]), + device_ctx=int(reply["device_ctx"]), + local_window_base=int(reply["local_window_base"]), + actual_window_size=int(reply["actual_window_size"]), + buffer_ptrs=[int(ptr) for ptr in reply["buffer_ptrs"]], + ) + + def shutdown_bootstrap_context(self, chip_bootstrap_config, *, comm_handle=0, buffer_ptrs): + """Release resources created by ``bootstrap_context``.""" + self.shutdown_bootstrap( + comm_handle=comm_handle, + buffer_ptrs=buffer_ptrs, + buffer_placements=[buf.placement for buf in chip_bootstrap_config.buffers], + ) + + def copy_device_to_bytes(self, dev_ptr, size) -> bytes: + """Copy a device buffer into a Python bytes object.""" + size = int(size) + if size == 0: + return b"" + import ctypes as _ct + + host_buf = _ct.create_string_buffer(size) + self._impl.copy_from_device(_ct.addressof(host_buf), int(dev_ptr), size) + return host_buf.raw[:size] + + def comm_barrier(self, comm_handle): + """Synchronize all ranks in the current communicator.""" + self._impl.comm_barrier(int(comm_handle)) + @property def device_id(self): return self._impl.device_id diff --git a/python/worker.py b/python/worker.py index 1d52fb060..a483caa9a 100644 --- a/python/worker.py +++ b/python/worker.py @@ -28,12 +28,19 @@ def my_orch(w, args): w.run(Task(orch=my_orch, args=my_args)) w.close() + + # L3 chip bootstrap extension: keep run/submit standard, pass optional + # per-chip bootstrap metadata through Worker.init(). + w = Worker(level=3, device_ids=[8, 9], + chip_bootstrap_configs=[...]) """ import ctypes import os +import signal import struct import sys +import time from dataclasses import dataclass, field from multiprocessing.shared_memory import SharedMemory from pathlib import Path @@ -45,9 +52,13 @@ def my_orch(w, args): sys.path.insert(0, _SCRIPTS) from task_interface import ( # noqa: E402 + DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE, DIST_CHIP_MAILBOX_SIZE, DIST_SUB_MAILBOX_SIZE, + ChipBootstrapMailboxState, ChipWorker, + DataType, + DistChipBootstrapChannel, DistChipProcess, DistInputSpec, DistOutputSpec, @@ -55,7 +66,7 @@ def my_orch(w, args): DistWorker, WorkerPayload, WorkerType, - _ChipWorker, + make_device_tensor_arg, ) # --------------------------------------------------------------------------- @@ -75,6 +86,153 @@ class Task: args: Any = field(default=None) +@dataclass +class ChipBufferSpec: + """Per-chip buffer contract used by the optional L3 chip bootstrap path.""" + + name: str + dtype: str + count: int + placement: str + nbytes: int + load_from_host: bool = False + store_to_host: bool = False + + def make_tensor_arg(self, ptr: int) -> Any: + return make_device_tensor_arg(ptr, (self.count,), _buffer_dtype_to_task_dtype(self.dtype)) + + +@dataclass +class HostBufferStaging: + """Named shared-memory staging region prepared by the parent process.""" + + name: str + shm_name: str + size: int + + +@dataclass +class ChipCommBootstrapConfig: + """Optional communicator bootstrap for a chip child.""" + + rank: int + nranks: int + rootinfo_path: str + window_size: int + win_sync_prefix: int = 0 + + +@dataclass +class ChipBootstrapConfig: + """Worker-side chip child bootstrap input.""" + + comm: Optional[ChipCommBootstrapConfig] = None + buffers: list[ChipBufferSpec] = field(default_factory=list) + host_inputs: list[HostBufferStaging] = field(default_factory=list) + host_outputs: list[HostBufferStaging] = field(default_factory=list) + + def input_staging(self, name: str) -> HostBufferStaging: + for staging in self.host_inputs: + if staging.name == name: + return staging + raise KeyError(f"Missing staged host input for chip buffer '{name}'") + + def output_staging(self, name: str) -> HostBufferStaging: + for staging in self.host_outputs: + if staging.name == name: + return staging + raise KeyError(f"Missing staged host output for chip buffer '{name}'") + + +@dataclass +class ChipBootstrapReply: + """Child -> parent bootstrap reply carried over the bootstrap mailbox.""" + + device_ctx: int + local_window_base: int + actual_window_size: int + buffer_ptrs: list[int] + + +@dataclass +class ChipBootstrapState: + """Child-local chip bootstrap state kept alive for the chip process lifetime.""" + + bootstrap_config: ChipBootstrapConfig + comm_handle: Optional[int] + bootstrap_reply: ChipBootstrapReply + + @property + def buffers(self) -> list[ChipBufferSpec]: + return self.bootstrap_config.buffers + + @property + def buffer_ptrs(self) -> dict[str, int]: + return { + buf.name: int(ptr) + for buf, ptr in zip(self.bootstrap_config.buffers, self.bootstrap_reply.buffer_ptrs, strict=True) + } + + @property + def device_ctx(self) -> int: + return self.bootstrap_reply.device_ctx + + @property + def local_window_base(self) -> int: + return self.bootstrap_reply.local_window_base + + @property + def actual_window_size(self) -> int: + return self.bootstrap_reply.actual_window_size + + def input_staging(self, name: str) -> HostBufferStaging: + return self.bootstrap_config.input_staging(name) + + def output_staging(self, name: str) -> HostBufferStaging: + return self.bootstrap_config.output_staging(name) + + +@dataclass +class ChipContext: + """Parent-visible chip bootstrap result used to build task args.""" + + bootstrap_config: ChipBootstrapConfig + device_id: int + bootstrap_reply: ChipBootstrapReply + buffer_tensors: dict[str, Any] + + @property + def rank(self) -> int: + if self.bootstrap_config.comm is None: + raise AttributeError("ChipContext.rank is only available when comm bootstrap is configured") + return self.bootstrap_config.comm.rank + + @property + def nranks(self) -> int: + if self.bootstrap_config.comm is None: + raise AttributeError("ChipContext.nranks is only available when comm bootstrap is configured") + return self.bootstrap_config.comm.nranks + + @property + def device_ctx(self) -> int: + return self.bootstrap_reply.device_ctx + + @property + def local_window_base(self) -> int: + return self.bootstrap_reply.local_window_base + + @property + def actual_window_size(self) -> int: + return self.bootstrap_reply.actual_window_size + + @property + def buffer_ptrs(self) -> dict[str, int]: + return { + buf.name: int(ptr) + for buf, ptr in zip(self.bootstrap_config.buffers, self.bootstrap_reply.buffer_ptrs, strict=True) + } + + # --------------------------------------------------------------------------- # Mailbox helpers (shared with host_worker) # --------------------------------------------------------------------------- @@ -124,6 +282,83 @@ def _sub_worker_loop(buf, registry: dict) -> None: _CHIP_OFF_ARGS = 64 +def _write_shared_memory_bytes(shm_name: str, data: bytes, expected_size: int) -> None: + if len(data) != expected_size: + raise ValueError(f"shared-memory staging size mismatch: got {len(data)}, expected {expected_size}") + shm = SharedMemory(name=shm_name) + try: + assert shm.buf is not None + if expected_size: + shm.buf[:expected_size] = data + finally: + shm.close() + + +_DIST_DTYPE_MAP = { + "float32": DataType.FLOAT32, + "float16": DataType.FLOAT16, + "bfloat16": DataType.BFLOAT16, + "int64": DataType.INT64, + "int32": DataType.INT32, + "int16": DataType.INT16, + "int8": DataType.INT8, + "uint8": DataType.UINT8, +} + + +def _buffer_dtype_to_task_dtype(dtype: str) -> DataType: + key = str(dtype).lower() + if key not in _DIST_DTYPE_MAP: + raise ValueError(f"Unsupported chip buffer dtype: {dtype}") + return _DIST_DTYPE_MAP[key] + + +def _materialize_buffer_tensors( + chip_bootstrap_config: ChipBootstrapConfig, buffer_ptrs: list[int] +) -> dict[str, Any]: + buffer_tensors: dict[str, Any] = {} + for buf_cfg, ptr in zip(chip_bootstrap_config.buffers, buffer_ptrs, strict=True): + name = buf_cfg.name + buffer_tensors[name] = buf_cfg.make_tensor_arg(ptr) + return buffer_tensors + + +def _enrich_chip_context( + chip_bootstrap_config: ChipBootstrapConfig, device_id: int, reply: ChipBootstrapReply +) -> ChipContext: + return ChipContext( + bootstrap_config=chip_bootstrap_config, + device_id=device_id, + bootstrap_reply=reply, + buffer_tensors=_materialize_buffer_tensors(chip_bootstrap_config, reply.buffer_ptrs), + ) + + +def _write_chip_bootstrap_reply(bootstrap_channel: DistChipBootstrapChannel, reply: ChipBootstrapReply) -> None: + bootstrap_channel.write_success( + reply.device_ctx, + reply.local_window_base, + reply.actual_window_size, + reply.buffer_ptrs, + ) + + +def _run_chip_bootstrap( + cw: ChipWorker, device_id: int, chip_bootstrap_config: ChipBootstrapConfig +) -> ChipBootstrapState: + bootstrap = cw.bootstrap_context(device_id, chip_bootstrap_config) + return ChipBootstrapState( + bootstrap_config=chip_bootstrap_config, + comm_handle=bootstrap.comm_handle, + bootstrap_reply=ChipBootstrapReply( + device_ctx=bootstrap.device_ctx, + local_window_base=bootstrap.local_window_base, + actual_window_size=bootstrap.actual_window_size, + buffer_ptrs=list(bootstrap.buffer_ptrs), + ), + ) + + def _chip_process_loop( buf: memoryview, host_lib_path: str, @@ -132,15 +367,38 @@ def _chip_process_loop( aicore_path: str, sim_context_lib_path: str = "", args_size: int = 1712, + bootstrap_mailbox_ptr: Optional[int] = None, + bootstrap_buffer_count: int = 0, + chip_bootstrap_config: Optional[ChipBootstrapConfig] = None, ) -> None: """Runs in forked child process. Loads host_runtime.so in own address space.""" import traceback as _tb # noqa: PLC0415 + cw: Optional[ChipWorker] = None + chip_context: Optional[ChipBootstrapState] = None + bootstrap_channel = ( + DistChipBootstrapChannel(bootstrap_mailbox_ptr, bootstrap_buffer_count) + if bootstrap_mailbox_ptr is not None + else None + ) try: - cw = _ChipWorker() + cw = ChipWorker() cw.init(host_lib_path, aicpu_path, aicore_path, sim_context_lib_path) - cw.set_device(device_id) + if chip_bootstrap_config is not None: + chip_context = _run_chip_bootstrap(cw, device_id, chip_bootstrap_config) + if bootstrap_channel is not None: + _write_chip_bootstrap_reply(bootstrap_channel, chip_context.bootstrap_reply) + elif bootstrap_channel is not None: + cw.set_device(device_id) + bootstrap_channel.write_success(0, 0, 0, []) + else: + cw.set_device(device_id) except Exception: + if bootstrap_channel is not None: + try: + bootstrap_channel.write_error(1, _tb.format_exc()) + except Exception: # noqa: BLE001 + pass _tb.print_exc() struct.pack_into("i", buf, _CHIP_OFF_ERROR, 99) return @@ -165,12 +423,37 @@ def _chip_process_loop( error = 0 try: - cw.run_raw(callable_ptr, heap_args_ptr, block_dim, aicpu_tn, bool(profiling)) + cw.run_raw( + callable_ptr, + heap_args_ptr, + block_dim=block_dim, + aicpu_thread_num=aicpu_tn, + enable_profiling=bool(profiling), + ) + if chip_context is not None: + if chip_context.comm_handle is not None: + cw.comm_barrier(chip_context.comm_handle) + for buf_cfg in chip_context.buffers: + if not buf_cfg.store_to_host: + continue + ptr = chip_context.buffer_ptrs[buf_cfg.name] + staged = chip_context.output_staging(buf_cfg.name) + _write_shared_memory_bytes( + staged.shm_name, + cw.copy_device_to_bytes(ptr, buf_cfg.nbytes), + staged.size, + ) except Exception: # noqa: BLE001 error = 1 struct.pack_into("i", buf, _CHIP_OFF_ERROR, error) struct.pack_into("i", buf, _CHIP_OFF_STATE, _TASK_DONE) elif state == _SHUTDOWN: + if chip_context is not None: + cw.shutdown_bootstrap_context( + chip_context.bootstrap_config, + comm_handle=chip_context.comm_handle or 0, + buffer_ptrs=[chip_context.buffer_ptrs[buf.name] for buf in chip_context.buffers], + ) cw.finalize() break @@ -214,11 +497,25 @@ def __init__(self, level: int, **config) -> None: # Level-3 internals self._dist_worker: Optional[DistWorker] = None self._dist_chip_procs: list[DistChipProcess] = [] + self._chip_contexts: list[ChipContext] = [] + self._chip_bootstrap_shms: list[SharedMemory] = [] self._chip_shms: list[SharedMemory] = [] self._chip_pids: list[int] = [] self._dist_sub_workers: list[DistSubWorker] = [] - self._shms: list[SharedMemory] = [] - self._pids: list[int] = [] + self._subworker_shms: list[SharedMemory] = [] + self._subworker_pids: list[int] = [] + + def _resolve_chip_bootstrap_configs(self, device_ids: list[int]) -> Optional[list[ChipBootstrapConfig]]: + configs = self._config.get("chip_bootstrap_configs") + if configs is None: + return None + if not isinstance(configs, list): + raise TypeError("chip_bootstrap_configs must be a list of ChipBootstrapConfig") + if any(not isinstance(cfg, ChipBootstrapConfig) for cfg in configs): + raise TypeError("chip_bootstrap_configs items must be ChipBootstrapConfig") + if len(configs) != len(device_ids): + raise ValueError("chip bootstrap config length must match device_ids") + return configs # ------------------------------------------------------------------ # Callable registration (before init) @@ -240,92 +537,212 @@ def init(self) -> None: if self._initialized: raise RuntimeError("Worker already initialized") - if self.level == 2: - self._init_level2() - elif self.level == 3: - self._init_level3() - else: - raise ValueError(f"Worker: level {self.level} not yet supported") + try: + if self.level == 2: + self._init_level2() + elif self.level == 3: + self._init_level3() + else: + raise ValueError(f"Worker: level {self.level} not yet supported") + except Exception: + if self.level == 3: + self._cleanup_level3_resources() + raise self._initialized = True + def _wait_for_pid_exit(self, pid: int, timeout_s: float = 2.0) -> None: + deadline = time.monotonic() + timeout_s + while time.monotonic() < deadline: + waited_pid, _ = os.waitpid(pid, os.WNOHANG) + if waited_pid == pid: + return + time.sleep(0.05) + + try: + os.kill(pid, signal.SIGTERM) + except ProcessLookupError: + return + os.waitpid(pid, 0) + + def _cleanup_level3_resources(self) -> None: + if self._dist_worker: + self._dist_worker.close() + self._dist_worker = None + + for sw in self._dist_sub_workers: + sw.shutdown() + for shm in self._subworker_shms: + buf = shm.buf + if buf is not None: + struct.pack_into("i", buf, _OFF_STATE, _SHUTDOWN) + for pid in self._subworker_pids: + self._wait_for_pid_exit(pid) + for shm in self._subworker_shms: + shm.close() + shm.unlink() + + for cp in self._dist_chip_procs: + cp.shutdown() + for shm in self._chip_shms: + buf = shm.buf + if buf is not None: + struct.pack_into("i", buf, _CHIP_OFF_STATE, _SHUTDOWN) + for pid in self._chip_pids: + self._wait_for_pid_exit(pid) + for shm in self._chip_shms: + shm.close() + shm.unlink() + for shm in self._chip_bootstrap_shms: + shm.close() + shm.unlink() + + self._subworker_shms.clear() + self._subworker_pids.clear() + self._chip_bootstrap_shms.clear() + self._chip_shms.clear() + self._chip_pids.clear() + self._dist_sub_workers.clear() + self._dist_chip_procs.clear() + self._chip_contexts.clear() + def _init_level2(self) -> None: + device_id = self._config.get("device_id", 0) + host_lib_path, aicpu_path, aicore_path, sim_ctx_path = self._resolve_runtime_binaries() + + self._chip_worker = ChipWorker() + self._chip_worker.init( + host_lib_path, + aicpu_path, + aicore_path, + sim_ctx_path, + ) + self._chip_worker.set_device(device_id) + + def _resolve_runtime_binaries(self) -> tuple[str, str, str, str]: + explicit = ( + self._config.get("host_path"), + self._config.get("aicpu_path"), + self._config.get("aicore_path"), + ) + if all(explicit): + return ( + str(explicit[0]), + str(explicit[1]), + str(explicit[2]), + str(self._config.get("sim_context_path", "")), + ) + from runtime_builder import RuntimeBuilder # noqa: PLC0415 platform = self._config["platform"] runtime = self._config["runtime"] - device_id = self._config.get("device_id", 0) - builder = RuntimeBuilder(platform) binaries = builder.get_binaries(runtime, build=False) - - self._chip_worker = ChipWorker() - self._chip_worker.init( + return ( str(binaries.host_path), str(binaries.aicpu_path), str(binaries.aicore_path), str(binaries.sim_context_path) if hasattr(binaries, "sim_context_path") else "", ) - self._chip_worker.set_device(device_id) def _init_level3(self) -> None: device_ids = self._config.get("device_ids", []) n_sub = self._config.get("num_sub_workers", 0) + chip_bootstrap_configs = self._resolve_chip_bootstrap_configs(device_ids) # 1. Allocate mailboxes for _ in range(n_sub): shm = SharedMemory(create=True, size=DIST_SUB_MAILBOX_SIZE) assert shm.buf is not None struct.pack_into("i", shm.buf, _OFF_STATE, _IDLE) - self._shms.append(shm) + self._subworker_shms.append(shm) # 2. Fork SubWorker processes (MUST be before any C++ threads) registry = self._callable_registry for i in range(n_sub): pid = os.fork() if pid == 0: - buf = self._shms[i].buf + buf = self._subworker_shms[i].buf assert buf is not None _sub_worker_loop(buf, registry) os._exit(0) else: - self._pids.append(pid) + self._subworker_pids.append(pid) # 3. Fork ChipWorker processes (only if device_ids provided) if device_ids: - from runtime_builder import RuntimeBuilder # noqa: PLC0415 - from task_interface import ChipStorageTaskArgs as _CSA # noqa: PLC0415 - - platform = self._config["platform"] - runtime = self._config["runtime"] - builder = RuntimeBuilder(platform) - binaries = builder.get_binaries(runtime, build=False) - - # Determine args_size (sizeof ChipStorageTaskArgs) - _objs = [_CSA() for _ in range(5)] - _ptrs = [o.__ptr__() for o in _objs] - args_size = min(abs(_ptrs[i + 1] - _ptrs[i]) for i in range(len(_ptrs) - 1)) - del _objs, _ptrs - - host_lib_path = str(binaries.host_path) - aicpu_path = str(binaries.aicpu_path) - aicore_path = str(binaries.aicore_path) - sim_ctx_path = str(binaries.sim_context_path) if hasattr(binaries, "sim_context_path") else "" - - for dev_id in device_ids: + from task_interface import CHIP_STORAGE_TASK_ARGS_SIZE # noqa: PLC0415 + + # Mailbox transport memcpy's a fixed-size ChipStorageTaskArgs blob. + # Use the binding-exported C++ sizeof(...) instead of inferring from + # Python object addresses, which is not layout-safe. + args_size = int(CHIP_STORAGE_TASK_ARGS_SIZE) + + host_lib_path, aicpu_path, aicore_path, sim_ctx_path = self._resolve_runtime_binaries() + pending_bootstrap_channels: list[tuple[DistChipBootstrapChannel, int, int]] = [] + + for chip_index, dev_id in enumerate(device_ids): shm = SharedMemory(create=True, size=DIST_CHIP_MAILBOX_SIZE) assert shm.buf is not None struct.pack_into("i", shm.buf, _CHIP_OFF_STATE, _IDLE) self._chip_shms.append(shm) + chip_bootstrap_config = None + bootstrap_mailbox_ptr = None + bootstrap_buffer_count = 0 + bootstrap_channel = None + if chip_bootstrap_configs is not None: + chip_bootstrap_config = chip_bootstrap_configs[chip_index] + bootstrap_shm = SharedMemory(create=True, size=DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE) + self._chip_bootstrap_shms.append(bootstrap_shm) + bootstrap_channel = DistChipBootstrapChannel( + _mailbox_addr(bootstrap_shm), len(chip_bootstrap_config.buffers) + ) + bootstrap_channel.reset() + bootstrap_mailbox_ptr = _mailbox_addr(bootstrap_shm) + bootstrap_buffer_count = len(chip_bootstrap_config.buffers) + pid = os.fork() if pid == 0: buf = shm.buf assert buf is not None - _chip_process_loop(buf, host_lib_path, dev_id, aicpu_path, aicore_path, sim_ctx_path, args_size) + _chip_process_loop( + buf, + host_lib_path, + dev_id, + aicpu_path, + aicore_path, + sim_ctx_path, + args_size, + bootstrap_mailbox_ptr, + bootstrap_buffer_count, + chip_bootstrap_config, + ) os._exit(0) else: self._chip_pids.append(pid) + if bootstrap_channel is not None: + pending_bootstrap_channels.append((bootstrap_channel, chip_index, dev_id)) + + for bootstrap_channel, chip_index, dev_id in pending_bootstrap_channels: + deadline = time.monotonic() + 30.0 + while bootstrap_channel.state == ChipBootstrapMailboxState.IDLE and time.monotonic() < deadline: + time.sleep(0.01) + if bootstrap_channel.state == ChipBootstrapMailboxState.ERROR: + raise RuntimeError( + f"chip bootstrap failed on device {dev_id}: {bootstrap_channel.error_message}" + ) + if bootstrap_channel.state != ChipBootstrapMailboxState.SUCCESS: + raise RuntimeError(f"chip bootstrap timed out on device {dev_id}") + reply = ChipBootstrapReply( + device_ctx=int(bootstrap_channel.device_ctx), + local_window_base=int(bootstrap_channel.local_window_base), + actual_window_size=int(bootstrap_channel.actual_window_size), + buffer_ptrs=[int(ptr) for ptr in bootstrap_channel.buffer_ptrs], + ) + if chip_bootstrap_configs is not None: + self._chip_contexts.append(_enrich_chip_context(chip_bootstrap_configs[chip_index], dev_id, reply)) # 4. Create DistWorker and wire chip processes + sub workers dw = DistWorker(3) @@ -337,7 +754,7 @@ def _init_level3(self) -> None: self._dist_chip_procs.append(cp) dw.add_chip_process(cp) - for shm in self._shms: + for shm in self._subworker_shms: sw = DistSubWorker(_mailbox_addr(shm)) self._dist_sub_workers.append(sw) dw.add_sub_worker(sw) @@ -395,8 +812,22 @@ def submit( ): """Submit a task. If args_list has >1 entries, submits a group task.""" assert self._dist_worker is not None - in_specs = [DistInputSpec(p) for p in (inputs or [])] - out_specs = [DistOutputSpec(s) for s in (outputs or [])] + in_specs = [] + for inp in inputs or []: + if isinstance(inp, tuple): + in_specs.append(DistInputSpec(inp[1], inp[0])) + else: + in_specs.append(DistInputSpec(inp)) + + out_specs = [] + for out in outputs or []: + if isinstance(out, dict): + ptr = int(out["ptr"]) + size = int(out.get("size", 0)) + worker_index = int(out.get("worker_index", -1)) + out_specs.append(DistOutputSpec.external(ptr, size, worker_index)) + else: + out_specs.append(DistOutputSpec(out)) if args_list and len(args_list) > 1: return self._dist_worker.submit_group(worker_type, payload, args_list, in_specs, out_specs) return self._dist_worker.submit(worker_type, payload, in_specs, out_specs) @@ -411,48 +842,27 @@ def scope(self): # ------------------------------------------------------------------ def close(self) -> None: - if not self._initialized: + if ( + not self._initialized + and not self._chip_pids + and not self._subworker_pids + and not self._chip_shms + and not self._subworker_shms + ): return if self.level == 2: if self._chip_worker: self._chip_worker.finalize() else: - if self._dist_worker: - self._dist_worker.close() - self._dist_worker = None - - # Shutdown SubWorker processes - for sw in self._dist_sub_workers: - sw.shutdown() - for shm in self._shms: - buf = shm.buf - assert buf is not None - struct.pack_into("i", buf, _OFF_STATE, _SHUTDOWN) - for pid in self._pids: - os.waitpid(pid, 0) - for shm in self._shms: - shm.close() - shm.unlink() - - # Shutdown ChipWorker processes - for cp in self._dist_chip_procs: - cp.shutdown() - for pid in self._chip_pids: - os.waitpid(pid, 0) - for shm in self._chip_shms: - shm.close() - shm.unlink() - - self._shms.clear() - self._pids.clear() - self._chip_shms.clear() - self._chip_pids.clear() - self._dist_sub_workers.clear() - self._dist_chip_procs.clear() + self._cleanup_level3_resources() self._initialized = False + @property + def chip_contexts(self) -> list[ChipContext]: + return list(self._chip_contexts) + def __enter__(self) -> "Worker": return self diff --git a/src/a2a3/platform/include/aicore/pto_async_backend_kernel.h b/src/a2a3/platform/include/aicore/pto_async_backend_kernel.h new file mode 100644 index 000000000..732570011 --- /dev/null +++ b/src/a2a3/platform/include/aicore/pto_async_backend_kernel.h @@ -0,0 +1,90 @@ +/** + * A2/A3 async backend helpers for AICore kernels. + * + * This header is the platform/backend implementation layer behind the generic + * PTO async kernel API in runtime/. Runtime headers should call these helpers + * rather than hard-code PTO-ISA engine details directly. + */ + +#ifndef SRC_A2A3_PLATFORM_INCLUDE_AICORE_PTO_ASYNC_BACKEND_KERNEL_H_ +#define SRC_A2A3_PLATFORM_INCLUDE_AICORE_PTO_ASYNC_BACKEND_KERNEL_H_ + +#include + +#include +#include +#include + +using PTO2BackendAsyncSession = pto::comm::AsyncSession; +using PTO2BackendAsyncEvent = pto::comm::AsyncEvent; + +inline constexpr uint32_t pto2_backend_remote_copy_default_block_bytes() { + return pto::comm::sdma::kDefaultSdmaBlockBytes; +} + +template +inline __aicore__ PTO2BackendAsyncSession pto2_backend_remote_copy_open( + uint32_t sq_id, + ScratchTile &scratch, + __gm__ uint8_t *context, + uint32_t sync_id, + uint32_t block_bytes, + uint32_t block_offset, + uint32_t repeat_times) +{ + PTO2BackendAsyncSession session; + pto::comm::sdma::SdmaBaseConfig base_config{ + block_bytes != 0 ? block_bytes : pto::comm::sdma::kDefaultSdmaBlockBytes, + block_offset, + repeat_times, + }; + pto::comm::BuildAsyncSession( + scratch, context, session, sync_id, base_config, sq_id); + return session; +} + +template +inline __aicore__ PTO2BackendAsyncEvent pto2_backend_remote_copy_put( + GlobalDstData &dst, + GlobalSrcData &src, + const PTO2BackendAsyncSession &session) +{ + return pto::comm::TPUT_ASYNC(dst, src, session); +} + +template +inline __aicore__ PTO2BackendAsyncEvent pto2_backend_remote_copy_get( + GlobalDstData &dst, + GlobalSrcData &src, + const PTO2BackendAsyncSession &session) +{ + return pto::comm::TGET_ASYNC(dst, src, session); +} + +inline __aicore__ bool pto2_backend_async_event_valid(const PTO2BackendAsyncEvent &event) { + return event.valid(); +} + +inline __aicore__ uint32_t pto2_backend_async_event_engine(const PTO2BackendAsyncEvent &event) { + return static_cast(event.engine); +} + +inline __aicore__ uint64_t pto2_backend_async_event_handle(const PTO2BackendAsyncEvent &event) { + return event.handle; +} + +inline __aicore__ void pto2_backend_send_notification( + volatile __gm__ int32_t *remote_counter_addr, + int32_t value, + uint32_t op) +{ + pto::comm::NotifyOp notify_op = + op == 0 ? pto::comm::NotifyOp::Set : pto::comm::NotifyOp::AtomicAdd; + pto::comm::Signal signal((__gm__ int32_t *)remote_counter_addr); + pto::comm::TNOTIFY(signal, value, notify_op); +#if defined(PIPE_ALL) + pipe_barrier(PIPE_ALL); +#endif +} + +#endif // SRC_A2A3_PLATFORM_INCLUDE_AICORE_PTO_ASYNC_BACKEND_KERNEL_H_ diff --git a/src/a2a3/platform/include/common/comm_context.h b/src/a2a3/platform/include/common/comm_context.h new file mode 100644 index 000000000..d3b74c8bd --- /dev/null +++ b/src/a2a3/platform/include/common/comm_context.h @@ -0,0 +1,30 @@ +/** + * CommDeviceContext — device-side distributed communication context. + * + * This struct is the ABI contract between host (comm_hccl.cpp / comm_sim.cpp) + * and device kernels. PTO communication instructions (TREDUCE, TGET, TPUT) + * access remote data through the GVA addresses in windowsIn[]/windowsOut[] + * via MTE2 DMA. + * + * On HCCL MESH topology the struct layout matches what HCCL returns directly. + * On RING topology the host builds it by extracting remote RDMA addresses + * from HcclOpResParam's remoteRes array. + * On simulation the host fills it with malloc'd pointers. + */ + +#pragma once + +#include + +static constexpr uint32_t COMM_MAX_RANK_NUM = 64; + +struct CommDeviceContext { + uint64_t workSpace; + uint64_t workSpaceSize; + + uint32_t rankId; + uint32_t rankNum; + uint64_t winSize; + uint64_t windowsIn[COMM_MAX_RANK_NUM]; + uint64_t windowsOut[COMM_MAX_RANK_NUM]; +}; diff --git a/src/a2a3/platform/include/host/comm.h b/src/a2a3/platform/include/host/comm.h new file mode 100644 index 000000000..2e03c4b7f --- /dev/null +++ b/src/a2a3/platform/include/host/comm.h @@ -0,0 +1,102 @@ +/** + * Backend-neutral distributed communication C API. + * + * Provides five primitives for multi-rank communication: init, allocate + * shared windows, query local window base, barrier, and destroy. + * + * Implementations: + * onboard/host/comm_hccl.cpp — HCCL backend (links CANN hccl/hccl_fwk) + * sim/host/comm_sim.cpp — malloc-based simulation + * + * All functions are compiled into libhost_runtime.so. The linker selects + * the implementation at build time (onboard vs sim), with no runtime + * dispatch or virtual functions. + */ + +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct CommHandle_* CommHandle; + +/** + * Initialize a communicator for the given rank. + * + * On the HCCL backend this performs ACL init, RootInfo exchange (rank 0 + * writes the file, others wait), and HcclCommInitRootInfo. + * + * @param rank This process's rank (0-based). + * @param nranks Total number of ranks. + * @param device_id Physical device ID used by this process. + * @param rootinfo_path Filesystem path used to exchange root info between + * ranks (rank 0 writes, others read). + * @return Opaque handle, or NULL on failure. + */ +CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path); + +/** + * Allocate RDMA / shared-memory windows and populate the device context. + * + * On HCCL this calls HcclAllocComResourceByTiling and extracts per-rank + * window addresses (MESH or RING topology). On sim it mallocs a shared + * region and partitions it. + * + * @param h Handle from comm_init(). + * @param win_size Window size hint (bytes per rank). The backend + * may allocate more; actual size is stored in the + * returned device context. + * @param device_ctx_out Receives a device pointer to a CommDeviceContext + * struct that can be passed to device kernels. + * @return 0 on success, non-zero on failure. + */ +int comm_alloc_windows(CommHandle h, size_t win_size, uint64_t* device_ctx_out); + +/** + * Get the base address of this rank's local window. + * + * Window buffers allocated via comm_alloc_windows() are contiguous per + * rank. This returns the start of the local rank's region. + * + * @param h Handle from comm_init(). + * @param base_out Receives the device-pointer base address. + * @return 0 on success, non-zero on failure. + */ +int comm_get_local_window_base(CommHandle h, uint64_t* base_out); + +/** + * Get the actual per-rank window size allocated by the backend. + * + * @param h Handle from comm_init(). + * @param size_out Receives the actual per-rank window size in bytes. + * @return 0 on success, non-zero on failure. + */ +int comm_get_window_size(CommHandle h, size_t* size_out); + +/** + * Synchronize all ranks. + * + * Blocks until every rank in the communicator has called comm_barrier(). + * + * @param h Handle from comm_init(). + * @return 0 on success, non-zero on failure. + */ +int comm_barrier(CommHandle h); + +/** + * Destroy the communicator and release all resources. + * + * After this call the handle is invalid. + * + * @param h Handle from comm_init(). + * @return 0 on success, non-zero on failure. + */ +int comm_destroy(CommHandle h); + +#ifdef __cplusplus +} +#endif diff --git a/src/a2a3/platform/onboard/host/CMakeLists.txt b/src/a2a3/platform/onboard/host/CMakeLists.txt index 12c86f4fd..86df40e25 100644 --- a/src/a2a3/platform/onboard/host/CMakeLists.txt +++ b/src/a2a3/platform/onboard/host/CMakeLists.txt @@ -27,6 +27,18 @@ else() message(FATAL_ERROR "MUST set CUSTOM_INCLUDE_DIRS to build Host runtime") endif() +set(PTO_ISA_INCLUDE_DIR "") +if(DEFINED ENV{PTO_ISA_ROOT} AND EXISTS "$ENV{PTO_ISA_ROOT}/include") + set(PTO_ISA_INCLUDE_DIR "$ENV{PTO_ISA_ROOT}/include") +elseif(EXISTS "${CMAKE_SOURCE_DIR}/examples/scripts/_deps/pto-isa/include") + set(PTO_ISA_INCLUDE_DIR "${CMAKE_SOURCE_DIR}/examples/scripts/_deps/pto-isa/include") +endif() + +if(PTO_ISA_INCLUDE_DIR) + list(APPEND CMAKE_CUSTOM_INCLUDE_DIRS "${PTO_ISA_INCLUDE_DIR}") + message(STATUS "Using PTO ISA include dir: ${PTO_ISA_INCLUDE_DIR}") +endif() + # Build complete source list: core host sources + sources from CUSTOM_SOURCE_DIRS set(HOST_RUNTIME_SOURCES "") list(APPEND HOST_RUNTIME_SOURCES @@ -38,6 +50,7 @@ list(APPEND HOST_RUNTIME_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/host_log.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/unified_log_host.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/performance_collector.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/comm_hccl.cpp" ) if(DEFINED CUSTOM_SOURCE_DIRS) foreach(SRC_DIR ${CUSTOM_SOURCE_DIRS}) @@ -90,6 +103,25 @@ target_link_directories(host_runtime ${ASCEND_HOME_PATH}/runtime/lib64 ) +# CANN 9.x exposes the working non-V2 HCCL entry points through libhcomm. +# Link it explicitly so comm_hccl.cpp can follow the same initialization path +# as the pto-isa communication tests. +find_library(HCOMM_LIB NAMES hcomm PATHS ${ASCEND_HOME_PATH}/lib64 NO_DEFAULT_PATH) +if(HCOMM_LIB) + set(HCCL_LINK_TARGETS hcomm) + message(STATUS "Using HCCL library: hcomm") +else() + message(FATAL_ERROR "libhcomm not found under ${ASCEND_HOME_PATH}/lib64") +endif() + +# Optionally link nnopbase (provides aclCreateTensor/aclDestroyTensor for SdmaWorkspaceManager) +find_library(NNOPBASE_LIB NAMES nnopbase PATHS ${ASCEND_HOME_PATH}/lib64 NO_DEFAULT_PATH) +if(NNOPBASE_LIB) + set(NNOPBASE_LINK nnopbase) +else() + set(NNOPBASE_LINK "") +endif() + # Link against CANN runtime libraries # ascend_hal is dynamically loaded at runtime via dlopen in device_runner # when performance profiling is enabled @@ -97,6 +129,8 @@ target_link_libraries(host_runtime PRIVATE runtime ascendcl + ${HCCL_LINK_TARGETS} + ${NNOPBASE_LINK} dl ) diff --git a/src/a2a3/platform/onboard/host/comm_hccl.cpp b/src/a2a3/platform/onboard/host/comm_hccl.cpp new file mode 100644 index 000000000..536bd0d30 --- /dev/null +++ b/src/a2a3/platform/onboard/host/comm_hccl.cpp @@ -0,0 +1,525 @@ +/** + * HCCL backend for the comm_* distributed communication API. + * + * Implements the five functions declared in host/comm.h using Ascend + * HCCL (bundled with CANN). Handles both MESH and RING topologies + * when extracting per-rank RDMA window addresses. + */ + +#include "host/comm.h" +#include "common/comm_context.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "hccl/hccl_comm.h" +#include "hccl/hccl_types.h" + +using CommTopo = uint32_t; + +// Internal HCCL helpers are exported by libhcomm on CANN 9.x. The public +// HCCL APIs below intentionally use the standard, non-V2 entry points to match +// the working pto-isa initialization sequence. +extern "C" HcclResult HcclAllocComResourceByTiling(HcclComm comm, void* stream, + void* mc2Tiling, void** commContext); +extern "C" HcclResult HcomGetCommHandleByGroup(const char* group, HcclComm* commHandle); +extern "C" HcclResult HcomGetL0TopoTypeEx(const char* group, CommTopo* topoType, + uint32_t isSetDevice); + +static inline HcclResult hccl_get_root_info(HcclRootInfo* ri) + { return HcclGetRootInfo(ri); } +static inline HcclResult hccl_comm_init_root_info(uint32_t n, const HcclRootInfo* ri, uint32_t r, HcclComm* c) + { return HcclCommInitRootInfo(n, ri, r, c); } +static inline HcclResult hccl_get_comm_name(HcclComm c, char* name) + { return HcclGetCommName(c, name); } +static inline HcclResult hccl_barrier(HcclComm c, aclrtStream s) + { return HcclBarrier(c, s); } +static inline HcclResult hccl_comm_destroy(HcclComm c) + { return HcclCommDestroy(c); } +static inline HcclResult hccl_alloc_com_resource(HcclComm c, void* s, void* t, void** ctx) + { return HcclAllocComResourceByTiling(c, s, t, ctx); } +static inline HcclResult hccl_get_comm_handle_by_group(const char* g, HcclComm* c) + { return HcomGetCommHandleByGroup(g, c); } +static inline HcclResult hccl_get_l0_topo_type_ex(const char* g, CommTopo* t, uint32_t f) + { return HcomGetL0TopoTypeEx(g, t, f); } + +static constexpr uint32_t COMM_IS_NOT_SET_DEVICE = 0; +static constexpr uint32_t COMM_TOPO_MESH = 0b1u; + +using rtStream_t = void*; +static constexpr int32_t RT_STREAM_PRIORITY_DEFAULT = 0; +extern "C" int32_t rtSetDevice(int32_t device); +extern "C" int32_t rtStreamCreate(rtStream_t* stream, int32_t priority); +extern "C" int32_t rtStreamDestroy(rtStream_t stream); + +// ============================================================================ +// HCCL tiling structures (required by HcclAllocComResourceByTiling) +// ============================================================================ + +namespace { + +static constexpr uint32_t MAX_CC_TILING_NUM = 8U; +static constexpr uint32_t GROUP_NAME_SIZE = 128U; +static constexpr uint32_t ALG_CONFIG_SIZE = 128U; + +struct Mc2InitTilingInner { + uint32_t version; + uint32_t mc2HcommCnt; + uint32_t offset[MAX_CC_TILING_NUM]; + uint8_t debugMode; + uint8_t preparePosition; + uint16_t queueNum; + uint16_t commBlockNum; + uint8_t devType; + char reserved[17]; +}; + +struct Mc2cCTilingInner { + uint8_t skipLocalRankCopy; + uint8_t skipBufferWindowCopy; + uint8_t stepSize; + uint8_t version; + char reserved[9]; + uint8_t commEngine; + uint8_t srcDataType; + uint8_t dstDataType; + char groupName[GROUP_NAME_SIZE]; + char algConfig[ALG_CONFIG_SIZE]; + uint32_t opType; + uint32_t reduceType; +}; + +struct Mc2CommConfigV2 { + Mc2InitTilingInner init; + Mc2cCTilingInner inner; +}; + +// HCCL compat structs for RING topology parsing +struct HcclSignalInfo { + uint64_t resId; + uint64_t addr; + uint32_t devId; + uint32_t tsId; + uint32_t rankId; + uint32_t flag; +}; + +struct HcclStreamInfo { + int32_t streamIds; + uint32_t sqIds; + uint32_t cqIds; + uint32_t logicCqids; +}; + +struct ListCommon { + uint64_t nextHost; + uint64_t preHost; + uint64_t nextDevice; + uint64_t preDevice; +}; + +static constexpr uint32_t COMPAT_LOCAL_NOTIFY_MAX_NUM = 64; +static constexpr uint32_t COMPAT_LOCAL_STREAM_MAX_NUM = 19; +static constexpr uint32_t COMPAT_AICPU_OP_NOTIFY_MAX_NUM = 2; + +struct LocalResInfoV2 { + uint32_t streamNum; + uint32_t signalNum; + HcclSignalInfo localSignals[COMPAT_LOCAL_NOTIFY_MAX_NUM]; + HcclStreamInfo streamInfo[COMPAT_LOCAL_STREAM_MAX_NUM]; + HcclStreamInfo mainStreamInfo; + HcclSignalInfo aicpuOpNotify[COMPAT_AICPU_OP_NOTIFY_MAX_NUM]; + ListCommon nextTagRes; +}; + +struct AlgoTopoInfo { + uint32_t userRank; + uint32_t userRankSize; + int32_t deviceLogicId; + bool isSingleMeshAggregation; + uint32_t deviceNumPerAggregation; + uint32_t superPodNum; + uint32_t devicePhyId; + uint32_t topoType; + uint32_t deviceType; + uint32_t serverNum; + uint32_t meshAggregationRankSize; + uint32_t multiModuleDiffDeviceNumMode; + uint32_t multiSuperPodDiffServerNumMode; + uint32_t realUserRank; + bool isDiffDeviceModule; + bool isDiffDeviceType; + uint32_t gcdDeviceNumPerAggregation; + uint32_t moduleNum; + uint32_t isUsedRdmaRankPairNum; + uint64_t isUsedRdmaRankPair; + uint32_t pairLinkCounterNum; + uint64_t pairLinkCounter; + uint32_t nicNum; + uint64_t nicList; + uint64_t complanRankLength; + uint64_t complanRank; + uint64_t bridgeRankNum; + uint64_t bridgeRank; + uint64_t serverAndsuperPodRankLength; + uint64_t serverAndsuperPodRank; +}; + +struct HcclOpConfig { + uint8_t deterministic; + uint8_t retryEnable; + uint8_t highPerfEnable; + uint8_t padding[5]; + uint8_t linkTimeOut[8]; + uint64_t notifyWaitTime; + uint32_t retryHoldTime; + uint32_t retryIntervalTime; + bool interXLinkDisable; + uint32_t floatOverflowMode; + uint32_t multiQpThreshold; +}; + +struct RemoteResPtr { + uint64_t nextHostPtr; + uint64_t nextDevicePtr; +}; + +struct HcclMC2WorkSpace { + uint64_t workspace; + uint64_t workspaceSize; +}; + +struct HcclRankRelationResV2 { + uint32_t remoteUsrRankId; + uint32_t remoteWorldRank; + uint64_t windowsIn; + uint64_t windowsOut; + uint64_t windowsExp; + ListCommon nextTagRes; +}; + +struct HcclOpResParamHead { + uint32_t localUsrRankId; + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; +}; + +struct HcclOpResParam { + HcclMC2WorkSpace mc2WorkSpace; + uint32_t localUsrRankId; + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; + uint32_t rWinStart; + uint32_t rWinOffset; + uint64_t version; + LocalResInfoV2 localRes; + AlgoTopoInfo topoInfo; + HcclOpConfig config; + uint64_t hostStateInfo; + uint64_t aicpuStateInfo; + uint64_t lockAddr; + uint32_t rsv[16]; + uint32_t notifysize; + uint32_t remoteResNum; + RemoteResPtr remoteRes[1]; +}; + +} // anonymous namespace + +// ============================================================================ +// Internal state +// ============================================================================ + +struct CommHandle_ { + int rank; + int nranks; + std::string rootinfo_path; + + rtStream_t stream = nullptr; + HcclComm hccl_comm = nullptr; + + CommDeviceContext host_ctx{}; + CommDeviceContext* device_ctx = nullptr; + bool owns_device_ctx = false; +}; + +// ============================================================================ +// Helpers +// ============================================================================ + +static bool wait_for_file(const std::string& path, int timeout_sec = 120) { + for (int i = 0; i < timeout_sec * 10; ++i) { + std::ifstream f(path, std::ios::binary); + if (f.good()) { + auto sz = f.seekg(0, std::ios::end).tellg(); + if (sz >= static_cast(HCCL_ROOT_INFO_BYTES)) return true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return false; +} + +static void file_barrier(const std::string& dir, int rank, int nranks, const std::string& tag) { + std::string my_marker = dir + "/barrier_" + tag + "_" + std::to_string(rank) + ".ready"; + { std::ofstream(my_marker) << "1"; } + + for (int r = 0; r < nranks; ++r) { + std::string marker = dir + "/barrier_" + tag + "_" + std::to_string(r) + ".ready"; + while (true) { + std::ifstream f(marker); + if (f.good()) break; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + } +} + +// ============================================================================ +// API implementation +// ============================================================================ + +extern "C" CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path) { + auto* h = new (std::nothrow) CommHandle_{}; + if (!h) return nullptr; + + h->rank = rank; + h->nranks = nranks; + h->rootinfo_path = rootinfo_path; + + // ACL init + constexpr int kAclRepeatInit = 100002; + aclError aRet = aclInit(nullptr); + if (aRet != ACL_SUCCESS && static_cast(aRet) != kAclRepeatInit) { + fprintf(stderr, "[comm rank %d] aclInit failed: %d\n", rank, (int)aRet); + delete h; + return nullptr; + } + + if (rank == 0) { + int32_t rtRet = rtSetDevice(device_id); + if (rtRet != 0) { + fprintf(stderr, "[comm rank %d] rtSetDevice(%d) failed: %d\n", + rank, device_id, rtRet); + delete h; + return nullptr; + } + } + + // HCCL requires an ACL runtime context bound to the physical device. + // This cannot be inferred from rank because distributed runs may map + // ranks to arbitrary device lists (for example devices=[2,4,5,7]). + aRet = aclrtSetDevice(device_id); + if (aRet != ACL_SUCCESS) { + fprintf(stderr, "[comm rank %d] aclrtSetDevice(%d) failed: %d\n", + rank, device_id, (int)aRet); + delete h; + return nullptr; + } + + // RootInfo exchange + HcclRootInfo rootInfo{}; + if (rank == 0) { + HcclResult hret = hccl_get_root_info(&rootInfo); + if (hret != HCCL_SUCCESS) { + fprintf(stderr, "[comm rank 0] HcclGetRootInfo failed: %d\n", (int)hret); + delete h; + return nullptr; + } + std::ofstream fout(rootinfo_path, std::ios::binary); + fout.write(rootInfo.internal, HCCL_ROOT_INFO_BYTES); + fout.close(); + } else { + if (!wait_for_file(rootinfo_path)) { + fprintf(stderr, "[comm rank %d] Timeout waiting for rootinfo\n", rank); + delete h; + return nullptr; + } + std::ifstream fin(rootinfo_path, std::ios::binary); + fin.read(rootInfo.internal, HCCL_ROOT_INFO_BYTES); + } + + std::string barrier_dir = h->rootinfo_path; + auto last_slash = barrier_dir.rfind('/'); + if (last_slash != std::string::npos) { + barrier_dir = barrier_dir.substr(0, last_slash); + } + file_barrier(barrier_dir, h->rank, h->nranks, "rootinfo_ready"); + + // Create stream for HCCL operations + rtStreamCreate(&h->stream, RT_STREAM_PRIORITY_DEFAULT); + + // Init communicator + HcclResult hret = hccl_comm_init_root_info( + static_cast(nranks), &rootInfo, static_cast(rank), &h->hccl_comm); + if (hret != HCCL_SUCCESS) { + fprintf(stderr, "[comm rank %d] HcclCommInitRootInfo failed: %d\n", rank, (int)hret); + if (h->stream) rtStreamDestroy(h->stream); + delete h; + return nullptr; + } + + return h; +} + +extern "C" int comm_alloc_windows(CommHandle h, size_t /*win_size*/, uint64_t* device_ctx_out) { + if (!h || !device_ctx_out) return -1; + + char group[128] = {}; + HcclResult hret = hccl_get_comm_name(h->hccl_comm, group); + if (hret != HCCL_SUCCESS) return -1; + + CommTopo topoType = 0; + hret = hccl_get_l0_topo_type_ex(group, &topoType, COMM_IS_NOT_SET_DEVICE); + if (hret != HCCL_SUCCESS) return -1; + + HcclComm commHandle = nullptr; + hret = hccl_get_comm_handle_by_group(group, &commHandle); + if (hret != HCCL_SUCCESS) return -1; + + // File barrier so all ranks have completed HcclCommInitRootInfo + std::string barrier_dir = h->rootinfo_path; + auto last_slash = barrier_dir.rfind('/'); + if (last_slash != std::string::npos) { + barrier_dir = barrier_dir.substr(0, last_slash); + } + file_barrier(barrier_dir, h->rank, h->nranks, "hccl_init"); + + // Tiling configuration for HcclAllocComResourceByTiling + Mc2CommConfigV2 tiling{}; + memset(&tiling, 0, sizeof(tiling)); + tiling.init.version = 100U; + tiling.init.mc2HcommCnt = 1U; + tiling.init.commBlockNum = 48U; + tiling.init.devType = 4U; + tiling.init.offset[0] = static_cast( + reinterpret_cast(&tiling.inner) - reinterpret_cast(&tiling.init)); + tiling.inner.opType = 18U; + tiling.inner.commEngine = 3U; + tiling.inner.version = 1U; + strncpy(tiling.inner.groupName, group, GROUP_NAME_SIZE - 1); + strncpy(tiling.inner.algConfig, "BatchWrite=level0:fullmesh", ALG_CONFIG_SIZE - 1); + + void* ctxPtr = nullptr; + hret = hccl_alloc_com_resource(commHandle, h->stream, &tiling, &ctxPtr); + if (hret != HCCL_SUCCESS || ctxPtr == nullptr) return -1; + + // Extract CommDeviceContext (topology-dependent) + aclError aRet; + if (topoType == COMM_TOPO_MESH) { + h->device_ctx = reinterpret_cast(ctxPtr); + aRet = aclrtMemcpy(&h->host_ctx, sizeof(h->host_ctx), + h->device_ctx, sizeof(h->host_ctx), ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + } else { + // RING topology: parse HcclOpResParam structure on device + auto* rawCtx = reinterpret_cast(ctxPtr); + + HcclOpResParamHead head{}; + const size_t headOff = offsetof(HcclOpResParam, localUsrRankId); + aRet = aclrtMemcpy(&head, sizeof(head), rawCtx + headOff, sizeof(head), + ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + + const size_t remoteResOff = offsetof(HcclOpResParam, remoteRes); + const size_t remoteResBytes = head.rankSize * sizeof(RemoteResPtr); + std::vector remoteResArr(head.rankSize); + aRet = aclrtMemcpy(remoteResArr.data(), remoteResBytes, + rawCtx + remoteResOff, remoteResBytes, ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + + memset(&h->host_ctx, 0, sizeof(h->host_ctx)); + + uint64_t wsFields[2] = {0, 0}; + aclrtMemcpy(wsFields, sizeof(wsFields), rawCtx, sizeof(wsFields), ACL_MEMCPY_DEVICE_TO_HOST); + h->host_ctx.workSpace = wsFields[0]; + h->host_ctx.workSpaceSize = wsFields[1]; + h->host_ctx.rankId = head.localUsrRankId; + h->host_ctx.rankNum = head.rankSize; + h->host_ctx.winSize = head.winSize; + + for (uint32_t i = 0; i < head.rankSize; ++i) { + if (i == head.localUsrRankId) { + h->host_ctx.windowsIn[i] = head.localWindowsIn; + continue; + } + uint64_t devPtr = remoteResArr[i].nextDevicePtr; + if (devPtr == 0) return -1; + + HcclRankRelationResV2 remoteInfo{}; + aRet = aclrtMemcpy(&remoteInfo, sizeof(remoteInfo), + reinterpret_cast(devPtr), sizeof(remoteInfo), + ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + h->host_ctx.windowsIn[i] = remoteInfo.windowsIn; + } + + void* newDevMem = nullptr; + aRet = aclrtMalloc(&newDevMem, sizeof(CommDeviceContext), ACL_MEM_MALLOC_HUGE_FIRST); + if (aRet != ACL_SUCCESS) return -1; + + aRet = aclrtMemcpy(newDevMem, sizeof(CommDeviceContext), + &h->host_ctx, sizeof(CommDeviceContext), ACL_MEMCPY_HOST_TO_DEVICE); + if (aRet != ACL_SUCCESS) { + aclrtFree(newDevMem); + return -1; + } + h->device_ctx = reinterpret_cast(newDevMem); + h->owns_device_ctx = true; + } + + *device_ctx_out = reinterpret_cast(h->device_ctx); + return 0; +} + +extern "C" int comm_get_local_window_base(CommHandle h, uint64_t* base_out) { + if (!h || !base_out) return -1; + *base_out = h->host_ctx.windowsIn[h->rank]; + return 0; +} + +extern "C" int comm_get_window_size(CommHandle h, size_t* size_out) { + if (!h || !size_out) return -1; + *size_out = static_cast(h->host_ctx.winSize); + return 0; +} + +extern "C" int comm_barrier(CommHandle h) { + if (!h) return -1; + hccl_barrier(h->hccl_comm, (aclrtStream)h->stream); + aclrtSynchronizeStream((aclrtStream)h->stream); + return 0; +} + +extern "C" int comm_destroy(CommHandle h) { + if (!h) return -1; + + if (h->owns_device_ctx && h->device_ctx) { + aclrtFree(h->device_ctx); + } + if (h->stream) rtStreamDestroy(h->stream); + if (h->hccl_comm) hccl_comm_destroy(h->hccl_comm); + + // NOTE: Do NOT call aclrtResetDevice / aclFinalize here. + // Device lifecycle is owned by DeviceRunner (static singleton) whose + // destructor frees all tracked device memory before resetting the device. + // Resetting early would invalidate pointers still held by MemoryAllocator. + + delete h; + return 0; +} diff --git a/src/a2a3/platform/onboard/host/device_runner.cpp b/src/a2a3/platform/onboard/host/device_runner.cpp index fa92ef0f4..2876f5fbc 100644 --- a/src/a2a3/platform/onboard/host/device_runner.cpp +++ b/src/a2a3/platform/onboard/host/device_runner.cpp @@ -23,6 +23,7 @@ #include #include #include +#include "acl/acl.h" // Include HAL constants from CANN (header only, library loaded dynamically) #include "ascend_hal.h" @@ -638,10 +639,19 @@ int DeviceRunner::finalize() { // Free all remaining allocations (including handshake buffer and binGmAddr) mem_alloc_.finalize(); + int saved_device_id = device_id_; device_id_ = -1; worker_count_ = 0; aicore_kernel_binary_.clear(); + // Reset device and finalize ACL AFTER all device memory is freed. + // This was previously done in comm_destroy, but that ran before the + // static DeviceRunner destructor, causing rtFree failures (107000). + if (saved_device_id >= 0) { + aclrtResetDevice(saved_device_id); + aclFinalize(); + } + LOG_INFO("DeviceRunner finalized"); return 0; } diff --git a/src/a2a3/platform/onboard/host/pto_runtime_c_api.cpp b/src/a2a3/platform/onboard/host/pto_runtime_c_api.cpp index eff903896..1ccf0e13d 100644 --- a/src/a2a3/platform/onboard/host/pto_runtime_c_api.cpp +++ b/src/a2a3/platform/onboard/host/pto_runtime_c_api.cpp @@ -20,6 +20,7 @@ #include "callable.h" #include "task_args.h" +#include #include #include @@ -27,6 +28,13 @@ #include "device_runner.h" // NOLINT(build/include_subdir) #include "runtime.h" // NOLINT(build/include_subdir) +#if __has_include("pto/npu/comm/async/sdma/sdma_workspace_manager.hpp") && __has_include("acl/acl.h") +#include "pto/npu/comm/async/sdma/sdma_workspace_manager.hpp" +#define PTO2_PLATFORM_HAS_SDMA_WORKSPACE_MANAGER 1 +#else +#define PTO2_PLATFORM_HAS_SDMA_WORKSPACE_MANAGER 0 +#endif + extern "C" { /* =========================================================================== @@ -96,6 +104,61 @@ static void remove_kernel_binary_wrapper(int func_id) { } catch (...) {} } +static bool env_flag_enabled(const char *primary_name, const char *legacy_name) { + const char *value = std::getenv(primary_name); + if (value == nullptr && legacy_name != nullptr) { + value = std::getenv(legacy_name); + } + return value != nullptr && value[0] == '1' && value[1] == '\0'; +} + +#ifdef PTO2_RUNTIME_HAS_ASYNC_HOST_API +#if PTO2_PLATFORM_HAS_SDMA_WORKSPACE_MANAGER +static pto::comm::sdma::SdmaWorkspaceManager &sdma_workspace_manager() { + static pto::comm::sdma::SdmaWorkspaceManager manager; + return manager; +} +#endif + +static PTO2AsyncContextInitStatus init_async_context(PTO2AsyncCapability capability, uint64_t *addr) { + if (addr == nullptr) { + return PTO2AsyncContextInitStatus::ERROR; + } + *addr = 0; + + switch (capability) { + case PTO2AsyncCapability::REMOTE_COPY: +#if PTO2_PLATFORM_HAS_SDMA_WORKSPACE_MANAGER + if (!env_flag_enabled("PTO2_ENABLE_REMOTE_COPY_ASYNC", "PTO2_ENABLE_SDMA")) { + return PTO2AsyncContextInitStatus::SKIPPED; + } + if (!sdma_workspace_manager().Init()) { + return PTO2AsyncContextInitStatus::ERROR; + } + *addr = reinterpret_cast(sdma_workspace_manager().GetWorkspaceAddr()); + return *addr != 0 ? PTO2AsyncContextInitStatus::READY : PTO2AsyncContextInitStatus::ERROR; +#else + return PTO2AsyncContextInitStatus::SKIPPED; +#endif + } + + return PTO2AsyncContextInitStatus::SKIPPED; +} + +static void destroy_async_context(PTO2AsyncCapability capability, uint64_t addr) { + (void)addr; + switch (capability) { + case PTO2AsyncCapability::REMOTE_COPY: +#if PTO2_PLATFORM_HAS_SDMA_WORKSPACE_MANAGER + sdma_workspace_manager().Finalize(); + break; +#else + break; +#endif + } +} +#endif + /* =========================================================================== * Public C API (resolved by ChipWorker via dlsym) * =========================================================================== */ @@ -121,6 +184,40 @@ int set_device(DeviceContextHandle ctx, int device_id) { } } +void *device_malloc_ctx(DeviceContextHandle ctx, size_t size) { + if (ctx == NULL) return NULL; + try { + return static_cast(ctx)->allocate_tensor(size); + } catch (...) { + return NULL; + } +} + +void device_free_ctx(DeviceContextHandle ctx, void *dev_ptr) { + if (ctx == NULL || dev_ptr == NULL) return; + try { + static_cast(ctx)->free_tensor(dev_ptr); + } catch (...) {} +} + +int copy_to_device_ctx(DeviceContextHandle ctx, void *dev_ptr, const void *host_ptr, size_t size) { + if (ctx == NULL || dev_ptr == NULL || host_ptr == NULL) return -1; + try { + return static_cast(ctx)->copy_to_device(dev_ptr, host_ptr, size); + } catch (...) { + return -1; + } +} + +int copy_from_device_ctx(DeviceContextHandle ctx, void *host_ptr, const void *dev_ptr, size_t size) { + if (ctx == NULL || host_ptr == NULL || dev_ptr == NULL) return -1; + try { + return static_cast(ctx)->copy_from_device(host_ptr, dev_ptr, size); + } catch (...) { + return -1; + } +} + int run_runtime( DeviceContextHandle ctx, RuntimeHandle runtime, const void *callable, const void *args, int block_dim, int aicpu_thread_num, int device_id, const uint8_t *aicpu_binary, size_t aicpu_size, const uint8_t *aicore_binary, @@ -142,6 +239,10 @@ int run_runtime( r->host_api.copy_from_device = copy_from_device; r->host_api.upload_kernel_binary = upload_kernel_binary_wrapper; r->host_api.remove_kernel_binary = remove_kernel_binary_wrapper; +#ifdef PTO2_RUNTIME_HAS_ASYNC_HOST_API + r->host_api.init_async_context = init_async_context; + r->host_api.destroy_async_context = destroy_async_context; +#endif LOG_DEBUG("About to call init_runtime_impl, r=%p", (void *)r); int rc = init_runtime_impl( diff --git a/src/a2a3/platform/sim/host/CMakeLists.txt b/src/a2a3/platform/sim/host/CMakeLists.txt index 8432536fd..a24769939 100644 --- a/src/a2a3/platform/sim/host/CMakeLists.txt +++ b/src/a2a3/platform/sim/host/CMakeLists.txt @@ -44,6 +44,7 @@ list(APPEND HOST_RUNTIME_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/unified_log_host.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/performance_collector.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../aicpu/platform_aicpu_affinity.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/comm_sim.cpp" ) if(DEFINED CUSTOM_SOURCE_DIRS) @@ -89,6 +90,7 @@ target_link_libraries(host_runtime PRIVATE pthread dl + rt ) # Allow undefined symbols from libcpu_sim_context.so (loaded with RTLD_GLOBAL at runtime). diff --git a/src/a2a3/platform/sim/host/comm_sim.cpp b/src/a2a3/platform/sim/host/comm_sim.cpp new file mode 100644 index 000000000..d1bdfd412 --- /dev/null +++ b/src/a2a3/platform/sim/host/comm_sim.cpp @@ -0,0 +1,205 @@ +/** + * Simulation backend for the comm_* distributed communication API. + * + * Uses POSIX shared memory (shm_open + mmap) so that multiple *processes* + * (one per rank, spawned by the L3 Worker distributed path) share the same RDMA + * window region. Synchronization primitives (barrier counters) live in + * the shared region itself, using GCC __atomic builtins which are safe + * on lock-free-capable types in mmap'd memory. + * + * Shared memory layout (page-aligned header + per-rank windows): + * [ SharedHeader (4096 bytes) ][ rank-0 window ][ rank-1 window ] ... + */ + +#include "host/comm.h" +#include "common/comm_context.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static constexpr size_t HEADER_SIZE = 4096; + +namespace { + +struct SharedHeader { + volatile int nranks; + volatile int alloc_done; + volatile int ready_count; + volatile int barrier_count; + volatile int barrier_phase; + volatile int destroy_count; + size_t per_rank_win_size; +}; + +std::string make_shm_name(const char* rootinfo_path) { + size_t h = std::hash{}(rootinfo_path ? rootinfo_path : "default"); + char buf[64]; + std::snprintf(buf, sizeof(buf), "/simpler_comm_%zx", h); + return buf; +} + +} // anonymous namespace + +// ============================================================================ +// Per-handle state (process-local) +// ============================================================================ + +struct CommHandle_ { + int rank; + int nranks; + std::string shm_name; + + void* mmap_base = nullptr; + size_t mmap_size = 0; + bool is_creator = false; + + CommDeviceContext host_ctx{}; +}; + +// ============================================================================ +// API implementation +// ============================================================================ + +extern "C" CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path) { + auto* h = new (std::nothrow) CommHandle_{}; + if (!h) return nullptr; + (void)device_id; + + h->rank = rank; + h->nranks = nranks; + h->shm_name = make_shm_name(rootinfo_path); + return h; +} + +extern "C" int comm_alloc_windows(CommHandle h, size_t win_size, uint64_t* device_ctx_out) { + if (!h || !device_ctx_out) return -1; + + size_t total = HEADER_SIZE + win_size * static_cast(h->nranks); + + int fd = shm_open(h->shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR, 0600); + if (fd >= 0) { + h->is_creator = true; + if (ftruncate(fd, static_cast(total)) != 0) { + std::perror("comm_sim: ftruncate"); + close(fd); + shm_unlink(h->shm_name.c_str()); + return -1; + } + } else if (errno == EEXIST) { + fd = shm_open(h->shm_name.c_str(), O_RDWR, 0600); + if (fd < 0) { std::perror("comm_sim: shm_open"); return -1; } + + // Wait for creator to finish ftruncate by checking file size + for (int i = 0; i < 5000; ++i) { + struct stat st; + if (fstat(fd, &st) == 0 && static_cast(st.st_size) >= total) break; + usleep(1000); + } + } else { + std::perror("comm_sim: shm_open O_EXCL"); + return -1; + } + + void* base = mmap(nullptr, total, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + close(fd); + if (base == MAP_FAILED) { std::perror("comm_sim: mmap"); return -1; } + + h->mmap_base = base; + h->mmap_size = total; + + auto* hdr = static_cast(base); + + if (h->is_creator) { + hdr->per_rank_win_size = win_size; + hdr->ready_count = 0; + hdr->barrier_count = 0; + hdr->barrier_phase = 0; + hdr->destroy_count = 0; + __atomic_store_n(&hdr->nranks, h->nranks, __ATOMIC_RELEASE); + __atomic_store_n(&hdr->alloc_done, 1, __ATOMIC_RELEASE); + } else { + while (__atomic_load_n(&hdr->alloc_done, __ATOMIC_ACQUIRE) == 0) { + usleep(100); + } + } + + auto* win_base = static_cast(base) + HEADER_SIZE; + + auto& ctx = h->host_ctx; + ctx.workSpace = 0; + ctx.workSpaceSize = 0; + ctx.rankId = static_cast(h->rank); + ctx.rankNum = static_cast(h->nranks); + ctx.winSize = win_size; + for (int i = 0; i < h->nranks; ++i) { + ctx.windowsIn[i] = reinterpret_cast( + win_base + static_cast(i) * win_size); + } + + *device_ctx_out = reinterpret_cast(&h->host_ctx); + + __atomic_add_fetch(&hdr->ready_count, 1, __ATOMIC_ACQ_REL); + while (__atomic_load_n(&hdr->ready_count, __ATOMIC_ACQUIRE) < h->nranks) { + usleep(100); + } + + return 0; +} + +extern "C" int comm_get_local_window_base(CommHandle h, uint64_t* base_out) { + if (!h || !base_out) return -1; + *base_out = h->host_ctx.windowsIn[h->rank]; + return 0; +} + +extern "C" int comm_get_window_size(CommHandle h, size_t* size_out) { + if (!h || !size_out) return -1; + *size_out = static_cast(h->host_ctx.winSize); + return 0; +} + +extern "C" int comm_barrier(CommHandle h) { + if (!h || !h->mmap_base) return -1; + + auto* hdr = static_cast(h->mmap_base); + int phase = __atomic_load_n(&hdr->barrier_phase, __ATOMIC_ACQUIRE); + int arrived = __atomic_add_fetch(&hdr->barrier_count, 1, __ATOMIC_ACQ_REL); + + if (arrived == h->nranks) { + __atomic_store_n(&hdr->barrier_count, 0, __ATOMIC_RELEASE); + __atomic_add_fetch(&hdr->barrier_phase, 1, __ATOMIC_ACQ_REL); + } else { + while (__atomic_load_n(&hdr->barrier_phase, __ATOMIC_ACQUIRE) == phase) { + usleep(50); + } + } + + return 0; +} + +extern "C" int comm_destroy(CommHandle h) { + if (!h) return -1; + + if (h->mmap_base) { + auto* hdr = static_cast(h->mmap_base); + int gone = __atomic_add_fetch(&hdr->destroy_count, 1, __ATOMIC_ACQ_REL); + + munmap(h->mmap_base, h->mmap_size); + h->mmap_base = nullptr; + + if (gone >= h->nranks) { + shm_unlink(h->shm_name.c_str()); + } + } + + delete h; + return 0; +} diff --git a/src/a2a3/platform/sim/host/pto_runtime_c_api.cpp b/src/a2a3/platform/sim/host/pto_runtime_c_api.cpp index 37028f27d..7dd025a2b 100644 --- a/src/a2a3/platform/sim/host/pto_runtime_c_api.cpp +++ b/src/a2a3/platform/sim/host/pto_runtime_c_api.cpp @@ -98,6 +98,21 @@ static void remove_kernel_binary_wrapper(int func_id) { } catch (...) {} } +#ifdef PTO2_RUNTIME_HAS_ASYNC_HOST_API +static PTO2AsyncContextInitStatus init_async_context(PTO2AsyncCapability capability, uint64_t *addr) { + (void)capability; + if (addr != nullptr) { + *addr = 0; + } + return PTO2AsyncContextInitStatus::SKIPPED; +} + +static void destroy_async_context(PTO2AsyncCapability capability, uint64_t addr) { + (void)capability; + (void)addr; +} +#endif + /* =========================================================================== * Public C API (resolved by ChipWorker via dlsym) * =========================================================================== */ @@ -121,6 +136,40 @@ int set_device(DeviceContextHandle ctx, int device_id) { return 0; } +void *device_malloc_ctx(DeviceContextHandle ctx, size_t size) { + if (ctx == NULL) return NULL; + try { + return static_cast(ctx)->allocate_tensor(size); + } catch (...) { + return NULL; + } +} + +void device_free_ctx(DeviceContextHandle ctx, void *dev_ptr) { + if (ctx == NULL || dev_ptr == NULL) return; + try { + static_cast(ctx)->free_tensor(dev_ptr); + } catch (...) {} +} + +int copy_to_device_ctx(DeviceContextHandle ctx, void *dev_ptr, const void *host_ptr, size_t size) { + if (ctx == NULL || dev_ptr == NULL || host_ptr == NULL) return -1; + try { + return static_cast(ctx)->copy_to_device(dev_ptr, host_ptr, size); + } catch (...) { + return -1; + } +} + +int copy_from_device_ctx(DeviceContextHandle ctx, void *host_ptr, const void *dev_ptr, size_t size) { + if (ctx == NULL || host_ptr == NULL || dev_ptr == NULL) return -1; + try { + return static_cast(ctx)->copy_from_device(host_ptr, dev_ptr, size); + } catch (...) { + return -1; + } +} + int run_runtime( DeviceContextHandle ctx, RuntimeHandle runtime, const void *callable, const void *args, int block_dim, int aicpu_thread_num, int device_id, const uint8_t *aicpu_binary, size_t aicpu_size, const uint8_t *aicore_binary, @@ -141,6 +190,10 @@ int run_runtime( r->host_api.copy_from_device = copy_from_device; r->host_api.upload_kernel_binary = upload_kernel_binary_wrapper; r->host_api.remove_kernel_binary = remove_kernel_binary_wrapper; +#ifdef PTO2_RUNTIME_HAS_ASYNC_HOST_API + r->host_api.init_async_context = init_async_context; + r->host_api.destroy_async_context = destroy_async_context; +#endif int rc = init_runtime_impl( r, reinterpret_cast(callable), reinterpret_cast(args) diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp b/src/a2a3/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp index 1f8f1601e..c23bb9e3c 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp @@ -34,6 +34,7 @@ #include "pto_runtime2.h" #include "pto_runtime2_types.h" #include "pto_shared_memory.h" +#include "pto_async_wait.h" // Performance profiling headers #include "aicpu/performance_collector_aicpu.h" @@ -85,6 +86,7 @@ constexpr int32_t STALL_DUMP_WAIT_MAX = 4; constexpr int32_t STALL_DUMP_CORE_MAX = 8; constexpr int32_t PROGRESS_VERBOSE_THRESHOLD = 10; // log every completion for the first N tasks constexpr int32_t PROGRESS_LOG_INTERVAL = 250; // log every N completions after threshold +constexpr int32_t MAX_DEFERRED_RELEASES = 256; static PTO2Runtime *rt{nullptr}; @@ -370,8 +372,9 @@ struct AicpuExecutor { template void check_running_cores_for_completion( int32_t thread_idx, Handshake *hank, int32_t &completed_this_turn, int32_t &cur_thread_completed, - bool &made_progress, PTO2TaskSlotState *deferred_release_slot_states[], int32_t &deferred_release_count, - PTO2LocalReadyBuffer *local_bufs + bool &made_progress, bool &fatal_error, int32_t &fatal_error_code, + PTO2TaskSlotState *deferred_release_slot_states[], int32_t &deferred_release_count, + PTO2LocalReadyBuffer *local_bufs, PTO2AsyncWaitList &async_wait_list #if PTO2_PROFILING , bool profiling_enabled, uint32_t &phase_complete_count @@ -415,6 +418,16 @@ struct AicpuExecutor { // Completion: increment atomic counter, trigger task-level completion on last subtask bool mixed_complete = rt->scheduler.on_subtask_complete(slot_state); if (mixed_complete) { + int32_t registration_error = PTO2_ERROR_NONE; + if (async_wait_list.register_deferred(slot_state, thread_idx, + registration_error)) { + // Deferred completion is now tracked by async_wait_list. + } else { + if (registration_error != PTO2_ERROR_NONE) { + fatal_error = true; + fatal_error_code = registration_error; + return; + } #if PTO2_SCHED_PROFILING PTO2CompletionStats cstats = rt->scheduler.on_mixed_task_complete(slot_state, thread_idx, local_bufs); @@ -428,7 +441,7 @@ struct AicpuExecutor { phase_complete_count++; #endif #endif - if (deferred_release_count < 256) { + if (deferred_release_count < MAX_DEFERRED_RELEASES) { deferred_release_slot_states[deferred_release_count++] = &slot_state; } else { DEV_ALWAYS("Thread %d: release", thread_idx); @@ -449,6 +462,7 @@ struct AicpuExecutor { } deferred_release_slot_states[deferred_release_count++] = &slot_state; } + } } tracker.change_core_state(bit_pos); #if PTO2_PROFILING @@ -491,7 +505,7 @@ struct AicpuExecutor { CT == CoreType::AIC ? "AIC" : "AIV", core_id, expected_reg_task_id, mixed_complete ? 1 : 0 ); cur_thread_completed++; - if (mixed_complete) { + if (mixed_complete && slot_state.payload != nullptr && !slot_state.payload->complete_in_future) { completed_this_turn++; } made_progress = true; @@ -1318,6 +1332,8 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime *runtime, int32_t threa PTO2TaskSlotState *deferred_release_slot_states[256]; int32_t deferred_release_count = 0; + PTO2AsyncWaitList async_wait_list; + bool cores_released = false; #if PTO2_PROFILING @@ -1332,7 +1348,7 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime *runtime, int32_t threa uint64_t _t0_phase = _t0; #endif int32_t task_count = 0; - if (!tracker.has_any_running_cores()) { + if (!tracker.has_any_running_cores() && async_wait_list.count == 0) { bool orch_done = orchestrator_done_; if (orch_done) { // Check for orchestrator fatal error — exit immediately @@ -1386,16 +1402,51 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime *runtime, int32_t threa // Sched time = finish_ts - dispatch_ts; recording finish_ts here at loop start reduces // tail overhead (time from AICore done to AICPU recording finish). + // Invariant: previous iteration fully consumed local_bufs + always_assert(local_bufs[0].count == 0 && local_bufs[1].count == 0); + + // Phase 0: Poll async completion conditions (deferred-completion tasks) + int32_t async_completed_this_turn = 0; + if (async_wait_list.count > 0) { + PTO2AsyncPollResult poll_result = async_wait_list.poll_and_complete( + &rt->scheduler, local_bufs, + deferred_release_slot_states, deferred_release_count, MAX_DEFERRED_RELEASES +#if PTO2_SCHED_PROFILING + , thread_idx +#endif + ); + if (poll_result.error_code != PTO2_ERROR_NONE) { + int32_t failed_task = -1; + if (poll_result.failed_slot_state != nullptr + && poll_result.failed_slot_state->task != nullptr) { + failed_task = static_cast( + poll_result.failed_slot_state->task->task_id.local()); + } + DEV_ERROR("Thread %d: async poll failed for task %d with error code %d", + thread_idx, failed_task, poll_result.error_code); + pto2_record_scheduler_error(header, thread_idx, poll_result.error_code); + emergency_shutdown(runtime); + completed_.store(true, std::memory_order_release); + return -1; + } + async_completed_this_turn = poll_result.completed; + if (async_completed_this_turn > 0) { + made_progress = true; + } + } + // Phase 1: Check running cores for completion, process and move to idle - int32_t completed_this_turn = 0; + int32_t completed_this_turn = async_completed_this_turn; + bool fatal_error = false; + int32_t fatal_error_code = PTO2_ERROR_NONE; // Check AIC running cores bool try_completed = false; if (tracker.has_running_cores()) { try_completed = true; check_running_cores_for_completion( - thread_idx, hank, completed_this_turn, cur_thread_completed, made_progress, - deferred_release_slot_states, deferred_release_count, local_bufs + thread_idx, hank, completed_this_turn, cur_thread_completed, made_progress, fatal_error, + fatal_error_code, deferred_release_slot_states, deferred_release_count, local_bufs, async_wait_list #if PTO2_PROFILING , profiling_enabled, phase_complete_count @@ -1406,14 +1457,22 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime *runtime, int32_t threa fanin_edges_total, fanin_max_degree, sched_complete_perf_cycle #endif ); + if (fatal_error) { + DEV_ERROR("Thread %d: async registration failed with error code %d", + thread_idx, fatal_error_code); + pto2_record_scheduler_error(header, thread_idx, fatal_error_code); + emergency_shutdown(runtime); + completed_.store(true, std::memory_order_release); + return -1; + } } // Check AIV running cores if (tracker.has_running_cores()) { try_completed = true; check_running_cores_for_completion( - thread_idx, hank, completed_this_turn, cur_thread_completed, made_progress, - deferred_release_slot_states, deferred_release_count, local_bufs + thread_idx, hank, completed_this_turn, cur_thread_completed, made_progress, fatal_error, + fatal_error_code, deferred_release_slot_states, deferred_release_count, local_bufs, async_wait_list #if PTO2_PROFILING , profiling_enabled, phase_complete_count @@ -1424,6 +1483,14 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime *runtime, int32_t threa fanin_edges_total, fanin_max_degree, sched_complete_perf_cycle #endif ); + if (fatal_error) { + DEV_ERROR("Thread %d: async registration failed with error code %d", + thread_idx, fatal_error_code); + pto2_record_scheduler_error(header, thread_idx, fatal_error_code); + emergency_shutdown(runtime); + completed_.store(true, std::memory_order_release); + return -1; + } } if (completed_this_turn > 0) { #if PTO2_SCHED_PROFILING @@ -1645,6 +1712,7 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime *runtime, int32_t threa "PTO2 stall: no progress for %d iterations, completed=%d total=%d (last progress at %d)", idle_iterations, c, task_count, last_progress_count ); + async_wait_list.dump(thread_idx, STALL_DUMP_WAIT_MAX); // Scan all task slots to find truly stuck tasks using scheduler state PTO2SchedulerState *sched = &rt->scheduler; PTO2SharedMemoryHeader *sm_header_diag = static_cast(sm_base); @@ -2092,6 +2160,11 @@ int32_t AicpuExecutor::run(Runtime *runtime) { // With multi-ring, slot_states are per-ring inside the scheduler. runtime->set_pto2_slot_states_ptr(nullptr); + // Pass host-side async engine contexts to device PTO2Runtime. + for (int e = 0; e < PTO2_NUM_ASYNC_ENGINES; e++) { + rt->async_context_addrs[e] = runtime->get_async_context_addr(static_cast(e)); + } + orch_func_ = orch_func; orch_bind_runtime_ = bind_runtime_func; orch_args_cached_ = &args; diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp b/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp index c07279056..7f14573f3 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp @@ -69,6 +69,48 @@ static uint64_t parse_env_uint64(const char *name, uint64_t min_val, bool requir return static_cast(val); } +static void init_optional_async_context(Runtime *runtime, PTO2AsyncCapability capability) { + if (runtime == nullptr || runtime->host_api.init_async_context == nullptr) { + return; + } + + uint64_t addr = 0; + PTO2AsyncContextInitStatus status = runtime->host_api.init_async_context(capability, &addr); + PTO2AsyncEngine engine = pto2_async_capability_default_engine(capability); + + if (status == PTO2AsyncContextInitStatus::READY) { + runtime->set_async_context_addr(engine, addr); + LOG_INFO( + "%s async context initialized via platform backend: addr=0x%lx", + pto2_async_capability_name(capability), static_cast(addr) + ); + return; + } + + runtime->set_async_context_addr(engine, 0); + if (status == PTO2AsyncContextInitStatus::ERROR) { + LOG_WARN( + "%s async context initialization failed, continuing without async backend support", + pto2_async_capability_name(capability) + ); + } +} + +static void destroy_optional_async_context(Runtime *runtime, PTO2AsyncCapability capability) { + if (runtime == nullptr || runtime->host_api.destroy_async_context == nullptr) { + return; + } + + PTO2AsyncEngine engine = pto2_async_capability_default_engine(capability); + uint64_t addr = runtime->get_async_context_addr(engine); + if (addr == 0) { + return; + } + + runtime->host_api.destroy_async_context(capability, addr); + runtime->set_async_context_addr(engine, 0); +} + /** * Initialize a pre-allocated runtime for device orchestration. * @@ -135,6 +177,14 @@ extern "C" int init_runtime_impl(Runtime *runtime, const ChipCallable *callable, int64_t t_args_start = _now_ms(); for (int i = 0; i < tensor_count; i++) { ContinuousTensor t = orch_args->tensor(i); + if (t.is_device_resident()) { + // External/bootstrap-provided device buffers are already valid in the + // target chip context, so runtime_maker must preserve the pointer + // instead of allocating/copying a second device buffer. + LOG_INFO(" Tensor %d: reusing device-resident pointer %p (%zu bytes)", i, t.data_as(), t.nbytes()); + device_args.add_tensor(t); + continue; + } void *host_ptr = reinterpret_cast(static_cast(t.data)); size_t size = static_cast(t.nbytes()); @@ -255,6 +305,8 @@ extern "C" int init_runtime_impl(Runtime *runtime, const ChipCallable *callable, runtime->set_pto2_gm_sm_ptr(sm_ptr); runtime->record_tensor_pair(nullptr, sm_ptr, static_cast(sm_size)); + init_optional_async_context(runtime, PTO2AsyncCapability::REMOTE_COPY); + // Set up device orchestration state runtime->set_orch_built_on_host(false); runtime->set_orch_args(device_args); @@ -356,6 +408,7 @@ extern "C" int validate_runtime_impl(Runtime *runtime) { // Cleanup device tensors LOG_INFO("=== Cleaning Up ==="); + destroy_optional_async_context(runtime, PTO2AsyncCapability::REMOTE_COPY); for (int i = 0; i < tensor_pair_count; i++) { if (tensor_pairs[i].dev_ptr != nullptr) { runtime->host_api.device_free(tensor_pairs[i].dev_ptr); diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h b/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h index cf752ef2d..91895ac20 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h @@ -70,6 +70,18 @@ inline Tensor make_tensor_external( }; } +/** + * Create a TensorCreateInfo for runtime-allocated output memory. + */ +inline TensorCreateInfo make_tensor( + const uint32_t shapes[], uint32_t ndims, DataType dtype = DataType::FLOAT32, bool manual_dep = false, + int32_t version = 0 +) { + TensorCreateInfo info(shapes, ndims, dtype, manual_dep); + info.version = version; + return info; +} + // Convert ContinuousTensor to Tensor static_assert( CONTINUOUS_TENSOR_MAX_DIMS == RUNTIME_MAX_TENSOR_DIMS, "ContinuousTensor and runtime max dims must match" @@ -115,6 +127,8 @@ void pto2_framework_bind_runtime(PTO2Runtime *rt); */ typedef struct PTO2RuntimeOps { TaskOutputTensors (*submit_task)(PTO2Runtime *rt, const MixedKernels &mixed_kernels, const Arg &args); + uint64_t (*get_async_context)(PTO2Runtime *rt, PTO2AsyncEngine engine); + uint64_t (*alloc_cq)(PTO2Runtime *rt); void (*scope_begin)(PTO2Runtime *rt); void (*scope_end)(PTO2Runtime *rt); void (*orchestration_done)(PTO2Runtime *rt); @@ -205,6 +219,118 @@ static inline TaskOutputTensors pto2_rt_submit_aiv_task(int32_t kernel_id, const return rt->ops->submit_task(rt, mk, args); } +static inline uint64_t pto2_rt_get_async_context(PTO2AsyncEngine engine) { + PTO2Runtime *rt = pto2_current_runtime(); + return rt->ops->get_async_context(rt, engine); +} + +static inline uint64_t pto2_rt_get_async_context(PTO2Runtime *rt, PTO2AsyncEngine engine) { + return rt->ops->get_async_context(rt, engine); +} + +static inline uint64_t pto2_rt_get_async_context(PTO2AsyncCapability capability) { + return pto2_rt_get_async_context(pto2_async_capability_default_engine(capability)); +} + +static inline uint64_t pto2_rt_get_async_context(PTO2Runtime *rt, PTO2AsyncCapability capability) { + return pto2_rt_get_async_context(rt, pto2_async_capability_default_engine(capability)); +} + +static inline uint64_t pto2_rt_get_remote_copy_context() { + return pto2_rt_get_async_context(PTO2AsyncCapability::REMOTE_COPY); +} + +static inline uint64_t pto2_rt_get_remote_copy_context(PTO2Runtime *rt) { + return pto2_rt_get_async_context(rt, PTO2AsyncCapability::REMOTE_COPY); +} + +// Compatibility alias for existing A2/A3 orchestration code. New code should +// use pto2_rt_get_remote_copy_context() so the orchestration layer does not +// depend on a specific transport backend name such as SDMA. +static inline uint64_t pto2_rt_get_sdma_context() { return pto2_rt_get_remote_copy_context(); } + +static inline uint64_t pto2_rt_get_sdma_context(PTO2Runtime *rt) { + return pto2_rt_get_remote_copy_context(rt); +} + +static inline uint64_t pto2_rt_alloc_cq() { + PTO2Runtime *rt = pto2_current_runtime(); + return rt->ops->alloc_cq(rt); +} + +static inline uint64_t pto2_rt_alloc_cq(PTO2Runtime *rt) { return rt->ops->alloc_cq(rt); } + +static inline TaskOutputTensors pto2_rt_submit_aiv_task_deferred(int32_t kernel_id, Arg &args, uint64_t cq_addr) { + args.complete_in_future = true; + args.cq_addr = cq_addr; + args.add_scalar(cq_addr); + return pto2_rt_submit_aiv_task(kernel_id, args); +} + +static inline TaskOutputTensors +pto2_rt_submit_aiv_task_deferred(PTO2Runtime *rt, int32_t kernel_id, Arg &args, uint64_t cq_addr) { + args.complete_in_future = true; + args.cq_addr = cq_addr; + args.add_scalar(cq_addr); + MixedKernels mk; + mk.aiv0_kernel_id = kernel_id; + return rt->ops->submit_task(rt, mk, args); +} + +static inline TaskOutputTensors pto2_rt_submit_aic_task_deferred(int32_t kernel_id, Arg &args, uint64_t cq_addr) { + args.complete_in_future = true; + args.cq_addr = cq_addr; + args.add_scalar(cq_addr); + PTO2Runtime *rt = pto2_current_runtime(); + MixedKernels mk; + mk.aic_kernel_id = kernel_id; + return rt->ops->submit_task(rt, mk, args); +} + +static inline TaskOutputTensors pto2_rt_submit_task_deferred( + const MixedKernels &mixed_kernels, Arg &args, uint64_t cq_addr +) { + args.complete_in_future = true; + args.cq_addr = cq_addr; + args.add_scalar(cq_addr); + return pto2_rt_submit_task(mixed_kernels, args); +} + +/** + * Submit a notification-wait deferred task and return a dependency token. + * + * Encapsulates the boilerplate for creating a NotifyWait task: + * 1. Allocate a CQ + * 2. Create a 1-element dummy output tensor (dependency token) + * 3. Submit a deferred AIV task with (counter_addr, expected_value, cq_addr) + * + * The returned token tensor should be added as an input to any downstream + * task that depends on the notification completing. + * + * @param kernel_id func_id of the NotifyWait kernel + * @param counter_addr GM address of the notification counter (int32*) + * @param expected_value threshold: task completes when *counter >= expected + * @return dependency token tensor (add as input to downstream tasks) + */ +static inline Tensor pto2_rt_submit_notification_wait_task( + int32_t kernel_id, + uint64_t counter_addr, + uint32_t expected_value) { + uint64_t cq_addr = pto2_rt_alloc_cq(); + always_assert(cq_addr != 0 && "pto2_rt_submit_notification_wait_task: failed to allocate CQ"); + + uint32_t dummy_shape[1] = { 1 }; + TensorCreateInfo token_info = make_tensor(dummy_shape, 1, DataType::INT32); + + Arg params; + params.add_output(token_info); + params.add_scalar(counter_addr); + params.add_scalar(static_cast(expected_value)); + TaskOutputTensors outputs = pto2_rt_submit_aiv_task_deferred(kernel_id, params, cq_addr); + + return outputs.get_ref(0); +} + static inline void pto2_rt_scope_begin() { PTO2Runtime *rt = pto2_current_runtime(); rt->ops->scope_begin(rt); diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_kernel_api.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_kernel_api.h new file mode 100644 index 000000000..672ed37c8 --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_kernel_api.h @@ -0,0 +1,22 @@ +/** + * PTO Async Kernel API — unified device-facing async helper entry. + * + * This header intentionally aggregates only AICore-side async helpers: + * - completion queue data layout and write helpers + * - send queue/session helpers + * - notify helpers + * + * It does not include scheduler/AICPU-side polling logic such as + * pto_async_wait.h. That boundary is kept explicit so device inline APIs + * and runtime completion management do not get mixed into one layer. + */ + +#ifndef PTO_ASYNC_KERNEL_API_H +#define PTO_ASYNC_KERNEL_API_H + +#include "pto_cq_types.h" +#include "pto_cq_kernel_api.h" +#include "pto_sq_kernel_api.h" +#include "pto_notify_kernel_api.h" + +#endif // PTO_ASYNC_KERNEL_API_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_wait.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_wait.h new file mode 100644 index 000000000..2d661a8f8 --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_wait.h @@ -0,0 +1,359 @@ +/** + * PTO Runtime2 - Async Completion Wait List + * + * Lightweight watch-list abstraction for deferred task completion. + * + * The scheduler polls two logical protocols described in docs/runtime_async.md: + * - CQ protocol: poll *counter_addr >= expected_value (unified COUNTER type) + * - Notification protocol: poll a GM counter until it reaches expected_value + * + * All completion conditions use a single COUNTER type. Hardware event flags + * (e.g. SDMA completion flags) are the special case where expected_value = 1. + * + * The scheduler polls this list each iteration (Phase 0) and triggers + * on_mixed_task_complete for tasks whose conditions are all satisfied. + * + * Design reference: docs/runtime_async.md + */ + +#ifndef PTO_ASYNC_WAIT_H +#define PTO_ASYNC_WAIT_H + +#include +#include "pto_runtime2_types.h" +#include "pto_scheduler.h" + +extern void cache_invalidate_range(const void* addr, size_t size); + +inline constexpr int32_t PTO2_MAX_ASYNC_WAITS = 64; + +enum class PTO2CompletionPollState : uint8_t { + PENDING = 0, + READY = 1, + FAILED = 2, +}; + +struct PTO2CompletionPollResult { + PTO2CompletionPollState state{PTO2CompletionPollState::PENDING}; + int32_t error_code{PTO2_ERROR_NONE}; +}; + +struct PTO2CompletionCondition { + PTO2AsyncEngine engine{PTO2_ASYNC_ENGINE_SDMA}; + bool satisfied{false}; + volatile uint32_t* counter_addr{nullptr}; + uint32_t expected_value{0}; + + PTO2CompletionPollResult test() const { + if (satisfied) { + return {PTO2CompletionPollState::READY, PTO2_ERROR_NONE}; + } + if (counter_addr == nullptr) { + return {PTO2CompletionPollState::FAILED, PTO2_ERROR_ASYNC_COMPLETION_INVALID}; + } + return {*counter_addr >= expected_value ? PTO2CompletionPollState::READY + : PTO2CompletionPollState::PENDING, + PTO2_ERROR_NONE}; + } +}; + +template +#if PTO2_SCHED_PROFILING +static inline PTO2CompletionStats pto2_complete_task( +#else +static inline void pto2_complete_task( +#endif + PTO2SchedulerState* sched, + PTO2TaskSlotState& slot_state, + PTO2LocalReadyBuffer* local_bufs, + PTO2TaskSlotState** deferred_release_slot_states, + int32_t& deferred_release_count +#if PTO2_SCHED_PROFILING + , int thread_idx +#endif + ) { +#if PTO2_SCHED_PROFILING + PTO2CompletionStats stats = sched->on_mixed_task_complete(slot_state, thread_idx, local_bufs); +#else + sched->on_mixed_task_complete(slot_state, local_bufs); +#endif + deferred_release_slot_states[deferred_release_count++] = &slot_state; +#if PTO2_SCHED_PROFILING + return stats; +#endif +} + +// ============================================================================= +// Async Wait Entry (one per deferred task) +// ============================================================================= + +struct PTO2AsyncWaitEntry { + PTO2TaskSlotState* slot_state{nullptr}; + PTO2CompletionCondition conditions[PTO2_MAX_COMPLETIONS_PER_TASK]; + int32_t condition_count{0}; + int32_t waiting_completion_count{0}; +}; + +struct PTO2AsyncPollResult { + int32_t completed{0}; + int32_t error_code{PTO2_ERROR_NONE}; + PTO2TaskSlotState* failed_slot_state{nullptr}; +}; + +// ============================================================================= +// Name helpers (used by dump / diagnostics) +// ============================================================================= + +inline const char* pto2_async_engine_name(PTO2AsyncEngine engine) { + switch (engine) { + case PTO2_ASYNC_ENGINE_SDMA: return "SDMA"; + case PTO2_ASYNC_ENGINE_ROCE: return "ROCE"; + case PTO2_ASYNC_ENGINE_URMA: return "URMA"; + case PTO2_ASYNC_ENGINE_CCU: return "CCU"; + default: return "UNKNOWN"; + } +} + +// ============================================================================= +// Async Wait List (managed by scheduler thread) +// ============================================================================= + +struct PTO2AsyncWaitList { + PTO2AsyncWaitEntry entries[PTO2_MAX_ASYNC_WAITS]; + int32_t count{0}; + + /** + * Find or create an entry for the given slot_state. + * Returns pointer to the entry, or nullptr if full. + */ + PTO2AsyncWaitEntry* find_or_create(PTO2TaskSlotState* slot_state) { + for (int32_t i = 0; i < count; i++) { + if (entries[i].slot_state == slot_state) { + return &entries[i]; + } + } + if (count >= PTO2_MAX_ASYNC_WAITS) { + return nullptr; + } + PTO2AsyncWaitEntry& e = entries[count++]; + e.slot_state = slot_state; + e.condition_count = 0; + e.waiting_completion_count = 0; + return &e; + } + + bool add_counter(PTO2TaskSlotState* slot_state, + volatile uint32_t* counter_addr, + uint32_t expected_value, + PTO2AsyncEngine engine = PTO2_ASYNC_ENGINE_SDMA) { + PTO2AsyncWaitEntry* entry = find_or_create(slot_state); + if (!entry || counter_addr == nullptr + || entry->condition_count >= PTO2_MAX_COMPLETIONS_PER_TASK) { + return false; + } + PTO2CompletionCondition& cond = entry->conditions[entry->condition_count++]; + cond.engine = engine; + cond.satisfied = false; + cond.counter_addr = counter_addr; + cond.expected_value = expected_value; + entry->waiting_completion_count++; + return true; + } + + /** + * Poll all entries. For each satisfied condition, decrement waiting_completion_count. + * When an entry's count reaches zero, call on_mixed_task_complete and add to + * deferred_release. Remove completed entries by swap-with-last. + * + * Returns the number of tasks that completed this call. + */ + template + PTO2AsyncPollResult poll_and_complete( + PTO2SchedulerState* sched, + PTO2LocalReadyBuffer* local_bufs, + PTO2TaskSlotState** deferred_release_slot_states, + int32_t& deferred_release_count, + int32_t deferred_release_capacity +#if PTO2_SCHED_PROFILING + , int thread_idx +#endif + ) { + PTO2AsyncPollResult result; + for (int32_t i = count - 1; i >= 0; --i) { + PTO2AsyncWaitEntry& entry = entries[i]; + + for (int32_t c = 0; c < entry.condition_count; c++) { + PTO2CompletionCondition& cond = entry.conditions[c]; + if (!cond.satisfied) { + // All current counter writers (SDMA engine flags, TNOTIFY + // RDMA atomics) bypass AICPU data cache. Invalidation is + // needed so the poll reads the true GM value. For any + // hypothetical CPU-written counter this is a harmless no-op. + if (cond.counter_addr) { + cache_invalidate_range( + reinterpret_cast(const_cast(cond.counter_addr)), + sizeof(uint32_t)); + } + PTO2CompletionPollResult poll = cond.test(); + if (poll.state == PTO2CompletionPollState::FAILED) { + result.error_code = poll.error_code; + result.failed_slot_state = entry.slot_state; + return result; + } + if (poll.state == PTO2CompletionPollState::READY) { + cond.satisfied = true; + entry.waiting_completion_count--; + } + } + } + + if (entry.waiting_completion_count <= 0) { + if (deferred_release_count >= deferred_release_capacity) { + result.error_code = PTO2_ERROR_ASYNC_WAIT_OVERFLOW; + result.failed_slot_state = entry.slot_state; + return result; + } +#if PTO2_SCHED_PROFILING + auto stats = pto2_complete_task( + sched, + *entry.slot_state, + local_bufs, + deferred_release_slot_states, + deferred_release_count, + thread_idx + ); + (void)stats; +#else + pto2_complete_task( + sched, + *entry.slot_state, + local_bufs, + deferred_release_slot_states, + deferred_release_count + ); +#endif + result.completed++; + + // Swap-remove: replace with last entry + int32_t last = count - 1; + if (i != last) { + entries[i] = entries[last]; + } + count = last; + } + } + return result; + } + /** + * Register deferred completions for a task from its CQ. + * + * Reads the kernel-written PTO2CompletionQueue and registers each entry + * as a COUNTER wait condition. Returns true when at least one condition + * was registered (task is now tracked by the wait list). On error, + * error_code is set to a non-zero PTO2_ERROR_* value. + */ + bool register_deferred(PTO2TaskSlotState& slot_state, + int32_t thread_idx, int32_t& error_code) { + (void)thread_idx; + error_code = PTO2_ERROR_NONE; + PTO2TaskPayload* payload = slot_state.payload; + if (payload == nullptr || !payload->complete_in_future) return false; + + if (payload->cq_addr == 0) { +#ifdef DEV_ERROR + DEV_ERROR("Thread %d: complete_in_future=true but no CQ entries for task %d", + thread_idx, + static_cast(slot_state.task->task_id.local())); +#endif + error_code = PTO2_ERROR_ASYNC_COMPLETION_INVALID; + return false; + } + + volatile PTO2CompletionQueue* cq = reinterpret_cast( + static_cast(payload->cq_addr)); + // AICore kernel flushes its cache (dcci) before returning, but the + // AICPU may still hold a stale cache line for this CQ. Invalidate + // before reading so we see the kernel's writes. + cache_invalidate_range( + const_cast(reinterpret_cast(cq)), + sizeof(PTO2CompletionQueue)); + int32_t cq_count = cq->count; + if (cq_count <= 0) { +#ifdef DEV_ALWAYS + DEV_ALWAYS("Thread %d: task %d CQ addr=0x%lx count=0, completing immediately", + thread_idx, + static_cast(slot_state.task->task_id.local()), + payload->cq_addr); +#endif + return false; + } + if (cq_count > PTO2_CQ_MAX_ENTRIES) { +#ifdef DEV_ERROR + DEV_ERROR("Thread %d: CQ count=%d exceeds max %d for task %d", + thread_idx, cq_count, PTO2_CQ_MAX_ENTRIES, + static_cast(slot_state.task->task_id.local())); +#endif + error_code = PTO2_ERROR_ASYNC_COMPLETION_INVALID; + return false; + } +#ifdef DEV_ALWAYS + DEV_ALWAYS("Thread %d: task %d reading CQ addr=0x%lx count=%d", + thread_idx, static_cast(slot_state.task->task_id.local()), + payload->cq_addr, cq_count); +#endif + for (int32_t i = 0; i < cq_count; ++i) { + const volatile PTO2CQEntry& entry = cq->entries[i]; +#ifdef DEV_ALWAYS + DEV_ALWAYS("Thread %d: task %d CQ[%d] engine=%s(%d) addr=0x%lx expected=%u", + thread_idx, + static_cast(slot_state.task->task_id.local()), + i, + pto2_async_engine_name(static_cast(entry.engine)), + static_cast(entry.engine), + entry.addr, + entry.expected_value); +#endif + volatile uint32_t* counter_addr = reinterpret_cast( + static_cast(entry.addr)); + if (!add_counter(&slot_state, counter_addr, entry.expected_value, + static_cast(entry.engine))) { + error_code = PTO2_ERROR_ASYNC_REGISTRATION_FAILED; + return false; + } + } + return true; + } + + /** + * Dump wait list state for stall diagnostics. + */ + void dump(int32_t thread_idx, int32_t max_entries = 4) const { +#ifdef DEV_ALWAYS + DEV_ALWAYS("Thread %d: async_wait_list pending entries=%d", thread_idx, count); + int32_t dump_count = count < max_entries ? count : max_entries; + for (int32_t i = 0; i < dump_count; ++i) { + const PTO2AsyncWaitEntry& entry = entries[i]; + int32_t task_id = -1; + if (entry.slot_state != nullptr && entry.slot_state->task != nullptr) { + task_id = static_cast(entry.slot_state->task->task_id.local()); + } + DEV_ALWAYS("Thread %d: async_wait[%d] task=%d waiting=%d conditions=%d", + thread_idx, i, task_id, entry.waiting_completion_count, entry.condition_count); + for (int32_t c = 0; c < entry.condition_count; ++c) { + const PTO2CompletionCondition& cond = entry.conditions[c]; + uint32_t value = cond.counter_addr == nullptr ? 0 : *cond.counter_addr; + DEV_ALWAYS("Thread %d: cond[%d] engine=%s satisfied=%d counter_addr=0x%lx value=%u expected=%u", + thread_idx, c, pto2_async_engine_name(cond.engine), + cond.satisfied ? 1 : 0, + static_cast(reinterpret_cast(cond.counter_addr)), + value, cond.expected_value); + } + } +#else + (void)thread_idx; + (void)max_entries; +#endif + } +}; + +#endif // PTO_ASYNC_WAIT_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_kernel_api.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_kernel_api.h new file mode 100644 index 000000000..03d594ac5 --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_kernel_api.h @@ -0,0 +1,136 @@ +/** + * PTO CQ Kernel API — inline functions for AICore kernels. + * + * These are NOT AICPU function calls. They are structured GM writes + * that the AICPU scheduler reads after the kernel returns. + * + * All overloads follow the (ENGINE, QUEUE, data...) parameter convention, + * symmetric with pto2_send_request_entry(ENGINE, SQ_ID, desc) in the SQ API. + * + * Usage in kernel code: + * + * auto* cq = pto2_cq_get(args[CQ_ARG_IDX]); + * pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, tag); // flag: expected=1 + * pto2_cq_flush(); + */ + +#ifndef PTO_CQ_KERNEL_API_H +#define PTO_CQ_KERNEL_API_H + +#include "pto_cq_types.h" + +// Requires __gm__ and __aicore__ to be defined before including this header. +// Kernel sources should define them (or include PTO-ISA headers) first. + +// Unified engine constants — shared by SQ and CQ APIs. +// Must match PTO2AsyncEngine in pto_types.h. +#define PTO2_ENGINE_SDMA 0 +#define PTO2_ENGINE_ROCE 1 +#define PTO2_ENGINE_URMA 2 +#define PTO2_ENGINE_CCU 3 + +// Completion type constants (must match PTO2CompletionType in pto_types.h) +#define PTO2_CQ_COMPLETION_COUNTER 0 + +inline __aicore__ void pto2_cq_writeback_gm_line(volatile __gm__ void* addr) { + __gm__ int32_t* gm_addr = (__gm__ int32_t*)addr; +#if defined(SINGLE_CACHE_LINE) && defined(CACHELINE_OUT) + dcci(gm_addr, SINGLE_CACHE_LINE, CACHELINE_OUT); +#elif defined(SINGLE_CACHE_LINE) + dcci(gm_addr, SINGLE_CACHE_LINE); +#endif +#if defined(DSB_DDR) + dsb(DSB_DDR); +#endif +} + +/** + * Obtain the completion queue pointer from a kernel scalar arg. + */ +inline __aicore__ volatile __gm__ PTO2CompletionQueue* pto2_cq_get(uint64_t addr) { + return reinterpret_cast( + static_cast(addr)); +} + +/** + * Reset the CQ header before the kernel appends completion entries. + * + * Runtime-owned CQ buffers may be reused across tasks, so kernels should + * explicitly republish an empty header before the first append. + */ +inline __aicore__ void pto2_cq_reset(volatile __gm__ PTO2CompletionQueue* cq) { + // Republish the header line even when the queue was already zeroed in a + // reused runtime buffer. Some hardware paths were observed to require an + // explicit header-state transition before the subsequent count increment + // became visible to the AICPU scheduler. + cq->count = -1; + pto2_cq_writeback_gm_line(&cq->count); + cq->count = 0; + pto2_cq_writeback_gm_line(&cq->count); +} + +/** + * Register one expected completion condition in the CQ. + * + * All completion conditions are COUNTER type: the scheduler polls + * *addr >= expected_value. Hardware flags (SDMA event flags) are + * the special case where expected_value = 1 (flag goes 0 → non-zero). + * + * Parameter order: (ENGINE, QUEUE, addr, expected) — symmetric with SQ API. + * Each call appends an entry and increments cq->count. + * The caller must ensure total calls per task <= PTO2_CQ_MAX_ENTRIES. + */ +inline __aicore__ void pto2_save_expected_completion( + uint32_t engine, + volatile __gm__ PTO2CompletionQueue* cq, + uint64_t addr, + uint32_t expected_value) +{ + int32_t idx = cq->count; + volatile __gm__ PTO2CQEntry* entry = + const_cast(&cq->entries[idx]); + entry->engine = engine; + entry->completion_type = PTO2_CQ_COMPLETION_COUNTER; + entry->addr = addr; + entry->expected_value = expected_value; + pto2_cq_writeback_gm_line(entry); + + cq->count = idx + 1; + pto2_cq_writeback_gm_line(&cq->count); +} + +/** + * Simplified overload for hardware flags: (ENGINE, CQ, tag). + * + * Registers a COUNTER condition with expected_value=1. + * Equivalent to polling *tag_addr >= 1 (i.e. flag != 0). + * Symmetric with pto2_send_request_entry(ENGINE, SQ_ID, desc). + */ +inline __aicore__ void pto2_save_expected_completion( + uint32_t engine, + volatile __gm__ PTO2CompletionQueue* cq, + uint64_t tag) +{ + pto2_save_expected_completion(engine, cq, tag, 1); +} + +/** + * Final flush before kernel returns. Ensures all CQ writes + * are visible to the AICPU scheduler. + * + * Uses CCE compiler built-in enum constants (cache_line_t, dcci_dst_t, + * dsb_mode_t, pipe_t) which are available when compiling for AICore + * via the bisheng/CCE toolchain. Previous #if-defined guards broke + * because these are C++ enums, not preprocessor macros. + */ +inline __aicore__ void pto2_cq_flush() { + pipe_barrier(PIPE_ALL); +} + +inline __aicore__ void pto2_cq_flush(volatile __gm__ PTO2CompletionQueue* cq) { + dcci((__gm__ int32_t*)cq, cache_line_t::ENTIRE_DATA_CACHE, dcci_dst_t::CACHELINE_OUT); + dsb(DSB_DDR); + pipe_barrier(PIPE_ALL); +} + +#endif // PTO_CQ_KERNEL_API_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_types.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_types.h new file mode 100644 index 000000000..e0c571a8a --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_types.h @@ -0,0 +1,47 @@ +/** + * PTO Completion Queue Types — shared between AICore kernels and AICPU runtime. + * + * This header must remain simple and C-compatible. AICore compilation + * environments have restricted standard library access. + */ + +#ifndef PTO_CQ_TYPES_H +#define PTO_CQ_TYPES_H + +#include + +#define PTO2_CQ_MAX_ENTRIES 64 + +/** + * Single CQ entry written by a kernel via pto2_save_expected_completion(). + * The scheduler reads these after the worker returns. + */ +struct PTO2CQEntry { + uint32_t engine; // PTO2AsyncEngine value + int32_t completion_type; // PTO2CompletionType value + uint64_t addr; // completion token (flag/handle/counter GM address) + uint32_t expected_value; // for COUNTER completions + uint32_t _pad; +}; + +/** + * Per-task completion queue. + * + * Allocated by the runtime and passed to the kernel as a scalar arg. + * The kernel calls pto2_save_expected_completion() to append entries + * and increment `count`. The scheduler reads the CQ after all + * subtasks have returned and creates completion conditions accordingly. + * + * Memory ordering contract: + * - Kernel writes entries[i] fields BEFORE incrementing count. + * - Kernel flushes caches (dcci+dsb on HW) before returning. + * - Scheduler reads only after detecting task_status==0, + * which implies all kernel writes are visible. + */ +struct PTO2CompletionQueue { + volatile int32_t count; // entries written so far (kernel increments) + int32_t _pad; + PTO2CQEntry entries[PTO2_CQ_MAX_ENTRIES]; +}; + +#endif // PTO_CQ_TYPES_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_notify_kernel_api.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_notify_kernel_api.h new file mode 100644 index 000000000..6d5e6a5e1 --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_notify_kernel_api.h @@ -0,0 +1,41 @@ +/** + * PTO Notify Kernel API — notification counter abstraction for AICore kernels. + * + * This wraps PTO-ISA TNOTIFY and maps the local counter wait condition onto + * the runtime's existing COUNTER deferred-completion path. + * + * Requires: + * - PTO-ISA headers included before this header + * - __gm__ and __aicore__ defined before this header + */ + +#ifndef PTO_NOTIFY_KERNEL_API_H +#define PTO_NOTIFY_KERNEL_API_H + +#include "pto_cq_kernel_api.h" +#include "aicore/pto_async_backend_kernel.h" + +enum class PTO2NotifyOp : uint32_t { + Set = 0, + AtomicAdd = 1, +}; + +inline __aicore__ void pto2_send_notification( + volatile __gm__ int32_t* remote_counter_addr, + int32_t value = 1, + PTO2NotifyOp op = PTO2NotifyOp::AtomicAdd) +{ + pto2_backend_send_notification(remote_counter_addr, value, static_cast(op)); +} + +inline __aicore__ void pto2_save_expected_notification_counter( + volatile __gm__ PTO2CompletionQueue* cq, + volatile __gm__ int32_t* local_counter_addr, + uint32_t expected_value) +{ + pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, + (uint64_t)local_counter_addr, + expected_value); +} + +#endif // PTO_NOTIFY_KERNEL_API_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.cpp b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.cpp index 8085ed63d..4d62cbb3b 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.cpp +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.cpp @@ -18,6 +18,7 @@ */ #include "pto_runtime2.h" +#include "pto_async_wait.h" #include #include @@ -45,6 +46,21 @@ static TaskOutputTensors alloc_tensors_impl(PTO2Runtime *rt, const Arg &args) { return pto2_alloc_tensors(&rt->orchestrator, args); } +static uint64_t get_async_context_impl(PTO2Runtime *rt, PTO2AsyncEngine engine) { + if (engine >= PTO2_NUM_ASYNC_ENGINES) return 0; + return rt->async_context_addrs[engine]; +} + +static uint64_t alloc_cq_impl(PTO2Runtime *rt) { + if (!rt->cq_pool || rt->cq_pool_next >= rt->cq_pool_size) { + return 0; + } + int32_t idx = rt->cq_pool_next++; + PTO2CompletionQueue *cq = &rt->cq_pool[idx]; + memset(cq, 0, sizeof(PTO2CompletionQueue)); + return reinterpret_cast(cq); +} + void pto2_rt_scope_begin(PTO2Runtime *rt) { pto2_scope_begin(&rt->orchestrator); } void pto2_rt_scope_end(PTO2Runtime *rt) { pto2_scope_end(&rt->orchestrator); } @@ -181,6 +197,8 @@ void pto2_set_tensor_data( static const PTO2RuntimeOps s_runtime_ops = { .submit_task = submit_task_impl, + .get_async_context = get_async_context_impl, + .alloc_cq = alloc_cq_impl, .scope_begin = pto2_rt_scope_begin, .scope_end = pto2_rt_scope_end, .orchestration_done = pto2_rt_orchestration_done, @@ -259,6 +277,10 @@ PTO2Runtime *pto2_runtime_create_custom( // Connect orchestrator to scheduler (for simulated mode) pto2_orchestrator_set_scheduler(&rt->orchestrator, &rt->scheduler); + rt->cq_pool_size = PTO2_MAX_ASYNC_WAITS; + rt->cq_pool = static_cast(calloc(rt->cq_pool_size, sizeof(PTO2CompletionQueue))); + rt->cq_pool_next = 0; + return rt; } @@ -292,6 +314,10 @@ PTO2Runtime *pto2_runtime_create_from_sm( pto2_orchestrator_set_scheduler(&rt->orchestrator, &rt->scheduler); + rt->cq_pool_size = PTO2_MAX_ASYNC_WAITS; + rt->cq_pool = static_cast(calloc(rt->cq_pool_size, sizeof(PTO2CompletionQueue))); + rt->cq_pool_next = 0; + return rt; } @@ -309,6 +335,7 @@ void pto2_runtime_destroy(PTO2Runtime *rt) { pto2_sm_destroy(rt->sm_handle); } + free(rt->cq_pool); free(rt); } diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.h index 779b75143..bbe86340e 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.h @@ -67,6 +67,8 @@ typedef struct PTO2Runtime PTO2Runtime; // forward declare for ops signatures struct PTO2RuntimeOps { TaskOutputTensors (*submit_task)(PTO2Runtime *rt, const MixedKernels &mixed_kernels, const Arg &args); + uint64_t (*get_async_context)(PTO2Runtime *rt, PTO2AsyncEngine engine); + uint64_t (*alloc_cq)(PTO2Runtime *rt); void (*scope_begin)(PTO2Runtime *rt); void (*scope_end)(PTO2Runtime *rt); void (*orchestration_done)(PTO2Runtime *rt); @@ -111,6 +113,14 @@ struct PTO2Runtime { // Mode PTO2RuntimeMode mode; + // Per-engine async context addresses (0 = not available). + uint64_t async_context_addrs[PTO2_NUM_ASYNC_ENGINES]{}; + + // Per-task completion queues for deferred completion. + PTO2CompletionQueue *cq_pool{nullptr}; + int32_t cq_pool_size{0}; + int32_t cq_pool_next{0}; + // Statistics int64_t total_cycles; }; diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h index 247f09fed..d5d8c72da 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h @@ -83,6 +83,9 @@ // Scheduler errors (100+): detected in scheduler threads #define PTO2_ERROR_SCHEDULER_TIMEOUT 100 +#define PTO2_ERROR_ASYNC_COMPLETION_INVALID 101 +#define PTO2_ERROR_ASYNC_WAIT_OVERFLOW 102 +#define PTO2_ERROR_ASYNC_REGISTRATION_FAILED 103 // ============================================================================= // Configuration Constants @@ -367,6 +370,10 @@ struct PTO2TaskPayload { Tensor tensors[MAX_TENSOR_ARGS]; // === Cache lines 35-50 (1024B) — scalars === uint64_t scalars[MAX_SCALAR_ARGS]; + // Async/deferred-completion metadata. Kept after tensors/scalars so the + // existing hot-path tensor/scalar layout stays unchanged. + bool complete_in_future{false}; + uint64_t cq_addr{0}; // Layout verification (size checks that don't need offsetof). static_assert(sizeof(Tensor) == 128, "Tensor must be 2 cache lines"); @@ -386,6 +393,8 @@ struct PTO2TaskPayload { init(const Arg &args, TaskOutputTensors &result, void *base_addr, uint64_t offsets[], uint64_t buffer_sizes[]) { tensor_count = args.tensor_count(); scalar_count = args.scalar_count(); + complete_in_future = args.complete_in_future; + cq_addr = args.cq_addr; // int32_t out_idx = 0; for (int32_t i = 0; i < args.tensor_count(); i++) { diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_scheduler.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_scheduler.h index 0c3f5a0ff..e7ddc3516 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_scheduler.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_scheduler.h @@ -36,6 +36,8 @@ #include "pto_runtime2_types.h" #include "pto_shared_memory.h" +struct PTO2SchedulerState; + #if PTO2_SCHED_PROFILING #include "aicpu/device_time.h" #define PTO2_SCHED_CYCLE_START() uint64_t _st0 = get_sys_cnt_aicpu(), _st1 @@ -566,8 +568,6 @@ struct PTO2SchedulerState { int32_t new_refcount = slot_state.fanin_refcount.fetch_add(1, std::memory_order_acq_rel) + 1; if (new_refcount == slot_state.fanin_count) { - // Local-first: try per-CoreType thread-local buffer before global queue - // Route by active_mask: AIC-containing tasks → buf[0], AIV-only → buf[1] PTO2ResourceShape shape = pto2_active_mask_to_shape(slot_state.active_mask); if (!local_bufs || !local_bufs[static_cast(shape)].try_push(&slot_state)) { ready_queues[static_cast(shape)].push(&slot_state); diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_shared_memory.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_shared_memory.h index b7d75180f..1829e4dba 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_shared_memory.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_shared_memory.h @@ -124,6 +124,15 @@ static_assert( "PTO2SharedMemoryHeader must be aligned to cache line (PTO2_ALIGN_SIZE)" ); +static inline void pto2_record_scheduler_error( + PTO2SharedMemoryHeader* header, int32_t thread_idx, int32_t error_code) { + if (header == nullptr) return; + header->sched_error_bitmap.fetch_or(1u << thread_idx, std::memory_order_acq_rel); + header->sched_error_code.store(error_code, std::memory_order_release); + header->sched_error_thread.store(thread_idx, std::memory_order_release); + header->orch_error_code.store(error_code, std::memory_order_release); +} + // ============================================================================= // Shared Memory Handle // ============================================================================= diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_sq_kernel_api.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_sq_kernel_api.h new file mode 100644 index 000000000..1d4d657a6 --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_sq_kernel_api.h @@ -0,0 +1,201 @@ +/** + * PTO SQ Kernel API — generic async remote-copy abstraction for AICore kernels. + * + * Two usage paths, both ending with CQ registration: + * + * Path 1 — High-level (send_request_entry, one-stop): + * + * auto desc = pto2_remote_copy_descriptor(dst, src, scratch, context); + * uint64_t tag = pto2_send_request_entry(PTO2_ENGINE_SDMA, sq_id, desc); + * pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, tag); + * pto2_cq_flush(); + * + * Path 2 — Low-level (sq_open + direct ISA instruction): + * + * auto session = pto2_sq_open(PTO2_ENGINE_SDMA, sq_id, scratch, context); + * PTO2AsyncEvent event = pto2_backend_remote_copy_put(dst, src, session); + * pto2_save_expected_completion(cq, event); + * pto2_cq_flush(); + * + * Layering: + * send_request_entry = sq_open + ISA instruction (syntactic sugar) + * sq_open = session management (BuildAsyncSession wrapper) + * + * Requires: + * - PTO-ISA headers included before this header + * - __gm__ and __aicore__ defined before this header + * - HW build only (uses PTO-ISA async instructions) + */ + +#ifndef PTO_SQ_KERNEL_API_H +#define PTO_SQ_KERNEL_API_H + +#include "pto_cq_types.h" +#include "pto_cq_kernel_api.h" +#include "aicore/pto_async_backend_kernel.h" + +// SQ engine types — aliases for the unified PTO2_ENGINE_* constants +#define PTO2_SQ_ENGINE_SDMA PTO2_ENGINE_SDMA +// #define PTO2_SQ_ENGINE_CCU PTO2_ENGINE_CCU // future +// #define PTO2_SQ_ENGINE_URMA PTO2_ENGINE_URMA // future + +#define PTO2_SQ_ID_AUTO UINT32_MAX + +using PTO2AsyncSession = PTO2BackendAsyncSession; +using PTO2AsyncEvent = PTO2BackendAsyncEvent; + +struct PTO2RemoteCopyBaseConfig { + uint32_t block_bytes{0}; + uint32_t block_offset{0}; + uint32_t repeat_times{1}; +}; + +// ============================================================================ +// pto2_sq_open — build async session for a hardware engine queue +// +// This is the foundation layer. Both send_request_entry (high-level) +// and direct ISA usage (low-level) go through this to obtain a session. +// ============================================================================ + +template +inline __aicore__ PTO2AsyncSession pto2_sq_open( + uint32_t sq_type, + uint32_t sq_id, + ScratchTile& scratch, + __gm__ uint8_t* context, + uint32_t sync_id = 0, + const PTO2RemoteCopyBaseConfig& base_config = {}) +{ + (void)sq_type; + return pto2_backend_remote_copy_open( + sq_id, scratch, context, sync_id, + base_config.block_bytes, + base_config.block_offset, + base_config.repeat_times); +} + +// ============================================================================ +// pto2_save_expected_completion — AsyncEvent overload +// +// Accepts a PTO-ISA AsyncEvent directly, auto-extracting engine and handle. +// For the low-level path where the user calls ISA instructions directly. +// ============================================================================ + +inline __aicore__ void pto2_save_expected_completion( + volatile __gm__ PTO2CompletionQueue* cq, + const PTO2AsyncEvent& event) +{ + pto2_save_expected_completion( + pto2_backend_async_event_engine(event), + cq, + pto2_backend_async_event_handle(event)); +} + +enum class PTO2RemoteCopyRequestOp : uint32_t { + Put = 0, + Get = 1, +}; + +// ============================================================================ +// SDMA descriptor + factories (for high-level path) +// ============================================================================ + +template +struct PTO2RemoteCopyDescriptor { + GlobalDstData& dst; + GlobalSrcData& src; + ScratchTile& scratch; + __gm__ uint8_t* context; + uint32_t sync_id; + PTO2RemoteCopyBaseConfig base_config; + PTO2RemoteCopyRequestOp op; +}; + +template +inline __aicore__ PTO2RemoteCopyDescriptor +pto2_remote_copy_descriptor( + GlobalDstData& dst, + GlobalSrcData& src, + ScratchTile& scratch, + __gm__ uint8_t* context, + uint32_t sync_id = 0, + const PTO2RemoteCopyBaseConfig& base_config = {}) +{ + return {dst, src, scratch, context, sync_id, base_config, + PTO2RemoteCopyRequestOp::Put}; +} + +template +inline __aicore__ PTO2RemoteCopyDescriptor +pto2_remote_copy_tget_descriptor( + GlobalDstData& dst, + GlobalSrcData& src, + ScratchTile& scratch, + __gm__ uint8_t* context, + uint32_t sync_id = 0, + const PTO2RemoteCopyBaseConfig& base_config = {}) +{ + return {dst, src, scratch, context, sync_id, base_config, + PTO2RemoteCopyRequestOp::Get}; +} + +using PTO2SdmaRequestOp = PTO2RemoteCopyRequestOp; +template +using PTO2SdmaDescriptor = PTO2RemoteCopyDescriptor; + +template +inline __aicore__ PTO2SdmaDescriptor +pto2_sdma_descriptor( + GlobalDstData& dst, + GlobalSrcData& src, + ScratchTile& scratch, + __gm__ uint8_t* context, + uint32_t sync_id = 0, + const PTO2RemoteCopyBaseConfig& base_config = {}) +{ + return pto2_remote_copy_descriptor(dst, src, scratch, context, sync_id, base_config); +} + +template +inline __aicore__ PTO2SdmaDescriptor +pto2_sdma_tget_descriptor( + GlobalDstData& dst, + GlobalSrcData& src, + ScratchTile& scratch, + __gm__ uint8_t* context, + uint32_t sync_id = 0, + const PTO2RemoteCopyBaseConfig& base_config = {}) +{ + return pto2_remote_copy_tget_descriptor(dst, src, scratch, context, sync_id, base_config); +} + +// ============================================================================ +// pto2_send_request_entry — high-level, sugar over sq_open + async ISA op +// +// Original design: tag = pto2_send_request_entry(SQ_TYPE, SQ_ID, descriptor) +// Internally: sq_open(session params from desc) → async ISA op → tag +// ============================================================================ + +template +inline __aicore__ uint64_t pto2_send_request_entry( + uint32_t sq_type, + uint32_t sq_id, + PTO2RemoteCopyDescriptor& desc) +{ + PTO2AsyncSession session = pto2_sq_open( + sq_type, sq_id, desc.scratch, desc.context, + desc.sync_id, desc.base_config); + if (!session.valid) return 0; + + PTO2AsyncEvent event; + if (desc.op == PTO2RemoteCopyRequestOp::Get) { + event = pto2_backend_remote_copy_get(desc.dst, desc.src, session); + } else { + event = pto2_backend_remote_copy_put(desc.dst, desc.src, session); + } + return pto2_backend_async_event_valid(event) + ? pto2_backend_async_event_handle(event) + : 0; +} + +#endif // PTO_SQ_KERNEL_API_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_types.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_types.h index 429cc2f02..a92b06588 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_types.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_types.h @@ -33,6 +33,7 @@ #endif #include "pto_submit_types.h" // NOLINT(build/include_subdir) -- PTO2LaunchSpec +#include "pto_cq_types.h" // NOLINT(build/include_subdir) #include "task_args.h" // NOLINT(build/include_subdir) -- TaskArgs base class #include "tensor.h" // NOLINT(build/include_subdir) #include "tensor_arg.h" // NOLINT(build/include_subdir) -- canonical TensorArgType definition @@ -43,6 +44,54 @@ #define PTO2_MAX_OUTPUTS 16 // Maximum outputs per task #define PTO2_MAX_INPUTS 16 // Maximum inputs per task #define PTO2_MAX_INOUTS 8 // Maximum in-out args per task +#define PTO2_MAX_COMPLETIONS_PER_TASK PTO2_CQ_MAX_ENTRIES + +typedef enum { + PTO2_ASYNC_ENGINE_SDMA = 0, + PTO2_ASYNC_ENGINE_ROCE = 1, + PTO2_ASYNC_ENGINE_URMA = 2, + PTO2_ASYNC_ENGINE_CCU = 3, + PTO2_NUM_ASYNC_ENGINES = 4 +} PTO2AsyncEngine; + +/** + * Runtime-level async capabilities exposed to orchestration. + * + * These are semantic capabilities, not hard-coded hardware backends. + * A platform may implement the same capability with different engines. + */ +enum class PTO2AsyncCapability : uint32_t { + REMOTE_COPY = 0, +}; + +/** + * Host-side initialization result for an async capability context. + */ +enum class PTO2AsyncContextInitStatus : int32_t { + ERROR = -1, + SKIPPED = 0, + READY = 1, +}; + +inline constexpr PTO2AsyncEngine pto2_async_capability_default_engine(PTO2AsyncCapability capability) { + switch (capability) { + case PTO2AsyncCapability::REMOTE_COPY: + return PTO2_ASYNC_ENGINE_SDMA; + } + return PTO2_ASYNC_ENGINE_SDMA; +} + +inline constexpr const char *pto2_async_capability_name(PTO2AsyncCapability capability) { + switch (capability) { + case PTO2AsyncCapability::REMOTE_COPY: + return "REMOTE_COPY"; + } + return "UNKNOWN"; +} + +enum class PTO2CompletionType : int32_t { + COUNTER = 0, +}; // ============================================================================= // Task Output Tensors (return value from submit) @@ -130,11 +179,24 @@ struct Arg : TaskArgs + static uint64_t pack_scalar(T value) { + static_assert(sizeof(T) <= sizeof(uint64_t), "pack_scalar: type must fit in 8 bytes"); + static_assert(std::is_trivially_copyable_v, "pack_scalar: type must be trivially copyable"); + uint64_t packed = 0; + memcpy(&packed, &value, sizeof(T)); + return packed; + } void reset() { clear(); has_error = false; error_msg = nullptr; + complete_in_future = false; + cq_addr = 0; } void set_error(const char *msg) { @@ -223,7 +285,7 @@ struct Arg : TaskArgs #include #include // for fprintf, printf @@ -38,6 +40,7 @@ #include "common/perf_profiling.h" #include "common/platform_config.h" #include "pto2_dispatch_payload.h" +#include "pto_types.h" #include "task_args.h" // ============================================================================= @@ -125,6 +128,8 @@ struct HostApi { int (*copy_from_device)(void *host_ptr, const void *dev_ptr, size_t size); uint64_t (*upload_kernel_binary)(int func_id, const uint8_t *bin_data, size_t bin_size); void (*remove_kernel_binary)(int func_id); + PTO2AsyncContextInitStatus (*init_async_context)(PTO2AsyncCapability capability, uint64_t *addr); + void (*destroy_async_context)(PTO2AsyncCapability capability, uint64_t addr); }; /** @@ -193,6 +198,7 @@ class Runtime { void *pto2_gm_sm_ptr_; // GM pointer to PTO2 shared memory (device) void *pto2_gm_heap_ptr_; // GM heap for orchestrator output buffers (device) void *pto2_slot_states_ptr_; // Pointer to PTO2TaskSlotState array (scheduler-private, for profiling) + uint64_t async_context_addrs_[PTO2_NUM_ASYNC_ENGINES]; ChipStorageTaskArgs orch_args_storage_; // Copy of args for device // Device orchestration SO binary (for dlopen on AICPU thread 3) @@ -246,6 +252,8 @@ class Runtime { void set_pto2_gm_sm_ptr(void *p); void set_pto2_gm_heap(void *p); void set_pto2_slot_states_ptr(void *p); + void set_async_context_addr(PTO2AsyncEngine engine, uint64_t addr); + uint64_t get_async_context_addr(PTO2AsyncEngine engine) const; void set_orch_args(const ChipStorageTaskArgs &args); // Device orchestration SO binary (for dlopen on AICPU thread 3) diff --git a/src/a5/platform/include/aicore/pto_async_backend_kernel.h b/src/a5/platform/include/aicore/pto_async_backend_kernel.h new file mode 100644 index 000000000..78fc1fef1 --- /dev/null +++ b/src/a5/platform/include/aicore/pto_async_backend_kernel.h @@ -0,0 +1,92 @@ +/** + * A5 async backend helpers for AICore kernels. + * + * This is currently a stub backend so the generic runtime async headers can + * compile without depending on A2/A3-specific transport implementations. + */ + +#ifndef SRC_A5_PLATFORM_INCLUDE_AICORE_PTO_ASYNC_BACKEND_KERNEL_H_ +#define SRC_A5_PLATFORM_INCLUDE_AICORE_PTO_ASYNC_BACKEND_KERNEL_H_ + +#include + +struct PTO2BackendAsyncSession { + bool valid{false}; +}; + +struct PTO2BackendAsyncEvent { + uint32_t engine{0}; + uint64_t handle{0}; + + bool valid() const { return false; } +}; + +inline constexpr uint32_t pto2_backend_remote_copy_default_block_bytes() { return 0; } + +template +inline __aicore__ PTO2BackendAsyncSession pto2_backend_remote_copy_open( + uint32_t sq_id, + ScratchTile &scratch, + __gm__ uint8_t *context, + uint32_t sync_id, + uint32_t block_bytes, + uint32_t block_offset, + uint32_t repeat_times) +{ + (void)sq_id; + (void)scratch; + (void)context; + (void)sync_id; + (void)block_bytes; + (void)block_offset; + (void)repeat_times; + return {}; +} + +template +inline __aicore__ PTO2BackendAsyncEvent pto2_backend_remote_copy_put( + GlobalDstData &dst, + GlobalSrcData &src, + const PTO2BackendAsyncSession &session) +{ + (void)dst; + (void)src; + (void)session; + return {}; +} + +template +inline __aicore__ PTO2BackendAsyncEvent pto2_backend_remote_copy_get( + GlobalDstData &dst, + GlobalSrcData &src, + const PTO2BackendAsyncSession &session) +{ + (void)dst; + (void)src; + (void)session; + return {}; +} + +inline __aicore__ bool pto2_backend_async_event_valid(const PTO2BackendAsyncEvent &event) { + return event.valid(); +} + +inline __aicore__ uint32_t pto2_backend_async_event_engine(const PTO2BackendAsyncEvent &event) { + return event.engine; +} + +inline __aicore__ uint64_t pto2_backend_async_event_handle(const PTO2BackendAsyncEvent &event) { + return event.handle; +} + +inline __aicore__ void pto2_backend_send_notification( + volatile __gm__ int32_t *remote_counter_addr, + int32_t value, + uint32_t op) +{ + (void)remote_counter_addr; + (void)value; + (void)op; +} + +#endif // SRC_A5_PLATFORM_INCLUDE_AICORE_PTO_ASYNC_BACKEND_KERNEL_H_ diff --git a/src/a5/platform/onboard/host/pto_runtime_c_api.cpp b/src/a5/platform/onboard/host/pto_runtime_c_api.cpp index eff903896..1cc2ae639 100644 --- a/src/a5/platform/onboard/host/pto_runtime_c_api.cpp +++ b/src/a5/platform/onboard/host/pto_runtime_c_api.cpp @@ -121,6 +121,40 @@ int set_device(DeviceContextHandle ctx, int device_id) { } } +void *device_malloc_ctx(DeviceContextHandle ctx, size_t size) { + if (ctx == NULL) return NULL; + try { + return static_cast(ctx)->allocate_tensor(size); + } catch (...) { + return NULL; + } +} + +void device_free_ctx(DeviceContextHandle ctx, void *dev_ptr) { + if (ctx == NULL || dev_ptr == NULL) return; + try { + static_cast(ctx)->free_tensor(dev_ptr); + } catch (...) {} +} + +int copy_to_device_ctx(DeviceContextHandle ctx, void *dev_ptr, const void *host_ptr, size_t size) { + if (ctx == NULL || dev_ptr == NULL || host_ptr == NULL) return -1; + try { + return static_cast(ctx)->copy_to_device(dev_ptr, host_ptr, size); + } catch (...) { + return -1; + } +} + +int copy_from_device_ctx(DeviceContextHandle ctx, void *host_ptr, const void *dev_ptr, size_t size) { + if (ctx == NULL || host_ptr == NULL || dev_ptr == NULL) return -1; + try { + return static_cast(ctx)->copy_from_device(host_ptr, dev_ptr, size); + } catch (...) { + return -1; + } +} + int run_runtime( DeviceContextHandle ctx, RuntimeHandle runtime, const void *callable, const void *args, int block_dim, int aicpu_thread_num, int device_id, const uint8_t *aicpu_binary, size_t aicpu_size, const uint8_t *aicore_binary, diff --git a/src/a5/platform/sim/host/pto_runtime_c_api.cpp b/src/a5/platform/sim/host/pto_runtime_c_api.cpp index 37028f27d..e3b4265a8 100644 --- a/src/a5/platform/sim/host/pto_runtime_c_api.cpp +++ b/src/a5/platform/sim/host/pto_runtime_c_api.cpp @@ -121,6 +121,40 @@ int set_device(DeviceContextHandle ctx, int device_id) { return 0; } +void *device_malloc_ctx(DeviceContextHandle ctx, size_t size) { + if (ctx == NULL) return NULL; + try { + return static_cast(ctx)->allocate_tensor(size); + } catch (...) { + return NULL; + } +} + +void device_free_ctx(DeviceContextHandle ctx, void *dev_ptr) { + if (ctx == NULL || dev_ptr == NULL) return; + try { + static_cast(ctx)->free_tensor(dev_ptr); + } catch (...) {} +} + +int copy_to_device_ctx(DeviceContextHandle ctx, void *dev_ptr, const void *host_ptr, size_t size) { + if (ctx == NULL || dev_ptr == NULL || host_ptr == NULL) return -1; + try { + return static_cast(ctx)->copy_to_device(dev_ptr, host_ptr, size); + } catch (...) { + return -1; + } +} + +int copy_from_device_ctx(DeviceContextHandle ctx, void *host_ptr, const void *dev_ptr, size_t size) { + if (ctx == NULL || host_ptr == NULL || dev_ptr == NULL) return -1; + try { + return static_cast(ctx)->copy_from_device(host_ptr, dev_ptr, size); + } catch (...) { + return -1; + } +} + int run_runtime( DeviceContextHandle ctx, RuntimeHandle runtime, const void *callable, const void *args, int block_dim, int aicpu_thread_num, int device_id, const uint8_t *aicpu_binary, size_t aicpu_size, const uint8_t *aicore_binary, diff --git a/src/a5/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp b/src/a5/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp index c07279056..3aef97b45 100644 --- a/src/a5/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp +++ b/src/a5/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp @@ -135,6 +135,14 @@ extern "C" int init_runtime_impl(Runtime *runtime, const ChipCallable *callable, int64_t t_args_start = _now_ms(); for (int i = 0; i < tensor_count; i++) { ContinuousTensor t = orch_args->tensor(i); + if (t.is_device_resident()) { + // External/bootstrap-provided device buffers are already valid in the + // target chip context, so runtime_maker must preserve the pointer + // instead of allocating/copying a second device buffer. + LOG_INFO(" Tensor %d: reusing device-resident pointer %p (%zu bytes)", i, t.data_as(), t.nbytes()); + device_args.add_tensor(t); + continue; + } void *host_ptr = reinterpret_cast(static_cast(t.data)); size_t size = static_cast(t.nbytes()); diff --git a/src/common/distributed/dist_chip_bootstrap_channel.cpp b/src/common/distributed/dist_chip_bootstrap_channel.cpp new file mode 100644 index 000000000..7f70730ff --- /dev/null +++ b/src/common/distributed/dist_chip_bootstrap_channel.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include "dist_chip_bootstrap_channel.h" + +#include +#include +#include + +DistChipBootstrapChannel::DistChipBootstrapChannel(void *mailbox_ptr, size_t max_buffer_count) : + mailbox_(mailbox_ptr), + max_buffer_count_(max_buffer_count) { + if (mailbox_ptr == nullptr) throw std::invalid_argument("DistChipBootstrapChannel: null mailbox_ptr"); + if (max_buffer_count > DIST_CHIP_BOOTSTRAP_PTR_CAPACITY) { + throw std::invalid_argument("DistChipBootstrapChannel: buffer count exceeds mailbox capacity"); + } +} + +ChipBootstrapMailboxState DistChipBootstrapChannel::read_state() const { + volatile int32_t *ptr = reinterpret_cast(base() + OFF_STATE); + int32_t value = 0; +#if defined(__aarch64__) + __asm__ volatile("ldar %w0, [%1]" : "=r"(value) : "r"(ptr) : "memory"); +#elif defined(__x86_64__) + value = *ptr; + __asm__ volatile("" ::: "memory"); +#else + __atomic_load(ptr, &value, __ATOMIC_ACQUIRE); +#endif + return static_cast(value); +} + +void DistChipBootstrapChannel::write_state(ChipBootstrapMailboxState state) { + volatile int32_t *ptr = reinterpret_cast(base() + OFF_STATE); + int32_t value = static_cast(state); +#if defined(__aarch64__) + __asm__ volatile("stlr %w0, [%1]" : : "r"(value), "r"(ptr) : "memory"); +#elif defined(__x86_64__) + __asm__ volatile("" ::: "memory"); + *ptr = value; +#else + __atomic_store(ptr, &value, __ATOMIC_RELEASE); +#endif +} + +void DistChipBootstrapChannel::reset() { + std::memset(base(), 0, DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE); + write_state(ChipBootstrapMailboxState::IDLE); +} + +void DistChipBootstrapChannel::write_success( + uint64_t device_ctx, uint64_t local_window_base, uint64_t actual_window_size, const std::vector &buffer_ptrs +) { + if (buffer_ptrs.size() > max_buffer_count_) { + throw std::invalid_argument("DistChipBootstrapChannel::write_success: buffer count exceeds configured capacity"); + } + reset(); + int32_t error_code = 0; + int32_t buffer_count = static_cast(buffer_ptrs.size()); + std::memcpy(base() + OFF_ERROR_CODE, &error_code, sizeof(error_code)); + std::memcpy(base() + OFF_BUFFER_COUNT, &buffer_count, sizeof(buffer_count)); + std::memcpy(base() + OFF_DEVICE_CTX, &device_ctx, sizeof(device_ctx)); + std::memcpy(base() + OFF_LOCAL_WINDOW_BASE, &local_window_base, sizeof(local_window_base)); + std::memcpy(base() + OFF_ACTUAL_WINDOW_SIZE, &actual_window_size, sizeof(actual_window_size)); + if (!buffer_ptrs.empty()) { + std::memcpy(base() + OFF_BUFFER_PTRS, buffer_ptrs.data(), buffer_ptrs.size() * sizeof(uint64_t)); + } + write_state(ChipBootstrapMailboxState::SUCCESS); +} + +void DistChipBootstrapChannel::write_error(int32_t error_code, const std::string &message) { + reset(); + std::memcpy(base() + OFF_ERROR_CODE, &error_code, sizeof(error_code)); + const size_t max_copy = std::max(1, error_msg_capacity()) - 1; + const size_t copy_size = std::min(max_copy, message.size()); + std::memcpy(base() + error_msg_offset(), message.data(), copy_size); + base()[error_msg_offset() + static_cast(copy_size)] = '\0'; + write_state(ChipBootstrapMailboxState::ERROR); +} + +ChipBootstrapMailboxState DistChipBootstrapChannel::state() const { return read_state(); } + +int32_t DistChipBootstrapChannel::error_code() const { + int32_t value = 0; + std::memcpy(&value, base() + OFF_ERROR_CODE, sizeof(value)); + return value; +} + +uint64_t DistChipBootstrapChannel::device_ctx() const { + uint64_t value = 0; + std::memcpy(&value, base() + OFF_DEVICE_CTX, sizeof(value)); + return value; +} + +uint64_t DistChipBootstrapChannel::local_window_base() const { + uint64_t value = 0; + std::memcpy(&value, base() + OFF_LOCAL_WINDOW_BASE, sizeof(value)); + return value; +} + +uint64_t DistChipBootstrapChannel::actual_window_size() const { + uint64_t value = 0; + std::memcpy(&value, base() + OFF_ACTUAL_WINDOW_SIZE, sizeof(value)); + return value; +} + +std::vector DistChipBootstrapChannel::buffer_ptrs() const { + int32_t count = 0; + std::memcpy(&count, base() + OFF_BUFFER_COUNT, sizeof(count)); + if (count < 0 || static_cast(count) > max_buffer_count_) { + throw std::runtime_error("DistChipBootstrapChannel: invalid buffer count in mailbox"); + } + std::vector values(static_cast(count)); + if (!values.empty()) { + std::memcpy(values.data(), base() + OFF_BUFFER_PTRS, values.size() * sizeof(uint64_t)); + } + return values; +} + +std::string DistChipBootstrapChannel::error_message() const { + const char *msg = base() + error_msg_offset(); + return std::string(msg, strnlen(msg, error_msg_capacity())); +} diff --git a/src/common/distributed/dist_chip_bootstrap_channel.h b/src/common/distributed/dist_chip_bootstrap_channel.h new file mode 100644 index 000000000..b04688d0a --- /dev/null +++ b/src/common/distributed/dist_chip_bootstrap_channel.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#pragma once + +#include +#include +#include +#include + +static constexpr size_t DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE = 4096; +static constexpr size_t DIST_CHIP_BOOTSTRAP_HEADER_SIZE = 64; +static constexpr size_t DIST_CHIP_BOOTSTRAP_ERROR_MSG_SIZE = 1024; +static constexpr size_t DIST_CHIP_BOOTSTRAP_PTR_CAPACITY = + (DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE - DIST_CHIP_BOOTSTRAP_HEADER_SIZE - DIST_CHIP_BOOTSTRAP_ERROR_MSG_SIZE) / + sizeof(uint64_t); + +enum class ChipBootstrapMailboxState : int32_t { + IDLE = 0, + SUCCESS = 1, + ERROR = 2, +}; + +class DistChipBootstrapChannel { +public: + DistChipBootstrapChannel(void *mailbox_ptr, size_t max_buffer_count); + + void reset(); + void write_success(uint64_t device_ctx, uint64_t local_window_base, uint64_t actual_window_size, + const std::vector &buffer_ptrs); + void write_error(int32_t error_code, const std::string &message); + + ChipBootstrapMailboxState state() const; + int32_t error_code() const; + uint64_t device_ctx() const; + uint64_t local_window_base() const; + uint64_t actual_window_size() const; + std::vector buffer_ptrs() const; + std::string error_message() const; + +private: + void *mailbox_; + size_t max_buffer_count_; + + static constexpr ptrdiff_t OFF_STATE = 0; + static constexpr ptrdiff_t OFF_ERROR_CODE = 4; + static constexpr ptrdiff_t OFF_BUFFER_COUNT = 8; + static constexpr ptrdiff_t OFF_DEVICE_CTX = 16; + static constexpr ptrdiff_t OFF_LOCAL_WINDOW_BASE = 24; + static constexpr ptrdiff_t OFF_ACTUAL_WINDOW_SIZE = 32; + static constexpr ptrdiff_t OFF_BUFFER_PTRS = 64; + + char *base() const { return static_cast(mailbox_); } + ptrdiff_t error_msg_offset() const { return OFF_BUFFER_PTRS + static_cast(max_buffer_count_ * 8); } + size_t error_msg_capacity() const { + return DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE - static_cast(error_msg_offset()); + } + + ChipBootstrapMailboxState read_state() const; + void write_state(ChipBootstrapMailboxState state); +}; diff --git a/src/common/distributed/dist_orchestrator.cpp b/src/common/distributed/dist_orchestrator.cpp index bfbe467b1..7a1843e90 100644 --- a/src/common/distributed/dist_orchestrator.cpp +++ b/src/common/distributed/dist_orchestrator.cpp @@ -63,13 +63,23 @@ DistSubmitResult DistOrchestrator::submit_group( s.output_bufs.reserve(output_specs.size()); s.output_sizes.reserve(output_specs.size()); + s.output_ownerships.reserve(output_specs.size()); s.output_keys.reserve(output_specs.size()); for (const DistOutputSpec &spec : output_specs) { - void *buf = spec.size > 0 ? ::operator new(spec.size) : nullptr; + void *buf = nullptr; + DistTensorKey key = spec.key; + if (spec.ownership == DistOutputOwnership::ALLOCATED) { + buf = spec.size > 0 ? ::operator new(spec.size) : nullptr; + key = DistTensorKey{-1, reinterpret_cast(buf)}; + } else { + buf = spec.external_ptr; + } s.output_bufs.push_back(buf); s.output_sizes.push_back(spec.size); + s.output_ownerships.push_back(static_cast(spec.ownership)); result.outputs.push_back({buf, spec.size}); + s.output_keys.push_back(key); } // --- Step 3: TensorMap lookup — collect producer slots --- @@ -77,7 +87,7 @@ DistSubmitResult DistOrchestrator::submit_group( std::vector producers; producers.reserve(inputs.size()); for (const DistInputSpec &inp : inputs) { - DistTaskSlot prod = tensormap_->lookup(inp.base_ptr); + DistTaskSlot prod = tensormap_->lookup(inp.key); if (prod != DIST_INVALID_SLOT) { bool found = false; for (DistTaskSlot p : producers) { @@ -92,10 +102,8 @@ DistSubmitResult DistOrchestrator::submit_group( // --- Step 4: TensorMap insert — register outputs --- for (size_t i = 0; i < output_specs.size(); ++i) { - if (s.output_bufs[i]) { - uint64_t key = reinterpret_cast(s.output_bufs[i]); - tensormap_->insert(key, slot); - s.output_keys.push_back(key); + if (s.output_keys[i].base_ptr != 0) { + tensormap_->insert(s.output_keys[i], slot); } } diff --git a/src/common/distributed/dist_orchestrator.h b/src/common/distributed/dist_orchestrator.h index f1c8da0e4..504862454 100644 --- a/src/common/distributed/dist_orchestrator.h +++ b/src/common/distributed/dist_orchestrator.h @@ -49,11 +49,19 @@ // --------------------------------------------------------------------------- struct DistInputSpec { - uint64_t base_ptr; // tensor base address for TensorMap lookup + DistTensorKey key; +}; + +enum class DistOutputOwnership : int32_t { + ALLOCATED = 0, + EXTERNAL = 1, }; struct DistOutputSpec { - size_t size; // bytes to allocate for this output + DistOutputOwnership ownership{DistOutputOwnership::ALLOCATED}; + size_t size{0}; + DistTensorKey key{}; + void *external_ptr{nullptr}; }; struct DistSubmitOutput { diff --git a/src/common/distributed/dist_tensormap.cpp b/src/common/distributed/dist_tensormap.cpp index eb844dfed..ffe6a993e 100644 --- a/src/common/distributed/dist_tensormap.cpp +++ b/src/common/distributed/dist_tensormap.cpp @@ -11,16 +11,16 @@ #include "dist_tensormap.h" -DistTaskSlot DistTensorMap::lookup(uint64_t base_ptr) const { - auto it = map_.find(base_ptr); +DistTaskSlot DistTensorMap::lookup(const DistTensorKey &key) const { + auto it = map_.find(key); if (it == map_.end()) return DIST_INVALID_SLOT; return it->second; } -void DistTensorMap::insert(uint64_t base_ptr, DistTaskSlot producer) { map_[base_ptr] = producer; } +void DistTensorMap::insert(const DistTensorKey &key, DistTaskSlot producer) { map_[key] = producer; } -void DistTensorMap::erase_task_outputs(const std::vector &keys) { - for (uint64_t key : keys) +void DistTensorMap::erase_task_outputs(const std::vector &keys) { + for (const DistTensorKey &key : keys) map_.erase(key); } diff --git a/src/common/distributed/dist_tensormap.h b/src/common/distributed/dist_tensormap.h index 9b2b73c0b..f8d6a92c1 100644 --- a/src/common/distributed/dist_tensormap.h +++ b/src/common/distributed/dist_tensormap.h @@ -28,6 +28,7 @@ #pragma once #include +#include #include #include @@ -37,19 +38,27 @@ class DistTensorMap { public: // Look up the producer for tensor base_ptr. // Returns DIST_INVALID_SLOT when not found. - DistTaskSlot lookup(uint64_t base_ptr) const; + DistTaskSlot lookup(const DistTensorKey &key) const; // Register base_ptr → producer mapping. // Overwrites any existing entry (re-use of the same buffer by a new producer). - void insert(uint64_t base_ptr, DistTaskSlot producer); + void insert(const DistTensorKey &key, DistTaskSlot producer); // Remove all entries whose key appears in 'keys'. // Called when a producer task transitions to CONSUMED. - void erase_task_outputs(const std::vector &keys); + void erase_task_outputs(const std::vector &keys); // Number of entries currently tracked. int32_t size() const; private: - std::unordered_map map_; + struct DistTensorKeyHash { + size_t operator()(const DistTensorKey &key) const { + size_t h1 = std::hash{}(key.worker_index); + size_t h2 = std::hash{}(key.base_ptr); + return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2)); + } + }; + + std::unordered_map map_; }; diff --git a/src/common/distributed/dist_types.cpp b/src/common/distributed/dist_types.cpp index f3267dbf8..42537c723 100644 --- a/src/common/distributed/dist_types.cpp +++ b/src/common/distributed/dist_types.cpp @@ -11,6 +11,8 @@ #include "dist_types.h" +#include "dist_orchestrator.h" + // ============================================================================= // DistTaskSlotState // ============================================================================= @@ -25,10 +27,14 @@ void DistTaskSlotState::reset() { fanout_total = 0; } fanout_released.store(0, std::memory_order_relaxed); - for (void *p : output_bufs) - ::operator delete(p); + for (size_t i = 0; i < output_bufs.size(); ++i) { + if (output_ownerships[i] == static_cast(DistOutputOwnership::ALLOCATED)) { + ::operator delete(output_bufs[i]); + } + } output_bufs.clear(); output_sizes.clear(); + output_ownerships.clear(); output_keys.clear(); fanin_producers.clear(); payload = WorkerPayload{}; diff --git a/src/common/distributed/dist_types.h b/src/common/distributed/dist_types.h index a8ea03675..06a955a6e 100644 --- a/src/common/distributed/dist_types.h +++ b/src/common/distributed/dist_types.h @@ -46,6 +46,19 @@ static constexpr int32_t DIST_INVALID_SLOT = -1; using DistTaskSlot = int32_t; +struct DistTensorKey { + // base_ptr alone is not a stable distributed identity: different workers may + // legally expose the same virtual address. Include worker_index so TensorMap + // can distinguish external buffers coming from different chip/sub-worker + // address spaces. + int32_t worker_index{-1}; + uint64_t base_ptr{0}; + + bool operator==(const DistTensorKey &other) const { + return worker_index == other.worker_index && base_ptr == other.base_ptr; + } +}; + // ============================================================================= // WorkerType // ============================================================================= @@ -110,9 +123,10 @@ struct DistTaskSlotState { // --- Output buffers (malloced by orch, freed when CONSUMED) --- std::vector output_bufs; // one entry per output std::vector output_sizes; + std::vector output_ownerships; // DistOutputOwnership per output_bufs[i]; controls reset()-time free // --- TensorMap keys registered by this task (for cleanup on CONSUMED) --- - std::vector output_keys; + std::vector output_keys; // --- Producer tasks this task depends on (for deferred release) --- // When this task reaches COMPLETED, the Scheduler releases one fanout ref diff --git a/src/common/task_interface/tensor_arg.h b/src/common/task_interface/tensor_arg.h index 04a2002d4..94c7b8558 100644 --- a/src/common/task_interface/tensor_arg.h +++ b/src/common/task_interface/tensor_arg.h @@ -23,11 +23,21 @@ constexpr int CONTINUOUS_TENSOR_MAX_DIMS = 5; +enum class TensorStorageType : uint8_t { + HOST = 0, // data points to host memory; runtime performs H2D before execution + DEVICE = 1, // data already points to device/window memory valid in the target chip context +}; + struct ContinuousTensor { uint64_t data; // Host/device memory address uint32_t shapes[CONTINUOUS_TENSOR_MAX_DIMS]; // Shape per dim (element count) uint32_t ndims; // Number of dimensions (1..5) DataType dtype; // DataType : uint8_t + // Storage kind tells the host runtime whether this tensor needs the usual + // device_malloc + copy_to_device path, or whether the pointer can be reused + // directly as an external device/window buffer. + TensorStorageType storage{TensorStorageType::HOST}; + uint8_t reserved[6]{}; [[nodiscard]] uint64_t nbytes() const { uint64_t total = 1; @@ -36,6 +46,8 @@ struct ContinuousTensor { return total * get_element_size(dtype); } + [[nodiscard]] bool is_device_resident() const { return storage == TensorStorageType::DEVICE; } + template T *data_as() const { return reinterpret_cast(static_cast(data)); diff --git a/src/common/worker/chip_worker.cpp b/src/common/worker/chip_worker.cpp index 8f566d8ff..7801cfa27 100644 --- a/src/common/worker/chip_worker.cpp +++ b/src/common/worker/chip_worker.cpp @@ -36,6 +36,15 @@ T load_symbol(void *handle, const char *name) { return reinterpret_cast(sym); } +template +T load_optional_symbol(void *handle, const char *name) { + dlerror(); + void *sym = dlsym(handle, name); + const char *err = dlerror(); + if (err) return nullptr; + return reinterpret_cast(sym); +} + // Process-wide singleton: libcpu_sim_context.so is loaded once with // RTLD_GLOBAL so that PTO ISA kernel SOs can find pto_cpu_sim_* symbols // via dlsym(RTLD_DEFAULT, ...). Never dlclosed. @@ -109,9 +118,19 @@ void ChipWorker::init( create_device_context_fn_ = load_symbol(handle, "create_device_context"); destroy_device_context_fn_ = load_symbol(handle, "destroy_device_context"); set_device_fn_ = load_symbol(handle, "set_device"); + device_malloc_ctx_fn_ = load_symbol(handle, "device_malloc_ctx"); + device_free_ctx_fn_ = load_symbol(handle, "device_free_ctx"); + copy_to_device_ctx_fn_ = load_symbol(handle, "copy_to_device_ctx"); + copy_from_device_ctx_fn_ = load_symbol(handle, "copy_from_device_ctx"); get_runtime_size_fn_ = load_symbol(handle, "get_runtime_size"); run_runtime_fn_ = load_symbol(handle, "run_runtime"); finalize_device_fn_ = load_symbol(handle, "finalize_device"); + comm_init_fn_ = load_optional_symbol(handle, "comm_init"); + comm_alloc_windows_fn_ = load_optional_symbol(handle, "comm_alloc_windows"); + comm_get_local_window_base_fn_ = load_optional_symbol(handle, "comm_get_local_window_base"); + comm_get_window_size_fn_ = load_optional_symbol(handle, "comm_get_window_size"); + comm_barrier_fn_ = load_optional_symbol(handle, "comm_barrier"); + comm_destroy_fn_ = load_optional_symbol(handle, "comm_destroy"); } catch (...) { dlclose(handle); throw; @@ -172,9 +191,19 @@ void ChipWorker::finalize() { create_device_context_fn_ = nullptr; destroy_device_context_fn_ = nullptr; set_device_fn_ = nullptr; + device_malloc_ctx_fn_ = nullptr; + device_free_ctx_fn_ = nullptr; + copy_to_device_ctx_fn_ = nullptr; + copy_from_device_ctx_fn_ = nullptr; get_runtime_size_fn_ = nullptr; run_runtime_fn_ = nullptr; finalize_device_fn_ = nullptr; + comm_init_fn_ = nullptr; + comm_alloc_windows_fn_ = nullptr; + comm_get_local_window_base_fn_ = nullptr; + comm_get_window_size_fn_ = nullptr; + comm_barrier_fn_ = nullptr; + comm_destroy_fn_ = nullptr; runtime_buf_.clear(); aicpu_binary_.clear(); aicore_binary_.clear(); @@ -205,3 +234,108 @@ void ChipWorker::run(const void *callable, const void *args, const CallConfig &c throw std::runtime_error("run_runtime failed with code " + std::to_string(rc)); } } + +uint64_t ChipWorker::device_malloc(size_t size) { + if (!device_set_) { + throw std::runtime_error("ChipWorker device not set; call set_device() first"); + } + void *ptr = device_malloc_ctx_fn_(device_ctx_, size); + if (ptr == nullptr) { + throw std::runtime_error("device_malloc_ctx failed"); + } + return reinterpret_cast(ptr); +} + +void ChipWorker::device_free(uint64_t dev_ptr) { + if (!device_set_) { + throw std::runtime_error("ChipWorker device not set; call set_device() first"); + } + device_free_ctx_fn_(device_ctx_, reinterpret_cast(dev_ptr)); +} + +void ChipWorker::copy_to_device(uint64_t dev_ptr, uint64_t host_ptr, size_t size) { + if (!device_set_) { + throw std::runtime_error("ChipWorker device not set; call set_device() first"); + } + int rc = copy_to_device_ctx_fn_(device_ctx_, reinterpret_cast(dev_ptr), reinterpret_cast(host_ptr), size); + if (rc != 0) { + throw std::runtime_error("copy_to_device_ctx failed with code " + std::to_string(rc)); + } +} + +void ChipWorker::copy_from_device(uint64_t host_ptr, uint64_t dev_ptr, size_t size) { + if (!device_set_) { + throw std::runtime_error("ChipWorker device not set; call set_device() first"); + } + int rc = copy_from_device_ctx_fn_(device_ctx_, reinterpret_cast(host_ptr), reinterpret_cast(dev_ptr), size); + if (rc != 0) { + throw std::runtime_error("copy_from_device_ctx failed with code " + std::to_string(rc)); + } +} + +uint64_t ChipWorker::comm_init(int rank, int nranks, int device_id, const std::string &rootinfo_path) { + if (comm_init_fn_ == nullptr) { + throw std::runtime_error("comm_init is not available in this runtime"); + } + void *handle = comm_init_fn_(rank, nranks, device_id, rootinfo_path.c_str()); + if (handle == nullptr) { + throw std::runtime_error("comm_init failed"); + } + return reinterpret_cast(handle); +} + +uint64_t ChipWorker::comm_alloc_windows(uint64_t comm_handle, size_t win_size) { + if (comm_alloc_windows_fn_ == nullptr) { + throw std::runtime_error("comm_alloc_windows is not available in this runtime"); + } + uint64_t device_ctx = 0; + int rc = comm_alloc_windows_fn_(reinterpret_cast(comm_handle), win_size, &device_ctx); + if (rc != 0) { + throw std::runtime_error("comm_alloc_windows failed with code " + std::to_string(rc)); + } + return device_ctx; +} + +uint64_t ChipWorker::comm_get_local_window_base(uint64_t comm_handle) { + if (comm_get_local_window_base_fn_ == nullptr) { + throw std::runtime_error("comm_get_local_window_base is not available in this runtime"); + } + uint64_t base = 0; + int rc = comm_get_local_window_base_fn_(reinterpret_cast(comm_handle), &base); + if (rc != 0) { + throw std::runtime_error("comm_get_local_window_base failed with code " + std::to_string(rc)); + } + return base; +} + +size_t ChipWorker::comm_get_window_size(uint64_t comm_handle) { + if (comm_get_window_size_fn_ == nullptr) { + throw std::runtime_error("comm_get_window_size is not available in this runtime"); + } + size_t win_size = 0; + int rc = comm_get_window_size_fn_(reinterpret_cast(comm_handle), &win_size); + if (rc != 0) { + throw std::runtime_error("comm_get_window_size failed with code " + std::to_string(rc)); + } + return win_size; +} + +void ChipWorker::comm_barrier(uint64_t comm_handle) { + if (comm_barrier_fn_ == nullptr) { + throw std::runtime_error("comm_barrier is not available in this runtime"); + } + int rc = comm_barrier_fn_(reinterpret_cast(comm_handle)); + if (rc != 0) { + throw std::runtime_error("comm_barrier failed with code " + std::to_string(rc)); + } +} + +void ChipWorker::comm_destroy(uint64_t comm_handle) { + if (comm_destroy_fn_ == nullptr) { + throw std::runtime_error("comm_destroy is not available in this runtime"); + } + int rc = comm_destroy_fn_(reinterpret_cast(comm_handle)); + if (rc != 0) { + throw std::runtime_error("comm_destroy failed with code " + std::to_string(rc)); + } +} diff --git a/src/common/worker/chip_worker.h b/src/common/worker/chip_worker.h index 09e37b5ff..7ff1ee694 100644 --- a/src/common/worker/chip_worker.h +++ b/src/common/worker/chip_worker.h @@ -57,6 +57,18 @@ class ChipWorker : public IWorker { // Direct invocation (used by Python wrapper and internal tests). void run(const void *callable, const void *args, const CallConfig &config); + uint64_t device_malloc(size_t size); + void device_free(uint64_t dev_ptr); + void copy_to_device(uint64_t dev_ptr, uint64_t host_ptr, size_t size); + void copy_from_device(uint64_t host_ptr, uint64_t dev_ptr, size_t size); + + uint64_t comm_init(int rank, int nranks, int device_id, const std::string &rootinfo_path); + uint64_t comm_alloc_windows(uint64_t comm_handle, size_t win_size); + uint64_t comm_get_local_window_base(uint64_t comm_handle); + size_t comm_get_window_size(uint64_t comm_handle); + void comm_barrier(uint64_t comm_handle); + void comm_destroy(uint64_t comm_handle); + int device_id() const { return device_id_; } bool initialized() const { return initialized_; } bool device_set() const { return device_set_; } @@ -65,19 +77,39 @@ class ChipWorker : public IWorker { using CreateDeviceContextFn = void *(*)(); using DestroyDeviceContextFn = void (*)(void *); using SetDeviceFn = int (*)(void *, int); + using DeviceMallocCtxFn = void *(*)(void *, size_t); + using DeviceFreeCtxFn = void (*)(void *, void *); + using CopyToDeviceCtxFn = int (*)(void *, void *, const void *, size_t); + using CopyFromDeviceCtxFn = int (*)(void *, void *, const void *, size_t); using GetRuntimeSizeFn = size_t (*)(); using RunRuntimeFn = int (*)( void *, void *, const void *, const void *, int, int, int, const uint8_t *, size_t, const uint8_t *, size_t, int ); using FinalizeDeviceFn = int (*)(void *); + using CommInitFn = void *(*)(int, int, int, const char *); + using CommAllocWindowsFn = int (*)(void *, size_t, uint64_t *); + using CommGetLocalWindowBaseFn = int (*)(void *, uint64_t *); + using CommGetWindowSizeFn = int (*)(void *, size_t *); + using CommBarrierFn = int (*)(void *); + using CommDestroyFn = int (*)(void *); void *lib_handle_ = nullptr; CreateDeviceContextFn create_device_context_fn_ = nullptr; DestroyDeviceContextFn destroy_device_context_fn_ = nullptr; SetDeviceFn set_device_fn_ = nullptr; + DeviceMallocCtxFn device_malloc_ctx_fn_ = nullptr; + DeviceFreeCtxFn device_free_ctx_fn_ = nullptr; + CopyToDeviceCtxFn copy_to_device_ctx_fn_ = nullptr; + CopyFromDeviceCtxFn copy_from_device_ctx_fn_ = nullptr; GetRuntimeSizeFn get_runtime_size_fn_ = nullptr; RunRuntimeFn run_runtime_fn_ = nullptr; FinalizeDeviceFn finalize_device_fn_ = nullptr; + CommInitFn comm_init_fn_ = nullptr; + CommAllocWindowsFn comm_alloc_windows_fn_ = nullptr; + CommGetLocalWindowBaseFn comm_get_local_window_base_fn_ = nullptr; + CommGetWindowSizeFn comm_get_window_size_fn_ = nullptr; + CommBarrierFn comm_barrier_fn_ = nullptr; + CommDestroyFn comm_destroy_fn_ = nullptr; void *device_ctx_ = nullptr; std::vector runtime_buf_; diff --git a/src/common/worker/pto_runtime_c_api.h b/src/common/worker/pto_runtime_c_api.h index 382806aff..1d89573e3 100644 --- a/src/common/worker/pto_runtime_c_api.h +++ b/src/common/worker/pto_runtime_c_api.h @@ -17,7 +17,8 @@ * * Public API — resolved by ChipWorker via dlsym: * create_device_context, destroy_device_context, - * get_runtime_size, set_device, run_runtime, finalize_device + * get_runtime_size, set_device, run_runtime, finalize_device, + * device_malloc_ctx, device_free_ctx, copy_to_device_ctx, copy_from_device_ctx * * Memory management: caller allocates a buffer of get_runtime_size() bytes * and passes it to run_runtime(). Error codes: 0 = success, negative = error. @@ -59,6 +60,18 @@ size_t get_runtime_size(void); /** Set the target device. Must be called before the first run_runtime(). */ int set_device(DeviceContextHandle ctx, int device_id); +/** Allocate device memory in the given device context. */ +void *device_malloc_ctx(DeviceContextHandle ctx, size_t size); + +/** Free device memory previously allocated in the given device context. */ +void device_free_ctx(DeviceContextHandle ctx, void *dev_ptr); + +/** Copy host memory to a device pointer within the given device context. */ +int copy_to_device_ctx(DeviceContextHandle ctx, void *dev_ptr, const void *host_ptr, size_t size); + +/** Copy device memory to a host pointer within the given device context. */ +int copy_from_device_ctx(DeviceContextHandle ctx, void *host_ptr, const void *dev_ptr, size_t size); + /** * Build the task graph, execute on device, copy results back, and clean up. * diff --git a/tests/ut/py/test_dist_worker/test_chip_bootstrap_channel.py b/tests/ut/py/test_dist_worker/test_chip_bootstrap_channel.py new file mode 100644 index 000000000..d736975e1 --- /dev/null +++ b/tests/ut/py/test_dist_worker/test_chip_bootstrap_channel.py @@ -0,0 +1,93 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +import ctypes +import os +from multiprocessing.shared_memory import SharedMemory +from pathlib import Path +import sys + +import pytest + +ROOT = Path(__file__).resolve().parents[4] +sys.path.insert(0, str(ROOT / "python")) +sys.path.insert(0, str(ROOT / "examples" / "scripts")) + +from code_runner import CodeRunner, create_code_runner # noqa: E402 +from task_interface import ( # noqa: E402 + DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE, + ChipBootstrapMailboxState, + DistChipBootstrapChannel, +) +from worker import Worker # noqa: E402 + + +def _mailbox_addr(shm: SharedMemory) -> int: + assert shm.buf is not None + return ctypes.addressof(ctypes.c_char.from_buffer(shm.buf)) + + +class TestDistChipBootstrapChannel: + def test_success_roundtrip(self): + shm = SharedMemory(create=True, size=DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE) + try: + ch = DistChipBootstrapChannel(_mailbox_addr(shm), 2) + ch.reset() + ch.write_success(11, 22, 33, [44, 55]) + + assert ch.state == ChipBootstrapMailboxState.SUCCESS + assert ch.error_code == 0 + assert ch.device_ctx == 11 + assert ch.local_window_base == 22 + assert ch.actual_window_size == 33 + assert ch.buffer_ptrs == [44, 55] + finally: + shm.close() + shm.unlink() + + def test_error_roundtrip(self): + shm = SharedMemory(create=True, size=DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE) + try: + ch = DistChipBootstrapChannel(_mailbox_addr(shm), 1) + ch.reset() + ch.write_error(7, "bootstrap failed") + + assert ch.state == ChipBootstrapMailboxState.ERROR + assert ch.error_code == 7 + assert "bootstrap failed" in ch.error_message + finally: + shm.close() + shm.unlink() + + +class TestDistributedWorkerApi: + def test_create_code_runner_returns_unified_code_runner_for_distributed(self): + kernels_dir = ROOT / "examples" / "a2a3" / "tensormap_and_ringbuffer" / "async_notify_demo" / "kernels" + golden_path = ROOT / "examples" / "a2a3" / "tensormap_and_ringbuffer" / "async_notify_demo" / "golden.py" + + old_platform = os.environ.get("PTO_PLATFORM") + os.environ["PTO_PLATFORM"] = "a2a3" + try: + runner = create_code_runner( + kernels_dir=str(kernels_dir), + golden_path=str(golden_path), + platform="a2a3", + nranks=2, + device_ids=[0, 1], + ) + finally: + if old_platform is None: + os.environ.pop("PTO_PLATFORM", None) + else: + os.environ["PTO_PLATFORM"] = old_platform + + assert isinstance(runner, CodeRunner) + assert runner._is_distributed is True + assert runner.nranks == 2 + assert runner.device_ids == [0, 1] From 12b4cc58b52ea2c8055561c8dd5df31022b5a349 Mon Sep 17 00:00:00 2001 From: PKUZHOU <751722308@qq.com> Date: Sun, 12 Apr 2026 17:49:58 +0800 Subject: [PATCH 2/2] add ffn_tp example --- .../ffn_tp_parallel/golden.py | 50 ++++ .../kernels/aic/kernel_local_linear.cpp | 101 +++++++ .../kernels/aiv/kernel_allreduce_sum.cpp | 119 ++++++++ .../ffn_tp_parallel/kernels/kernel_config.py | 55 ++++ .../orchestration/allreduce_sum_orch.cpp | 43 +++ .../kernels/orchestration/ffn_local_orch.cpp | 36 +++ .../kernels/orchestration/host_orch.py | 25 ++ examples/scripts/README.md | 46 +++ examples/scripts/code_runner.py | 270 +++++++++++++++--- 9 files changed, 709 insertions(+), 36 deletions(-) create mode 100644 examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/golden.py create mode 100644 examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/aic/kernel_local_linear.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/aiv/kernel_allreduce_sum.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/kernel_config.py create mode 100644 examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/allreduce_sum_orch.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/ffn_local_orch.cpp create mode 100644 examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/host_orch.py diff --git a/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/golden.py b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/golden.py new file mode 100644 index 000000000..834ddfe61 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/golden.py @@ -0,0 +1,50 @@ +"""Golden script for the ffn_tp_parallel distributed example.""" + +import torch + +M = 64 +K_SHARD = 64 +N = 64 + +__outputs__ = ["y"] +RTOL = 1e-4 +ATOL = 1e-4 + + +def _make_rank_inputs(rank: int): + x = ( + torch.arange(M * K_SHARD, dtype=torch.float32).reshape(M, K_SHARD) + + torch.tensor(float(rank) * 0.25, dtype=torch.float32) + ) / 32.0 + w = ( + torch.arange(K_SHARD * N, dtype=torch.float32).reshape(K_SHARD, N) + + torch.tensor(float(rank + 1) * 0.5, dtype=torch.float32) + ) / 48.0 + return x, w + + +def generate_distributed_inputs(rank: int, nranks: int, root: int, comm_ctx=None) -> list: + del root + del comm_ctx + + x_shard, w_shard = _make_rank_inputs(rank) + zeros = torch.zeros(M * N, dtype=torch.float32) + mailbox = torch.zeros(nranks * M * N, dtype=torch.float32) + notify_counter = torch.zeros(1, dtype=torch.int32) + return [ + ("x_shard", x_shard.flatten().tolist()), + ("w_shard", w_shard.flatten().tolist()), + ("partial_local", zeros.tolist()), + ("partial_window", mailbox.tolist()), + ("y", zeros.tolist()), + ("notify_counter", notify_counter.tolist()), + ] + + +def compute_golden(tensors: dict, params: dict) -> None: + nranks = int(params.get("nranks", 2)) + expected = torch.zeros((M, N), dtype=torch.float32) + for rank in range(nranks): + x_shard, w_shard = _make_rank_inputs(rank) + expected += torch.matmul(x_shard, w_shard) + tensors["y"][:] = expected.flatten() diff --git a/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/aic/kernel_local_linear.cpp b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/aic/kernel_local_linear.cpp new file mode 100644 index 000000000..a706deac2 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/aic/kernel_local_linear.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include + +#include +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +AICORE constexpr inline T CeilAlign(T num_1, T num_2) { + if (num_2 == 0) { + return 0; + } + return (num_1 + num_2 - 1) / num_2 * num_2; +} + +static __aicore__ void local_linear_impl(__gm__ Tensor *x_tensor, __gm__ Tensor *w_tensor, __gm__ Tensor *out_tensor) { + __gm__ float *x_ptr = reinterpret_cast<__gm__ float *>(x_tensor->buffer.addr) + x_tensor->start_offset; + __gm__ float *w_ptr = reinterpret_cast<__gm__ float *>(w_tensor->buffer.addr) + w_tensor->start_offset; + __gm__ float *out_ptr = reinterpret_cast<__gm__ float *>(out_tensor->buffer.addr) + out_tensor->start_offset; + + constexpr int TILE = 64; + constexpr int block_align = C0_SIZE_BYTE / sizeof(float); + constexpr int M = CeilAlign(TILE, 16); + constexpr int K = CeilAlign(TILE, block_align); + constexpr int N = CeilAlign(TILE, block_align); + + using GlobalData = + GlobalTensor, Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + using TileMatA = Tile; + using TileMatB = Tile; + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + GlobalData x_global(x_ptr); + GlobalData w_global(w_ptr); + GlobalData out_global(out_ptr); + + TileMatA x_mat; + TileMatB w_mat; + TASSIGN(x_mat, 0x0); + TASSIGN(w_mat, 0x20000); + + LeftTile x_tile; + RightTile w_tile; + AccTile out_tile; + TASSIGN(x_tile, 0x0); + TASSIGN(w_tile, 0x0); + TASSIGN(out_tile, 0x0); + + TLOAD(x_mat, x_global); + TLOAD(w_mat, w_global); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TMOV(x_tile, x_mat); + TMOV(w_tile, w_mat); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(out_tile, x_tile, w_tile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(out_global, out_tile); + + set_flag(PIPE_FIX, PIPE_S, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *x_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *w_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *out_tensor = reinterpret_cast<__gm__ Tensor *>(args[2]); + local_linear_impl(x_tensor, w_tensor, out_tensor); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/aiv/kernel_allreduce_sum.cpp b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/aiv/kernel_allreduce_sum.cpp new file mode 100644 index 000000000..31c17a128 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/aiv/kernel_allreduce_sum.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include + +#include +#include +#include + +#include "common/comm_context.h" +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +AICORE inline __gm__ T *CommRemotePtr(__gm__ CommDeviceContext *ctx, __gm__ T *local_ptr, int peer_rank) { + uint64_t local_base = ctx->windowsIn[ctx->rankId]; + uint64_t offset = reinterpret_cast(local_ptr) - local_base; + return reinterpret_cast<__gm__ T *>(ctx->windowsIn[peer_rank] + offset); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *partial_local_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *partial_window_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *y_tensor = reinterpret_cast<__gm__ Tensor *>(args[2]); + __gm__ Tensor *notify_counter_tensor = reinterpret_cast<__gm__ Tensor *>(args[3]); + __gm__ CommDeviceContext *comm_ctx = reinterpret_cast<__gm__ CommDeviceContext *>(args[4]); + + __gm__ float *partial_local_ptr = + reinterpret_cast<__gm__ float *>(partial_local_tensor->buffer.addr) + partial_local_tensor->start_offset; + __gm__ float *partial_window_ptr = + reinterpret_cast<__gm__ float *>(partial_window_tensor->buffer.addr) + partial_window_tensor->start_offset; + __gm__ float *y_ptr = reinterpret_cast<__gm__ float *>(y_tensor->buffer.addr) + y_tensor->start_offset; + __gm__ int32_t *notify_counter_ptr = + reinterpret_cast<__gm__ int32_t *>(notify_counter_tensor->buffer.addr) + notify_counter_tensor->start_offset; + + constexpr int kRows = 64; + constexpr int kCols = 64; + constexpr int kElemsPerPartial = kRows * kCols; + + using MatrixGlobal = GlobalTensor, Stride<1, 1, 1, kCols, 1>>; + using MatrixTile = Tile; + + int my_rank = static_cast(comm_ctx->rankId); + int nranks = static_cast(comm_ctx->rankNum); + + MatrixGlobal partial_local_global(partial_local_ptr); + + MatrixTile sum_tile(kRows, kCols); + MatrixTile tmp_tile(kRows, kCols); + MatrixTile staging_tile(kRows, kCols); + TASSIGN(sum_tile, 0x0); + TASSIGN(tmp_tile, 0x10000); + TASSIGN(staging_tile, 0x20000); + + TLOAD(sum_tile, partial_local_global); + pipe_barrier(PIPE_ALL); + + // First publish this rank's local partial into every peer's mailbox slot. + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) { + continue; + } + __gm__ float *remote_mailbox_base = CommRemotePtr(comm_ctx, partial_window_ptr, peer); + __gm__ float *remote_slot_ptr = remote_mailbox_base + my_rank * kElemsPerPartial; + MatrixGlobal remote_slot(remote_slot_ptr); + pto::comm::TPUT(remote_slot, partial_local_global, staging_tile); + } + pipe_barrier(PIPE_ALL); + + // Only notify peers after the TPUT sequence above has been issued. + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) { + continue; + } + __gm__ int32_t *remote_counter = CommRemotePtr(comm_ctx, notify_counter_ptr, peer); + pto::comm::Signal remote_signal(remote_counter); + pto::comm::TNOTIFY(remote_signal, 1, pto::comm::NotifyOp::AtomicAdd); + } + pipe_barrier(PIPE_ALL); + + pto::comm::Signal local_counter(notify_counter_ptr); + pto::comm::TWAIT(local_counter, nranks - 1, pto::comm::WaitCmp::GE); + pipe_barrier(PIPE_ALL); + + // After all peers have published, accumulate the mailbox slots that were + // written into this rank's local comm window. + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) { + continue; + } + __gm__ float *mailbox_slot_ptr = partial_window_ptr + peer * kElemsPerPartial; + MatrixGlobal mailbox_slot(mailbox_slot_ptr); + TLOAD(tmp_tile, mailbox_slot); + pipe_barrier(PIPE_ALL); + TADD(sum_tile, sum_tile, tmp_tile); + pipe_barrier(PIPE_ALL); + } + + MatrixGlobal y_global(y_ptr); + TSTORE(y_global, sum_tile); + pipe_barrier(PIPE_ALL); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/kernel_config.py b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/kernel_config.py new file mode 100644 index 000000000..ea80144de --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/kernel_config.py @@ -0,0 +1,55 @@ +"""Kernel config for the distributed ffn_tp_parallel example.""" + +import os +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent +_platform = os.environ.get("PTO_PLATFORM", "a2a3sim") +_DIST_NRANKS = 2 + +if _platform != "a2a3": + raise RuntimeError("ffn_tp_parallel currently requires PTO_PLATFORM=a2a3") + +KERNELS = [ + { + "func_id": 0, + "name": "LOCAL_LINEAR", + "source": str(_KERNELS_ROOT / "aic" / "kernel_local_linear.cpp"), + "core_type": "aic", + }, + { + "func_id": 1, + "name": "ALLREDUCE_SUM", + "source": str(_KERNELS_ROOT / "aiv" / "kernel_allreduce_sum.cpp"), + "core_type": "aiv", + }, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "orch_thread_num": 1, + "block_dim": 3, + "rounds": 1, +} + +DISTRIBUTED_CONFIG = { + "nranks": _DIST_NRANKS, + "root": 0, + "win_sync_prefix": 256, + "buffers": [ + {"name": "x_shard", "dtype": "float32", "count": 64 * 64, "placement": "device"}, + {"name": "w_shard", "dtype": "float32", "count": 64 * 64, "placement": "device"}, + {"name": "partial_local", "dtype": "float32", "count": 64 * 64, "placement": "device"}, + {"name": "partial_window", "dtype": "float32", "count": _DIST_NRANKS * 64 * 64, "placement": "window"}, + {"name": "y", "dtype": "float32", "count": 64 * 64, "placement": "device"}, + {"name": "notify_counter", "dtype": "int32", "count": 1, "placement": "window"}, + ], + "inputs": ["x_shard", "w_shard", "partial_window", "y", "notify_counter"], + "outputs": ["y"], +} + +DISTRIBUTED_HOST_ORCH = { + "source": str(_KERNELS_ROOT / "orchestration" / "host_orch.py"), + "function_name": "distributed_orch", +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/allreduce_sum_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/allreduce_sum_orch.cpp new file mode 100644 index 000000000..604103746 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/allreduce_sum_orch.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include + +#include "common/comm_context.h" +#include "pto_orchestration_api.h" + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +aicpu_orchestration_config(const ChipStorageTaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{.expected_arg_count = 5}; +} + +__attribute__((visibility("default"))) void aicpu_orchestration_entry(const ChipStorageTaskArgs &orch_args) { + Tensor partial_local = from_tensor_arg(orch_args.tensor(0)); + Tensor partial_window = from_tensor_arg(orch_args.tensor(1)); + Tensor y = from_tensor_arg(orch_args.tensor(2)); + Tensor notify_counter = from_tensor_arg(orch_args.tensor(3)); + auto *comm_ctx = reinterpret_cast(static_cast(orch_args.scalar(0))); + + Arg params; + params.add_input(partial_local); + params.add_inout(partial_window); + params.add_output(y); + params.add_inout(notify_counter); + params.add_scalar((uint64_t)(uintptr_t)comm_ctx); + // Keep publish, notify, wait, and accumulation in one device kernel so + // the peer notification cannot race ahead of the corresponding TPUT. + pto2_rt_submit_aiv_task(1, params); +} + +} // extern "C" diff --git a/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/ffn_local_orch.cpp b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/ffn_local_orch.cpp new file mode 100644 index 000000000..54416670d --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/ffn_local_orch.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +#include + +#include "pto_orchestration_api.h" + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +aicpu_orchestration_config(const ChipStorageTaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{.expected_arg_count = 3}; +} + +__attribute__((visibility("default"))) void aicpu_orchestration_entry(const ChipStorageTaskArgs &orch_args) { + Tensor x_shard = from_tensor_arg(orch_args.tensor(0)); + Tensor w_shard = from_tensor_arg(orch_args.tensor(1)); + Tensor partial_local = from_tensor_arg(orch_args.tensor(2)); + + Arg params; + params.add_input(x_shard); + params.add_input(w_shard); + params.add_output(partial_local); + pto2_rt_submit_aic_task(0, params); +} + +} // extern "C" diff --git a/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/host_orch.py b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/host_orch.py new file mode 100644 index 000000000..1d1641300 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/kernels/orchestration/host_orch.py @@ -0,0 +1,25 @@ +"""Host-side L3 orchestration for the distributed ffn_tp_parallel example.""" + +from pathlib import Path + +_ORCH_ROOT = Path(__file__).parent + +DISTRIBUTED_TASKS = [ + { + "name": "ffn_local", + "source": str(_ORCH_ROOT / "ffn_local_orch.cpp"), + "function_name": "aicpu_orchestration_entry", + "args": ["x_shard", "w_shard", "partial_local"], + }, + { + "name": "allreduce_sum", + "source": str(_ORCH_ROOT / "allreduce_sum_orch.cpp"), + "function_name": "aicpu_orchestration_entry", + "args": ["partial_local", "partial_window", "y", "notify_counter", "deviceCtx"], + }, +] + + +def distributed_orch(ctx) -> None: + ctx.submit_task("ffn_local", outputs=["partial_local"]) + ctx.submit_task("allreduce_sum", inputs=["partial_local"], outputs=["y"]) diff --git a/examples/scripts/README.md b/examples/scripts/README.md index fb2bb46af..80abbc591 100644 --- a/examples/scripts/README.md +++ b/examples/scripts/README.md @@ -236,6 +236,51 @@ def compute_golden(tensors: dict, params: dict) -> None: output[i] = float(nranks * i + 100 * nranks * (nranks - 1) // 2) ``` +### 4.1 Optional Multi-Task Distributed Host Orchestration + +For examples that want to express multiple L3 tasks connected by TensorMap +dependencies, `kernel_config.py` may point at a separate host-orchestration +Python file: + +```python +DISTRIBUTED_HOST_ORCH = { + "source": str(KERNELS_DIR / "orchestration" / "host_orch.py"), + "function_name": "distributed_orch", +} +``` + +Then `host_orch.py` defines the task registry and submit order: + +```python +from pathlib import Path + +_ORCH_ROOT = Path(__file__).parent + +DISTRIBUTED_TASKS = [ + { + "name": "task_a", + "source": str(_ORCH_ROOT / "task_a.cpp"), + "function_name": "aicpu_orchestration_entry", + "args": ["input", "tmp"], + }, + { + "name": "task_b", + "source": str(_ORCH_ROOT / "task_b.cpp"), + "function_name": "aicpu_orchestration_entry", + "args": ["tmp", "output", "deviceCtx"], + }, +] + +def distributed_orch(ctx): + ctx.submit_task("task_a", outputs=["tmp"]) + ctx.submit_task("task_b", inputs=["tmp"], outputs=["output"]) +``` + +`ctx.submit_task(...)` submits one CHIP group task across all ranks. The +runner expands named buffers into distributed input/output specs with +`worker_index=rank`, so the second task can depend on the first purely through +L3 TensorMap lookup/insert. + ### 5. Standard `golden.py` Format ```python @@ -484,6 +529,7 @@ TEST FAILED: Output 'f' does not match golden - **Single-Card Example**: [examples/a2a3/host_build_graph/vector_example/](../a2a3/host_build_graph/vector_example/) - **Async Completion Demo** (2-card, deferred RDMA read): [examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/](../a2a3/tensormap_and_ringbuffer/async_completion_demo/) - **Async Notify Demo** (2-card, TNOTIFY launch gating): [examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/](../a2a3/tensormap_and_ringbuffer/async_notify_demo/) +- **FFN TP Parallel** (multi-task L3 `FFN -> AllReduce`): [examples/a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/](../a2a3/tensormap_and_ringbuffer/ffn_tp_parallel/) ## FAQ diff --git a/examples/scripts/code_runner.py b/examples/scripts/code_runner.py index e1b103aa5..e818e9142 100644 --- a/examples/scripts/code_runner.py +++ b/examples/scripts/code_runner.py @@ -458,6 +458,33 @@ def _temporary_env(env_updates: dict[str, str]): os.environ[k] = prev +class DistributedOrchContext: + """Host-side helper passed to distributed_orch(ctx) in multi-task examples.""" + + def __init__(self, runner, worker, task_payloads, task_rank_args): + self._runner = runner + self._worker = worker + self._task_payloads = task_payloads + self._task_rank_args = task_rank_args + self.nranks = runner.nranks + self.chip_contexts = list(worker.chip_contexts) + + def submit_task(self, task_name: str, inputs: Optional[list[str]] = None, outputs: Optional[list[str]] = None): + payload = self._task_payloads.get(task_name) + if payload is None: + raise KeyError(f"Unknown distributed task: {task_name}") + rank_args = self._task_rank_args.get(task_name) + if rank_args is None: + raise KeyError(f"Missing rank args for distributed task: {task_name}") + return self._worker.submit( + payload.worker_type, + payload, + inputs=self._runner._task_input_specs(inputs or []), + outputs=self._runner._task_output_specs(outputs or []), + args_list=[arg.__ptr__() for arg in rank_args], + ) + + class CodeRunner: """ Simplified test runner that loads kernel config and golden script. @@ -529,7 +556,14 @@ def __init__( # noqa: PLR0913 # Extract kernel configuration self.kernels = self._kernel_config.KERNELS - self.orchestration = self._kernel_config.ORCHESTRATION + self.orchestration = getattr(self._kernel_config, "ORCHESTRATION", None) + self._distributed_host_orch_module = self._load_distributed_host_orch_module() + self._distributed_tasks = getattr( + self._distributed_host_orch_module, + "DISTRIBUTED_TASKS", + getattr(self._kernel_config, "DISTRIBUTED_TASKS", None), + ) + self._distributed_orch = self._load_distributed_orch_callable() # Extract golden configuration — determine which cases to run all_cases = getattr(self._golden_module, "ALL_CASES", {"Default": {}}) @@ -572,7 +606,6 @@ def __init__( # noqa: PLR0913 if len(device_ids) != self.nranks: raise ValueError(f"Expected {self.nranks} device ids, got {len(device_ids)}: {device_ids}") self.device_ids = list(device_ids) - self.orch_func = self.orchestration["function_name"] self._dist_run_dir = ( self.project_root / "build" / "distributed" / "runs" / f"run_{os.getpid()}_{time.time_ns()}" ) @@ -583,6 +616,10 @@ def __init__( # noqa: PLR0913 self._dist_example_output_artifacts: list[dict[str, Path]] = [] self._dist_example_inputs_by_rank = [] self._dist_example_outputs_by_rank = [] + self._worker_chip_contexts = [] + self._validate_distributed_task_config() + elif self.orchestration is None: + raise AttributeError("kernel_config.py must define ORCHESTRATION for non-distributed examples") def _load_kernel_config(self): """Load kernel_config.py from kernels directory.""" @@ -591,6 +628,33 @@ def _load_kernel_config(self): raise FileNotFoundError(f"kernel_config.py not found in {self.kernels_dir}\nExpected: {config_path}") return _load_module_from_path(config_path, f"kernel_config_{id(self)}") + def _load_distributed_host_orch_module(self): + cfg = getattr(self._kernel_config, "DISTRIBUTED_HOST_ORCH", None) + if cfg is None: + return None + if not isinstance(cfg, dict): + raise TypeError("DISTRIBUTED_HOST_ORCH must be a dict when provided") + source = cfg.get("source") + if not isinstance(source, str) or not source: + raise KeyError("DISTRIBUTED_HOST_ORCH must define non-empty 'source'") + source_path = Path(source) + if not source_path.is_absolute(): + source_path = (self.kernels_dir / source_path).resolve() + if not source_path.exists(): + raise FileNotFoundError(f"Distributed host orchestration file not found: {source_path}") + return _load_module_from_path(source_path, f"distributed_host_orch_{id(self)}") + + def _load_distributed_orch_callable(self): + cfg = getattr(self._kernel_config, "DISTRIBUTED_HOST_ORCH", None) + func_name = "distributed_orch" + if cfg is not None: + if not isinstance(cfg, dict): + raise TypeError("DISTRIBUTED_HOST_ORCH must be a dict when provided") + func_name = str(cfg.get("function_name", func_name)) + if self._distributed_host_orch_module is not None: + return getattr(self._distributed_host_orch_module, func_name, None) + return getattr(self._kernel_config, "distributed_orch", None) + def _load_golden_module(self): """Load golden.py script.""" if not self.golden_path.exists(): @@ -599,15 +663,40 @@ def _load_golden_module(self): module = _load_module_from_path(self.golden_path, f"golden_{id(self)}") # Validate required functions - if not hasattr(module, "generate_inputs"): - raise AttributeError(f"golden.py must define generate_inputs(params) function\nFile: {self.golden_path}") if not hasattr(module, "compute_golden"): raise AttributeError( f"golden.py must define compute_golden(tensors, params) function\nFile: {self.golden_path}" ) + if not hasattr(module, "generate_inputs") and not hasattr(module, "generate_distributed_inputs"): + raise AttributeError( + f"golden.py must define generate_inputs(params) or generate_distributed_inputs(rank, nranks, root, comm_ctx=None)\nFile: {self.golden_path}" + ) return module + def _validate_distributed_task_config(self) -> None: + if self._distributed_tasks is None: + if self.orchestration is None: + raise AttributeError("Distributed examples must define ORCHESTRATION or DISTRIBUTED_TASKS") + return + if not isinstance(self._distributed_tasks, list) or not self._distributed_tasks: + raise TypeError("DISTRIBUTED_TASKS must be a non-empty list") + seen = set() + for task in self._distributed_tasks: + if not isinstance(task, dict): + raise TypeError("Each DISTRIBUTED_TASKS item must be a dict") + for key in ("name", "source", "function_name", "args"): + if key not in task: + raise KeyError(f"DISTRIBUTED_TASKS item is missing '{key}'") + name = str(task["name"]) + if name in seen: + raise ValueError(f"Duplicate distributed task name: {name}") + seen.add(name) + if not isinstance(task["args"], list): + raise TypeError(f"DISTRIBUTED_TASKS[{name}].args must be a list") + if self._distributed_orch is not None and not callable(self._distributed_orch): + raise TypeError("distributed_orch must be callable when provided") + def _identify_outputs(self, tensors: dict[str, torch.Tensor]) -> tuple[dict, dict]: """ Separate inputs and outputs from tensor dict using __outputs__. @@ -1037,6 +1126,17 @@ def _chip_buffer_dtype_to_task_dtype(self, dtype: str): raise ValueError(f"Unsupported distributed buffer dtype: {dtype}") return mapping[dtype] + def _dist_task_configs(self) -> list[dict[str, Any]]: + if self._distributed_tasks is None: + return [] + return list(self._distributed_tasks) + + def _dist_task_config(self, task_name: str) -> dict[str, Any]: + for task in self._dist_task_configs(): + if task["name"] == task_name: + return task + raise KeyError(f"Unknown distributed task: {task_name}") + def _chip_runtime_artifact_paths(self): return { "host": self.artifact_dir / "libhost_runtime.so", @@ -1044,23 +1144,28 @@ def _chip_runtime_artifact_paths(self): "aicore": self.artifact_dir / "aicore_kernel.o", } - def _chip_orch_artifact_name(self): + def _chip_orch_artifact_name(self, task_name: Optional[str] = None): + if task_name is not None: + return f"{task_name}.so" + assert self.orchestration is not None return Path(self.orchestration["source"]).stem + ".so" def _chip_kernel_artifact_name(self, kernel_cfg): return Path(kernel_cfg["source"]).stem + ".bin" - def _build_chip_callable(self): + def _build_chip_callable(self, task_name: Optional[str] = None): from task_interface import ChipCallable, CoreCallable # noqa: PLC0415 - orch_binary = (self.artifact_dir / self._chip_orch_artifact_name()).read_bytes() + orch_cfg = self._dist_task_config(task_name) if task_name is not None else self.orchestration + assert orch_cfg is not None + orch_binary = (self.artifact_dir / self._chip_orch_artifact_name(task_name)).read_bytes() children = [] for kernel_cfg in self.kernels: binary = (self.artifact_dir / self._chip_kernel_artifact_name(kernel_cfg)).read_bytes() children.append((kernel_cfg["func_id"], CoreCallable.build(kernel_cfg.get("signature", []), binary))) return ChipCallable.build( - self.orchestration.get("signature", []), - self.orch_func, + orch_cfg.get("signature", []), + orch_cfg["function_name"], orch_binary, children, ) @@ -1188,13 +1293,13 @@ def _build_chip_bootstrap_config(self, rank: int): host_outputs=list(self._dist_example_outputs_by_rank[rank].values()), ) - def _make_chip_task_args(self, chip_context): + def _make_chip_task_args(self, chip_context, arg_tokens: Optional[list[str]] = None): from task_interface import ChipStorageTaskArgs, scalar_to_uint64 # noqa: PLC0415 dist = getattr(self._kernel_config, "DISTRIBUTED_CONFIG", {}) buf_cfg_by_name = {buf["name"]: buf for buf in dist.get("buffers", [])} args = ChipStorageTaskArgs() - for tok in dist.get("args", []): + for tok in arg_tokens or dist.get("args", []): if tok == "nranks": args.add_scalar(scalar_to_uint64(self.nranks)) elif tok == "root": @@ -1219,6 +1324,50 @@ def _make_chip_task_args(self, chip_context): args.add_tensor(tensor_arg) return args + def _task_input_specs(self, buffer_names: list[str]) -> list[tuple[int, int]]: + specs = [] + for rank, chip_context in enumerate(self._worker_chip_contexts): + for name in buffer_names: + specs.append((rank, int(chip_context.buffer_ptrs[name]))) + return specs + + def _task_output_specs(self, buffer_names: list[str]) -> list[dict[str, int]]: + specs: list[dict[str, int]] = [] + for rank, chip_context in enumerate(self._worker_chip_contexts): + for name in buffer_names: + buf_cfg = self._dist_example_buffer_config(name) + specs.append( + { + "ptr": int(chip_context.buffer_ptrs[name]), + "size": self._buffer_nbytes(buf_cfg), + "worker_index": rank, + } + ) + return specs + + def _build_dist_task_payloads(self): + from task_interface import WorkerPayload, WorkerType # noqa: PLC0415 + + runtime_cfg = getattr(self._kernel_config, "RUNTIME_CONFIG", {}) + payloads = {} + for task in self._dist_task_configs(): + payload = WorkerPayload() + payload.worker_type = WorkerType.CHIP + payload.callable = self._chip_callables[task["name"]].buffer_ptr() + payload.block_dim = int(runtime_cfg.get("block_dim", 1)) + payload.aicpu_thread_num = int(runtime_cfg.get("aicpu_thread_num", 1)) + payloads[task["name"]] = payload + return payloads + + def _build_dist_task_rank_args(self): + task_rank_args = {} + for task in self._dist_task_configs(): + task_rank_args[task["name"]] = [ + self._make_chip_task_args(chip_context, list(task["args"])) + for chip_context in self._worker_chip_contexts + ] + return task_rank_args + def _compile_dist_example_artifacts(self): from elf_parser import extract_text_section # noqa: PLC0415 from kernel_compiler import KernelCompiler # noqa: PLC0415 @@ -1243,12 +1392,23 @@ def _compile_dist_example_artifacts(self): runtime_bins = builder.get_binaries(self.runtime_name, build=self.build_runtime) logger.info("=== Phase 2: Compiling orchestration ===") - orch_binary = kernel_compiler.compile_orchestration( - self.runtime_name, - str(Path(self.orchestration["source"]).resolve()), - extra_include_dirs=kernel_compiler.get_orchestration_include_dirs(self.runtime_name), - build_dir=str(self.build_dir), - ) + orchestration_binaries = {} + if self._dist_task_configs(): + for task in self._dist_task_configs(): + orchestration_binaries[task["name"]] = kernel_compiler.compile_orchestration( + self.runtime_name, + str(Path(task["source"]).resolve()), + extra_include_dirs=kernel_compiler.get_orchestration_include_dirs(self.runtime_name), + build_dir=str(self.build_dir), + ) + else: + assert self.orchestration is not None + orchestration_binaries["__single__"] = kernel_compiler.compile_orchestration( + self.runtime_name, + str(Path(self.orchestration["source"]).resolve()), + extra_include_dirs=kernel_compiler.get_orchestration_include_dirs(self.runtime_name), + build_dir=str(self.build_dir), + ) logger.info("=== Phase 3: Compiling kernels ===") extra_includes = kernel_compiler.get_orchestration_include_dirs(self.runtime_name) @@ -1275,10 +1435,20 @@ def save(name, data): save("libhost_runtime.so", runtime_bins.host_path.read_bytes()) save("libaicpu_kernel.so", runtime_bins.aicpu_path.read_bytes()) save("aicore_kernel.o", runtime_bins.aicore_path.read_bytes()) - save(self._chip_orch_artifact_name(), orch_binary) + if self._dist_task_configs(): + for task_name, orch_binary in orchestration_binaries.items(): + save(self._chip_orch_artifact_name(task_name), orch_binary) + else: + save(self._chip_orch_artifact_name(), orchestration_binaries["__single__"]) for _, (kcfg, data) in kernel_bins.items(): save(self._chip_kernel_artifact_name(kcfg), data) - self._chip_callable = self._build_chip_callable() + if self._dist_task_configs(): + self._chip_callables = { + task["name"]: self._build_chip_callable(task["name"]) + for task in self._dist_task_configs() + } + else: + self._chip_callable = self._build_chip_callable() logger.info(f"All artifacts saved to {self.artifact_dir}") def _dump_dist_example_outputs(self) -> None: @@ -1292,7 +1462,6 @@ def _dump_dist_example_outputs(self) -> None: shm.close() def _run_distributed(self): - from task_interface import WorkerPayload, WorkerType # noqa: PLC0415 from worker import Task, Worker # noqa: PLC0415 if not self._dist_example_inputs_by_rank or not self._dist_example_outputs_by_rank: @@ -1304,7 +1473,13 @@ def _run_distributed(self): if rootinfo_file.exists(): rootinfo_file.unlink() - if not hasattr(self, "_chip_callable"): + if self._dist_task_configs(): + if not hasattr(self, "_chip_callables"): + self._chip_callables = { + task["name"]: self._build_chip_callable(task["name"]) + for task in self._dist_task_configs() + } + elif not hasattr(self, "_chip_callable"): self._chip_callable = self._build_chip_callable() runtime_paths = self._chip_runtime_artifact_paths() @@ -1326,23 +1501,45 @@ def _run_distributed(self): chip_bootstrap_configs=chip_bootstrap_configs, ) worker.init() - rank_args = [self._make_chip_task_args(ctx) for ctx in worker.chip_contexts] - - payload = WorkerPayload() - payload.worker_type = WorkerType.CHIP - payload.callable = self._chip_callable.buffer_ptr() - payload.block_dim = int(getattr(self._kernel_config, "RUNTIME_CONFIG", {}).get("block_dim", 1)) - payload.aicpu_thread_num = int(getattr(self._kernel_config, "RUNTIME_CONFIG", {}).get("aicpu_thread_num", 1)) - - def orch_fn(w, args_list): - w.submit( - WorkerType.CHIP, - payload, - args_list=[arg.__ptr__() for arg in args_list], - outputs=[], + self._worker_chip_contexts = list(worker.chip_contexts) + if self._dist_task_configs(): + if self._distributed_orch is None: + raise AttributeError( + "Distributed multi-task examples must define distributed_orch(ctx) in kernel_config.py" + ) + orch_ctx = DistributedOrchContext( + self, + worker, + self._build_dist_task_payloads(), + self._build_dist_task_rank_args(), + ) + + def orch_fn(_worker, ctx): + self._distributed_orch(ctx) + + worker.run(Task(orch=orch_fn, args=orch_ctx)) + else: + from task_interface import WorkerPayload, WorkerType # noqa: PLC0415 + + rank_args = [self._make_chip_task_args(ctx) for ctx in worker.chip_contexts] + + payload = WorkerPayload() + payload.worker_type = WorkerType.CHIP + payload.callable = self._chip_callable.buffer_ptr() + payload.block_dim = int(getattr(self._kernel_config, "RUNTIME_CONFIG", {}).get("block_dim", 1)) + payload.aicpu_thread_num = int( + getattr(self._kernel_config, "RUNTIME_CONFIG", {}).get("aicpu_thread_num", 1) ) - worker.run(Task(orch=orch_fn, args=rank_args)) + def orch_fn(w, args_list): + w.submit( + WorkerType.CHIP, + payload, + args_list=[arg.__ptr__() for arg in args_list], + outputs=[], + ) + + worker.run(Task(orch=orch_fn, args=rank_args)) self._dump_dist_example_outputs() except Exception: logger.exception("Distributed worker execution failed") @@ -1350,6 +1547,7 @@ def orch_fn(w, args_list): finally: if worker is not None: worker.close() + self._worker_chip_contexts = [] for f in self.artifact_dir.glob("barrier_*.ready"): f.unlink()