Skip to content

Commit a15f659

Browse files
committed
feat: add synchronous communication ops
1 parent 33f371d commit a15f659

16 files changed

Lines changed: 1503 additions & 1 deletion

File tree

docs/PTO_IR_manual.md

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,171 @@ pto.store_scalar %val, %ptr[%offset] : !pto.ptr<f32>, f32
988988

989989
---
990990

991+
##### `pto.tput` - Synchronous Remote Write
992+
993+
**Summary:** Lowers to `pto::comm::TPUT(...)` and copies data from local GM to remote GM through a VEC staging tile.
994+
995+
**Arguments:**
996+
997+
| Name | Type | Description |
998+
|------|------|-------------|
999+
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote destination buffer |
1000+
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local source buffer |
1001+
| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile |
1002+
| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer |
1003+
| `atomicType` | `#pto.atomic_type<...>` | Atomic mode, default `atomic_none` |
1004+
1005+
**Constraints & Verification:**
1006+
1007+
- `dst` / `src` must be GM-shaped values with positive static shapes.
1008+
- `dst` and `src` must have the same element type and static shape.
1009+
- `ping` / `pong` must be local VEC tile-like values whose element type matches `src`.
1010+
1011+
**Basic Example:**
1012+
1013+
```mlir
1014+
pto.tput %dst, %src, %ping {atomicType = #pto.atomic_type<atomic_none>} :
1015+
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
1016+
1017+
pto.tput %dst, %src, %ping, %pong {atomicType = #pto.atomic_type<atomic_add>} :
1018+
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
1019+
```
1020+
1021+
---
1022+
1023+
##### `pto.tget` - Synchronous Remote Read
1024+
1025+
**Summary:** Lowers to `pto::comm::TGET(...)` and copies data from remote GM to local GM through a VEC staging tile.
1026+
1027+
**Arguments:**
1028+
1029+
| Name | Type | Description |
1030+
|------|------|-------------|
1031+
| `dst` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Local destination buffer |
1032+
| `src` | GM memref / `pto.tensor_view` / `pto.partition_tensor_view` | Remote source buffer |
1033+
| `ping` | `pto.tile_buf` / local VEC memref | Required staging tile |
1034+
| `pong` | `pto.tile_buf` / local VEC memref | Optional second staging tile for ping-pong transfer |
1035+
1036+
**Constraints & Verification:**
1037+
1038+
- Same GM/global-like and staging constraints as `pto.tput`.
1039+
- `dst` and `src` must have the same element type and static shape.
1040+
1041+
**Basic Example:**
1042+
1043+
```mlir
1044+
pto.tget %dst, %src, %ping :
1045+
!pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>
1046+
```
1047+
1048+
---
1049+
1050+
##### `pto.tnotify` / `pto.twait` / `pto.ttest` - Communication Signal Ops
1051+
1052+
**Summary:** Lower to `pto::comm::TNOTIFY/TWAIT/TTEST` for GM `i32` signal buffers.
1053+
1054+
**Arguments:**
1055+
1056+
| Op | Operands | Attributes | Result |
1057+
|----|----------|------------|--------|
1058+
| `pto.tnotify` | `signal`, `value` | `notifyOp = #pto.notify_op<atomic_add/set>` | none |
1059+
| `pto.twait` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp<eq/ne/gt/ge/lt/le>` | none |
1060+
| `pto.ttest` | `signal`, `cmpValue` | `cmp = #pto.wait_cmp<eq/ne/gt/ge/lt/le>` | `i1` |
1061+
1062+
**Constraints & Verification:**
1063+
1064+
- `signal` must be a GM-shaped value with element type `i32`.
1065+
- `value` / `cmpValue` must be signless integer scalars.
1066+
1067+
**Basic Example:**
1068+
1069+
```mlir
1070+
pto.tnotify %sig, %v {notifyOp = #pto.notify_op<set>} : !pto.partition_tensor_view<1xi32>, i32
1071+
pto.twait %sig, %v {cmp = #pto.wait_cmp<ge>} : !pto.partition_tensor_view<1xi32>, i32
1072+
%ok = pto.ttest %sig, %v {cmp = #pto.wait_cmp<eq>} : !pto.partition_tensor_view<1xi32>, i32 -> i1
1073+
```
1074+
1075+
---
1076+
1077+
##### `pto.tbroadcast` - Collective Broadcast
1078+
1079+
**Summary:** Lowers to `pto::comm::TBROADCAST(...)`.
1080+
1081+
**Arguments:**
1082+
1083+
| Name | Type | Description |
1084+
|------|------|-------------|
1085+
| `src` | GM-shaped value | Root source buffer |
1086+
| `ping` / `pong` | local VEC tile-like values | Staging tiles |
1087+
| `group` | variadic GM-shaped values | Parallel group members |
1088+
| `root` | `i32` attr | Root rank index inside `group` |
1089+
1090+
**Constraints & Verification:**
1091+
1092+
- `group` must be non-empty and all members must have identical types.
1093+
- `src` must have the same type as each `group` member.
1094+
- `root` must be in range `[0, group.size)`.
1095+
1096+
**Basic Example:**
1097+
1098+
```mlir
1099+
pto.tbroadcast %src, %ping, %g0, %g1, %g2 {root = 1, operandSegmentSizes = array<i32: 1, 1, 0, 3>} :
1100+
!pto.partition_tensor_view<128xf32>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>, !pto.partition_tensor_view<128xf32>
1101+
```
1102+
1103+
---
1104+
1105+
##### `pto.comm_tgather` - Collective Gather
1106+
1107+
**Summary:** Communication collective that lowers to `pto::comm::TGATHER(...)`. This op is distinct from tile-level `pto.tgather`.
1108+
1109+
**Arguments:** `dst`, `ping`, optional `pong`, variadic `group`, `root`
1110+
1111+
**Constraints & Verification:**
1112+
1113+
- `group` must be non-empty and all members must have identical types.
1114+
- `dst` element type must match the group element type.
1115+
- `ping` / `pong` must be local VEC tile-like values with matching element type.
1116+
1117+
---
1118+
1119+
##### `pto.comm_tscatter` - Collective Scatter
1120+
1121+
**Summary:** Communication collective that lowers to `pto::comm::TSCATTER(...)`. This op is distinct from tile-level `pto.tscatter`.
1122+
1123+
**Arguments:** `src`, `ping`, optional `pong`, variadic `group`, `root`
1124+
1125+
**Constraints & Verification:**
1126+
1127+
- `group` must be non-empty and all members must have identical types.
1128+
- `src` element type must match the group element type.
1129+
- `ping` / `pong` must be local VEC tile-like values with matching element type.
1130+
1131+
---
1132+
1133+
##### `pto.treduce` - Collective Reduce
1134+
1135+
**Summary:** Lowers to `pto::comm::TREDUCE(...)`.
1136+
1137+
**Arguments:**
1138+
1139+
| Name | Type | Description |
1140+
|------|------|-------------|
1141+
| `dst` | GM-shaped value | Root destination buffer |
1142+
| `acc` | local VEC tile-like value | Accumulation tile |
1143+
| `recvPing` / `recvPong` | local VEC tile-like values | Receive staging tiles |
1144+
| `group` | variadic GM-shaped values | Parallel group members |
1145+
| `reduceOp` | `#pto.reduce_op<sum/max/min>` | Reduction mode |
1146+
| `root` | `i32` attr | Root rank index inside `group` |
1147+
1148+
**Constraints & Verification:**
1149+
1150+
- `group` must be non-empty and all members must have identical types.
1151+
- `dst` element type must match the group element type.
1152+
- `acc` and `recvPing` / `recvPong` must be local VEC tile-like values whose element type matches `dst`.
1153+
1154+
---
1155+
9911156
##### `pto.tmov` - Tile Move Between Local Domains
9921157

9931158
**Summary:** Moves data between local memory domains (for example `mat/acc/vec/bias/scaling`) using tile buffers, and supports the same optional parameter families as the `TMOV/TMOV_FP` APIs in `pto-isa`.

include/PTO/IR/PTOAttrs.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,41 @@ def PTO_AtomicTypeAttr : EnumAttr<PTO_Dialect, PTO_AtomicTypeEnum, "atomic_type"
394394
let summary = "TSTORE atomic type attribute";
395395
}
396396

397+
def PTO_NotifyOpEnum : PTO_I32Enum<
398+
"NotifyOp", "PTO communication notify op", [
399+
I32EnumAttrCase<"AtomicAdd", 0, "atomic_add">,
400+
I32EnumAttrCase<"Set", 1, "set">
401+
]>;
402+
403+
def PTO_NotifyOpAttr : EnumAttr<PTO_Dialect, PTO_NotifyOpEnum, "notify_op"> {
404+
let summary = "communication notify operation attribute";
405+
}
406+
407+
def PTO_WaitCmpEnum : PTO_I32Enum<
408+
"WaitCmp", "PTO communication wait/test compare", [
409+
I32EnumAttrCase<"EQ", 0, "eq">,
410+
I32EnumAttrCase<"NE", 1, "ne">,
411+
I32EnumAttrCase<"GT", 2, "gt">,
412+
I32EnumAttrCase<"GE", 3, "ge">,
413+
I32EnumAttrCase<"LT", 4, "lt">,
414+
I32EnumAttrCase<"LE", 5, "le">
415+
]>;
416+
417+
def PTO_WaitCmpAttr : EnumAttr<PTO_Dialect, PTO_WaitCmpEnum, "wait_cmp"> {
418+
let summary = "communication wait/test comparison attribute";
419+
}
420+
421+
def PTO_ReduceOpEnum : PTO_I32Enum<
422+
"ReduceOp", "PTO communication reduce operation", [
423+
I32EnumAttrCase<"Sum", 0, "sum">,
424+
I32EnumAttrCase<"Max", 1, "max">,
425+
I32EnumAttrCase<"Min", 2, "min">
426+
]>;
427+
428+
def PTO_ReduceOpAttr : EnumAttr<PTO_Dialect, PTO_ReduceOpEnum, "reduce_op"> {
429+
let summary = "communication reduce operation attribute";
430+
}
431+
397432
def PTO_ReluPreModeEnum : PTO_I32Enum<
398433
"ReluPreMode", "PTO TSTORE relu pre mode", [
399434
I32EnumAttrCase<"NoRelu", 0, "no_relu">,

include/PTO/IR/PTOOps.td

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,6 +1655,140 @@ def TestAsyncEventOp : PTO_Op<"test_async_event", [
16551655
}];
16561656
}
16571657

1658+
def TPutOp : PTO_Op<"tput", [
1659+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
1660+
]> {
1661+
let summary = "Synchronous remote write from local GM to remote GM";
1662+
let arguments = (ins
1663+
PTODpsType:$dst,
1664+
PTODpsType:$src,
1665+
PTODpsType:$ping,
1666+
Optional<PTODpsType>:$pong,
1667+
DefaultValuedAttr<PTO_AtomicTypeAttr, "::mlir::pto::AtomicType::AtomicNone">:$atomicType
1668+
);
1669+
let results = (outs);
1670+
let hasVerifier = 1;
1671+
}
1672+
1673+
def TGetOp : PTO_Op<"tget", [
1674+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
1675+
]> {
1676+
let summary = "Synchronous remote read from remote GM to local GM";
1677+
let arguments = (ins
1678+
PTODpsType:$dst,
1679+
PTODpsType:$src,
1680+
PTODpsType:$ping,
1681+
Optional<PTODpsType>:$pong
1682+
);
1683+
let results = (outs);
1684+
let hasVerifier = 1;
1685+
}
1686+
1687+
def TNotifyOp : PTO_Op<"tnotify", [
1688+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
1689+
]> {
1690+
let summary = "Send a signal notification to remote GM";
1691+
let arguments = (ins
1692+
PTODpsType:$signal,
1693+
AnySignlessInteger:$value,
1694+
PTO_NotifyOpAttr:$notifyOp
1695+
);
1696+
let results = (outs);
1697+
let hasVerifier = 1;
1698+
}
1699+
1700+
def TWaitOp : PTO_Op<"twait", [
1701+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
1702+
]> {
1703+
let summary = "Block until signal(s) satisfy a comparison";
1704+
let arguments = (ins
1705+
PTODpsType:$signal,
1706+
AnySignlessInteger:$cmpValue,
1707+
PTO_WaitCmpAttr:$cmp
1708+
);
1709+
let results = (outs);
1710+
let hasVerifier = 1;
1711+
}
1712+
1713+
def TTestOp : PTO_Op<"ttest", [
1714+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
1715+
]> {
1716+
let summary = "Non-blocking signal comparison test";
1717+
let arguments = (ins
1718+
PTODpsType:$signal,
1719+
AnySignlessInteger:$cmpValue,
1720+
PTO_WaitCmpAttr:$cmp
1721+
);
1722+
let results = (outs I1:$result);
1723+
let hasVerifier = 1;
1724+
}
1725+
1726+
def TBroadcastOp : PTO_Op<"tbroadcast", [
1727+
AttrSizedOperandSegments,
1728+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
1729+
]> {
1730+
let summary = "Broadcast local GM data to all group members";
1731+
let arguments = (ins
1732+
PTODpsType:$src,
1733+
PTODpsType:$ping,
1734+
Optional<PTODpsType>:$pong,
1735+
Variadic<PTODpsType>:$group,
1736+
I32Attr:$root
1737+
);
1738+
let results = (outs);
1739+
let hasVerifier = 1;
1740+
}
1741+
1742+
def CommTGatherOp : PTO_Op<"comm_tgather", [
1743+
AttrSizedOperandSegments,
1744+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
1745+
]> {
1746+
let summary = "Gather remote GM data from a communication group";
1747+
let arguments = (ins
1748+
PTODpsType:$dst,
1749+
PTODpsType:$ping,
1750+
Optional<PTODpsType>:$pong,
1751+
Variadic<PTODpsType>:$group,
1752+
I32Attr:$root
1753+
);
1754+
let results = (outs);
1755+
let hasVerifier = 1;
1756+
}
1757+
1758+
def CommTScatterOp : PTO_Op<"comm_tscatter", [
1759+
AttrSizedOperandSegments,
1760+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
1761+
]> {
1762+
let summary = "Scatter local GM data to a communication group";
1763+
let arguments = (ins
1764+
PTODpsType:$src,
1765+
PTODpsType:$ping,
1766+
Optional<PTODpsType>:$pong,
1767+
Variadic<PTODpsType>:$group,
1768+
I32Attr:$root
1769+
);
1770+
let results = (outs);
1771+
let hasVerifier = 1;
1772+
}
1773+
1774+
def TReduceOp : PTO_Op<"treduce", [
1775+
AttrSizedOperandSegments,
1776+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
1777+
]> {
1778+
let summary = "Reduce remote GM data from a communication group";
1779+
let arguments = (ins
1780+
PTODpsType:$dst,
1781+
PTODpsType:$acc,
1782+
PTODpsType:$recvPing,
1783+
Optional<PTODpsType>:$recvPong,
1784+
Variadic<PTODpsType>:$group,
1785+
PTO_ReduceOpAttr:$reduceOp,
1786+
I32Attr:$root
1787+
);
1788+
let results = (outs);
1789+
let hasVerifier = 1;
1790+
}
1791+
16581792
def InitializeL2G2LPipeOp : PTO_Op<"initialize_l2g2l_pipe", [
16591793
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
16601794
]> {

include/pto-c/Dialect/PTO.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,18 @@ MLIR_CAPI_EXPORTED int32_t mlirPTOAccToVecModeAttrGetValue(MlirAttribute attr);
9898
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAReluPreModeAttr(MlirAttribute attr);
9999
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOReluPreModeAttrGet(MlirContext ctx, int32_t value);
100100
MLIR_CAPI_EXPORTED int32_t mlirPTOReluPreModeAttrGetValue(MlirAttribute attr);
101+
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAAtomicTypeAttr(MlirAttribute attr);
102+
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOAtomicTypeAttrGet(MlirContext ctx, int32_t value);
103+
MLIR_CAPI_EXPORTED int32_t mlirPTOAtomicTypeAttrGetValue(MlirAttribute attr);
104+
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsANotifyOpAttr(MlirAttribute attr);
105+
MLIR_CAPI_EXPORTED MlirAttribute mlirPTONotifyOpAttrGet(MlirContext ctx, int32_t value);
106+
MLIR_CAPI_EXPORTED int32_t mlirPTONotifyOpAttrGetValue(MlirAttribute attr);
107+
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAWaitCmpAttr(MlirAttribute attr);
108+
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOWaitCmpAttrGet(MlirContext ctx, int32_t value);
109+
MLIR_CAPI_EXPORTED int32_t mlirPTOWaitCmpAttrGetValue(MlirAttribute attr);
110+
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsAReduceOpAttr(MlirAttribute attr);
111+
MLIR_CAPI_EXPORTED MlirAttribute mlirPTOReduceOpAttrGet(MlirContext ctx, int32_t value);
112+
MLIR_CAPI_EXPORTED int32_t mlirPTOReduceOpAttrGetValue(MlirAttribute attr);
101113
MLIR_CAPI_EXPORTED MlirAttribute mlirPTORoundModeAttrGet(MlirContext ctx, int32_t value);
102114
MLIR_CAPI_EXPORTED bool mlirPTOAttrIsARoundModeAttr(MlirAttribute attr);
103115
MLIR_CAPI_EXPORTED int32_t mlirPTORoundModeAttrGetValue(MlirAttribute attr);

0 commit comments

Comments
 (0)