Skip to content

Commit baa8ea3

Browse files
committed
feat: add synchronous communication ops
1 parent 33f371d commit baa8ea3

16 files changed

Lines changed: 1625 additions & 1 deletion

File tree

docs/PTO_IR_manual.md

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7712,6 +7712,293 @@ pto.trap
77127712

77137713
---
77147714

7715+
### 4.21 Communication Operations
7716+
7717+
This section documents PTO communication primitives. PTOAS currently exposes:
7718+
7719+
- Synchronous point-to-point ops: `pto.tput`, `pto.tget`
7720+
- Synchronous signal ops: `pto.tnotify`, `pto.twait`, `pto.ttest`
7721+
- Synchronous collective ops: `pto.tbroadcast`, `pto.comm_tgather`, `pto.comm_tscatter`, `pto.treduce`
7722+
- Asynchronous communication/session ops: `pto.build_async_session`, `pto.tput_async`, `pto.tget_async`, `pto.wait_async_event`, `pto.test_async_event`
7723+
7724+
##### `pto.build_async_session` - Create Async DMA Session
7725+
7726+
**Summary:** Creates an async DMA session handle used by `pto.tput_async` and `pto.tget_async`.
7727+
7728+
**Arguments:**
7729+
7730+
| Name | Type | Description |
7731+
|------|------|-------------|
7732+
| `scratch` | `pto.tile_buf` / local memref | Local scratch/staging buffer used by the async runtime |
7733+
| `workspace` | `!pto.ptr<...>` / GM memref | Global workspace backing the async session |
7734+
| `sync_id` | optional `i32` attr | Session synchronization ID |
7735+
| `block_bytes` | optional `i64` attr | Communication block size in bytes |
7736+
| `comm_block_offset` | optional `i64` attr | Per-block GM offset in bytes |
7737+
| `queue_num` | optional `i32` attr | Queue count hint |
7738+
| `channel_group_idx` | optional `i64` attr | Communication channel-group selector |
7739+
7740+
**Results:** `!pto.async_session`
7741+
7742+
**Constraints & Verification:**
7743+
7744+
- `scratch` must be tile-like local storage.
7745+
- `workspace` must be a GM pointer/memref.
7746+
- Optional attrs are forwarded as session configuration and must use the declared integer types.
7747+
7748+
**Basic Example:**
7749+
7750+
```mlir
7751+
%session = pto.build_async_session(%scratch, %workspace : !pto.tile_buf<loc=vec, dtype=i8, rows=1, cols=256, v_row=1, v_col=256, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.ptr<i8>) {sync_id = 0 : i32} -> !pto.async_session
7752+
```
7753+
7754+
---
7755+
7756+
##### `pto.tput_async` - Asynchronous Remote Write
7757+
7758+
**Summary:** Starts an asynchronous remote write from local GM to remote GM and returns an async event handle.
7759+
7760+
**Arguments:**
7761+
7762+
| Name | Type | Description |
7763+
|------|------|-------------|
7764+
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote destination buffer |
7765+
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local source buffer |
7766+
| `session` | `!pto.async_session` | Async DMA session |
7767+
7768+
**Results:** `!pto.async_event`
7769+
7770+
**Constraints & Verification:**
7771+
7772+
- `dst` / `src` must be GM-shaped values with identical element type and static shape.
7773+
- Current lowering only supports flat contiguous logical-1D transfers for async GM operands.
7774+
- `session` must come from `pto.build_async_session`.
7775+
7776+
**Basic Example:**
7777+
7778+
```mlir
7779+
%event = pto.tput_async(%dst, %src, %session : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.async_session) -> !pto.async_event
7780+
```
7781+
7782+
---
7783+
7784+
##### `pto.tget_async` - Asynchronous Remote Read
7785+
7786+
**Summary:** Starts an asynchronous remote read from remote GM to local GM and returns an async event handle.
7787+
7788+
**Arguments:**
7789+
7790+
| Name | Type | Description |
7791+
|------|------|-------------|
7792+
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local destination buffer |
7793+
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote source buffer |
7794+
| `session` | `!pto.async_session` | Async DMA session |
7795+
7796+
**Results:** `!pto.async_event`
7797+
7798+
**Constraints & Verification:**
7799+
7800+
- Same operand constraints as `pto.tput_async`.
7801+
- `session` must be compatible with the transfer workspace and staging configuration.
7802+
7803+
**Basic Example:**
7804+
7805+
```mlir
7806+
%event = pto.tget_async(%dst, %src, %session : !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.async_session) -> !pto.async_event
7807+
```
7808+
7809+
---
7810+
7811+
##### `pto.wait_async_event` / `pto.test_async_event` - Async Event Completion
7812+
7813+
**Summary:** Consume an async event produced by `pto.tput_async` / `pto.tget_async`.
7814+
7815+
**Arguments:**
7816+
7817+
| Op | Operands | Result | Description |
7818+
|----|----------|--------|-------------|
7819+
| `pto.wait_async_event` | `event`, `session` | `i1` | Blocking wait for completion |
7820+
| `pto.test_async_event` | `event`, `session` | `i1` | Non-blocking completion test |
7821+
7822+
**Constraints & Verification:**
7823+
7824+
- `event` must have type `!pto.async_event`.
7825+
- `session` must have type `!pto.async_session`.
7826+
- The event/session pair is expected to come from the same async communication flow.
7827+
7828+
**Basic Example:**
7829+
7830+
```mlir
7831+
%done0 = pto.wait_async_event(%event0, %session : !pto.async_event, !pto.async_session) -> i1
7832+
%done1 = pto.test_async_event(%event1, %session : !pto.async_event, !pto.async_session) -> i1
7833+
```
7834+
7835+
---
7836+
7837+
##### `pto.tput` - Synchronous Remote Write
7838+
7839+
**Summary:** Lowers to `pto::comm::TPUT(...)` and copies data from local GM to remote GM through a VEC staging tile.
7840+
7841+
**Arguments:**
7842+
7843+
| Name | Type | Description |
7844+
|------|------|-------------|
7845+
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote destination buffer |
7846+
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local source buffer |
7847+
| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile |
7848+
| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer |
7849+
| `atomicType` | `#pto.atomic_type<...>` | Atomic mode, default `atomic_none` |
7850+
7851+
**Constraints & Verification:**
7852+
7853+
- `dst` / `src` must be GM-shaped values with positive static shapes.
7854+
- `dst` and `src` must have the same element type and static shape.
7855+
- `ping` / `pong` must be local VEC tile-like values whose element type matches `src`.
7856+
7857+
**Basic Example:**
7858+
7859+
```mlir
7860+
pto.tput %dst, %src, %ping {atomicType = #pto.atomic_type<atomic_none>} :
7861+
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
7862+
7863+
pto.tput %dst, %src, %ping, %pong {atomicType = #pto.atomic_type<atomic_add>} :
7864+
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
7865+
```
7866+
7867+
---
7868+
7869+
##### `pto.tget` - Synchronous Remote Read
7870+
7871+
**Summary:** Lowers to `pto::comm::TGET(...)` and copies data from remote GM to local GM through a VEC staging tile.
7872+
7873+
**Arguments:**
7874+
7875+
| Name | Type | Description |
7876+
|------|------|-------------|
7877+
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local destination buffer |
7878+
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote source buffer |
7879+
| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile |
7880+
| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer |
7881+
7882+
**Constraints & Verification:**
7883+
7884+
- Same GM/global-like and staging constraints as `pto.tput`.
7885+
- `dst` and `src` must have the same element type and static shape.
7886+
7887+
**Basic Example:**
7888+
7889+
```mlir
7890+
pto.tget %dst, %src, %ping :
7891+
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
7892+
```
7893+
7894+
---
7895+
7896+
##### `pto.tnotify` / `pto.twait` / `pto.ttest` - Communication Signal Ops
7897+
7898+
**Summary:** Lower to `pto::comm::TNOTIFY/TWAIT/TTEST` for GM `i32` signal buffers.
7899+
7900+
**Arguments:**
7901+
7902+
| Op | Operands | Attributes | Result |
7903+
|----|----------|------------|--------|
7904+
| `pto.tnotify` | `signal`, `value` | `notifyOp = #pto.notify_op<atomic_add/set>` | none |
7905+
| `pto.twait` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp<eq/ne/gt/ge/lt/le>` | none |
7906+
| `pto.ttest` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp<eq/ne/gt/ge/lt/le>` | `i1` |
7907+
7908+
**Constraints & Verification:**
7909+
7910+
- `signal` must be a GM-shaped value with element type `i32`.
7911+
- `value` / `cmpValue` must be signless integer scalars.
7912+
7913+
**Basic Example:**
7914+
7915+
```mlir
7916+
pto.tnotify %sig, %v {notifyOp = #pto.notify_op<set>} : !pto.partition_tensor_view<1xi32>, i32
7917+
pto.twait %sig, %v {cmp = #pto.wait_cmp<ge>} : !pto.partition_tensor_view<1xi32>, i32
7918+
%ok = pto.ttest %sig, %v {cmp = #pto.wait_cmp<eq>} : !pto.partition_tensor_view<1xi32>, i32 -> i1
7919+
```
7920+
7921+
---
7922+
7923+
##### `pto.tbroadcast` - Collective Broadcast
7924+
7925+
**Summary:** Lowers to `pto::comm::TBROADCAST(...)`.
7926+
7927+
**Arguments:**
7928+
7929+
| Name | Type | Description |
7930+
|------|------|-------------|
7931+
| `src` | GM-shaped value | Root source buffer |
7932+
| `ping` / `pong` | local VEC tile-like values | Staging tiles |
7933+
| `group` | variadic GM-shaped values | Parallel group members |
7934+
| `root` | `i32` attr | Root rank index inside `group` |
7935+
7936+
**Constraints & Verification:**
7937+
7938+
- `group` must be non-empty and all members must have identical types.
7939+
- `src` must have the same type as each `group` member.
7940+
- `root` must be in range `[0, group.size)`.
7941+
7942+
**Basic Example:**
7943+
7944+
```mlir
7945+
pto.tbroadcast %src, %ping, %g0, %g1, %g2 {root = 1, operandSegmentSizes = array<i32: 1, 1, 0, 3>} :
7946+
!pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>
7947+
```
7948+
7949+
---
7950+
7951+
##### `pto.comm_tgather` - Collective Gather
7952+
7953+
**Summary:** Communication collective that lowers to `pto::comm::TGATHER(...)`. This op is distinct from tile-level `pto.tgather`.
7954+
7955+
**Arguments:** `dst`, `ping`, optional `pong`, variadic `group`, `root`
7956+
7957+
**Constraints & Verification:**
7958+
7959+
- `group` must be non-empty and all members must have identical types.
7960+
- `dst` element type must match the group element type.
7961+
- `ping` / `pong` must be local VEC tile-like values with matching element type.
7962+
7963+
---
7964+
7965+
##### `pto.comm_tscatter` - Collective Scatter
7966+
7967+
**Summary:** Communication collective that lowers to `pto::comm::TSCATTER(...)`. This op is distinct from tile-level `pto.tscatter`.
7968+
7969+
**Arguments:** `src`, `ping`, optional `pong`, variadic `group`, `root`
7970+
7971+
**Constraints & Verification:**
7972+
7973+
- `group` must be non-empty and all members must have identical types.
7974+
- `src` element type must match the group element type.
7975+
- `ping` / `pong` must be local VEC tile-like values with matching element type.
7976+
7977+
---
7978+
7979+
##### `pto.treduce` - Collective Reduce
7980+
7981+
**Summary:** Lowers to `pto::comm::TREDUCE(...)`.
7982+
7983+
**Arguments:**
7984+
7985+
| Name | Type | Description |
7986+
|------|------|-------------|
7987+
| `dst` | GM-shaped value | Root destination buffer |
7988+
| `acc` | local VEC tile-like value | Accumulation tile |
7989+
| `recvPing` / `recvPong` | local VEC tile-like values | Receive staging tiles |
7990+
| `group` | variadic GM-shaped values | Parallel group members |
7991+
| `reduceOp` | `#pto.reduce_op<sum/max/min>` | Reduction mode |
7992+
| `root` | `i32` attr | Root rank index inside `group` |
7993+
7994+
**Constraints & Verification:**
7995+
7996+
- `group` must be non-empty and all members must have identical types.
7997+
- `dst` element type must match the group element type.
7998+
- `acc` and `recvPing` / `recvPong` must be local VEC tile-like values whose element type matches `dst`.
7999+
8000+
---
8001+
77158002
## 5. Operation Summary Table
77168003

77178004
| Category | Count | Pipeline |

include/PTO/IR/PTOAttrs.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,41 @@ def PTO_AtomicTypeAttr : EnumAttr<PTO_Dialect, PTO_AtomicTypeEnum, "atomic_type"
394394
let summary = "TSTORE atomic type attribute";
395395
}
396396

397+
def PTO_NotifyOpEnum : PTO_I32Enum<
398+
"NotifyOp", "PTO communication notify op", [
399+
I32EnumAttrCase<"AtomicAdd", 0, "atomic_add">,
400+
I32EnumAttrCase<"Set", 1, "set">
401+
]>;
402+
403+
def PTO_NotifyOpAttr : EnumAttr<PTO_Dialect, PTO_NotifyOpEnum, "notify_op"> {
404+
let summary = "communication notify operation attribute";
405+
}
406+
407+
def PTO_WaitCmpEnum : PTO_I32Enum<
408+
"WaitCmp", "PTO communication wait/test compare", [
409+
I32EnumAttrCase<"EQ", 0, "eq">,
410+
I32EnumAttrCase<"NE", 1, "ne">,
411+
I32EnumAttrCase<"GT", 2, "gt">,
412+
I32EnumAttrCase<"GE", 3, "ge">,
413+
I32EnumAttrCase<"LT", 4, "lt">,
414+
I32EnumAttrCase<"LE", 5, "le">
415+
]>;
416+
417+
def PTO_WaitCmpAttr : EnumAttr<PTO_Dialect, PTO_WaitCmpEnum, "wait_cmp"> {
418+
let summary = "communication wait/test comparison attribute";
419+
}
420+
421+
def PTO_ReduceOpEnum : PTO_I32Enum<
422+
"ReduceOp", "PTO communication reduce operation", [
423+
I32EnumAttrCase<"Sum", 0, "sum">,
424+
I32EnumAttrCase<"Max", 1, "max">,
425+
I32EnumAttrCase<"Min", 2, "min">
426+
]>;
427+
428+
def PTO_ReduceOpAttr : EnumAttr<PTO_Dialect, PTO_ReduceOpEnum, "reduce_op"> {
429+
let summary = "communication reduce operation attribute";
430+
}
431+
397432
def PTO_ReluPreModeEnum : PTO_I32Enum<
398433
"ReluPreMode", "PTO TSTORE relu pre mode", [
399434
I32EnumAttrCase<"NoRelu", 0, "no_relu">,

0 commit comments

Comments
 (0)