diff --git a/TraceLens/EventReplay/__init__.py b/TraceLens/EventReplay/__init__.py index fb6cb2734..4f79b0b11 100644 --- a/TraceLens/EventReplay/__init__.py +++ b/TraceLens/EventReplay/__init__.py @@ -3,3 +3,21 @@ # # See LICENSE for license information. ############################################################################### + +from .event_replay import EventReplayer +from .custom_inits import ( + CustomInit, + PagedAttentionInit, + MoeRoutingInit, + extract_batch_context, +) +from .utils import benchmark_func + +__all__ = [ + "EventReplayer", + "CustomInit", + "PagedAttentionInit", + "MoeRoutingInit", + "extract_batch_context", + "benchmark_func", +] diff --git a/TraceLens/EventReplay/batched_replay.py b/TraceLens/EventReplay/batched_replay.py index d948e6d45..97008cb75 100644 --- a/TraceLens/EventReplay/batched_replay.py +++ b/TraceLens/EventReplay/batched_replay.py @@ -99,11 +99,17 @@ def _get_args_kwargs_from_ir( replayed_count = 0 errors = 0 - for i, repro_info in enumerate(repro_data_list): + ops_to_replay = repro_data_list + if args.op_filter: + ops_to_replay = [r for r in ops_to_replay if args.op_filter in r["op_name"]] + if args.op_limit: + ops_to_replay = ops_to_replay[: args.op_limit] + + for i, repro_info in enumerate(ops_to_replay): op_name = repro_info["op_name"] replay_ir = repro_info["replay_ir"] - print(f"\n[{replayed_count + 1}/{len(repro_data_list)}] Replaying: {op_name}") + print(f"\n[{replayed_count + 1}/{len(ops_to_replay)}] Replaying: {op_name}") # Get the PyTorch operation function try: @@ -151,15 +157,16 @@ def _get_args_kwargs_from_ir( errors += 1 continue # --- Benchmark the function --- - mean_time_us = benchmark_func( + metrics = benchmark_func( lambda: func(*pos_args, **kwargs), args.device, warmup=50, avg_steps=100 ) - print(f" Average time taken: {mean_time_us:.2f} microseconds") + mean_time_us = metrics["mean_us"] + print(f" Average time taken: {mean_time_us:.2f} us (median: {metrics['median_us']:.2f} us)") if "count" in repro_info: count_workload = repro_info["count"] total_time_us = mean_time_us * count_workload print(f" Count in workload: {count_workload}") - print(f" Est time in workload: {total_time_us:.2f} microseconds") + print(f" Est time in workload: {total_time_us:.2f} us") # --- Optionally sync again --- if args.device == "cuda": torch.cuda.synchronize() @@ -190,7 +197,9 @@ def _get_args_kwargs_from_ir( print("\n--- Replay Summary ---") print(f"Total operations in file: {len(repro_data_list)}") if args.op_filter: - print(f"Filter applied: '{args.op_filter}'") + print(f"Filter applied: '{args.op_filter}' ({len(ops_to_replay)} matched)") + if args.op_limit: + print(f"Limit applied: {args.op_limit}") print(f"Attempted replays: {replayed_count}") print(f"Successful replays: {replayed_count - errors}") print(f"Errors encountered: {errors}") diff --git a/TraceLens/EventReplay/custom_inits.py b/TraceLens/EventReplay/custom_inits.py new file mode 100644 index 000000000..bce9960ef --- /dev/null +++ b/TraceLens/EventReplay/custom_inits.py @@ -0,0 +1,390 @@ +############################################################################### +# Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +"""Custom initializers for EventReplayer. + +Operations captured by the PyTorch profiler have zeroed-out metadata tensors +(block tables, routing tensors, etc.) because the profiler records shapes and +dtypes but not tensor values. Custom initializers fill these tensors with +realistic content so the GPU kernel exercises real memory-access and compute +patterns during replay benchmarking. + +To add a custom initializer for a new op family: + 1. Subclass ``CustomInit`` + 2. Set ``op_patterns`` to one or more substrings that match the op name + 3. Implement ``initialize()`` — mutate replayer.args / replayer.kwargs in-place + 4. Return a one-line summary string (printed by EventReplayer) + 5. Register with ``EventReplayer.register_custom_init(YourInit())`` + or add it to the ``_custom_init_registry`` default list. +""" + +from __future__ import annotations + +import re +import warnings +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + pass # EventReplayer imported at runtime to avoid circular dep + +# -- Batch context extraction from vLLM profiler annotations --------------- + +_BATCH_ANNO_RE = re.compile( + r"execute_context_(\d+)\((\d+)\)_generation_(\d+)\((\d+)\)" +) + + +def extract_batch_context(analyzer: Any) -> int: + """Parse vLLM ``user_annotation`` events and attach batch context to ops. + + vLLM annotates each ``execute_model`` step with a ``user_annotation`` + event of the form ``execute_context_N(T)_generation_N(T)`` where + *N* = number of sequences and *T* = total query tokens for that phase. + + This function: + 1. Collects all such annotations with their ``[ts, ts+dur]`` ranges. + 2. For every ``paged_attention`` cpu_op event, finds the enclosing + annotation by timestamp and attaches a ``batch_context`` dict:: + + event["batch_context"] = { + "n_prefill": 2, + "prefill_tokens": 18, + "n_decode": 2, + "decode_tokens": 2, + } + + Args: + analyzer: A ``TreePerfAnalyzer`` (or any object whose ``.tree.events`` + yields the trace event list). + + Returns: + Number of paged_attention events that were annotated. + """ + annotations = [] + for e in analyzer.tree.events: + cat = e.get("cat") or "" + if cat != "user_annotation": + continue + m = _BATCH_ANNO_RE.search(e.get("name", "")) + if not m: + continue + ts = e.get("ts", 0) + dur = e.get("dur", 0) + annotations.append({ + "ts": ts, + "end": ts + dur, + "n_prefill": int(m.group(1)), + "prefill_tokens": int(m.group(2)), + "n_decode": int(m.group(3)), + "decode_tokens": int(m.group(4)), + }) + + if not annotations: + return 0 + + annotations.sort(key=lambda a: a["ts"]) + + annotated = 0 + for e in analyzer.tree.events: + name = e.get("name", "") + if "paged_attention" not in name: + continue + if not e.get("args", {}).get("Input Dims"): + continue + ts = e.get("ts", 0) + for a in annotations: + if a["ts"] <= ts <= a["end"]: + e["batch_context"] = { + "n_prefill": a["n_prefill"], + "prefill_tokens": a["prefill_tokens"], + "n_decode": a["n_decode"], + "decode_tokens": a["decode_tokens"], + } + annotated += 1 + break + + return annotated + + +class CustomInit(ABC): + """Base class for tensor initializers applied before replay.""" + + op_patterns: List[str] = [] + + def applies_to(self, replayer: Any) -> bool: + op_name = replayer.event.get("name", "") + return op_name in self.op_patterns + + @abstractmethod + def initialize(self, replayer: Any, **kwargs) -> Optional[str]: + """Mutate replayer.args/kwargs in-place. Return a summary string.""" + ... + + +class PagedAttentionInit(CustomInit): + """Initialize block_tables, seq_lens, and query_start_loc for paged attention. + + When ``batch_context`` is present on the event (attached by + :func:`extract_batch_context`), uses the exact prefill/decode split from + vLLM's profiler annotations. Otherwise falls back to heuristics: + - query_tokens == num_seqs → decode (1 token/seq) + - query_tokens > num_seqs → prefill (tokens distributed uniformly) + + In all cases: + - ``seq_lens`` set to ``max_seq_len`` for every sequence. + - Block table entries drawn from a random permutation of the pool. + """ + + op_patterns = ["_rocm_C::paged_attention"] + + def initialize(self, replayer: Any, **kwargs) -> Optional[str]: + try: + import numpy as np + except ImportError: + return "[custom init] PagedAttentionInit skipped — numpy not available" + + args = replayer.args + op_name = replayer.event.get("name", "") + + ir = replayer.event_replay_IR + arg_names = [a["arg_name"] for a in ir["list_pos_args"]] + def _by_name_or_pos(name, pos): + if name in arg_names: + return args[arg_names.index(name)] + return args[pos] + + block_tables = _by_name_or_pos("block_tables", 9) + seq_lens = _by_name_or_pos("seq_lens", 10) + key_cache = _by_name_or_pos("key_cache", 5) + block_size = int(_by_name_or_pos("block_size", 12)) + max_seq_len = int(_by_name_or_pos("max_seq_len", 13)) + + num_seqs = block_tables.shape[0] + max_blocks_per_seq = block_tables.shape[1] + num_blocks_total = key_cache.shape[0] + + query = _by_name_or_pos("query", 4) + num_query_tokens = query.shape[0] + + rng = np.random.default_rng(42) + + # -- Determine per-sequence query token counts ------------------------- + batch_ctx = replayer.event.get("batch_context") + if batch_ctx is not None: + n_pf = batch_ctx["n_prefill"] + pf_tok = batch_ctx["prefill_tokens"] + n_dec = batch_ctx["n_decode"] + dec_tok = batch_ctx["decode_tokens"] + + per_seq_queries = [] + if n_pf > 0: + base_pf = pf_tok // n_pf + rem_pf = pf_tok % n_pf + for s in range(n_pf): + per_seq_queries.append(base_pf + (1 if s < rem_pf else 0)) + for _ in range(n_dec): + per_seq_queries.append(1) + + if len(per_seq_queries) != num_seqs: + per_seq_queries = per_seq_queries[:num_seqs] + while len(per_seq_queries) < num_seqs: + per_seq_queries.append(1) + + phase = ("mixed" if n_pf > 0 and n_dec > 0 + else "prefill" if n_pf > 0 else "decode") + source = "annotation" + else: + tokens_per_seq = num_query_tokens / num_seqs if num_seqs else 1 + if tokens_per_seq > 1: + base_q = num_query_tokens // num_seqs + rem_q = num_query_tokens % num_seqs + per_seq_queries = [base_q + (1 if s < rem_q else 0) + for s in range(num_seqs)] + phase = "prefill" + else: + per_seq_queries = [1] * num_seqs + phase = "decode" + source = "heuristic" + + # -- seq_lens: max_seq_len for every sequence -------------------------- + lengths = np.full(num_seqs, max_seq_len, dtype=np.int32) + + # -- block_tables: permutation of physical block pool ------------------ + bt = np.zeros((num_seqs, max_blocks_per_seq), dtype=np.int32) + all_block_ids = rng.permutation(num_blocks_total) + block_cursor = 0 + for s in range(num_seqs): + blocks_needed = (int(lengths[s]) + block_size - 1) // block_size + blocks_needed = min(blocks_needed, max_blocks_per_seq) + for b in range(blocks_needed): + bt[s, b] = all_block_ids[block_cursor % num_blocks_total] + block_cursor += 1 + + import torch + + block_tables.copy_(torch.from_numpy(bt).to(block_tables.device)) + seq_lens.copy_(torch.from_numpy(lengths).to(seq_lens.device)) + + # -- query_start_loc: CSR indptr encoding per-seq query counts --------- + qsl = _by_name_or_pos("query_start_loc", 11) + if (qsl is not None + and hasattr(qsl, "shape") + and qsl.numel() > 0): + qloc = np.zeros(num_seqs + 1, dtype=np.int32) + for s in range(num_seqs): + qloc[s + 1] = qloc[s] + per_seq_queries[s] + qloc = qloc[: qsl.numel()] + qsl.copy_(torch.from_numpy(qloc).to(qsl.device)) + + ctx_str = "" + if batch_ctx is not None: + ctx_str = (f" Annotation: {batch_ctx['n_prefill']} prefill " + f"({batch_ctx['prefill_tokens']} tok) + " + f"{batch_ctx['n_decode']} decode " + f"({batch_ctx['decode_tokens']} tok).") + + return ( + f"[custom init] {op_name} — paged attention metadata: " + f"phase={phase} ({source}), num_seqs={num_seqs}, " + f"max_seq_len={max_seq_len}, block_size={block_size}, " + f"num_blocks={num_blocks_total}, " + f"max_blocks_per_seq={max_blocks_per_seq}.{ctx_str}" + ) + + +class MoeRoutingInit(CustomInit): + """Initialize MoE routing tensors (sorted_token_ids, sorted_expert_ids, + num_valid_ids) so the CK kernel processes real token-to-expert assignments + instead of short-circuiting on num_valid_ids=0. + + Supported kwargs: + moe_distribution: "uniform" (default) or "zipf" + moe_zipf_s: Zipf exponent (default 1.2), only used with "zipf" + + Arg layout for aiter::ck_moe_stage1/2: + [0] hidden_states [M, K] bf16 + [1] w1 [E, N, K] bf16 + [2] w2 [E, K2, N2] bf16 + [3] sorted_token_ids [padded] int32 <- init + [4] sorted_expert_ids [blocks+1] int32 <- init + [5] num_valid_ids [2] int32 <- init + [6] output [M, top_k, N2] bf16 + [7] top_k (scalar) + ... + [11] block_m (scalar) + """ + + op_patterns = ["aiter::ck_moe_stage1", "aiter::ck_moe_stage2"] + + def initialize(self, replayer: Any, **kwargs) -> Optional[str]: + try: + import numpy as np + except ImportError: + return "[custom init] MoeRoutingInit skipped — numpy not available" + + distribution = kwargs.get("moe_distribution", "uniform") + zipf_s = kwargs.get("moe_zipf_s", 1.2) + + args = replayer.args + op_name = replayer.event.get("name", "") + + # Locate args by name from the IR when available, fall back to position + ir = replayer.event_replay_IR + arg_names = [a["arg_name"] for a in ir["list_pos_args"]] + def _by_name_or_pos(name, pos): + if name in arg_names: + return args[arg_names.index(name)] + return args[pos] + + sorted_token_ids = _by_name_or_pos("sorted_token_ids", 3) + sorted_expert_ids = _by_name_or_pos("sorted_expert_ids", 4) + num_valid_ids = _by_name_or_pos("num_valid_ids", 5) + top_k = int(_by_name_or_pos("topk", 7)) + block_m_val = _by_name_or_pos("block_m", 11) + block_m = int(block_m_val) if block_m_val is not None else 32 + + hidden = _by_name_or_pos("hidden_states", 0) + M = hidden.shape[0] + w1 = _by_name_or_pos("w1", 1) + E = w1.shape[0] + num_tokens = M * top_k + padded_total = sorted_token_ids.shape[0] + num_blocks = sorted_expert_ids.shape[0] + + rng = np.random.default_rng(42) + + if distribution == "zipf": + ranks = np.arange(1, E + 1, dtype=np.float64) + weights = 1.0 / np.power(ranks, zipf_s) + probs = weights / weights.sum() + expert_assignments = rng.choice(E, size=num_tokens, p=probs) + else: + expert_assignments = rng.integers(0, E, size=num_tokens) + + token_ids_list: list = [] + expert_ids_list: list = [] + for expert_id in range(E): + tokens_for_expert = np.where(expert_assignments == expert_id)[0] + count = len(tokens_for_expert) + if count == 0: + continue + padded_count = ((count + block_m - 1) // block_m) * block_m + n_blocks_for_expert = padded_count // block_m + padded_tokens = np.full(padded_count, num_tokens, dtype=np.int32) + padded_tokens[:count] = tokens_for_expert // top_k + token_ids_list.append(padded_tokens) + expert_ids_list.extend([expert_id] * n_blocks_for_expert) + + all_token_ids = ( + np.concatenate(token_ids_list) + if token_ids_list + else np.array([], dtype=np.int32) + ) + + if len(all_token_ids) < padded_total: + padding = np.full( + padded_total - len(all_token_ids), num_tokens, dtype=np.int32 + ) + all_token_ids = np.concatenate([all_token_ids, padding]) + else: + all_token_ids = all_token_ids[:padded_total] + + all_expert_ids = np.array(expert_ids_list, dtype=np.int32) + if len(all_expert_ids) < num_blocks: + padding = np.zeros(num_blocks - len(all_expert_ids), dtype=np.int32) + all_expert_ids = np.concatenate([all_expert_ids, padding]) + else: + all_expert_ids = all_expert_ids[:num_blocks] + + import torch + + sorted_token_ids.copy_( + torch.from_numpy(all_token_ids).to(sorted_token_ids.device) + ) + sorted_expert_ids.copy_( + torch.from_numpy(all_expert_ids).to(sorted_expert_ids.device) + ) + + valid_count = min( + len(np.concatenate(token_ids_list)) if token_ids_list else 0, + padded_total, + ) + if num_valid_ids.numel() >= 1: + num_valid_ids[0] = valid_count + if num_valid_ids.numel() >= 2: + num_valid_ids[1] = valid_count + + dist_label = f"zipf(s={zipf_s})" if distribution == "zipf" else "uniform" + experts_active = len(set(expert_assignments.tolist())) + return ( + f"[custom init] {op_name} — initialized MoE routing: " + f"dist={dist_label}, M={M}, top_k={top_k}, E={E}, block_m={block_m}, " + f"num_tokens={num_tokens}, active_experts={experts_active}/{E}, " + f"valid_ids={valid_count}/{padded_total}, " + f"blocks={len(expert_ids_list)}/{num_blocks}. " + f"Assumptions: {dist_label} expert distribution, deterministic seed." + ) diff --git a/TraceLens/EventReplay/event_replay.py b/TraceLens/EventReplay/event_replay.py index cb0be6a37..715ede6f0 100644 --- a/TraceLens/EventReplay/event_replay.py +++ b/TraceLens/EventReplay/event_replay.py @@ -9,6 +9,7 @@ from pprint import pprint from typing import Dict, Any, List, Optional, Tuple +import logging import re import warnings import time @@ -19,15 +20,219 @@ build_tensor, list_profile_tensor_types, ) +from .custom_inits import CustomInit, PagedAttentionInit, MoeRoutingInit + +logger = logging.getLogger(__name__) + +# -- Known defaults for string arguments the profiler drops ---------------- +_STR_ARG_DEFAULTS: Dict[str, str] = { + "kv_cache_dtype": "auto", +} + +# -- Op-name aliases ------------------------------------------------------- +_OP_NAME_ALIASES: Dict[str, List[str]] = { + "_rocm_C::wvSplitK": ["_rocm_C::wvSpltK"], +} + +# -- Auto-import registry -------------------------------------------------- +_NAMESPACE_IMPORTS: Dict[str, List[str]] = { + "aiter": ["aiter"], + "_rocm_C": ["vllm._rocm_C"], + "_C": ["vllm._C"], + "_C_cache_ops": ["vllm._C"], + "vllm": ["vllm._C", "vllm._rocm_C"], +} + +_auto_import_attempted: set = set() + + +def _try_auto_import(op_name: str, verbose: bool = False) -> bool: + """Try to import the library that registers a custom op's schema. + + Returns True if at least one new module was successfully imported. + """ + namespace = op_name.split("::")[0] if "::" in op_name else "" + if not namespace or namespace == "aten": + return False + if namespace in _auto_import_attempted: + return False + _auto_import_attempted.add(namespace) + + modules = _NAMESPACE_IMPORTS.get(namespace, [namespace]) + imported_any = False + for mod in modules: + try: + __import__(mod) + print(f"[EventReplayer] Auto-imported '{mod}' for op '{op_name}'") + imported_any = True + except ImportError: + if verbose: + print(f"[EventReplayer] Could not import '{mod}' for namespace '{namespace}'") + return imported_any + + +def _try_resolve(op_name: str): + """Attempt JIT + torch.ops + module resolution for a single op name. + Returns (func, source_str) or (None, None). + """ + torch = _get_torch_or_raise() + import importlib + + # 1. JIT registry + try: + func, _ = torch._C._jit_get_operation(op_name) + if func is not None: + return func, "jit" + except RuntimeError: + pass + + if "::" in op_name: + ns, func_name = op_name.split("::", 1) + + # 2. torch.ops namespace + ns_obj = getattr(torch.ops, ns, None) + if ns_obj is not None: + func_obj = getattr(ns_obj, func_name, None) + if callable(func_obj): + return func_obj, "torch.ops" + + # 3. Direct Python module lookup + try: + mod = importlib.import_module(ns) + func_obj = getattr(mod, func_name, None) + if callable(func_obj): + return func_obj, f"module:{ns}" + except ImportError: + pass + + return None, None + + +def _resolve_op_func(op_name: str, verbose: bool = False): + """Resolve an op name to a callable, with auto-import on failure. + + Tries: JIT registry -> torch.ops -> module import -> auto-import -> aliases. + Returns (func, source_str, resolved_name) or raises RuntimeError. + """ + for attempt in range(2): + func, source = _try_resolve(op_name) + if func is not None: + return func, source, op_name + + for alias in _OP_NAME_ALIASES.get(op_name, []): + func, source = _try_resolve(alias) + if func is not None: + logger.warning( + "Op '%s' resolved via alias '%s' (%s).", + op_name, alias, source, + ) + return func, source, alias + + if attempt == 0 and _try_auto_import(op_name, verbose): + continue + break + + ns = op_name.split("::")[0] if "::" in op_name else "" + hint = "" + if ns and ns != "aten": + known = _NAMESPACE_IMPORTS.get(ns) + if known: + hint = f" Try: {', '.join(f'import {m}' for m in known)}" + else: + hint = (f" The namespace '{ns}' is not in the auto-import registry." + f" Use EventReplayer.register_namespace('{ns}', ['your.module'])" + f" to add it.") + + raise RuntimeError( + f"Cannot resolve op '{op_name}'.{hint} " + f"Ensure the library that defines it is imported." + ) + + +def _search_schemas(op_name: str, verbose: bool = False): + """Return all registered FunctionSchemas for *op_name*, + with auto-import on empty results. + """ + torch = _get_torch_or_raise() + + for attempt in range(2): + schemas: list = [] + seen_strs: set = set() + + for s in torch._C._jit_get_all_schemas(): + if s.name == op_name: + s_str = str(s) + if s_str not in seen_strs: + schemas.append(s) + seen_strs.add(s_str) + + if "::" in op_name: + ns, func_name = op_name.split("::", 1) + ns_obj = getattr(torch.ops, ns, None) + if ns_obj is not None: + op_obj = getattr(ns_obj, func_name, None) + if op_obj is not None: + try: + for overload_name in op_obj.overloads(): + overload = getattr(op_obj, overload_name) + s = overload._schema + s_str = str(s) + if s_str not in seen_strs: + schemas.append(s) + seen_strs.add(s_str) + except Exception: + try: + s = op_obj.default._schema + s_str = str(s) + if s_str not in seen_strs: + schemas.append(s) + seen_strs.add(s_str) + except Exception: + pass + + if schemas or attempt > 0: + break + if not _try_auto_import(op_name, verbose): + break + + if verbose: + print(f"Found {len(schemas)} schemas for {op_name}:") + for s in schemas: + pprint(str(s)) + print("-" * 80) + + return schemas class EventReplayer: + _custom_init_registry: List[CustomInit] = [ + PagedAttentionInit(), + MoeRoutingInit(), + ] + + @classmethod + def register_custom_init(cls, init: CustomInit): + """Add a custom initializer to the registry.""" + cls._custom_init_registry.append(init) + + @classmethod + def register_namespace(cls, namespace: str, modules: List[str]): + """Register a namespace-to-module mapping for auto-import.""" + _NAMESPACE_IMPORTS[namespace] = modules + + @classmethod + def list_custom_inits(cls) -> List[Tuple[str, List[str]]]: + """List all registered custom initializers and their op patterns.""" + return [(type(i).__name__, i.op_patterns) for i in cls._custom_init_registry] + def __init__( self, event: Dict[str, Any], device: str = "cuda", lazy: bool = False, verbose: bool = False, + auto_init: bool = True, + init_kwargs: Optional[Dict[str, Any]] = None, ): """ Initialize the EventReplayer with the event data and device type. @@ -35,12 +240,22 @@ def __init__( Args: event (Dict[str, Any]): From the pytorch profile json data['traceEvents'] device (str): The device type ('cuda' or 'cpu'). + lazy (bool): If True, defer tensor creation until replay(). verbose (bool): Flag to enable verbose output. + auto_init (bool): If True, automatically apply custom initializers + for ops that need realistic tensor content (e.g., paged attention + block tables, MoE routing tensors). + init_kwargs (Dict[str, Any]): Parameters passed to custom initializers + (e.g., {"moe_distribution": "zipf", "moe_zipf_s": 1.5}). """ self.event = event self.device = device self.lazy = lazy self.verbose = verbose + self._auto_init = auto_init + self._init_kwargs = init_kwargs or {} + self._inits_applied = False + self._func = None self._setup() def _setup(self): @@ -49,10 +264,37 @@ def _setup(self): """ if self.verbose: print(f"Preparing {self.event['name']} event for replay") - self.matched_schema = EventReplayer._search_schema(self.event, self.verbose) - self.event_replay_IR = EventReplayer._get_event_replay_IR( - self.event, self.matched_schema, self.verbose + + self._func, self._func_source, self._resolved_name = _resolve_op_func( + self.event["name"], verbose=self.verbose ) + if self.verbose: + print(f"Resolved op via {self._func_source}") + if self._resolved_name != self.event["name"]: + print(f" (aliased from '{self.event['name']}' -> '{self._resolved_name}')") + + try: + self.matched_schema = EventReplayer._search_schema( + self.event, self._resolved_name, self.verbose + ) + self._schemaless = False + except ValueError: + if self.verbose: + print( + "No schema found; falling back to schemaless replay " + "(all args treated as positional, types inferred from profile)" + ) + self.matched_schema = None + self._schemaless = True + + if self._schemaless: + self.event_replay_IR = EventReplayer._get_event_replay_IR_schemaless( + self.event, self.verbose, resolved_name=self._resolved_name + ) + else: + self.event_replay_IR = EventReplayer._get_event_replay_IR( + self.event, self.matched_schema, self.verbose + ) if not self.lazy: if self.verbose: print("setting up args and kwargs") @@ -63,35 +305,43 @@ def _setup(self): def replay(self): """ Replay the event using the matched schema and event replay IR. - """ - torch = _get_torch_or_raise() - # Get the function from the schema - func, _ = torch._C._jit_get_operation(self.event["name"]) - # Call the function with the arguments + Returns: + The result of the PyTorch operation. + """ if self.lazy: - args, kwargs = EventReplayer._get_args_kwargs( + self.args, self.kwargs = EventReplayer._get_args_kwargs( self.event_replay_IR, device=self.device ) - else: - args, kwargs = self.args, self.kwargs - # Call the function with the arguments - func(*args, **kwargs) + if not self._inits_applied and self._auto_init: + self._apply_custom_inits() + + return self._func(*self.args, **self.kwargs) + + def _apply_custom_inits(self): + """Apply the first matching custom initializer to this replayer's tensors.""" + for custom_init in self._custom_init_registry: + if custom_init.applies_to(self): + try: + summary = custom_init.initialize(self, **self._init_kwargs) + if summary: + print(summary) + except Exception as e: + warnings.warn( + f"[custom init] {type(custom_init).__name__} failed: {e}" + ) + break + self._inits_applied = True @staticmethod def _search_schema( - event: Dict[str, Any], verbose: bool = False + event: Dict[str, Any], + resolved_name: Optional[str] = None, + verbose: bool = False, ) -> Optional["torch._C.FunctionSchema"]: - torch = _get_torch_or_raise() - all_schemas = torch._C._jit_get_all_schemas() - op_schemas = [s for s in all_schemas if s.name == event["name"]] - # print each schema in separate line - if verbose: - print(f"Found {len(op_schemas)} schemas for {event['name']}:") - for schema in op_schemas: - pprint(str(schema)) - print("-" * 80) + name = resolved_name or event["name"] + op_schemas = _search_schemas(name, verbose=verbose) for schema in op_schemas: if verbose: @@ -106,7 +356,9 @@ def _search_schema( print("-" * 80) raise ValueError( - f"Cannot find matching schema for {event['name']}. Please check the event data and schema." + f"Cannot find matching schema for {name}. " + f"Searched {len(op_schemas)} candidate(s). " + f"Please check the event data and ensure the op's library is imported." ) @staticmethod @@ -115,22 +367,13 @@ def _is_schema_match( ) -> bool: """ Check if the event matches the schema. - - Args: - event (Dict[str, Any]): The event data. - schema (torch._C.FunctionSchema): The schema to match against. - - Returns: - bool: True if the event matches the schema, False otherwise. """ op_name, pos_args_schema, kwargs_schema, return_type = ( EventReplayer.parse_schema_string(schema) ) full_args_schema = pos_args_schema + kwargs_schema - # Check if the number of args in the event matches the schema if len(event["args"]["Input type"]) != len(full_args_schema): return False - # Check if the types match for idx in range(len(event["args"]["Input type"])): profiled_type = event["args"]["Input type"][idx] schema_type = full_args_schema[idx]["arg_type"] @@ -138,16 +381,8 @@ def _is_schema_match( print(f"Checking arg {idx}:") print(f"\tSchema type: {schema_type}") print(f"\tProfiled type: {profiled_type}") - # Rules for matching types - # 1. for tensor types, schema type should be 'Tensor' and profiled type can be any of the tensor types 'float', 'c10::Half', 'c10::BFloat16' ... - # 2. for bool types, schema type should be 'bool' and profiled type is 'Scalar'. So we need to further check the concrete Inputs if it only contains 'true' or 'false' - # 3. for int types, schema type should be 'int' or 'SymInt' and profiled type is 'Scalar'. So we need to further check the concrete Inputs if it is a digit - # 4. for float types, schema type should be 'Scalar' and profiled type is 'Scalar'. So we need to further check the concrete Inputs if it is a float - # 5. for int[] types, schema type should be 'int[]' or 'SymInt[]' and profiled type is 'ScalarList'. So we need to further check the concrete Inputs if it is a list of digits - # 6. for bool[] types, schema type should be 'bool[]' and profiled type is 'ScalarList'. So we need to further check the concrete Inputs if it is a list of 'true' or 'false' - # 7. for tensor[] types, we cannot replay the event as the tensor shapes are not provided in the event. So we need to skip this case. Maybe suggest PyTorch to add this in the future. + is_match = True - # if the schema type ends with '?' then the profiled type can be blank as well if schema_type.endswith("?"): schema_type = schema_type[:-1] if profiled_type == "": @@ -157,20 +392,20 @@ def _is_schema_match( and event["args"]["Concrete Inputs"][idx] == "[]" ): continue - if schema_type in ["Tensor", "Tensor?", "Tensor(a!)"]: + if EventReplayer._is_tensor_schema_type(schema_type): if profiled_type not in list_profile_tensor_types: is_match = False elif schema_type == "bool": profiled_value = event["args"]["Concrete Inputs"][idx] - if profiled_value.lower() not in ["true", "false"]: + if profiled_value.lower() not in ("true", "false"): is_match = False - elif schema_type == "int" or schema_type == "SymInt": + elif schema_type in ("int", "SymInt"): if profiled_type != "Scalar": is_match = False profiled_value = event["args"]["Concrete Inputs"][idx] if not profiled_value.lstrip("-").isdigit(): is_match = False - elif schema_type in ["float", "Scalar"]: + elif schema_type in ("float", "Scalar"): if profiled_type != "Scalar": is_match = False profiled_value = event["args"]["Concrete Inputs"][idx] @@ -179,7 +414,6 @@ def _is_schema_match( except ValueError: is_match = False elif schema_type.startswith("int[") or schema_type.startswith("SymInt["): - # custom dev debugging if profiled_type != "ScalarList": is_match = False profiled_value = event["args"]["Concrete Inputs"][idx] @@ -196,16 +430,27 @@ def _is_schema_match( x.strip() for x in profiled_value.strip()[1:-1].split(",") ] if not all( - x.lower() in ["true", "false"] for x in profiled_value_cleaned + x.lower() in ("true", "false") for x in profiled_value_cleaned ): is_match = False elif schema_type.startswith("Tensor["): raise ValueError( - f"Tensor list type not supported: {schema_type} as the tensor shapes are not provided in the event" + f"Tensor list type not supported: {schema_type} as the " + f"tensor shapes are not provided in the event" ) + elif schema_type == "str": + is_match = profiled_type == "Scalar" or profiled_type == "" + elif schema_type == "ScalarType": + is_match = profiled_type == "Scalar" or profiled_type == "" + elif schema_type == "Layout": + is_match = profiled_type == "Scalar" or profiled_type == "" + elif schema_type == "Device": + is_match = profiled_type == "Scalar" or profiled_type == "" + elif schema_type == "MemoryFormat": + is_match = profiled_type == "Scalar" or profiled_type == "" + elif schema_type == "Generator": + is_match = profiled_type == "" or profiled_type == "Scalar" else: - # raise ValueError(f"Unknown schema type: {schema_type}") - # warning: if the schema type is not in the list, we will skip this case warnings.warn( f"Unknown schema type: {schema_type}. Skipping this case." ) @@ -218,36 +463,31 @@ def _is_schema_match( return False return True + @staticmethod + def _is_tensor_schema_type(schema_type: str) -> bool: + """Check if a schema type string represents a Tensor argument.""" + if schema_type in ("Tensor", "Tensor?"): + return True + if schema_type.startswith("Tensor("): + return True + return False + + @staticmethod + def _should_skip_tensor_init(evt_name: str, arg_name: str, arg_idx: int) -> bool: + """Determine whether a tensor argument is an output-only buffer.""" + if evt_name.endswith("_") and arg_idx == 0: + return True + if arg_name == "out": + return True + if evt_name == "aten::copy_" and arg_name != "src": + return True + return False + @staticmethod def _get_event_replay_IR( event: Dict[str, Any], schema: "torch._C.FunctionSchema", verbose: bool = False ) -> Dict[str, Any]: - """ - Get the event replay IR from the event and schema. - - Args: - event (Dict[str, Any]): The event data. - schema (torch._C.FunctionSchema): The schema to match against. - - Returns: - { - 'pos_args': [ - - dummy_tensor0, - dummy_tensor1, - value0, - value1, - ... - ], - 'kwargs': { - 'arg0': value0, - 'arg1': dummy_tensor0, - 'arg2': dummy_tensor1, - 'arg3': value1, - ... - } - } - """ + """Get the event replay IR from the event and schema.""" evt_name = event["name"] op_name, pos_args_schema, kwargs_schema, return_type = ( EventReplayer.parse_schema_string(schema) @@ -267,36 +507,60 @@ def _get_event_replay_IR( print(f"Concrete Inputs: {event['args']['Concrete Inputs'][idx]}") if arg_type.endswith("?") and event["args"]["Input type"][idx] == "": - value = None + if arg_type.startswith("str"): + default = full_args_schema[idx].get("default") + if default is None or default == "None": + default = _STR_ARG_DEFAULTS.get(arg_name) + if default is not None: + logger.warning( + "%s arg '%s' (position %d): profiler dropped " + "the string value. Using known default '%s'.", + evt_name, arg_name, idx, default, + ) + value = "" if default is None else default + else: + value = None elif ( arg_type.endswith("?") and event["args"]["Concrete Inputs"][idx] == "[]" ): value = [] else: - if arg_type in ["Tensor", "Tensor?", "Tensor(a!)"]: + if EventReplayer._is_tensor_schema_type(arg_type): init = "normal" - if evt_name == "aten::fill_": - # special case for fill_ where we don't need to initialize the tensor - # as it will be filled with a value later - init = None - elif evt_name == "aten::copy_" and arg_name != "src": - # special case for copy_ where we don't need to initialize the tensor - # as it will be copied from another tensor + if EventReplayer._should_skip_tensor_init(evt_name, arg_name, idx): init = None + profiled_dtype = event["args"]["Input type"][idx] + if profiled_dtype in ("long", "long int", "int", "bool", "unsigned char"): + init = "zeros" if init == "normal" else init value = TensorCfg( shape=event["args"]["Input Dims"][idx], - dtype=event["args"]["Input type"][idx], + dtype=profiled_dtype, strides=event["args"]["Input Strides"][idx], init=init, ) else: arg_str = event["args"]["Concrete Inputs"][idx] - if arg_type in ["bool", "bool?"]: + if arg_type in ("bool", "bool?"): value = arg_str.lower() == "true" - elif arg_type in ["int", "SymInt"]: + elif arg_type in ("int", "int?", "SymInt", "SymInt?"): value = int(arg_str) - elif arg_type in ["float", "float?", "Scalar", "Scalar?"]: + elif arg_type in ("Scalar", "Scalar?"): + if arg_str.lstrip("-").isdigit(): + value = int(arg_str) + else: + value = float(arg_str) + elif arg_type in ("float", "float?"): value = float(arg_str) + elif arg_type in ("str", "str?"): + if not arg_str and arg_name in _STR_ARG_DEFAULTS: + value = _STR_ARG_DEFAULTS[arg_name] + logger.warning( + "%s arg '%s' (position %d): profiler dropped " + "the string value. Using known default '%s'.", + evt_name, arg_name, idx, value, + ) + else: + value = arg_str elif arg_type.startswith("int[") or arg_type.startswith("SymInt["): value = [ int(x.strip()) for x in arg_str.strip()[1:-1].split(",") @@ -324,19 +588,117 @@ def _get_event_replay_IR( ) return {"list_pos_args": list_pos_args, "list_kwargs": list_kwargs} + @staticmethod + def _get_schema_arg_names(op_name: str) -> List[str]: + """Best-effort: return a list of arg names from any available schema.""" + schemas = _search_schemas(op_name, verbose=False) + if not schemas: + return [] + try: + _, pos_args, kw_args, _ = EventReplayer.parse_schema_string(schemas[0]) + return [a["arg_name"] for a in pos_args + kw_args] + except Exception: + return [] + + @staticmethod + def _get_event_replay_IR_schemaless( + event: Dict[str, Any], + verbose: bool = False, + resolved_name: Optional[str] = None, + ) -> Dict[str, Any]: + """Build a replay IR without a schema by inferring types from profile data.""" + evt_name = event["name"] + schema_arg_names = EventReplayer._get_schema_arg_names( + resolved_name or evt_name + ) + + list_pos_args = [] + n_args = len(event["args"]["Input type"]) + for idx in range(n_args): + profiled_type = event["args"]["Input type"][idx] + profiled_dims = event["args"]["Input Dims"][idx] + profiled_strides = event["args"]["Input Strides"][idx] + concrete = event["args"]["Concrete Inputs"][idx] + + if verbose: + print( + f"Schemaless arg {idx}: type={profiled_type!r} " + f"dims={profiled_dims} concrete={concrete!r}" + ) + + if profiled_type in list_profile_tensor_types: + init = "normal" + if profiled_type in ( + "long", "long int", "int", "bool", "unsigned char", + ): + init = "zeros" + value = TensorCfg( + shape=profiled_dims, + dtype=profiled_type, + strides=profiled_strides, + init=init, + ) + arg_type = "Tensor" + elif profiled_type == "Scalar" and concrete: + if concrete.lower() in ("true", "false"): + value = concrete.lower() == "true" + arg_type = "bool" + elif concrete.lstrip("-").isdigit(): + value = int(concrete) + arg_type = "int" + else: + try: + value = float(concrete) + arg_type = "float" + except ValueError: + value = concrete + arg_type = "str" + elif profiled_type == "" and concrete == "": + hint_name = ( + schema_arg_names[idx] if idx < len(schema_arg_names) else None + ) + default = _STR_ARG_DEFAULTS.get(hint_name) if hint_name else None + if default is not None: + value = default + arg_type = "str" + logger.warning( + "%s arg '%s' (position %d): profiler dropped the " + "string value. Using known default '%s'.", + evt_name, hint_name, idx, default, + ) + else: + value = None + arg_type = "None" + elif profiled_type == "ScalarList" and concrete: + items = [x.strip() for x in concrete.strip()[1:-1].split(",") if x.strip()] + if all(x.lstrip("-").isdigit() for x in items): + value = [int(x) for x in items] + else: + value = [float(x) for x in items] + arg_type = "list" + else: + value = None + arg_type = "unknown" + if verbose: + print(f" -> defaulting to None for unknown type") + + inferred_name = ( + schema_arg_names[idx] if idx < len(schema_arg_names) else f"arg{idx}" + ) + list_pos_args.append( + {"arg_name": inferred_name, "arg_type": arg_type, "value": value} + ) + if verbose: + print(f" -> {inferred_name}: {arg_type} = {value}") + print("-" * 80) + + return {"list_pos_args": list_pos_args, "list_kwargs": []} + @staticmethod def _get_args_kwargs( event_replay_IR: Dict[str, Any], device: str = "cuda" ) -> tuple[List["torch.Tensor"], Dict[str, Any]]: - """ - Get the arguments and keyword arguments from the event replay IR. - - Args: - event_replay_IR (Dict[str, Any]): The event replay IR. - - Returns: - (List[torch.Tensor], Dict[str, Any]): The positional arguments and keyword arguments. - """ + """Get the arguments and keyword arguments from the event replay IR.""" pos_args = [] for arg in event_replay_IR["list_pos_args"]: value = arg["value"] @@ -367,17 +729,21 @@ def parse_schema_string( kwarg_part = parts[1].lstrip(",").strip() if len(parts) > 1 else "" def _parse_arg(raw_arg: str) -> Tuple[str, str, Optional[str], bool]: - m = re.match(r"^(\S+)\s+(.*)$", raw_arg) + # Greedy (.+) consumes everything up to the last whitespace before + # a valid identifier, so "Tensor($0! -> ) key_cache" parses correctly. + m = re.match( + r"^(.+)\s+([A-Za-z_]\w*(?:=.*)?)$", raw_arg.strip() + ) if not m: raise ValueError(f"Invalid arg: {raw_arg}") - arg_type, rest = m.groups() - m2 = re.match(r"^([A-Za-z_][A-Za-z0-9_]*)(?:=(.*))?$", rest) + arg_type = m.group(1).strip() + name_default = m.group(2) + m2 = re.match(r"^([A-Za-z_]\w*)(?:=(.*))?$", name_default) if not m2: - raise ValueError(f"Invalid arg name/default: {rest}") - arg_name, default = m2.group(1), ( - m2.group(2).strip() if m2.group(2) else None - ) - return arg_type.strip(), arg_name.strip(), default + raise ValueError(f"Invalid arg name/default: {name_default}") + arg_name = m2.group(1) + default = m2.group(2).strip() if m2.group(2) else None + return arg_type, arg_name, default args = [] for item in [x.strip() for x in pos_part.split(",") if x.strip()]: @@ -398,31 +764,20 @@ def get_repro_info(self) -> Dict[str, Any]: """ Extracts the minimal, serializable information needed to reproduce the event call. - Returns: - Dict[str, Any]: A dictionary containing the operator name and the replay IR. - Suitable for JSON serialization using the custom encoder. + Safe to call multiple times — does not mutate self.event_replay_IR. """ - # return { - # 'op_name': self.event['name'], - # 'replay_ir': self.event_replay_IR - # # No device info here - device is decided by the runner - # } - dict_repro_info = {} - dict_repro_info["op_name"] = self.event["name"] - list_pos_args, list_kwargs = ( - self.event_replay_IR["list_pos_args"], - self.event_replay_IR["list_kwargs"], - ) - # Convert TensorCfg to dict for JSON serialization - list_pos_args_copy, list_kwargs_copy = list_pos_args.copy(), list_kwargs.copy() - for idx, val in enumerate(list_pos_args_copy): - if isinstance(val["value"], TensorCfg): - list_pos_args_copy[idx]["value"] = val["value"].__dict__ - for idx, val in enumerate(list_kwargs_copy): - if isinstance(val["value"], TensorCfg): - list_kwargs_copy[idx]["value"] = val["value"].__dict__ - dict_repro_info["replay_ir"] = { - "list_pos_args": list_pos_args_copy, - "list_kwargs": list_kwargs_copy, + def _serialize_arg(arg: Dict[str, Any]) -> Dict[str, Any]: + val = arg["value"] + return { + **arg, + "value": val.__dict__.copy() if isinstance(val, TensorCfg) else val, + } + + ir = self.event_replay_IR + return { + "op_name": self.event["name"], + "replay_ir": { + "list_pos_args": [_serialize_arg(a) for a in ir["list_pos_args"]], + "list_kwargs": [_serialize_arg(a) for a in ir["list_kwargs"]], + }, } - return dict_repro_info diff --git a/TraceLens/EventReplay/test_event_replay.py b/TraceLens/EventReplay/test_event_replay.py new file mode 100644 index 000000000..7948caa0f --- /dev/null +++ b/TraceLens/EventReplay/test_event_replay.py @@ -0,0 +1,204 @@ +############################################################################### +# Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Tests for EventReplay core functionality. + +All tests use CPU-only ops (aten::mm) so they run without a GPU. +Run from the repo root: + python -m pytest TraceLens/EventReplay/test_event_replay.py -v +""" + +import sys +import os +import pytest +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from TraceLens.EventReplay.event_replay import EventReplayer # noqa: E402 +from TraceLens.EventReplay.custom_inits import CustomInit # noqa: E402 +from TraceLens.EventReplay.utils import TensorCfg # noqa: E402 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _make_mm_event(M=4, K=8, N=16): + """Minimal profiler event dict for aten::mm (M x K) @ (K x N).""" + return { + "name": "aten::mm", + "args": { + "Input Dims": [[M, K], [K, N]], + "Input type": ["float", "float"], + "Input Strides": [[K, 1], [N, 1]], + "Concrete Inputs": ["", ""], + }, + } + + +@pytest.fixture(autouse=True) +def _isolate_registry(): + """Save and restore the global custom-init registry around every test.""" + saved = EventReplayer._custom_init_registry[:] + yield + EventReplayer._custom_init_registry = saved + + +# --------------------------------------------------------------------------- +# BUG-1: lazy=True + auto_init=True must not crash +# --------------------------------------------------------------------------- + +class TestLazyAutoInit: + def test_lazy_replay_sets_self_args(self): + """replay() in lazy mode must populate self.args.""" + replayer = EventReplayer(_make_mm_event(), device="cpu", lazy=True, auto_init=False) + assert not hasattr(replayer, "args") + replayer.replay() + assert hasattr(replayer, "args") + assert isinstance(replayer.args, list) + + def test_lazy_with_custom_init_no_crash(self): + """A custom init that reads replayer.args must work in lazy mode.""" + accessed = {} + + class ProbeInit(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + accessed["args"] = replayer.args + accessed["kwargs"] = replayer.kwargs + return "[probe] ok" + + EventReplayer.register_custom_init(ProbeInit()) + replayer = EventReplayer(_make_mm_event(), device="cpu", lazy=True, auto_init=True) + replayer.replay() + assert "args" in accessed + assert len(accessed["args"]) == 2 # self, mat2 + + +# --------------------------------------------------------------------------- +# BUG-2: get_repro_info() must not corrupt event_replay_IR +# --------------------------------------------------------------------------- + +class TestGetReproInfo: + def test_idempotent(self): + """Calling get_repro_info() twice must produce identical output.""" + replayer = EventReplayer(_make_mm_event(), device="cpu", lazy=True) + assert replayer.get_repro_info() == replayer.get_repro_info() + + def test_does_not_mutate_ir(self): + """TensorCfg objects in the IR must survive get_repro_info().""" + replayer = EventReplayer(_make_mm_event(), device="cpu", lazy=True) + replayer.get_repro_info() + for arg in replayer.event_replay_IR["list_pos_args"]: + if arg["arg_type"].startswith("Tensor"): + assert isinstance(arg["value"], TensorCfg), ( + f"arg '{arg['arg_name']}' is {type(arg['value'])}, expected TensorCfg" + ) + + def test_replay_works_after_get_repro_info(self): + """replay() must succeed after get_repro_info() (IR still intact).""" + replayer = EventReplayer(_make_mm_event(), device="cpu", lazy=True) + replayer.get_repro_info() + result = replayer.replay() + assert isinstance(result, torch.Tensor) + + +# --------------------------------------------------------------------------- +# CLAIM-1: first-match-wins (only one init runs) +# --------------------------------------------------------------------------- + +class TestFirstMatchWins: + def test_only_first_matching_init_runs(self): + """When two inits match, only the first registered one executes.""" + log = [] + + class InitA(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + log.append("A") + + class InitB(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + log.append("B") + + EventReplayer._custom_init_registry = [InitA(), InitB()] + EventReplayer(_make_mm_event(), device="cpu", auto_init=True).replay() + assert log == ["A"] + + +# --------------------------------------------------------------------------- +# CLAIM-4: replay() returns the op result +# --------------------------------------------------------------------------- + +class TestReplayReturn: + def test_returns_tensor(self): + """replay() of aten::mm must return a correctly-shaped Tensor.""" + replayer = EventReplayer(_make_mm_event(M=4, K=8, N=16), device="cpu") + result = replayer.replay() + assert isinstance(result, torch.Tensor) + assert result.shape == (4, 16) + + def test_returns_tensor_lazy(self): + """Lazy replay must also return the result.""" + result = EventReplayer(_make_mm_event(), device="cpu", lazy=True).replay() + assert isinstance(result, torch.Tensor) + + +# --------------------------------------------------------------------------- +# Exact name matching for op_patterns +# --------------------------------------------------------------------------- + +class TestExactNameMatching: + def test_exact_match_hits(self): + """op_patterns=["aten::mm"] matches event name "aten::mm".""" + matched = [] + + class ExactInit(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + matched.append(True) + + EventReplayer._custom_init_registry = [ExactInit()] + EventReplayer(_make_mm_event(), device="cpu", auto_init=True).replay() + assert matched == [True] + + def test_substring_does_not_match(self): + """op_patterns=["mm"] must NOT match "aten::mm" (exact only).""" + matched = [] + + class SubstringInit(CustomInit): + op_patterns = ["mm"] + def initialize(self, replayer, **kwargs): + matched.append(True) + + EventReplayer._custom_init_registry = [SubstringInit()] + EventReplayer(_make_mm_event(), device="cpu", auto_init=True).replay() + assert matched == [] + + +# --------------------------------------------------------------------------- +# auto_init=False skips all inits +# --------------------------------------------------------------------------- + +class TestAutoInitDisabled: + def test_no_init_runs_when_disabled(self): + log = [] + + class AlwaysInit(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + log.append("ran") + + EventReplayer._custom_init_registry = [AlwaysInit()] + EventReplayer(_make_mm_event(), device="cpu", auto_init=False).replay() + assert log == [] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/TraceLens/EventReplay/test_event_replay_gpu.py b/TraceLens/EventReplay/test_event_replay_gpu.py new file mode 100644 index 000000000..9998ec956 --- /dev/null +++ b/TraceLens/EventReplay/test_event_replay_gpu.py @@ -0,0 +1,261 @@ +############################################################################### +# Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +GPU integration tests for EventReplay. + +Profiles real ops, replays from the captured trace, and validates: + 1. Kernel name match between original and replayed execution + 2. BUG-1: lazy=True + auto_init=True works on GPU + 3. BUG-2: get_repro_info() is idempotent (doesn't corrupt IR) + 4. CLAIM-4: replay() returns a tensor + 5. CLAIM-1: first-match-wins with real ops + +Requires a GPU (MI300X / MI210 / etc). Run from the repo root: + python TraceLens/EventReplay/test_event_replay_gpu.py +""" + +import sys, os, json, time +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +import torch +from torch.profiler import profile, ProfilerActivity + +from TraceLens.EventReplay.event_replay import EventReplayer +from TraceLens.EventReplay.custom_inits import CustomInit +from TraceLens.EventReplay.utils import TensorCfg + +assert torch.cuda.is_available(), "GPU required for this test" + +DEVICE = "cuda" +TRACE_FILE = "/tmp/test_event_replay_gpu_trace.json" +REPLAY_TRACE = "/tmp/test_event_replay_gpu_replay.json" + +# --------------------------------------------------------------------------- +# Step 1: Profile a set of real ops +# --------------------------------------------------------------------------- + +print("=" * 80) +print("Step 1: Profiling real ops") +print("=" * 80) + +M, K, N = 256, 1024, 512 +mm_a = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE) +mm_b = torch.randn(K, N, dtype=torch.bfloat16, device=DEVICE) +add_a = torch.randn(M, N, dtype=torch.bfloat16, device=DEVICE) +add_b = torch.randn(M, N, dtype=torch.bfloat16, device=DEVICE) +bmm_a = torch.randn(4, M, K, dtype=torch.bfloat16, device=DEVICE) +bmm_b = torch.randn(4, K, N, dtype=torch.bfloat16, device=DEVICE) + +def run_ops(): + torch.mm(mm_a, mm_b) + torch.add(add_a, add_b) + torch.bmm(bmm_a, bmm_b) + torch.mul(add_a, add_b) + torch.sigmoid(add_a) + +for _ in range(10): + run_ops() +torch.cuda.synchronize() + +def trace_handler(p): + p.export_chrome_trace(TRACE_FILE) + +wait, warmup, active = 3, 3, 5 +with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), + record_shapes=True, + on_trace_ready=trace_handler, +) as p: + for _ in range(wait + warmup + active): + run_ops() + p.step() + +print(f"Trace saved to {TRACE_FILE}") + +# --------------------------------------------------------------------------- +# Step 2: Load trace and find events +# --------------------------------------------------------------------------- + +print(f"\n{'=' * 80}") +print("Step 2: Loading trace") +print("=" * 80) + +with open(TRACE_FILE) as f: + trace_data = json.load(f) + +all_events = trace_data.get("traceEvents", []) + +OPS_TO_TEST = ["aten::mm", "aten::add", "aten::bmm", "aten::mul", "aten::sigmoid"] + +def find_event(events, op_name): + """Find a cpu_op event with the right name and shape data.""" + candidates = [ + e for e in events + if e.get("cat") == "cpu_op" + and e.get("name") == op_name + and "args" in e + and "Input Dims" in e.get("args", {}) + ] + if candidates: + return candidates[len(candidates) // 2] + return None + +results = [] +errors = [] + +# --------------------------------------------------------------------------- +# Step 3: Replay each op and validate +# --------------------------------------------------------------------------- + +print(f"\n{'=' * 80}") +print("Step 3: Replay and validate") +print("=" * 80) + +print(f"\n{'Op':<30} {'Kernel Match':<15} {'Return':<10} {'Lazy':<10} {'ReproInfo':<12} {'Status'}") +print("-" * 100) + +for op_name in OPS_TO_TEST: + evt = find_event(all_events, op_name) + if evt is None: + print(f"{op_name:<30} {'SKIP':<15} {'---':<10} {'---':<10} {'---':<12} not in trace") + continue + + status = [] + + # --- Test: basic replay returns a result (CLAIM-4) --- + try: + replayer = EventReplayer(evt, device=DEVICE, auto_init=False) + result = replayer.replay() + returns_ok = isinstance(result, torch.Tensor) + except Exception as e: + returns_ok = False + status.append(f"replay error: {e}") + + # --- Test: lazy mode works (BUG-1) --- + try: + lazy_replayer = EventReplayer(evt, device=DEVICE, lazy=True, auto_init=False) + lazy_result = lazy_replayer.replay() + lazy_ok = isinstance(lazy_result, torch.Tensor) + assert hasattr(lazy_replayer, "args"), "self.args not set after lazy replay" + except Exception as e: + lazy_ok = False + status.append(f"lazy error: {e}") + + # --- Test: get_repro_info idempotent (BUG-2) --- + try: + repro_replayer = EventReplayer(evt, device=DEVICE, lazy=True) + info1 = repro_replayer.get_repro_info() + info2 = repro_replayer.get_repro_info() + repro_ok = (info1 == info2) + for arg in repro_replayer.event_replay_IR["list_pos_args"]: + if arg["arg_type"].startswith("Tensor"): + assert isinstance(arg["value"], TensorCfg), "IR corrupted after get_repro_info" + repro_replayer.replay() + except Exception as e: + repro_ok = False + status.append(f"repro error: {e}") + + # --- Test: kernel name match --- + kernel_match = "N/A" + try: + replay_replayer = EventReplayer(evt, device=DEVICE, auto_init=False) + for _ in range(5): + replay_replayer.replay() + torch.cuda.synchronize() + + def th(p): + p.export_chrome_trace(REPLAY_TRACE) + + w, wu, a = 2, 2, 3 + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=w, warmup=wu, active=a, repeat=1), + record_shapes=True, + on_trace_ready=th, + ) as p: + for _ in range(w + wu + a): + replay_replayer.replay() + p.step() + + with open(REPLAY_TRACE) as f: + replay_trace = json.load(f) + + orig_gpu = set() + for e in all_events: + if e.get("cat") == "kernel" and e.get("name", ""): + orig_gpu.add(e["name"]) + + replay_gpu = set() + for e in replay_trace.get("traceEvents", []): + if e.get("cat") == "kernel" and e.get("name", ""): + replay_gpu.add(e["name"]) + + kernel_match = "MATCH" if replay_gpu.issubset(orig_gpu) else "MISMATCH" + except Exception as e: + kernel_match = "ERROR" + status.append(f"kernel error: {e}") + + ok = returns_ok and lazy_ok and repro_ok and kernel_match in ("MATCH", "N/A") + tag = "PASS" if ok else "FAIL" + detail = "; ".join(status) if status else "" + + print(f"{op_name:<30} {kernel_match:<15} {'OK' if returns_ok else 'FAIL':<10} {'OK' if lazy_ok else 'FAIL':<10} {'OK' if repro_ok else 'FAIL':<12} {tag} {detail}") + results.append({"op": op_name, "ok": ok, "kernel": kernel_match, + "returns": returns_ok, "lazy": lazy_ok, "repro": repro_ok}) + +# --------------------------------------------------------------------------- +# Step 4: First-match-wins test (CLAIM-1) on GPU +# --------------------------------------------------------------------------- + +print(f"\n{'=' * 80}") +print("Step 4: First-match-wins (CLAIM-1)") +print("=" * 80) + +log = [] + +class InitA(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + log.append("A") + +class InitB(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + log.append("B") + +saved_registry = EventReplayer._custom_init_registry[:] +try: + EventReplayer._custom_init_registry = [InitA(), InitB()] + mm_evt = find_event(all_events, "aten::mm") + if mm_evt: + r = EventReplayer(mm_evt, device=DEVICE, auto_init=True) + r.replay() + first_match_ok = (log == ["A"]) + print(f" First-match-wins: {'PASS' if first_match_ok else 'FAIL'} (log={log})") + else: + first_match_ok = True + print(" SKIP: aten::mm not in trace") +finally: + EventReplayer._custom_init_registry = saved_registry + +# --------------------------------------------------------------------------- +# Summary +# --------------------------------------------------------------------------- + +print(f"\n{'=' * 80}") +print("Summary") +print("=" * 80) + +total = len(results) +passed = sum(1 for r in results if r["ok"]) +print(f"Op tests: {passed}/{total} passed") +print(f"First-match-wins: {'PASS' if first_match_ok else 'FAIL'}") + +all_pass = passed == total and first_match_ok +print(f"\nOverall: {'ALL PASSED' if all_pass else 'FAILURES DETECTED'}") +sys.exit(0 if all_pass else 1) diff --git a/TraceLens/EventReplay/utils.py b/TraceLens/EventReplay/utils.py index 304cf6a98..9ea2f1764 100644 --- a/TraceLens/EventReplay/utils.py +++ b/TraceLens/EventReplay/utils.py @@ -32,8 +32,16 @@ def _get_torch_or_raise() -> Any: # Changed return type to Any for flexibility "c10::Half", "c10::BFloat16", "long", + "long int", "int", "bool", + "unsigned char", + "char", + "short", + "c10::Float8_e4m3fnuz", + "c10::Float8_e5m2fnuz", + "c10::Float8_e4m3fn", + "c10::Float8_e5m2", ] from dataclasses import dataclass @@ -58,15 +66,35 @@ def build_tensor(cfg: TensorCfg, device: str = "cuda") -> "torch.Tensor": "bool": torch.bool, "int": torch.int, "long": torch.long, + "long int": torch.long, + "short": torch.short, + "char": torch.int8, + "unsigned char": torch.uint8, "double": torch.float64, "float": torch.float32, "c10::Half": torch.float16, "c10::BFloat16": torch.bfloat16, } + # FP8 types (available in PyTorch >= 2.1) + for fp8_name in ( + "c10::Float8_e4m3fnuz", + "c10::Float8_e5m2fnuz", + "c10::Float8_e4m3fn", + "c10::Float8_e5m2", + ): + attr = fp8_name.replace("c10::", "") + torch_dtype = getattr(torch, attr.lower(), None) + if torch_dtype is not None: + dict_profile2torchdtype[fp8_name] = torch_dtype + + if cfg.dtype not in dict_profile2torchdtype: + raise ValueError( + f"Unknown profiled dtype '{cfg.dtype}'. " + f"Known types: {list(dict_profile2torchdtype.keys())}" + ) dtype = dict_profile2torchdtype[cfg.dtype] size = cfg.shape stride = cfg.strides - # allocate *exactly* the storage needed for that stride/shape t = torch.empty_strided(size, stride, dtype=dtype, device=device) is_floating = t.is_floating_point() or t.is_complex() init = cfg.init @@ -75,7 +103,9 @@ def build_tensor(cfg: TensorCfg, device: str = "cuda") -> "torch.Tensor": raise ValueError( f"Cannot initialize tensor of type {cfg.dtype} with 'normal' init." ) - t.normal_() # or whatever init you like + t.normal_() + elif init == "zeros": + t.zero_() elif init is not None: raise ValueError(f"Unsupported tensor initialization: {init}") return t @@ -94,32 +124,74 @@ def summarize_tensor(tensor: "torch.Tensor") -> str: return f"Tensor(shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}, strides={tensor.stride()})" -def benchmark_func(func, device, warmup=50, avg_steps=100): - """ - Benchmark a function with warmup and average steps. - Disclaimer: This method would be inaccurate for very short ops. +_L2_FLUSH_BUFFER = None +_L2_FLUSH_SIZE = 256 * 1024 * 1024 # 256 MB -- larger than any GPU's L2 + + +def _flush_l2(device: str): + """Force-evict GPU L2 cache by reading a large buffer.""" + global _L2_FLUSH_BUFFER + torch = _get_torch_or_raise() + if _L2_FLUSH_BUFFER is None or str(_L2_FLUSH_BUFFER.device) != device: + _L2_FLUSH_BUFFER = torch.empty( + _L2_FLUSH_SIZE // 4, dtype=torch.float32, device=device + ) + _L2_FLUSH_BUFFER.sum() + + +def benchmark_func( + func, + device, + warmup=50, + avg_steps=100, + flush_l2=False, +): + """Benchmark a function with warmup and per-iteration CUDA event timing. + Args: - func (callable): The function to benchmark. - warmup (int): Number of warmup iterations. - avg_steps (int): Number of iterations to average over. + func: Callable to benchmark. + device: CUDA device string. + warmup: Number of warmup iterations. + avg_steps: Number of measured iterations. + flush_l2: If True, flush the GPU L2 cache before each measured iteration + to simulate cold-cache conditions (more representative of real + inference where other kernels pollute L2 between invocations). + Returns: - float: Average time taken per iteration in microseconds. + dict with keys: median_us, mean_us, std_us, min_us, max_us, + all_us (list of per-iteration timings in microseconds). """ torch = _get_torch_or_raise() - # Warmup phase + for _ in range(warmup): func() - - # Benchmarking phase torch.cuda.synchronize(device) - start_time = time.time() + + timings_ms: List[float] = [] for _ in range(avg_steps): + if flush_l2: + _flush_l2(device) + torch.cuda.synchronize(device) + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() func() - torch.cuda.synchronize(device) - end_time = time.time() - - elapsed_time = end_time - start_time - avg_time_sec = elapsed_time / avg_steps - avg_time_us = avg_time_sec * 1e6 - - return avg_time_us + end_evt.record() + torch.cuda.synchronize(device) + timings_ms.append(start_evt.elapsed_time(end_evt)) + + timings_us = [t * 1000.0 for t in timings_ms] + sorted_us = sorted(timings_us) + n = len(sorted_us) + median = (sorted_us[n // 2] + sorted_us[(n - 1) // 2]) / 2.0 + mean = sum(timings_us) / n + variance = sum((t - mean) ** 2 for t in timings_us) / n + std = variance ** 0.5 + return { + "median_us": median, + "mean_us": mean, + "std_us": std, + "min_us": sorted_us[0], + "max_us": sorted_us[-1], + "all_us": timings_us, + } diff --git a/docs/EventReplay.md b/docs/EventReplay.md index bf6f4fb79..6e6898465 100644 --- a/docs/EventReplay.md +++ b/docs/EventReplay.md @@ -6,59 +6,179 @@ See LICENSE for license information. # Event Replay -Optimizing GPU performance in deep learning requires isolating and benchmarking individual operations to identify bottlenecks. However, reproducing operations directly from complex model code or large profiles can be cumbersome. - -Event Replay is a Python-based tool within TraceLens that extracts and replays almost arbitrary PyTorch operations using minimal, portable Intermediate Representation (IR). It enables users to easily reproduce, analyze, and benchmark specific operators independently from the original model execution, streamlining performance optimization workflows. - ---- - -## Key Features - -- **Generic Operator Replay**: Reconstructs and benchmarks any PyTorch operator from profile data, including convolutions, GEMMs, reductions, element-wise operations, and more. -- **Minimalistic IR**: Extracts essential operator attributes (tensor shapes, strides, dtypes, and other arguments) into a lightweight, portable JSON-based IR. -- **Portable Artifacts**: Enables sharing standalone artifacts (JSON IR and scripts) with teammates or upstream repositories without requiring access to the model or TraceLens. +Optimizing GPU performance in deep learning requires isolating and benchmarking +individual operations to identify bottlenecks. However, reproducing operations +directly from complex model code or large profiles can be cumbersome — the +profiler captures tensor dimensions and types but strips argument names, semantic +context, and the relationship between arguments. + +Event Replay is a Python-based tool within TraceLens that extracts and replays +almost arbitrary PyTorch operations using minimal, portable Intermediate +Representation (IR). It translates the opaque profiler output into human-readable +JSON — named arguments, tensor shapes, dtypes, strides, and scalar values — +making profiler traces interpretable, shareable, and replayable on any machine +with the right op libraries installed. + +**Contents:** +[Quick Start](#quick-start) | +[Batch Replay](#batch-replay) | +[Architecture](#architecture) | +[Custom Initializers](#custom-initializers) | +[Auto-Import](#auto-import-for-custom-ops) | +[Iteration Annotations](#iteration-annotations-vllm-traces) | +[Limitations](#known-limitations) | +[Use Cases](#use-cases) --- - ## Quick Start -### Example: Replay a Single Event +### Replay a Single Event ```python -from TraceLens import TreePerfAnalyzer, EventReplayer +from TraceLens import TreePerfAnalyzer +from TraceLens.EventReplay import EventReplayer -# Load profile and get event perf_analyzer = TreePerfAnalyzer.from_file('/path/to/profile.json') -uid = 12345 # Replace with actual UID of interest +uid = 12345 event = perf_analyzer.tree.get_UID2event(uid) -# Initialize and replay replayer = EventReplayer(event, device='cuda') replayer.replay() ``` +### Inspect the IR (without replaying) + +Even without a GPU, the extracted IR is valuable for understanding what a +profiled op actually does. The profiler's native format stores arguments as +unlabeled dimension lists with no argument names or semantic context: + +```json +{ + "cat": "cpu_op", "name": "aten::mm", + "args": { + "Input Dims": [[20, 2048], [2048, 11264]], + "Input type": ["BFloat16", "BFloat16"] + } +} +``` + +What are these two tensors? Which is the activation, which is the weight? Is +`mat2` transposed? You can't tell from `Input Dims` alone. Event Replay resolves +the op's registered schema and produces named, typed JSON: + +```python +replayer = EventReplayer(event, lazy=True) +ir = replayer.get_repro_info() +``` + +**After — Event Replay IR for `aten::mm`:** + +```json +{ + "op_name": "aten::mm", + "replay_ir": { + "list_pos_args": [ + { + "arg_name": "self", + "arg_type": "Tensor", + "value": { "shape": [20, 2048], "dtype": "c10::BFloat16", + "strides": [2048, 1], "init": "normal" } + }, + { + "arg_name": "mat2", + "arg_type": "Tensor", + "value": { "shape": [2048, 11264], "dtype": "c10::BFloat16", + "strides": [1, 2048], "init": "normal" } + } + ] + } +} +``` + +Now you can immediately read: BF16 GEMM, M=20, K=2048, N=11264, `mat2` +is column-major (stride pattern `[1, K]`). + +The contrast is sharper for complex ops. Here's what the profiler gives you +for a MoE fused expert call: + +**Raw profiler — `aiter::ck_moe_stage1`:** + +```json +{ + "cat": "cpu_op", "name": "aiter::ck_moe_stage1", + "args": { + "Input Dims": [[2, 2048], [60, 2816, 2048], [60, 2048, 1408], + [1924], [61], [2], [2, 4, 1408], [], [], [], [], [], [], [], [], [], [], []], + "Input type": ["BFloat16", "BFloat16", "BFloat16", + "Int", "Int", "Int", "BFloat16", + "Scalar", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar", + "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"] + } +} +``` + +18 arguments, most labeled just "Scalar" — which one is `topk`? Which is +`block_m`? What does `[1924]` represent? Uninterpretable without reading the +source code. + +**After — Event Replay IR:** + +```json +{ + "op_name": "aiter::ck_moe_stage1", + "replay_ir": { + "list_pos_args": [ + { "arg_name": "hidden_states", "arg_type": "Tensor", + "value": { "shape": [2, 2048], "dtype": "c10::BFloat16" } }, + { "arg_name": "w1", "arg_type": "Tensor", + "value": { "shape": [60, 2816, 2048], "dtype": "c10::BFloat16" } }, + { "arg_name": "w2", "arg_type": "Tensor", + "value": { "shape": [60, 2048, 1408], "dtype": "c10::BFloat16" } }, + { "arg_name": "sorted_token_ids", "arg_type": "Tensor", + "value": { "shape": [1924], "dtype": "int", "init": "zeros" } }, + { "arg_name": "sorted_expert_ids", "arg_type": "Tensor", + "value": { "shape": [61], "dtype": "int", "init": "zeros" } }, + { "arg_name": "num_valid_ids", "arg_type": "Tensor", + "value": { "shape": [2], "dtype": "int", "init": "zeros" } }, + { "arg_name": "out", "arg_type": "Tensor", + "value": { "shape": [2, 4, 1408], "dtype": "c10::BFloat16", "init": null } }, + { "arg_name": "topk", "arg_type": "SymInt", "value": 4 }, + { "arg_name": "block_m", "arg_type": "SymInt?", "value": 32 }, + { "arg_name": "use_non_temporal_load", "arg_type": "bool", "value": true } + ] + } +} +``` + +Now you can read: 2 tokens routed to top-4 of 60 experts, gate hidden dim 2048, +up-projection to 2816, down-projection through 1408, `[1924]` is +`sorted_token_ids` (the routing table), block tile size 32, NTL enabled. + --- -## Batch Replay and Benchmark +## Batch Replay -### Extract Operator IR from TraceLens Profiles +### Extract IR for Multiple Events ```python import json -# Extract replay IR for events of interest -repro_data = [EventReplayer(event, lazy=True).get_repro_info() for event in events_of_interest] +repro_data = [EventReplayer(event, lazy=True).get_repro_info() + for event in events_of_interest] with open('event_replay_ir.json', 'w') as f: json.dump(repro_data, f, indent=4) ``` ```bash -python batched_replay.py event_replay_ir.json +python batched_replay.py event_replay_ir.json # default (timing only) +python batched_replay.py -v event_replay_ir.json # verbose (shows args) +python batched_replay.py --op-filter aten::mm event_replay_ir.json # filter by name +python batched_replay.py --op-limit 5 event_replay_ir.json # first 5 ops ``` -#### Example Output +#### Example Output (`-v`) ``` [7/11] Replaying: aten::convolution @@ -74,50 +194,24 @@ python batched_replay.py event_replay_ir.json output_padding SymInt[]: [0, 0] groups SymInt: 1 Keyword Args: - Average time taken: 100.38 microseconds + Average time taken: 100.38 us (median: 98.21 us) Successfully executed aten::convolution. Result: Tensor(shape=torch.Size([20, 256, 14, 14]), dtype=torch.bfloat16, device=cuda:0) - -[8/11] Replaying: aten::convolution - Reconstructing arguments for 'aten::convolution'... - Positional Args: - input Tensor: {'shape': [20, 256, 14, 14], 'dtype': 'c10::BFloat16', 'strides': [50176, 196, 14, 1]} - weight Tensor: {'shape': [512, 256, 3, 3], 'dtype': 'c10::BFloat16', 'strides': [2304, 9, 3, 1]} - bias Tensor?: None - stride SymInt[]: [2, 2] - padding SymInt[]: [1, 1] - dilation SymInt[]: [1, 1] - transposed bool: False - output_padding SymInt[]: [0, 0] - groups SymInt: 1 - Keyword Args: - Average time taken: 92.83 microseconds - Successfully executed aten::convolution. - Result: Tensor(shape=torch.Size([20, 512, 7, 7]), dtype=torch.bfloat16, device=cuda:0) ... --- Replay Summary --- Total operations in file: 11 Attempted replays: 11 Successful replays: 11 Errors encountered: 0 - ``` -------------------- -### Creating Standalone Replay Artifacts - -You can optionally package the extracted replay IR and scripts into a standalone zip file for easy sharing and reproduction, independent of the original model code or TraceLens repository. -Artifacts included: -- `event_replay_ir.json`: Serialized operator replay instructions. -- `utils.py`: Tensor creation and helper utilities. -- `batched_replay.py`: Script to batch replay and benchmark operations. -- `batched_replay_readme.md`: Instructions for running the replay. +### Creating Standalone Replay Artifacts -Example packaging code: +Package the IR and scripts into a standalone zip for sharing and reproduction, +independent of the original model code or TraceLens: ```python -import zipfile -import os +import zipfile, os from TraceLens.EventReplay import utils as tl_utils from TraceLens.EventReplay import batched_replay @@ -128,25 +222,329 @@ files = [ batched_replay.__file__.replace('batched_replay.py', 'batched_replay_readme.md') ] -zip_file_path = '/path/to/replay_code.zip' -with zipfile.ZipFile(zip_file_path, 'w') as zipf: +with zipfile.ZipFile('/path/to/replay_code.zip', 'w') as zipf: for file in files: zipf.write(file, arcname=os.path.basename(file)) +``` + +--- + +## Architecture + +Event Replay operates in two distinct phases: + +### Phase 1: IR Extraction (deterministic) + +The profiler captures tensor dimensions and types but not argument names: + +``` +Input Dims: [[2, 2048], [60, 2816, 2048], [60, 2048, 1408], [1924], [61], [2], ...] +Input type: [BFloat16, BFloat16, BFloat16, Int, Int, Int, ...] +``` + +EventReplayer looks up the op's **registered schema** from the PyTorch dispatcher +(via `torch._C._jit_get_all_schemas()` or `torch.ops`). For example, querying +the registry for `aiter::ck_moe_stage1` returns: -print(f"Created zip file: {zip_file_path}") ``` +aiter::ck_moe_stage1(Tensor(a0!) hidden_states, Tensor(a1!) w1, Tensor(a2!) w2, + Tensor(a3!) sorted_token_ids, Tensor(a4!) sorted_expert_ids, + Tensor(a5!) num_valid_ids, Tensor(a6!) out, SymInt topk, + str? kernelName="", Tensor(a9!)? w1_scale=None, + Tensor(a10!)? a1_scale=None, SymInt? block_m=None, ...) -> () +``` + +This schema provides argument names, types, and defaults. EventReplayer zips the +schema with the profiler's `Input Dims` / `Input type` arrays to produce the +named, typed IR: + +- **Op name** — the fully qualified operator name (e.g., `aten::mm`, `_rocm_C::paged_attention`) +- **Argument metadata** — for each positional and keyword argument: + - Tensors: shape, dtype, strides, initialization hint + - Scalars: concrete value (int, float, bool, str) + - Lists: element values + - Optionals: `null` when not provided + +This phase is purely mechanical: the same profiler event always produces the same +IR. The output is a portable JSON dictionary. + +**Prerequisite — ops must be registered with the PyTorch dispatcher.** If an op +is called as a plain Python function (e.g., a Triton kernel launched directly), +there is no schema to query and no IR can be extracted. The op must go through +`torch.ops`, `torch.library`, or the JIT registry. This is why aiter's CK-based +ops (`aiter::ck_moe_stage1`) produce full IR while direct Triton kernel calls do +not. The fix is wrapping such kernels in `torch.library.custom_op` so the +dispatcher has a schema to query. + +### Phase 2: Init & Replay (requires judgment) + +Given an IR, Event Replay: + +1. **Allocates tensors** matching the recorded shapes, dtypes, and strides +2. **Initializes values** — `randn` for floating-point tensors, `zeros` for + integer/bool tensors, `None` for optional arguments +3. **Resolves the op** to a callable function (via JIT registry, `torch.ops`, + or direct module import) +4. **Calls the op** and optionally benchmarks it + +The default initialization works well for compute-bound ops like GEMMs and +convolutions, where kernel performance is independent of input values. However, +**control and index tensors require realistic values** — zeroed-out metadata +produces behavior that is not representative of the true workload: + +| Op family | Affected tensors | Effect of zeros | +|-----------|-----------------|-----------------| +| Paged Attention | `block_tables`, `seq_lens`, `query_start_loc` | Kernel sees 0 context length — does no real work | +| MoE Routing | `sorted_token_ids`, `sorted_expert_ids`, `num_valid_ids` | Kernel sees 0 valid tokens — skips all computation | + +A true reproduction would require the exact tensor values from the original +execution, but the profiler doesn't capture tensor contents — only shapes and +dtypes. **Custom Initializers** bridge this gap by constructing plausible values +from shapes and metadata already in the IR, without additional instrumentation. + --- -## Use Cases +## Custom Initializers -- **Performance Debugging**: Quickly isolate and reproduce performance issues from large models. -- **Regression Testing**: Automate benchmarks to detect performance regressions at the operator level. -- **Kernel Development**: Extract minimal reproducers for GPU kernel optimization and debugging. -- **Numerical Validation**: Evaluate numerical correctness and stability of isolated operations across hardware. -- **Hardware Counter Profiling**: Use with hardware counters to analyze performance bottlenecks in specific operations. +Custom initializers fill metadata tensors with realistic values before replay. +They are applied automatically when `auto_init=True` (the default). + +### Built-in Initializers + +These ship with TraceLens and require no setup — they activate automatically +when the op name matches: + +**`PagedAttentionInit`** — matches `_rocm_C::paged_attention` + +Initializes the KV cache metadata so the attention kernel does real work: +- `block_tables` — random permutation of the physical block pool (simulates + realistic scattered memory allocation) +- `seq_lens` — all sequences set to `max_seq_len` +- `query_start_loc` — CSR indptr encoding per-sequence query token counts. + When iteration annotations are available (see [Iteration Annotations](#iteration-annotations-vllm-traces)), + uses the exact prefill/decode split; otherwise falls back to heuristics + +**`MoeRoutingInit`** — matches `aiter::ck_moe_stage1`, `aiter::ck_moe_stage2` + +Constructs a complete token-to-expert routing table: +- `sorted_token_ids` — padded to `block_m` boundaries per expert +- `sorted_expert_ids` — block-level expert assignment +- `num_valid_ids` — total valid (non-padding) token slots + +Supports configurable token distribution via `init_kwargs`: + +```python +# Default: uniform random assignment across experts +replayer = EventReplayer(event, device='cuda') + +# Zipf: skewed distribution (few experts get most tokens, closer to real routing) +replayer = EventReplayer(event, device='cuda', + init_kwargs={"moe_distribution": "zipf", "moe_zipf_s": 1.5}) +``` + +### Writing Your Own Initializer + +If you're replaying an op that needs realistic tensor content but isn't +covered by the built-ins, you can write your own custom initializer in +three steps. For real-world examples, see `PagedAttentionInit` and +`MoeRoutingInit` in `TraceLens/EventReplay/custom_inits.py`. + +**Step 1 — Subclass `CustomInit`.** Set `op_patterns` to the exact op name(s) +you want to target. This is an exact match against the profiler event name +(e.g., `"aten::index_add_"`, not `"index_add"`): + +```python +from TraceLens.EventReplay import EventReplayer, CustomInit + +class IndexAddInit(CustomInit): + op_patterns = ["aten::index_add_"] +``` + +**Step 2 — Implement `initialize()`.** This method receives the `replayer` +object and mutates its tensors **in-place** before the op executes. You have +access to: + +- `replayer.args` — list of allocated tensors/scalars (in schema order) +- `replayer.kwargs` — dict of keyword arguments +- `replayer.event` — the raw profiler event dict +- `replayer.event_replay_IR` — the extracted IR with named argument metadata + +Look up arguments **by name** from the IR rather than hardcoding positional +indices — this keeps your initializer robust to schema changes across library +versions: + +```python + def initialize(self, replayer, **kwargs): + import torch + + ir = replayer.event_replay_IR + arg_names = [a["arg_name"] for a in ir["list_pos_args"]] + + self_tensor = replayer.args[arg_names.index("self")] + dim = replayer.args[arg_names.index("dim")] + index = replayer.args[arg_names.index("index")] + + dim_size = self_tensor.shape[dim] + index.copy_(torch.randint(0, dim_size, index.shape, + device=index.device)) + + return (f"[custom init] index_add — index randint(0, {dim_size}), " + f"shape={list(index.shape)}") +``` + +In this example, `aten::index_add_` accumulates source rows into `self` at +positions given by `index`. The default zero-init makes every row land on +row 0 — not representative of the real scatter pattern. The initializer +fills `index` with random valid indices so the kernel exercises realistic +memory access. + +**Step 3 — Register it.** Once registered, the initializer fires automatically +on every future replay of matching ops: + +```python +EventReplayer.register_custom_init(IndexAddInit()) +``` + +When `replay()` runs with `auto_init=True`, it iterates over registered +initializers and applies the **first match** (built-ins are checked first, +then user-registered ones in order). + +To see what's currently registered: + +```python +EventReplayer.list_custom_inits() +``` + +--- + +## Auto-Import for Custom Ops + +When EventReplayer encounters an op from a non-`aten` namespace (e.g., +`_rocm_C::paged_attention`, `aiter::ck_moe_stage1`), it automatically attempts +to import the library that registers the op's schema. The import is conditional +— it only fires if the namespace is recognized and hasn't been attempted yet. + +Built-in namespace mappings: + +| Namespace | Imported modules | +|-----------|-----------------| +| `aiter` | `aiter` | +| `_rocm_C` | `vllm._rocm_C` | +| `_C` | `vllm._C` | +| `vllm` | `vllm._C`, `vllm._rocm_C` | + +Register additional namespaces: + +```python +EventReplayer.register_namespace("my_lib", ["my_lib.ops"]) +``` + +--- + +## Iteration Annotations (vLLM traces) + +### The problem + +Paged attention's `query_start_loc` is a CSR indptr array that encodes how +many query tokens each sequence contributes. In a mixed batch (common in +vLLM's continuous batching), some sequences are **prefill** (many query tokens) +and others are **decode** (1 token each). The profiler captures the tensor shape +but not the per-sequence breakdown, so without additional information the +custom initializer has to guess — and guessing wrong changes the compute +pattern significantly (prefill is quadratic in sequence length, decode is +linear). + +### How vLLM exposes this + +vLLM emits a `user_annotation` event for each `execute_model` iteration +with the exact prefill/decode composition encoded in the name: + +``` +execute_context_2(18)_generation_5(5) +``` + +This is an **iteration annotation** — it describes one forward pass. Here it +means: **2 prefill sequences** with **18 total query tokens**, and **5 decode +sequences** with **5 tokens** (1 each). + +### Extracting batch context from iteration annotations + +`extract_batch_context` parses these iteration annotations by timestamp and +attaches a `batch_context` dict to each paged attention event that falls +within the annotation's time range: + +```python +from TraceLens import TreePerfAnalyzer +from TraceLens.EventReplay import EventReplayer, extract_batch_context + +analyzer = TreePerfAnalyzer.from_file("vllm_trace.json") + +# Annotate paged attention events with prefill/decode split +num_annotated = extract_batch_context(analyzer) +print(f"Annotated {num_annotated} paged_attention events") + +# Now replay — PagedAttentionInit reads event["batch_context"] automatically +event = analyzer.tree.get_UID2event(some_uid) +replayer = EventReplayer(event, device='cuda') +replayer.replay() # query_start_loc reflects the real prefill/decode split +``` + +After `extract_batch_context`, each annotated event carries: + +```python +event["batch_context"] = { + "n_prefill": 2, # number of prefill sequences + "prefill_tokens": 18, # total query tokens across prefill sequences + "n_decode": 5, # number of decode sequences + "decode_tokens": 5, # total query tokens across decode sequences (1 each) +} +``` + +`PagedAttentionInit` uses this to build `query_start_loc` accurately: +prefill sequences get `prefill_tokens / n_prefill` tokens each, decode +sequences get 1 token each. + +### Without iteration annotations + +`PagedAttentionInit` still runs — it always initializes `block_tables` +(random permutation of the physical block pool) and `seq_lens` (set to +`max_seq_len`). The only difference is how `query_start_loc` is built. +Without annotations, it falls back to heuristics: + +- `query_tokens == num_seqs` → assumes pure decode (1 token/seq) +- `query_tokens > num_seqs` → assumes pure prefill (tokens distributed + uniformly) + +This is a reasonable approximation for homogeneous batches but inaccurate +for mixed prefill+decode batches, where the iteration annotation provides +the exact split. --- -## Notes +## Known Limitations + +- **Unregistered ops are invisible.** Triton kernels called directly from Python + (e.g., aiter's Triton attention path) have no schema in the profiler. The fix + is wrapping them in `torch.library.custom_op` — a one-time registration effort + in the upstream library. -- Event Replay uses randomized data based on extracted tensor shapes; thus, replay timings approximate real-world performance. +- **Single-op isolation vs. real workload.** Replay runs each op in isolation + with no surrounding memory traffic. Timings are a lower bound on in-model + performance. Sequence replay (ops in trace order to reproduce natural cache + pollution) is a planned enhancement. + +- **Data-dependent kernels.** Custom initializers provide plausible but not exact + values from the original execution (the profiler doesn't capture tensor + contents). For most ops this doesn't matter; for ops with data-dependent + control flow (e.g., sparse attention with variable sequence lengths), timing + may vary. + +--- + +## Use Cases + +- **Trace Interpretation**: Translate opaque profiler arguments into named, typed JSON for understanding what each op actually computes. +- **Performance Debugging**: Isolate and reproduce performance issues from large models without running the model. +- **Regression Testing**: Automate benchmarks to detect performance regressions at the operator level. +- **Kernel Development**: Extract minimal reproducers for GPU kernel optimization and debugging. +- **Portable Sharing**: Package IR + replay scripts as standalone zip artifacts for teammates or upstream repos.