Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fd152bf
[Docs] C++ trace pipeline design (runtime-tag pairing, ABI)
YWHyuk Jun 24, 2026
be967cf
[TOGSim] C++ trace pipeline: front end, runtime, loader, bridge, Core…
YWHyuk Jun 24, 2026
78f77bc
[TOGSim] Per-iteration tag pairing for multi-tile-K and conv
YWHyuk Jun 24, 2026
b189df4
[TOGSim] Work-item outlining and ABI v12 dispatch
YWHyuk Jun 24, 2026
03b7f11
[TOGSim] SRAM-capacity and SA weight-buffer throttle for the trace path
YWHyuk Jun 24, 2026
05770cb
[Tooling] TOGSim trace timeline (Perfetto) and the trace emits it needs
YWHyuk Jun 24, 2026
a23690a
[TOGSim] Make the C++ trace path the default and stabilize it
YWHyuk Jun 24, 2026
76a2862
[TOGSim] Make the trace runtime test self-contained
YWHyuk Jun 24, 2026
4558e65
[Frontend] Trace cache-safe replay and compile-race fixes
YWHyuk Jun 24, 2026
9033945
[TOGSim] Redesign trace-bridge dependency, barrier, SRAM-version, and…
YWHyuk Jun 24, 2026
2146ee5
[Frontend] Guard MLIR tile sizing against symbolic dims
YWHyuk Jun 22, 2026
cf2950c
[Frontend] Emit symbolic loop bounds and dynamic memref dims
YWHyuk Jun 22, 2026
5743a20
[Frontend] Make the kernel meta import-safe under dynamic shape
YWHyuk Jun 22, 2026
a3a8c57
[Frontend] Skip compile-time Spike validation for dynamic-shape kernels
YWHyuk Jun 22, 2026
5164d86
[Frontend] Sample per-tile cycles on a one-tile copy (dynamic shape)
YWHyuk Jun 22, 2026
6201031
[Frontend] Emit a dynamic-shape trace producer (shape_args loop bounds)
YWHyuk Jun 23, 2026
7d985f1
[TOGSim] Pass the runtime shape to the trace producer via the attribu…
YWHyuk Jun 23, 2026
479d407
[TOGSim] Functional output for dynamic shape (shape-agnostic Spike bi…
YWHyuk Jun 23, 2026
e2841c2
[Test] Dynamic-shape elementwise add on the trace path
YWHyuk Jun 23, 2026
c2c0db4
[Frontend] Tidy dynamic-shape detection and drop dead code
YWHyuk Jun 23, 2026
d86e9cd
[Frontend] Consolidate symbolic-dim guards into is_symbolic_dim
YWHyuk Jun 23, 2026
e8cb0d2
[Docs] Dynamic-shape implementation plan (storage; drop before merge)
YWHyuk Jun 24, 2026
82e9255
[Frontend] Make aligned axis-split symbolic-aware (detection layer)
YWHyuk Jun 24, 2026
50f6550
[Frontend] Modularize buffer-to-memref type construction
YWHyuk Jun 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions AsmParser/tog_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# DEPRECATED (timing path): legacy ONNX Tile-Operation-Graph producer. Builds
# the TOG and serializes it to ONNX for the C++ TileGraphParser. Superseded by
# the C++ trace pipeline (PyTorchSimFrontend/mlir/passes/build_skeleton.py +
# lower_to_emitc.py + cycle_table.py -> a compiled trace .so). Kept live so the
# current pipeline does not break; to be retired once the trace pipeline (P3+)
# stabilizes. See docs/design/togsim_cpp_trace.md.
import os
import sys
import importlib.util
Expand Down
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Note: `TOGSIM_CONFIG` is **overwritten** while inside a `with TOGSimulator(confi
Located under `configs/*.yml`:

- `num_cores`, `core_freq_mhz`, `num_systolic_array_per_core`
- `sa_weight_buffer_depth` (per-SA resident weight slots; **must be > 0** — the simulator errors on 0. Raise it to effectively disable the preload run-ahead throttle. Defaults to 2 if the key is absent.)
- `vpu_num_lanes`, `vpu_spad_size_kb_per_lane`, `vpu_vector_length_bits`
- `dram_type` (`ramulator2` | `simple`), `dram_channels`, `dram_freq_mhz`, `ramulator_config_path`
- `icnt_type` (`simple` | `booksim`), `icnt_latency_cycles`, `icnt_freq_mhz`, `icnt_config_path`
Expand Down
214 changes: 162 additions & 52 deletions PyTorchSimFrontend/extension_codecache.py

Large diffs are not rendered by default.

147 changes: 119 additions & 28 deletions PyTorchSimFrontend/mlir/axis_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,43 +29,130 @@ def _as_int(x):
return None


# --- symbolic-aware boundary arithmetic ------------------------------------
# These reduce EXACTLY to the integer case when their operands are concrete, so
# static axis splitting is unchanged; they additionally accept symbolic size
# expressions (e.g. a flattened reshape extent E = M*N with divisor N), where a
# boundary that is a genuine product of dims divides the extent by construction.
# A dynamic dim symbol is created integer/positive, so sympy proves the
# divisibility (Mod(M*N, N) -> 0) and the quotient (cancel(M*N/N) -> M).

def _divides(d, E):
"""True iff d divides E. For concrete ints this is `E % d == 0`."""
di, Ei = _as_int(d), _as_int(E)
if di is not None and Ei is not None:
return di != 0 and Ei % di == 0
try:
return bool(sympy.simplify(sympy.Mod(E, d)) == 0)
except Exception:
return False


def _eq(a, b):
"""Provable equality of two size exprs (structural for ints)."""
ai, bi = _as_int(a), _as_int(b)
if ai is not None and bi is not None:
return ai == bi
try:
return bool(sympy.simplify(a - b) == 0)
except Exception:
return a == b


def _gt1(x):
"""True iff x is a non-trivial boundary (> 1). A symbolic dim is assumed > 1."""
xi = _as_int(x)
if xi is not None:
return xi > 1
return not _eq(x, sympy.Integer(1))


def _proper(b, E):
"""True iff b is a proper interior divisor of E: 1 < b < E and b | E."""
bi, Ei = _as_int(b), _as_int(E)
if bi is not None and Ei is not None:
return 1 < bi < Ei and Ei % bi == 0
return _gt1(b) and not _eq(b, E) and _divides(b, E)


def _quotient(a, b):
"""a / b as an exact int (concrete) or simplified sympy expr (symbolic)."""
ai, bi = _as_int(a), _as_int(b)
if ai is not None and bi is not None:
return ai // bi
return sympy.cancel(a / b)


def _as_size(x):
"""Wrap a concrete int as sympy.Integer; pass a sympy expr through unchanged
(preserving its integer/positive assumptions)."""
xi = _as_int(x)
return sympy.Integer(xi) if xi is not None else x


def _ordered_chain(boundaries, E):
"""Order the proper divisors of E into a divisibility chain [1, ..., E], else None.

Generalises the old `_is_chain` + numeric `sorted`: orders by the divisibility
partial order (b_i precedes b_j iff b_i | b_j) rather than by numeric value, so
symbolic boundaries (suffix-products of dims, e.g. N | M*N) chain correctly. For
concrete ints this yields exactly the old ascending divisibility chain. Returns
None when the boundaries do not form a TOTAL divisibility chain (the
incompatible-radix / misaligned case), so the axis is left unsplit.
"""
bs = []
for b in boundaries:
if _proper(b, E) and not any(_eq(b, x) for x in bs):
bs.append(b)
ordered = []
remaining = list(bs)
while remaining:
# the divisibility-minimum is the unique element that divides all others.
mins = [b for b in remaining
if all(_divides(b, o) for o in remaining if not _eq(b, o))]
if len(mins) != 1:
return None # no unique minimum -> incomparable -> not a chain
ordered.append(mins[0])
remaining = [o for o in remaining if not _eq(o, mins[0])]
chain = [sympy.Integer(1)] + ordered + [_as_size(E)]
for i in range(len(chain) - 1):
if not _divides(chain[i], chain[i + 1]):
return None
return chain


def collect_boundaries(exprs, var_to_axis, var_ranges):
"""{axis_index: set(boundary cut points)} for the given index expressions.

A FloorDiv(v, k) contributes boundary k; ModularIndexing(v, k, m) contributes
k and k*m. Only aligned terms count (boundary divides the var extent). Shared
by find_split_plan (fused LoopBody) and graph_copy (operand loaders).
by find_split_plan (fused LoopBody) and graph_copy (operand loaders). Boundaries
and extents may be symbolic (dynamic reshape); divisibility is checked via
`_divides`, so a symbolic divisor that is a genuine factor of the extent counts.
"""
import collections
bset = collections.defaultdict(set)
for expr in exprs:
for fd in expr.atoms(FloorDiv):
base, div = fd.args
k = _as_int(div)
if base in var_to_axis and k and k > 1:
E = _as_int(var_ranges.get(base))
if E and E % k == 0:
bset[var_to_axis[base]].add(k)
if base in var_to_axis and _gt1(div):
E = var_ranges.get(base)
if E is not None and _divides(div, E):
bset[var_to_axis[base]].add(div)
for mi in expr.atoms(ModularIndexing):
base, div, mod = mi.args
k, m = _as_int(div), _as_int(mod)
if base in var_to_axis and k and m:
E = _as_int(var_ranges.get(base))
if E and E % (k * m) == 0:
if base in var_to_axis:
E = var_ranges.get(base)
km = div * mod
if E is not None and _divides(km, E):
ax = var_to_axis[base]
if k > 1:
bset[ax].add(k)
if k * m < E:
bset[ax].add(k * m)
if _gt1(div):
bset[ax].add(div)
if _proper(km, E):
bset[ax].add(km)
return bset


def _is_chain(boundaries, E):
"""True iff [1, sorted(boundaries in (1,E)), E] is a divisibility chain."""
chain = [1] + sorted(b for b in boundaries if 1 < b < E) + [E]
return all(chain[i + 1] % chain[i] == 0 for i in range(len(chain) - 1))


def find_split_plan(nodes):
"""Inspect a group of scheduler nodes and return {axis_index: boundaries}.

Expand All @@ -80,13 +167,14 @@ def find_split_plan(nodes):
collected boundaries for an axis do NOT form a divisibility chain (e.g.
floor-by-2 and mod-by-3 on extent 6), the radices are incompatible -> the axis
is left unsplit (its floor/mod stays for the misaligned/recompile path).
Boundaries/extents may be symbolic (see _ordered_chain).

axis_index is positional in the group's iteration space, so the same plan
applies to every fused node sharing that space.
"""
import collections
bset = collections.defaultdict(set) # axis -> set of boundary cut points
ext_of = {} # axis -> extent
ext_of = {} # axis -> extent (int or symbolic)
for n in nodes:
body = getattr(n, "_body", None)
if body is None:
Expand All @@ -95,14 +183,17 @@ def find_split_plan(nodes):
nb = collect_boundaries(body.indexing_exprs.values(), var_to_axis, body.var_ranges)
for ax, bs in nb.items():
bset[ax] |= bs
ext_of[ax] = _as_int(body.var_ranges[body.iter_vars[ax]])
ext_of[ax] = body.var_ranges[body.iter_vars[ax]]

plan = {}
for ax, bs in bset.items():
E = ext_of[ax]
E = ext_of.get(ax)
if E is None:
continue
# require a real, divisibility-chain split (incompatible radices -> skip).
if E and any(1 < b < E for b in bs) and _is_chain(bs, E):
plan[ax] = [1] + sorted(b for b in bs if 1 < b < E) + [E]
chain = _ordered_chain(bs, E)
if chain is not None and len(chain) > 2:
plan[ax] = chain

# A split may push the per-axis index rank past 4. The resulting >4D logical tile
# is peeled into <=4D physical descriptors by the decompose-transfer pass (an
Expand Down Expand Up @@ -143,15 +234,15 @@ def build_split_body(node, plan, prefix="z"):
subs = [] # (symbol, extent, significance) low->high
expr = sympy.Integer(0)
for i in range(len(bounds) - 1):
seg_ext = bounds[i + 1] // bounds[i]
seg_ext = _quotient(bounds[i + 1], bounds[i])
nv = sympy_index_symbol(f"{prefix}{ctr}"); ctr += 1
subs.append((nv, seg_ext, bounds[i]))
expr = expr + nv * bounds[i]
# iteration nest: most-significant (outermost) dim first.
for nv, seg_ext, _sig in reversed(subs):
iter_vars.append(nv)
var_ranges[nv] = sympy.Integer(seg_ext)
index_size.append(sympy.Integer(seg_ext))
var_ranges[nv] = _as_size(seg_ext)
index_size.append(_as_size(seg_ext))
index_args.append(expr)
else:
nv = sympy_index_symbol(f"{prefix}{ctr}"); ctr += 1
Expand Down
7 changes: 5 additions & 2 deletions PyTorchSimFrontend/mlir/mlir_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __str__(self) -> str:
def make_run_fn(
self, input_tensors: torch.Tensor, output_tensors: torch.Tensor
) -> Callable[[], None]:
from PyTorchSimFrontend.extension_codecache import CustomAsyncCompile
from PyTorchSimFrontend.extension_codecache import CustomAsyncCompile, get_header
custom_async_compile = CustomAsyncCompile()

# Check already cached result.
Expand All @@ -80,12 +80,15 @@ def cached_run_fn(*args, autotune_subprocess_timeout_sec=None, **kwargs):
return cached_run_fn

# Run a candidate code
_headers = get_header(self.source_code)
_header_kwargs = {} if _headers is None else {
"global_var_header": _headers[0], "gem5_global_var_header": _headers[1]}
run_method = custom_async_compile.mlir(
self.source_code, vectorlane_size=self.extra_args["vector_lane"],
loop_size=self.extra_args["loop_size"], spad_info=self.extra_args["spad_info"],
vlen=self.extra_args["vlen"], arg_attributes=self.extra_args["arg_attributes"],
origins=self.extra_args["origins"], silent_mode=True,
autotune=self.extra_args['autotune'])
autotune=self.extra_args['autotune'], **_header_kwargs)

args = [
tensor
Expand Down
88 changes: 68 additions & 20 deletions PyTorchSimFrontend/mlir/mlir_caller_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,32 @@ def get_argv_idx(self):
self.arg_use_count += 1
return self.arg_use_count-1

def _is_var(self, flag):
return MLIRKernelArgs.is_mlir_arg_var(flag)

@staticmethod
def _is_symbol(numel):
"""A numel that is a size SYMBOL (e.g. 's52'), not a concrete value. Concrete
sizes may also be strings here (the meta stringifies sympy.Integer, e.g.
'128'); those are numeric, a symbol is not."""
return isinstance(numel, str) and not numel.isdigit()

def _numel_c_expr(self, numel):
"""C expression for an arg's element count. Dynamic shape: a size SYMBOL is
the runtime extent, read into `N_<symbol>` from its size buffer (see
generate_args_define); a concrete numel (int or numeric string) is a literal."""
return f"N_{numel}" if self._is_symbol(numel) else str(numel)

def _assign_argv_indices(self):
"""Assign each loaded/dumped arg an argv slot in arg_attributes order, the
same order Simulator.dump_args writes the .raw paths. Size (VAR) args get a
slot too (they are kernel inputs)."""
for arg_name, arg_attribute in self.arg_attributes:
flag = arg_attribute[0]
if (self.is_in_arg(flag) or self.is_out_arg(flag) or self._is_var(flag)) \
and arg_name not in self.load_args:
self.load_args[arg_name] = self.get_argv_idx()

def write_header(self):
self.writeline('#include <stdio.h>')
self.writeline('#include <stdlib.h>')
Expand All @@ -56,12 +82,12 @@ def is_inout_arg(self, value):

def load_arg(self):
for arg_name, arg_attribute in self.arg_attributes:
if self.is_in_arg(arg_attribute[0]):
argv_idx = self.get_argv_idx() if arg_name not in self.load_args else self.load_args[arg_name]
self.load_args[arg_name] = argv_idx
# VAR (size) args are loaded in generate_args_define (before the tensor
# buffers they size); skip them here.
if self.is_in_arg(arg_attribute[0]) and not self._is_var(arg_attribute[0]):
argv_idx = self.load_args[arg_name]
ctype = DTYPE_TO_C[arg_attribute[1]]
elem_count = arg_attribute[2]
size_expr = f'({elem_count}ULL * sizeof({ctype}))'
size_expr = f'((uint64_t)({self._numel_c_expr(arg_attribute[2])}) * sizeof({ctype}))'

self.writeline(f'if(load_arg(c_{arg_name}, {size_expr}, argv[{argv_idx}]) == -1){self.open_bracket}')
with self.code.indent():
Expand All @@ -71,10 +97,9 @@ def load_arg(self):
def dump_arg(self):
for arg_name, arg_attribute in self.arg_attributes:
if self.is_out_arg(arg_attribute[0]):
argv_idx = self.get_argv_idx() if not self.is_inout_arg(arg_attribute[0]) else self.load_args[arg_name]
argv_idx = self.load_args[arg_name]
ctype = DTYPE_TO_C[arg_attribute[1]]
elem_count = arg_attribute[2]
size_expr = f'({elem_count}ULL * sizeof({ctype}))'
size_expr = f'((uint64_t)({self._numel_c_expr(arg_attribute[2])}) * sizeof({ctype}))'
self.writeline(f'if(dump_arg(c_{arg_name}, {size_expr}, argv[{argv_idx}]) == -1){self.open_bracket}')
with self.code.indent():
self.writeline(f'return -1{self.ending}')
Expand All @@ -93,30 +118,53 @@ def generate_args_define(self):
name_set = set()
if self.validation:
self.writeline(f"int* padding = malloc(0x100000ULL * sizeof(int)){self.ending}")
for arg_name, (_, arg_type, arg_size, arg_sizes, arg_stride) in self.arg_attributes:
if not arg_name in name_set:
if torch.is_floating_point(torch.tensor([], dtype=arg_type)):
bits = torch.finfo(arg_type).bits
elif arg_type == torch.bool:
bits = 8
else:
bits = torch.iinfo(arg_type).bits
buffer_size = int(math.ceil(arg_size * bits // 8 / 64) * 64) * 2 # Round up to 64 bytes + Add some padding for safety
self.writeline(f'{DTYPE_TO_C[arg_type]}* c_{arg_name} = malloc({buffer_size}ULL){self.ending}')
name_set.add(arg_name)
# Dynamic shape: handle size (VAR) args first -- malloc, load from argv, and
# read the runtime extent into N_<name>, BEFORE the tensor buffers, which are
# sized from it.
for arg_name, (flag, arg_type, arg_size, _, _) in self.arg_attributes:
if not self._is_var(flag) or arg_name in name_set:
continue
ctype = DTYPE_TO_C[arg_type]
self.writeline(f'{ctype}* c_{arg_name} = malloc(64ULL){self.ending}')
if self.validation:
self.writeline(f'if(load_arg(c_{arg_name}, sizeof(int64_t), argv[{self.load_args[arg_name]}]) == -1){self.open_bracket}')
with self.code.indent():
self.writeline(f'return -1{self.ending}')
self.writeline(self.closed_bracket)
self.writeline(f'int64_t N_{arg_name} = ((int64_t*)c_{arg_name})[0]{self.ending}')
name_set.add(arg_name)
for arg_name, (flag, arg_type, arg_size, arg_sizes, arg_stride) in self.arg_attributes:
if self._is_var(flag) or arg_name in name_set:
continue
if torch.is_floating_point(torch.tensor([], dtype=arg_type)):
bits = torch.finfo(arg_type).bits
elif arg_type == torch.bool:
bits = 8
else:
bits = torch.iinfo(arg_type).bits
ctype = DTYPE_TO_C[arg_type]
if self._is_symbol(arg_size):
# runtime extent: round bytes up to 64 and double, computed in C.
nbytes = f"(N_{arg_size} * {bits} / 8)"
buffer_size = f"((({nbytes} + 63) / 64) * 64) * 2"
else:
buffer_size = f"{int(math.ceil(int(arg_size) * bits // 8 / 64) * 64) * 2}ULL" # round up to 64 bytes + safety pad
self.writeline(f'{ctype}* c_{arg_name} = malloc({buffer_size}){self.ending}')
name_set.add(arg_name)
self.writeline(self.newline)

def generate_main(self):
self.writeline(f'{self.newline}int main(int argc, char *argv[]) {self.open_bracket}{self.newline}')
with self.code.indent():
if self.validation:
self._assign_argv_indices() # argv slots in arg order (incl. size args)
self.generate_args_define()
self.load_arg()
self.writeline(self.newline)
else:
self.generate_args_define()

func_arguments = [f"c_{arg_name}, c_{arg_name}, 0, {arg_shape}, 1" for arg_name, (_, arg_type, arg_shape, _, _) in self.arg_attributes]
func_arguments = [f"c_{arg_name}, c_{arg_name}, 0, {self._numel_c_expr(arg_shape)}, 1" for arg_name, (_, arg_type, arg_shape, _, _) in self.arg_attributes]
self.writeline(f"wrapper_{self.kernel_name}({', '.join(func_arguments)}){self.ending}{self.newline}")

if self.validation:
Expand Down
Loading