From e291628b3296f58a3053f92f559606d4f7d6821f Mon Sep 17 00:00:00 2001 From: Adeem Jassani Date: Sat, 28 Mar 2026 18:10:47 +0000 Subject: [PATCH 1/4] Expand EventReplay to support custom ops beyond aten:: Op resolution: - Add _resolve_op_func() with JIT-first resolution (preserves in-place kernel dispatch for aten ops) and torch.ops fallback for custom ops - Add _search_schemas() to collect schemas from both JIT registry and torch.ops namespace overloads Schemaless replay: - Add _get_event_replay_IR_schemaless() that infers argument types directly from profiled data when no schema is available, enabling replay of ops like _C::silu_and_mul and _C::rotary_embedding Type handling: - Fix Scalar type: preserve integer values for integral tensor ops instead of always casting to float - Add str/str?, SymInt?/int?, Generator? support in schema matching - Add _is_tensor_schema_type() for annotated variants like Tensor(a!) - Add _should_skip_tensor_init() generalizing in-place/output detection Dtype support (utils.py): - Add long int, unsigned char, char, short, and FP8 types - Use zeros init for non-floating-point tensors Tested with: - ResNet regression suite (70 aten:: ops) - vLLM Qwen1.5-MoE-A2.7B trace: 7/9 aiter ops, plus _rocm_C::wvSplitK, _C::silu_and_mul, _C::rotary_embedding (requires import vllm._C/_rocm_C) Made-with: Cursor --- TraceLens/EventReplay/event_replay.py | 373 +++++++++++++++++++------- TraceLens/EventReplay/utils.py | 34 ++- 2 files changed, 310 insertions(+), 97 deletions(-) diff --git a/TraceLens/EventReplay/event_replay.py b/TraceLens/EventReplay/event_replay.py index cb0be6a37..3aa00bfca 100644 --- a/TraceLens/EventReplay/event_replay.py +++ b/TraceLens/EventReplay/event_replay.py @@ -21,6 +21,96 @@ ) +def _resolve_op_func(op_name: str): + """ + Resolve an op name (e.g. 'aten::mm', 'vllm::rocm_unquantized_gemm') to a + callable. Tries multiple resolution strategies and returns the first that + yields a non-None callable. + + Returns (func, source_str) or raises RuntimeError. + """ + torch = _get_torch_or_raise() + + # 1. JIT registry first — preserves dispatch behaviour that the original + # profiled run used (important for in-place aten ops like add_). + try: + func, _ = torch._C._jit_get_operation(op_name) + if func is not None: + return func, "jit" + except RuntimeError: + pass + + # 2. torch.ops namespace lookup (most reliable for custom ops that may + # not be registered in the JIT registry). + if "::" in op_name: + ns, func_name = op_name.split("::", 1) + ns_obj = getattr(torch.ops, ns, None) + if ns_obj is not None: + func_obj = getattr(ns_obj, func_name, None) + if callable(func_obj): + return func_obj, "torch.ops" + + raise RuntimeError( + f"Cannot resolve op '{op_name}'. Ensure the library that defines it " + f"is imported (e.g. 'import vllm', 'import aiter')." + ) + + +def _search_schemas(op_name: str, verbose: bool = False): + """ + Return all registered FunctionSchemas for *op_name*. + + Searches both the JIT schema registry and the torch.ops namespace, which + covers aten ops, custom C++ ops, and Python-defined torch.library ops. + """ + torch = _get_torch_or_raise() + schemas: list = [] + seen_strs: set = set() + + # JIT registry + for s in torch._C._jit_get_all_schemas(): + if s.name == op_name: + s_str = str(s) + if s_str not in seen_strs: + schemas.append(s) + seen_strs.add(s_str) + + # torch.ops namespace (catches custom ops not in the JIT list) + if "::" in op_name: + ns, func_name = op_name.split("::", 1) + ns_obj = getattr(torch.ops, ns, None) + if ns_obj is not None: + op_obj = getattr(ns_obj, func_name, None) + if op_obj is not None: + # OpOverloadPacket exposes overloads + try: + for overload_name in op_obj.overloads(): + overload = getattr(op_obj, overload_name) + s = overload._schema + s_str = str(s) + if s_str not in seen_strs: + schemas.append(s) + seen_strs.add(s_str) + except Exception: + # Fallback: try .default directly + try: + s = op_obj.default._schema + s_str = str(s) + if s_str not in seen_strs: + schemas.append(s) + seen_strs.add(s_str) + except Exception: + pass + + if verbose: + print(f"Found {len(schemas)} schemas for {op_name}:") + for s in schemas: + pprint(str(s)) + print("-" * 80) + + return schemas + + class EventReplayer: def __init__( self, @@ -41,6 +131,7 @@ def __init__( self.device = device self.lazy = lazy self.verbose = verbose + self._func = None self._setup() def _setup(self): @@ -49,10 +140,33 @@ def _setup(self): """ if self.verbose: print(f"Preparing {self.event['name']} event for replay") - self.matched_schema = EventReplayer._search_schema(self.event, self.verbose) - self.event_replay_IR = EventReplayer._get_event_replay_IR( - self.event, self.matched_schema, self.verbose - ) + + self._func, self._func_source = _resolve_op_func(self.event["name"]) + if self.verbose: + print(f"Resolved op via {self._func_source}") + + try: + self.matched_schema = EventReplayer._search_schema( + self.event, self.verbose + ) + self._schemaless = False + except ValueError: + if self.verbose: + print( + "No schema found; falling back to schemaless replay " + "(all args treated as positional, types inferred from profile)" + ) + self.matched_schema = None + self._schemaless = True + + if self._schemaless: + self.event_replay_IR = EventReplayer._get_event_replay_IR_schemaless( + self.event, self.verbose + ) + else: + self.event_replay_IR = EventReplayer._get_event_replay_IR( + self.event, self.matched_schema, self.verbose + ) if not self.lazy: if self.verbose: print("setting up args and kwargs") @@ -64,11 +178,6 @@ def replay(self): """ Replay the event using the matched schema and event replay IR. """ - torch = _get_torch_or_raise() - # Get the function from the schema - func, _ = torch._C._jit_get_operation(self.event["name"]) - - # Call the function with the arguments if self.lazy: args, kwargs = EventReplayer._get_args_kwargs( self.event_replay_IR, device=self.device @@ -76,22 +185,13 @@ def replay(self): else: args, kwargs = self.args, self.kwargs - # Call the function with the arguments - func(*args, **kwargs) + self._func(*args, **kwargs) @staticmethod def _search_schema( event: Dict[str, Any], verbose: bool = False ) -> Optional["torch._C.FunctionSchema"]: - torch = _get_torch_or_raise() - all_schemas = torch._C._jit_get_all_schemas() - op_schemas = [s for s in all_schemas if s.name == event["name"]] - # print each schema in separate line - if verbose: - print(f"Found {len(op_schemas)} schemas for {event['name']}:") - for schema in op_schemas: - pprint(str(schema)) - print("-" * 80) + op_schemas = _search_schemas(event["name"], verbose=verbose) for schema in op_schemas: if verbose: @@ -106,7 +206,9 @@ def _search_schema( print("-" * 80) raise ValueError( - f"Cannot find matching schema for {event['name']}. Please check the event data and schema." + f"Cannot find matching schema for {event['name']}. " + f"Searched {len(op_schemas)} candidate(s). " + f"Please check the event data and ensure the op's library is imported." ) @staticmethod @@ -115,22 +217,13 @@ def _is_schema_match( ) -> bool: """ Check if the event matches the schema. - - Args: - event (Dict[str, Any]): The event data. - schema (torch._C.FunctionSchema): The schema to match against. - - Returns: - bool: True if the event matches the schema, False otherwise. """ op_name, pos_args_schema, kwargs_schema, return_type = ( EventReplayer.parse_schema_string(schema) ) full_args_schema = pos_args_schema + kwargs_schema - # Check if the number of args in the event matches the schema if len(event["args"]["Input type"]) != len(full_args_schema): return False - # Check if the types match for idx in range(len(event["args"]["Input type"])): profiled_type = event["args"]["Input type"][idx] schema_type = full_args_schema[idx]["arg_type"] @@ -138,16 +231,9 @@ def _is_schema_match( print(f"Checking arg {idx}:") print(f"\tSchema type: {schema_type}") print(f"\tProfiled type: {profiled_type}") - # Rules for matching types - # 1. for tensor types, schema type should be 'Tensor' and profiled type can be any of the tensor types 'float', 'c10::Half', 'c10::BFloat16' ... - # 2. for bool types, schema type should be 'bool' and profiled type is 'Scalar'. So we need to further check the concrete Inputs if it only contains 'true' or 'false' - # 3. for int types, schema type should be 'int' or 'SymInt' and profiled type is 'Scalar'. So we need to further check the concrete Inputs if it is a digit - # 4. for float types, schema type should be 'Scalar' and profiled type is 'Scalar'. So we need to further check the concrete Inputs if it is a float - # 5. for int[] types, schema type should be 'int[]' or 'SymInt[]' and profiled type is 'ScalarList'. So we need to further check the concrete Inputs if it is a list of digits - # 6. for bool[] types, schema type should be 'bool[]' and profiled type is 'ScalarList'. So we need to further check the concrete Inputs if it is a list of 'true' or 'false' - # 7. for tensor[] types, we cannot replay the event as the tensor shapes are not provided in the event. So we need to skip this case. Maybe suggest PyTorch to add this in the future. + is_match = True - # if the schema type ends with '?' then the profiled type can be blank as well + # Optional types: schema ends with '?' => profiled type can be blank if schema_type.endswith("?"): schema_type = schema_type[:-1] if profiled_type == "": @@ -157,20 +243,20 @@ def _is_schema_match( and event["args"]["Concrete Inputs"][idx] == "[]" ): continue - if schema_type in ["Tensor", "Tensor?", "Tensor(a!)"]: + if EventReplayer._is_tensor_schema_type(schema_type): if profiled_type not in list_profile_tensor_types: is_match = False elif schema_type == "bool": profiled_value = event["args"]["Concrete Inputs"][idx] if profiled_value.lower() not in ["true", "false"]: is_match = False - elif schema_type == "int" or schema_type == "SymInt": + elif schema_type in ("int", "SymInt"): if profiled_type != "Scalar": is_match = False profiled_value = event["args"]["Concrete Inputs"][idx] if not profiled_value.lstrip("-").isdigit(): is_match = False - elif schema_type in ["float", "Scalar"]: + elif schema_type in ("float", "Scalar"): if profiled_type != "Scalar": is_match = False profiled_value = event["args"]["Concrete Inputs"][idx] @@ -179,7 +265,6 @@ def _is_schema_match( except ValueError: is_match = False elif schema_type.startswith("int[") or schema_type.startswith("SymInt["): - # custom dev debugging if profiled_type != "ScalarList": is_match = False profiled_value = event["args"]["Concrete Inputs"][idx] @@ -196,16 +281,27 @@ def _is_schema_match( x.strip() for x in profiled_value.strip()[1:-1].split(",") ] if not all( - x.lower() in ["true", "false"] for x in profiled_value_cleaned + x.lower() in ("true", "false") for x in profiled_value_cleaned ): is_match = False elif schema_type.startswith("Tensor["): raise ValueError( - f"Tensor list type not supported: {schema_type} as the tensor shapes are not provided in the event" + f"Tensor list type not supported: {schema_type} as the " + f"tensor shapes are not provided in the event" ) + elif schema_type == "str": + is_match = profiled_type == "Scalar" or profiled_type == "" + elif schema_type == "ScalarType": + is_match = profiled_type == "Scalar" or profiled_type == "" + elif schema_type == "Layout": + is_match = profiled_type == "Scalar" or profiled_type == "" + elif schema_type == "Device": + is_match = profiled_type == "Scalar" or profiled_type == "" + elif schema_type == "MemoryFormat": + is_match = profiled_type == "Scalar" or profiled_type == "" + elif schema_type == "Generator": + is_match = profiled_type == "" or profiled_type == "Scalar" else: - # raise ValueError(f"Unknown schema type: {schema_type}") - # warning: if the schema type is not in the list, we will skip this case warnings.warn( f"Unknown schema type: {schema_type}. Skipping this case." ) @@ -218,35 +314,42 @@ def _is_schema_match( return False return True + @staticmethod + def _is_tensor_schema_type(schema_type: str) -> bool: + """Check if a schema type string represents a Tensor argument.""" + if schema_type in ("Tensor", "Tensor?"): + return True + # Handles annotated variants like Tensor(a!), Tensor(a), Tensor(b!) + if schema_type.startswith("Tensor("): + return True + return False + + @staticmethod + def _should_skip_tensor_init(evt_name: str, arg_name: str, arg_idx: int) -> bool: + """ + Determine whether a tensor argument is an output-only buffer that + does not need random initialization. + + Generalizes the old aten::fill_ / aten::copy_ special-cases to + any in-place or out-of-place output tensor. + """ + # In-place ops (name ends with '_'): first tensor is the mutated output + if evt_name.endswith("_") and arg_idx == 0: + return True + # Explicit 'out' arguments in .out variants + if arg_name == "out": + return True + # aten::copy_ destination + if evt_name == "aten::copy_" and arg_name != "src": + return True + return False + @staticmethod def _get_event_replay_IR( event: Dict[str, Any], schema: "torch._C.FunctionSchema", verbose: bool = False ) -> Dict[str, Any]: """ Get the event replay IR from the event and schema. - - Args: - event (Dict[str, Any]): The event data. - schema (torch._C.FunctionSchema): The schema to match against. - - Returns: - { - 'pos_args': [ - - dummy_tensor0, - dummy_tensor1, - value0, - value1, - ... - ], - 'kwargs': { - 'arg0': value0, - 'arg1': dummy_tensor0, - 'arg2': dummy_tensor1, - 'arg3': value1, - ... - } - } """ evt_name = event["name"] op_name, pos_args_schema, kwargs_schema, return_type = ( @@ -267,36 +370,45 @@ def _get_event_replay_IR( print(f"Concrete Inputs: {event['args']['Concrete Inputs'][idx]}") if arg_type.endswith("?") and event["args"]["Input type"][idx] == "": - value = None + if arg_type.startswith("str"): + default = full_args_schema[idx].get("default") + value = "" if default is None or default == "None" else default + else: + value = None elif ( arg_type.endswith("?") and event["args"]["Concrete Inputs"][idx] == "[]" ): value = [] else: - if arg_type in ["Tensor", "Tensor?", "Tensor(a!)"]: + if EventReplayer._is_tensor_schema_type(arg_type): init = "normal" - if evt_name == "aten::fill_": - # special case for fill_ where we don't need to initialize the tensor - # as it will be filled with a value later - init = None - elif evt_name == "aten::copy_" and arg_name != "src": - # special case for copy_ where we don't need to initialize the tensor - # as it will be copied from another tensor + if EventReplayer._should_skip_tensor_init(evt_name, arg_name, idx): init = None + profiled_dtype = event["args"]["Input type"][idx] + # Non-floating-point tensors cannot use 'normal' init + if profiled_dtype in ("long", "long int", "int", "bool", "unsigned char"): + init = "zeros" if init == "normal" else init value = TensorCfg( shape=event["args"]["Input Dims"][idx], - dtype=event["args"]["Input type"][idx], + dtype=profiled_dtype, strides=event["args"]["Input Strides"][idx], init=init, ) else: arg_str = event["args"]["Concrete Inputs"][idx] - if arg_type in ["bool", "bool?"]: + if arg_type in ("bool", "bool?"): value = arg_str.lower() == "true" - elif arg_type in ["int", "SymInt"]: + elif arg_type in ("int", "int?", "SymInt", "SymInt?"): value = int(arg_str) - elif arg_type in ["float", "float?", "Scalar", "Scalar?"]: + elif arg_type in ("Scalar", "Scalar?"): + if arg_str.lstrip("-").isdigit(): + value = int(arg_str) + else: + value = float(arg_str) + elif arg_type in ("float", "float?"): value = float(arg_str) + elif arg_type in ("str", "str?"): + value = arg_str elif arg_type.startswith("int[") or arg_type.startswith("SymInt["): value = [ int(x.strip()) for x in arg_str.strip()[1:-1].split(",") @@ -324,18 +436,95 @@ def _get_event_replay_IR( ) return {"list_pos_args": list_pos_args, "list_kwargs": list_kwargs} + @staticmethod + def _get_event_replay_IR_schemaless( + event: Dict[str, Any], verbose: bool = False + ) -> Dict[str, Any]: + """ + Build a replay IR without a schema by inferring types directly from the + profiled data. All arguments are treated as positional. + + Heuristics: + - If Input type is a known tensor dtype -> TensorCfg + - If Input type is 'Scalar' and Concrete Inputs looks like int -> int + - If Input type is 'Scalar' and Concrete Inputs looks like float -> float + - If Input type is 'Scalar' and Concrete Inputs is true/false -> bool + - If Input type is '' and Concrete Inputs is '' -> None + """ + evt_name = event["name"] + list_pos_args = [] + n_args = len(event["args"]["Input type"]) + for idx in range(n_args): + profiled_type = event["args"]["Input type"][idx] + profiled_dims = event["args"]["Input Dims"][idx] + profiled_strides = event["args"]["Input Strides"][idx] + concrete = event["args"]["Concrete Inputs"][idx] + + if verbose: + print( + f"Schemaless arg {idx}: type={profiled_type!r} " + f"dims={profiled_dims} concrete={concrete!r}" + ) + + if profiled_type in list_profile_tensor_types: + init = "normal" + if profiled_type in ( + "long", "long int", "int", "bool", "unsigned char", + ): + init = "zeros" + value = TensorCfg( + shape=profiled_dims, + dtype=profiled_type, + strides=profiled_strides, + init=init, + ) + arg_type = "Tensor" + elif profiled_type == "Scalar" and concrete: + if concrete.lower() in ("true", "false"): + value = concrete.lower() == "true" + arg_type = "bool" + elif concrete.lstrip("-").isdigit(): + value = int(concrete) + arg_type = "int" + else: + try: + value = float(concrete) + arg_type = "float" + except ValueError: + value = concrete + arg_type = "str" + elif profiled_type == "" and concrete == "": + value = None + arg_type = "None" + elif profiled_type == "ScalarList" and concrete: + items = [x.strip() for x in concrete.strip()[1:-1].split(",") if x.strip()] + if all(x.lstrip("-").isdigit() for x in items): + value = [int(x) for x in items] + else: + value = [float(x) for x in items] + arg_type = "list" + else: + value = None + arg_type = "unknown" + if verbose: + print(f" -> defaulting to None for unknown type") + + inferred_name = f"arg{idx}" + list_pos_args.append( + {"arg_name": inferred_name, "arg_type": arg_type, "value": value} + ) + if verbose: + print(f" -> {inferred_name}: {arg_type} = {value}") + print("-" * 80) + + return {"list_pos_args": list_pos_args, "list_kwargs": []} + @staticmethod def _get_args_kwargs( event_replay_IR: Dict[str, Any], device: str = "cuda" ) -> tuple[List["torch.Tensor"], Dict[str, Any]]: """ Get the arguments and keyword arguments from the event replay IR. - - Args: - event_replay_IR (Dict[str, Any]): The event replay IR. - - Returns: - (List[torch.Tensor], Dict[str, Any]): The positional arguments and keyword arguments. """ pos_args = [] for arg in event_replay_IR["list_pos_args"]: @@ -402,18 +591,12 @@ def get_repro_info(self) -> Dict[str, Any]: Dict[str, Any]: A dictionary containing the operator name and the replay IR. Suitable for JSON serialization using the custom encoder. """ - # return { - # 'op_name': self.event['name'], - # 'replay_ir': self.event_replay_IR - # # No device info here - device is decided by the runner - # } dict_repro_info = {} dict_repro_info["op_name"] = self.event["name"] list_pos_args, list_kwargs = ( self.event_replay_IR["list_pos_args"], self.event_replay_IR["list_kwargs"], ) - # Convert TensorCfg to dict for JSON serialization list_pos_args_copy, list_kwargs_copy = list_pos_args.copy(), list_kwargs.copy() for idx, val in enumerate(list_pos_args_copy): if isinstance(val["value"], TensorCfg): diff --git a/TraceLens/EventReplay/utils.py b/TraceLens/EventReplay/utils.py index 304cf6a98..a4d92ebca 100644 --- a/TraceLens/EventReplay/utils.py +++ b/TraceLens/EventReplay/utils.py @@ -32,8 +32,16 @@ def _get_torch_or_raise() -> Any: # Changed return type to Any for flexibility "c10::Half", "c10::BFloat16", "long", + "long int", "int", "bool", + "unsigned char", + "char", + "short", + "c10::Float8_e4m3fnuz", + "c10::Float8_e5m2fnuz", + "c10::Float8_e4m3fn", + "c10::Float8_e5m2", ] from dataclasses import dataclass @@ -58,15 +66,35 @@ def build_tensor(cfg: TensorCfg, device: str = "cuda") -> "torch.Tensor": "bool": torch.bool, "int": torch.int, "long": torch.long, + "long int": torch.long, + "short": torch.short, + "char": torch.int8, + "unsigned char": torch.uint8, "double": torch.float64, "float": torch.float32, "c10::Half": torch.float16, "c10::BFloat16": torch.bfloat16, } + # FP8 types (available in PyTorch >= 2.1) + for fp8_name in ( + "c10::Float8_e4m3fnuz", + "c10::Float8_e5m2fnuz", + "c10::Float8_e4m3fn", + "c10::Float8_e5m2", + ): + attr = fp8_name.replace("c10::", "") + torch_dtype = getattr(torch, attr.lower(), None) + if torch_dtype is not None: + dict_profile2torchdtype[fp8_name] = torch_dtype + + if cfg.dtype not in dict_profile2torchdtype: + raise ValueError( + f"Unknown profiled dtype '{cfg.dtype}'. " + f"Known types: {list(dict_profile2torchdtype.keys())}" + ) dtype = dict_profile2torchdtype[cfg.dtype] size = cfg.shape stride = cfg.strides - # allocate *exactly* the storage needed for that stride/shape t = torch.empty_strided(size, stride, dtype=dtype, device=device) is_floating = t.is_floating_point() or t.is_complex() init = cfg.init @@ -75,7 +103,9 @@ def build_tensor(cfg: TensorCfg, device: str = "cuda") -> "torch.Tensor": raise ValueError( f"Cannot initialize tensor of type {cfg.dtype} with 'normal' init." ) - t.normal_() # or whatever init you like + t.normal_() + elif init == "zeros": + t.zero_() elif init is not None: raise ValueError(f"Unsupported tensor initialization: {init}") return t From b00976c87c7fc58844b9b07df82abc5950fcba69 Mon Sep 17 00:00:00 2001 From: Adeem Jassani Date: Sat, 28 Mar 2026 18:39:50 +0000 Subject: [PATCH 2/4] Add string arg defaults, op aliases, and module resolution String arg defaults (_STR_ARG_DEFAULTS): - When the profiler drops a str arg value, check a known-defaults table keyed by arg name (e.g. kv_cache_dtype -> "auto") - Log a WARNING when a default is used so users know the value was inferred - Recovers _C_cache_ops::reshape_and_cache_flash (1.97% GPU time) Op name aliases (_OP_NAME_ALIASES): - Map trace-recorded names to their runtime-registered names (e.g. _rocm_C::wvSplitK -> _rocm_C::wvSpltK) Python module resolution (3rd strategy): - After JIT and torch.ops, try importlib.import_module(namespace) for JIT-compiled ops like aiter Schema parser fix: - parse_schema_string handles annotated tensor types with spaces like "Tensor($0! -> )" correctly now Tested with vLLM Qwen1.5-MoE-A2.7B trace on MI300X (tw025). Made-with: Cursor --- TraceLens/EventReplay/event_replay.py | 185 +++++++++++++++++++++----- 1 file changed, 152 insertions(+), 33 deletions(-) diff --git a/TraceLens/EventReplay/event_replay.py b/TraceLens/EventReplay/event_replay.py index 3aa00bfca..685190b1a 100644 --- a/TraceLens/EventReplay/event_replay.py +++ b/TraceLens/EventReplay/event_replay.py @@ -9,6 +9,7 @@ from pprint import pprint from typing import Dict, Any, List, Optional, Tuple +import logging import re import warnings import time @@ -20,19 +21,36 @@ list_profile_tensor_types, ) - -def _resolve_op_func(op_name: str): - """ - Resolve an op name (e.g. 'aten::mm', 'vllm::rocm_unquantized_gemm') to a - callable. Tries multiple resolution strategies and returns the first that - yields a non-None callable. - - Returns (func, source_str) or raises RuntimeError. +logger = logging.getLogger(__name__) + +# ── Known defaults for string arguments the profiler drops ────────────── +# The PyTorch profiler records `str` arguments as empty strings. When we +# know the only sensible default we fill it in automatically and warn. +# Key = argument name, Value = default string value. +_STR_ARG_DEFAULTS: Dict[str, str] = { + "kv_cache_dtype": "auto", +} + +# ── Op-name aliases ───────────────────────────────────────────────────── +# Some frameworks profile an op under one namespace but register the +# actual callable under a different one. +# Key = name as it appears in the trace, Value = list of candidates to try. +# NOTE: aiter::paged_attention_v1/v2 are NOT aliasable — the aiter JIT +# wrapper records a different arg layout than the underlying _C:: / _rocm_C:: +# schemas, so arg mapping would fail even if resolution succeeds. +_OP_NAME_ALIASES: Dict[str, List[str]] = { + "_rocm_C::wvSplitK": ["_rocm_C::wvSpltK"], +} + + +def _try_resolve(op_name: str): + """Attempt JIT + torch.ops + module resolution for a single op name. + Returns (func, source_str) or (None, None). """ torch = _get_torch_or_raise() + import importlib - # 1. JIT registry first — preserves dispatch behaviour that the original - # profiled run used (important for in-place aten ops like add_). + # 1. JIT registry — preserves dispatch behaviour for aten ops. try: func, _ = torch._C._jit_get_operation(op_name) if func is not None: @@ -40,16 +58,55 @@ def _resolve_op_func(op_name: str): except RuntimeError: pass - # 2. torch.ops namespace lookup (most reliable for custom ops that may - # not be registered in the JIT registry). if "::" in op_name: ns, func_name = op_name.split("::", 1) + + # 2. torch.ops namespace — custom ops registered via torch.library. ns_obj = getattr(torch.ops, ns, None) if ns_obj is not None: func_obj = getattr(ns_obj, func_name, None) if callable(func_obj): return func_obj, "torch.ops" + # 3. Direct Python module lookup — handles JIT-compiled ops (e.g. + # aiter) that exist as Python callables but aren't registered in + # the torch op registry. + try: + mod = importlib.import_module(ns) + func_obj = getattr(mod, func_name, None) + if callable(func_obj): + return func_obj, f"module:{ns}" + except ImportError: + pass + + return None, None + + +def _resolve_op_func(op_name: str): + """ + Resolve an op name (e.g. 'aten::mm', 'vllm::rocm_unquantized_gemm') to a + callable. Tries multiple resolution strategies: + + 1. JIT registry (preserves original dispatch behaviour). + 2. torch.ops namespace (custom ops registered via torch.library / pybind). + 3. Known aliases from _OP_NAME_ALIASES (handles trace-name mismatches). + + Returns (func, source_str, resolved_name) or raises RuntimeError. + """ + func, source = _try_resolve(op_name) + if func is not None: + return func, source, op_name + + for alias in _OP_NAME_ALIASES.get(op_name, []): + func, source = _try_resolve(alias) + if func is not None: + logger.warning( + "Op '%s' resolved via alias '%s' (%s). " + "The trace recorded a different namespace than the runtime registration.", + op_name, alias, source, + ) + return func, source, alias + raise RuntimeError( f"Cannot resolve op '{op_name}'. Ensure the library that defines it " f"is imported (e.g. 'import vllm', 'import aiter')." @@ -141,13 +198,17 @@ def _setup(self): if self.verbose: print(f"Preparing {self.event['name']} event for replay") - self._func, self._func_source = _resolve_op_func(self.event["name"]) + self._func, self._func_source, self._resolved_name = _resolve_op_func( + self.event["name"] + ) if self.verbose: print(f"Resolved op via {self._func_source}") + if self._resolved_name != self.event["name"]: + print(f" (aliased from '{self.event['name']}' -> '{self._resolved_name}')") try: self.matched_schema = EventReplayer._search_schema( - self.event, self.verbose + self.event, self._resolved_name, self.verbose ) self._schemaless = False except ValueError: @@ -161,7 +222,7 @@ def _setup(self): if self._schemaless: self.event_replay_IR = EventReplayer._get_event_replay_IR_schemaless( - self.event, self.verbose + self.event, self.verbose, resolved_name=self._resolved_name ) else: self.event_replay_IR = EventReplayer._get_event_replay_IR( @@ -189,9 +250,12 @@ def replay(self): @staticmethod def _search_schema( - event: Dict[str, Any], verbose: bool = False + event: Dict[str, Any], + resolved_name: Optional[str] = None, + verbose: bool = False, ) -> Optional["torch._C.FunctionSchema"]: - op_schemas = _search_schemas(event["name"], verbose=verbose) + name = resolved_name or event["name"] + op_schemas = _search_schemas(name, verbose=verbose) for schema in op_schemas: if verbose: @@ -206,7 +270,7 @@ def _search_schema( print("-" * 80) raise ValueError( - f"Cannot find matching schema for {event['name']}. " + f"Cannot find matching schema for {name}. " f"Searched {len(op_schemas)} candidate(s). " f"Please check the event data and ensure the op's library is imported." ) @@ -372,7 +436,15 @@ def _get_event_replay_IR( if arg_type.endswith("?") and event["args"]["Input type"][idx] == "": if arg_type.startswith("str"): default = full_args_schema[idx].get("default") - value = "" if default is None or default == "None" else default + if default is None or default == "None": + default = _STR_ARG_DEFAULTS.get(arg_name) + if default is not None: + logger.warning( + "%s arg '%s' (position %d): profiler dropped " + "the string value. Using known default '%s'.", + evt_name, arg_name, idx, default, + ) + value = "" if default is None else default else: value = None elif ( @@ -408,7 +480,15 @@ def _get_event_replay_IR( elif arg_type in ("float", "float?"): value = float(arg_str) elif arg_type in ("str", "str?"): - value = arg_str + if not arg_str and arg_name in _STR_ARG_DEFAULTS: + value = _STR_ARG_DEFAULTS[arg_name] + logger.warning( + "%s arg '%s' (position %d): profiler dropped " + "the string value. Using known default '%s'.", + evt_name, arg_name, idx, value, + ) + else: + value = arg_str elif arg_type.startswith("int[") or arg_type.startswith("SymInt["): value = [ int(x.strip()) for x in arg_str.strip()[1:-1].split(",") @@ -436,9 +516,23 @@ def _get_event_replay_IR( ) return {"list_pos_args": list_pos_args, "list_kwargs": list_kwargs} + @staticmethod + def _get_schema_arg_names(op_name: str) -> List[str]: + """Best-effort: return a list of arg names from any available schema.""" + schemas = _search_schemas(op_name, verbose=False) + if not schemas: + return [] + try: + _, pos_args, kw_args, _ = EventReplayer.parse_schema_string(schemas[0]) + return [a["arg_name"] for a in pos_args + kw_args] + except Exception: + return [] + @staticmethod def _get_event_replay_IR_schemaless( - event: Dict[str, Any], verbose: bool = False + event: Dict[str, Any], + verbose: bool = False, + resolved_name: Optional[str] = None, ) -> Dict[str, Any]: """ Build a replay IR without a schema by inferring types directly from the @@ -449,9 +543,13 @@ def _get_event_replay_IR_schemaless( - If Input type is 'Scalar' and Concrete Inputs looks like int -> int - If Input type is 'Scalar' and Concrete Inputs looks like float -> float - If Input type is 'Scalar' and Concrete Inputs is true/false -> bool - - If Input type is '' and Concrete Inputs is '' -> None + - If Input type is '' and Concrete Inputs is '' -> check _STR_ARG_DEFAULTS """ evt_name = event["name"] + schema_arg_names = EventReplayer._get_schema_arg_names( + resolved_name or evt_name + ) + list_pos_args = [] n_args = len(event["args"]["Input type"]) for idx in range(n_args): @@ -494,8 +592,22 @@ def _get_event_replay_IR_schemaless( value = concrete arg_type = "str" elif profiled_type == "" and concrete == "": - value = None - arg_type = "None" + # Likely a dropped str arg — check known defaults + hint_name = ( + schema_arg_names[idx] if idx < len(schema_arg_names) else None + ) + default = _STR_ARG_DEFAULTS.get(hint_name) if hint_name else None + if default is not None: + value = default + arg_type = "str" + logger.warning( + "%s arg '%s' (position %d): profiler dropped the " + "string value. Using known default '%s'.", + evt_name, hint_name, idx, default, + ) + else: + value = None + arg_type = "None" elif profiled_type == "ScalarList" and concrete: items = [x.strip() for x in concrete.strip()[1:-1].split(",") if x.strip()] if all(x.lstrip("-").isdigit() for x in items): @@ -509,7 +621,9 @@ def _get_event_replay_IR_schemaless( if verbose: print(f" -> defaulting to None for unknown type") - inferred_name = f"arg{idx}" + inferred_name = ( + schema_arg_names[idx] if idx < len(schema_arg_names) else f"arg{idx}" + ) list_pos_args.append( {"arg_name": inferred_name, "arg_type": arg_type, "value": value} ) @@ -556,17 +670,22 @@ def parse_schema_string( kwarg_part = parts[1].lstrip(",").strip() if len(parts) > 1 else "" def _parse_arg(raw_arg: str) -> Tuple[str, str, Optional[str], bool]: - m = re.match(r"^(\S+)\s+(.*)$", raw_arg) + # Match type (may contain spaces, e.g. "Tensor($0! -> )") then name[=default]. + # Greedy (.+) consumes everything up to the last whitespace before + # a valid identifier, so "Tensor($0! -> ) key_cache" parses correctly. + m = re.match( + r"^(.+)\s+([A-Za-z_]\w*(?:=.*)?)$", raw_arg.strip() + ) if not m: raise ValueError(f"Invalid arg: {raw_arg}") - arg_type, rest = m.groups() - m2 = re.match(r"^([A-Za-z_][A-Za-z0-9_]*)(?:=(.*))?$", rest) + arg_type = m.group(1).strip() + name_default = m.group(2) + m2 = re.match(r"^([A-Za-z_]\w*)(?:=(.*))?$", name_default) if not m2: - raise ValueError(f"Invalid arg name/default: {rest}") - arg_name, default = m2.group(1), ( - m2.group(2).strip() if m2.group(2) else None - ) - return arg_type.strip(), arg_name.strip(), default + raise ValueError(f"Invalid arg name/default: {name_default}") + arg_name = m2.group(1) + default = m2.group(2).strip() if m2.group(2) else None + return arg_type, arg_name, default args = [] for item in [x.strip() for x in pos_part.split(",") if x.strip()]: From a2bd6226555aad07ec0db046150b64bbd1097ac7 Mon Sep 17 00:00:00 2001 From: Jassani Date: Tue, 28 Apr 2026 14:36:06 -0400 Subject: [PATCH 3/4] Add custom initializers, auto-import, and updated docs for EventReplay Made-with: Cursor --- TraceLens/EventReplay/__init__.py | 18 + TraceLens/EventReplay/custom_inits.py | 389 +++++++++++++++++++++ TraceLens/EventReplay/event_replay.py | 284 +++++++++------- TraceLens/EventReplay/utils.py | 82 +++-- docs/EventReplay.md | 473 ++++++++++++++++++++++---- 5 files changed, 1048 insertions(+), 198 deletions(-) create mode 100644 TraceLens/EventReplay/custom_inits.py diff --git a/TraceLens/EventReplay/__init__.py b/TraceLens/EventReplay/__init__.py index fb6cb2734..4f79b0b11 100644 --- a/TraceLens/EventReplay/__init__.py +++ b/TraceLens/EventReplay/__init__.py @@ -3,3 +3,21 @@ # # See LICENSE for license information. ############################################################################### + +from .event_replay import EventReplayer +from .custom_inits import ( + CustomInit, + PagedAttentionInit, + MoeRoutingInit, + extract_batch_context, +) +from .utils import benchmark_func + +__all__ = [ + "EventReplayer", + "CustomInit", + "PagedAttentionInit", + "MoeRoutingInit", + "extract_batch_context", + "benchmark_func", +] diff --git a/TraceLens/EventReplay/custom_inits.py b/TraceLens/EventReplay/custom_inits.py new file mode 100644 index 000000000..9ded92540 --- /dev/null +++ b/TraceLens/EventReplay/custom_inits.py @@ -0,0 +1,389 @@ +############################################################################### +# Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +"""Custom initializers for EventReplayer. + +Operations captured by the PyTorch profiler have zeroed-out metadata tensors +(block tables, routing tensors, etc.) because the profiler records shapes and +dtypes but not tensor values. Custom initializers fill these tensors with +realistic content so the GPU kernel exercises real memory-access and compute +patterns during replay benchmarking. + +To add a custom initializer for a new op family: + 1. Subclass ``CustomInit`` + 2. Set ``op_patterns`` to one or more substrings that match the op name + 3. Implement ``initialize()`` — mutate replayer.args / replayer.kwargs in-place + 4. Return a one-line summary string (printed by EventReplayer) + 5. Register with ``EventReplayer.register_custom_init(YourInit())`` + or add it to the ``_custom_init_registry`` default list. +""" + +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + pass # EventReplayer imported at runtime to avoid circular dep + +# -- Batch context extraction from vLLM profiler annotations --------------- + +_BATCH_ANNO_RE = re.compile( + r"execute_context_(\d+)\((\d+)\)_generation_(\d+)\((\d+)\)" +) + + +def extract_batch_context(analyzer: Any) -> int: + """Parse vLLM ``user_annotation`` events and attach batch context to ops. + + vLLM annotates each ``execute_model`` step with a ``user_annotation`` + event of the form ``execute_context_N(T)_generation_N(T)`` where + *N* = number of sequences and *T* = total query tokens for that phase. + + This function: + 1. Collects all such annotations with their ``[ts, ts+dur]`` ranges. + 2. For every ``paged_attention`` cpu_op event, finds the enclosing + annotation by timestamp and attaches a ``batch_context`` dict:: + + event["batch_context"] = { + "n_prefill": 2, + "prefill_tokens": 18, + "n_decode": 2, + "decode_tokens": 2, + } + + Args: + analyzer: A ``TreePerfAnalyzer`` (or any object whose ``.tree.events`` + yields the trace event list). + + Returns: + Number of paged_attention events that were annotated. + """ + annotations = [] + for e in analyzer.tree.events: + cat = e.get("cat") or "" + if cat != "user_annotation": + continue + m = _BATCH_ANNO_RE.search(e.get("name", "")) + if not m: + continue + ts = e.get("ts", 0) + dur = e.get("dur", 0) + annotations.append({ + "ts": ts, + "end": ts + dur, + "n_prefill": int(m.group(1)), + "prefill_tokens": int(m.group(2)), + "n_decode": int(m.group(3)), + "decode_tokens": int(m.group(4)), + }) + + if not annotations: + return 0 + + annotations.sort(key=lambda a: a["ts"]) + + annotated = 0 + for e in analyzer.tree.events: + name = e.get("name", "") + if "paged_attention" not in name: + continue + if not e.get("args", {}).get("Input Dims"): + continue + ts = e.get("ts", 0) + for a in annotations: + if a["ts"] <= ts <= a["end"]: + e["batch_context"] = { + "n_prefill": a["n_prefill"], + "prefill_tokens": a["prefill_tokens"], + "n_decode": a["n_decode"], + "decode_tokens": a["decode_tokens"], + } + annotated += 1 + break + + return annotated + + +class CustomInit(ABC): + """Base class for tensor initializers applied before replay.""" + + op_patterns: List[str] = [] + + def applies_to(self, replayer: Any) -> bool: + op_name = replayer.event.get("name", "") + return any(pat in op_name for pat in self.op_patterns) + + @abstractmethod + def initialize(self, replayer: Any, **kwargs) -> Optional[str]: + """Mutate replayer.args/kwargs in-place. Return a summary string.""" + ... + + +class PagedAttentionInit(CustomInit): + """Initialize block_tables, seq_lens, and query_start_loc for paged attention. + + When ``batch_context`` is present on the event (attached by + :func:`extract_batch_context`), uses the exact prefill/decode split from + vLLM's profiler annotations. Otherwise falls back to heuristics: + - query_tokens == num_seqs → decode (1 token/seq) + - query_tokens > num_seqs → prefill (tokens distributed uniformly) + + In all cases: + - ``seq_lens`` set to ``max_seq_len`` for every sequence. + - Block table entries drawn from a random permutation of the pool. + """ + + op_patterns = ["paged_attention"] + + def initialize(self, replayer: Any, **kwargs) -> Optional[str]: + try: + import numpy as np + except ImportError: + return "[custom init] PagedAttentionInit skipped — numpy not available" + + args = replayer.args + op_name = replayer.event.get("name", "") + + ir = replayer.event_replay_IR + arg_names = [a["arg_name"] for a in ir["list_pos_args"]] + def _by_name_or_pos(name, pos): + if name in arg_names: + return args[arg_names.index(name)] + return args[pos] + + block_tables = _by_name_or_pos("block_tables", 9) + seq_lens = _by_name_or_pos("seq_lens", 10) + key_cache = _by_name_or_pos("key_cache", 5) + block_size = int(_by_name_or_pos("block_size", 12)) + max_seq_len = int(_by_name_or_pos("max_seq_len", 13)) + + num_seqs = block_tables.shape[0] + max_blocks_per_seq = block_tables.shape[1] + num_blocks_total = key_cache.shape[0] + + query = _by_name_or_pos("query", 4) + num_query_tokens = query.shape[0] + + rng = np.random.default_rng(42) + + # -- Determine per-sequence query token counts ------------------------- + batch_ctx = replayer.event.get("batch_context") + if batch_ctx is not None: + n_pf = batch_ctx["n_prefill"] + pf_tok = batch_ctx["prefill_tokens"] + n_dec = batch_ctx["n_decode"] + dec_tok = batch_ctx["decode_tokens"] + + per_seq_queries = [] + if n_pf > 0: + base_pf = pf_tok // n_pf + rem_pf = pf_tok % n_pf + for s in range(n_pf): + per_seq_queries.append(base_pf + (1 if s < rem_pf else 0)) + for _ in range(n_dec): + per_seq_queries.append(1) + + if len(per_seq_queries) != num_seqs: + per_seq_queries = per_seq_queries[:num_seqs] + while len(per_seq_queries) < num_seqs: + per_seq_queries.append(1) + + phase = ("mixed" if n_pf > 0 and n_dec > 0 + else "prefill" if n_pf > 0 else "decode") + source = "annotation" + else: + tokens_per_seq = num_query_tokens / num_seqs if num_seqs else 1 + if tokens_per_seq > 1: + base_q = num_query_tokens // num_seqs + rem_q = num_query_tokens % num_seqs + per_seq_queries = [base_q + (1 if s < rem_q else 0) + for s in range(num_seqs)] + phase = "prefill" + else: + per_seq_queries = [1] * num_seqs + phase = "decode" + source = "heuristic" + + # -- seq_lens: max_seq_len for every sequence -------------------------- + lengths = np.full(num_seqs, max_seq_len, dtype=np.int32) + + # -- block_tables: permutation of physical block pool ------------------ + bt = np.zeros((num_seqs, max_blocks_per_seq), dtype=np.int32) + all_block_ids = rng.permutation(num_blocks_total) + block_cursor = 0 + for s in range(num_seqs): + blocks_needed = (int(lengths[s]) + block_size - 1) // block_size + blocks_needed = min(blocks_needed, max_blocks_per_seq) + for b in range(blocks_needed): + bt[s, b] = all_block_ids[block_cursor % num_blocks_total] + block_cursor += 1 + + import torch + + block_tables.copy_(torch.from_numpy(bt).to(block_tables.device)) + seq_lens.copy_(torch.from_numpy(lengths).to(seq_lens.device)) + + # -- query_start_loc: CSR indptr encoding per-seq query counts --------- + qsl = _by_name_or_pos("query_start_loc", 11) + if (qsl is not None + and hasattr(qsl, "shape") + and qsl.numel() > 0): + qloc = np.zeros(num_seqs + 1, dtype=np.int32) + for s in range(num_seqs): + qloc[s + 1] = qloc[s] + per_seq_queries[s] + qloc = qloc[: qsl.numel()] + qsl.copy_(torch.from_numpy(qloc).to(qsl.device)) + + ctx_str = "" + if batch_ctx is not None: + ctx_str = (f" Annotation: {batch_ctx['n_prefill']} prefill " + f"({batch_ctx['prefill_tokens']} tok) + " + f"{batch_ctx['n_decode']} decode " + f"({batch_ctx['decode_tokens']} tok).") + + return ( + f"[custom init] {op_name} — paged attention metadata: " + f"phase={phase} ({source}), num_seqs={num_seqs}, " + f"max_seq_len={max_seq_len}, block_size={block_size}, " + f"num_blocks={num_blocks_total}, " + f"max_blocks_per_seq={max_blocks_per_seq}.{ctx_str}" + ) + + +class MoeRoutingInit(CustomInit): + """Initialize MoE routing tensors (sorted_token_ids, sorted_expert_ids, + num_valid_ids) so the CK kernel processes real token-to-expert assignments + instead of short-circuiting on num_valid_ids=0. + + Supported kwargs: + moe_distribution: "uniform" (default) or "zipf" + moe_zipf_s: Zipf exponent (default 1.2), only used with "zipf" + + Arg layout for aiter::ck_moe_stage1/2: + [0] hidden_states [M, K] bf16 + [1] w1 [E, N, K] bf16 + [2] w2 [E, K2, N2] bf16 + [3] sorted_token_ids [padded] int32 <- init + [4] sorted_expert_ids [blocks+1] int32 <- init + [5] num_valid_ids [2] int32 <- init + [6] output [M, top_k, N2] bf16 + [7] top_k (scalar) + ... + [11] block_m (scalar) + """ + + op_patterns = ["ck_moe_stage1", "ck_moe_stage2"] + + def initialize(self, replayer: Any, **kwargs) -> Optional[str]: + try: + import numpy as np + except ImportError: + return "[custom init] MoeRoutingInit skipped — numpy not available" + + distribution = kwargs.get("moe_distribution", "uniform") + zipf_s = kwargs.get("moe_zipf_s", 1.2) + + args = replayer.args + op_name = replayer.event.get("name", "") + + # Locate args by name from the IR when available, fall back to position + ir = replayer.event_replay_IR + arg_names = [a["arg_name"] for a in ir["list_pos_args"]] + def _by_name_or_pos(name, pos): + if name in arg_names: + return args[arg_names.index(name)] + return args[pos] + + sorted_token_ids = _by_name_or_pos("sorted_token_ids", 3) + sorted_expert_ids = _by_name_or_pos("sorted_expert_ids", 4) + num_valid_ids = _by_name_or_pos("num_valid_ids", 5) + top_k = int(_by_name_or_pos("topk", 7)) + block_m_val = _by_name_or_pos("block_m", 11) + block_m = int(block_m_val) if block_m_val is not None else 32 + + hidden = _by_name_or_pos("hidden_states", 0) + M = hidden.shape[0] + w1 = _by_name_or_pos("w1", 1) + E = w1.shape[0] + num_tokens = M * top_k + padded_total = sorted_token_ids.shape[0] + num_blocks = sorted_expert_ids.shape[0] + + rng = np.random.default_rng(42) + + if distribution == "zipf": + ranks = np.arange(1, E + 1, dtype=np.float64) + weights = 1.0 / np.power(ranks, zipf_s) + probs = weights / weights.sum() + expert_assignments = rng.choice(E, size=num_tokens, p=probs) + else: + expert_assignments = rng.integers(0, E, size=num_tokens) + + token_ids_list: list = [] + expert_ids_list: list = [] + for expert_id in range(E): + tokens_for_expert = np.where(expert_assignments == expert_id)[0] + count = len(tokens_for_expert) + if count == 0: + continue + padded_count = ((count + block_m - 1) // block_m) * block_m + n_blocks_for_expert = padded_count // block_m + padded_tokens = np.full(padded_count, num_tokens, dtype=np.int32) + padded_tokens[:count] = tokens_for_expert // top_k + token_ids_list.append(padded_tokens) + expert_ids_list.extend([expert_id] * n_blocks_for_expert) + + all_token_ids = ( + np.concatenate(token_ids_list) + if token_ids_list + else np.array([], dtype=np.int32) + ) + + if len(all_token_ids) < padded_total: + padding = np.full( + padded_total - len(all_token_ids), num_tokens, dtype=np.int32 + ) + all_token_ids = np.concatenate([all_token_ids, padding]) + else: + all_token_ids = all_token_ids[:padded_total] + + all_expert_ids = np.array(expert_ids_list, dtype=np.int32) + if len(all_expert_ids) < num_blocks: + padding = np.zeros(num_blocks - len(all_expert_ids), dtype=np.int32) + all_expert_ids = np.concatenate([all_expert_ids, padding]) + else: + all_expert_ids = all_expert_ids[:num_blocks] + + import torch + + sorted_token_ids.copy_( + torch.from_numpy(all_token_ids).to(sorted_token_ids.device) + ) + sorted_expert_ids.copy_( + torch.from_numpy(all_expert_ids).to(sorted_expert_ids.device) + ) + + valid_count = min( + len(np.concatenate(token_ids_list)) if token_ids_list else 0, + padded_total, + ) + if num_valid_ids.numel() >= 1: + num_valid_ids[0] = valid_count + if num_valid_ids.numel() >= 2: + num_valid_ids[1] = valid_count + + dist_label = f"zipf(s={zipf_s})" if distribution == "zipf" else "uniform" + experts_active = len(set(expert_assignments.tolist())) + return ( + f"[custom init] {op_name} — initialized MoE routing: " + f"dist={dist_label}, M={M}, top_k={top_k}, E={E}, block_m={block_m}, " + f"num_tokens={num_tokens}, active_experts={experts_active}/{E}, " + f"valid_ids={valid_count}/{padded_total}, " + f"blocks={len(expert_ids_list)}/{num_blocks}. " + f"Assumptions: {dist_label} expert distribution, deterministic seed." + ) diff --git a/TraceLens/EventReplay/event_replay.py b/TraceLens/EventReplay/event_replay.py index 685190b1a..bc26a13b6 100644 --- a/TraceLens/EventReplay/event_replay.py +++ b/TraceLens/EventReplay/event_replay.py @@ -20,28 +20,56 @@ build_tensor, list_profile_tensor_types, ) +from .custom_inits import CustomInit, PagedAttentionInit, MoeRoutingInit logger = logging.getLogger(__name__) -# ── Known defaults for string arguments the profiler drops ────────────── -# The PyTorch profiler records `str` arguments as empty strings. When we -# know the only sensible default we fill it in automatically and warn. -# Key = argument name, Value = default string value. +# -- Known defaults for string arguments the profiler drops ---------------- _STR_ARG_DEFAULTS: Dict[str, str] = { "kv_cache_dtype": "auto", } -# ── Op-name aliases ───────────────────────────────────────────────────── -# Some frameworks profile an op under one namespace but register the -# actual callable under a different one. -# Key = name as it appears in the trace, Value = list of candidates to try. -# NOTE: aiter::paged_attention_v1/v2 are NOT aliasable — the aiter JIT -# wrapper records a different arg layout than the underlying _C:: / _rocm_C:: -# schemas, so arg mapping would fail even if resolution succeeds. +# -- Op-name aliases ------------------------------------------------------- _OP_NAME_ALIASES: Dict[str, List[str]] = { "_rocm_C::wvSplitK": ["_rocm_C::wvSpltK"], } +# -- Auto-import registry -------------------------------------------------- +_NAMESPACE_IMPORTS: Dict[str, List[str]] = { + "aiter": ["aiter"], + "_rocm_C": ["vllm._rocm_C"], + "_C": ["vllm._C"], + "_C_cache_ops": ["vllm._C"], + "vllm": ["vllm._C", "vllm._rocm_C"], +} + +_auto_import_attempted: set = set() + + +def _try_auto_import(op_name: str, verbose: bool = False) -> bool: + """Try to import the library that registers a custom op's schema. + + Returns True if at least one new module was successfully imported. + """ + namespace = op_name.split("::")[0] if "::" in op_name else "" + if not namespace or namespace == "aten": + return False + if namespace in _auto_import_attempted: + return False + _auto_import_attempted.add(namespace) + + modules = _NAMESPACE_IMPORTS.get(namespace, [namespace]) + imported_any = False + for mod in modules: + try: + __import__(mod) + print(f"[EventReplayer] Auto-imported '{mod}' for op '{op_name}'") + imported_any = True + except ImportError: + if verbose: + print(f"[EventReplayer] Could not import '{mod}' for namespace '{namespace}'") + return imported_any + def _try_resolve(op_name: str): """Attempt JIT + torch.ops + module resolution for a single op name. @@ -50,7 +78,7 @@ def _try_resolve(op_name: str): torch = _get_torch_or_raise() import importlib - # 1. JIT registry — preserves dispatch behaviour for aten ops. + # 1. JIT registry try: func, _ = torch._C._jit_get_operation(op_name) if func is not None: @@ -61,16 +89,14 @@ def _try_resolve(op_name: str): if "::" in op_name: ns, func_name = op_name.split("::", 1) - # 2. torch.ops namespace — custom ops registered via torch.library. + # 2. torch.ops namespace ns_obj = getattr(torch.ops, ns, None) if ns_obj is not None: func_obj = getattr(ns_obj, func_name, None) if callable(func_obj): return func_obj, "torch.ops" - # 3. Direct Python module lookup — handles JIT-compiled ops (e.g. - # aiter) that exist as Python callables but aren't registered in - # the torch op registry. + # 3. Direct Python module lookup try: mod = importlib.import_module(ns) func_obj = getattr(mod, func_name, None) @@ -82,82 +108,92 @@ def _try_resolve(op_name: str): return None, None -def _resolve_op_func(op_name: str): - """ - Resolve an op name (e.g. 'aten::mm', 'vllm::rocm_unquantized_gemm') to a - callable. Tries multiple resolution strategies: - - 1. JIT registry (preserves original dispatch behaviour). - 2. torch.ops namespace (custom ops registered via torch.library / pybind). - 3. Known aliases from _OP_NAME_ALIASES (handles trace-name mismatches). +def _resolve_op_func(op_name: str, verbose: bool = False): + """Resolve an op name to a callable, with auto-import on failure. + Tries: JIT registry -> torch.ops -> module import -> auto-import -> aliases. Returns (func, source_str, resolved_name) or raises RuntimeError. """ - func, source = _try_resolve(op_name) - if func is not None: - return func, source, op_name - - for alias in _OP_NAME_ALIASES.get(op_name, []): - func, source = _try_resolve(alias) + for attempt in range(2): + func, source = _try_resolve(op_name) if func is not None: - logger.warning( - "Op '%s' resolved via alias '%s' (%s). " - "The trace recorded a different namespace than the runtime registration.", - op_name, alias, source, - ) - return func, source, alias + return func, source, op_name + + for alias in _OP_NAME_ALIASES.get(op_name, []): + func, source = _try_resolve(alias) + if func is not None: + logger.warning( + "Op '%s' resolved via alias '%s' (%s).", + op_name, alias, source, + ) + return func, source, alias + + if attempt == 0 and _try_auto_import(op_name, verbose): + continue + break + + ns = op_name.split("::")[0] if "::" in op_name else "" + hint = "" + if ns and ns != "aten": + known = _NAMESPACE_IMPORTS.get(ns) + if known: + hint = f" Try: {', '.join(f'import {m}' for m in known)}" + else: + hint = (f" The namespace '{ns}' is not in the auto-import registry." + f" Use EventReplayer.register_namespace('{ns}', ['your.module'])" + f" to add it.") raise RuntimeError( - f"Cannot resolve op '{op_name}'. Ensure the library that defines it " - f"is imported (e.g. 'import vllm', 'import aiter')." + f"Cannot resolve op '{op_name}'.{hint} " + f"Ensure the library that defines it is imported." ) def _search_schemas(op_name: str, verbose: bool = False): - """ - Return all registered FunctionSchemas for *op_name*. - - Searches both the JIT schema registry and the torch.ops namespace, which - covers aten ops, custom C++ ops, and Python-defined torch.library ops. + """Return all registered FunctionSchemas for *op_name*, + with auto-import on empty results. """ torch = _get_torch_or_raise() - schemas: list = [] - seen_strs: set = set() - - # JIT registry - for s in torch._C._jit_get_all_schemas(): - if s.name == op_name: - s_str = str(s) - if s_str not in seen_strs: - schemas.append(s) - seen_strs.add(s_str) - - # torch.ops namespace (catches custom ops not in the JIT list) - if "::" in op_name: - ns, func_name = op_name.split("::", 1) - ns_obj = getattr(torch.ops, ns, None) - if ns_obj is not None: - op_obj = getattr(ns_obj, func_name, None) - if op_obj is not None: - # OpOverloadPacket exposes overloads - try: - for overload_name in op_obj.overloads(): - overload = getattr(op_obj, overload_name) - s = overload._schema - s_str = str(s) - if s_str not in seen_strs: - schemas.append(s) - seen_strs.add(s_str) - except Exception: - # Fallback: try .default directly + + for attempt in range(2): + schemas: list = [] + seen_strs: set = set() + + for s in torch._C._jit_get_all_schemas(): + if s.name == op_name: + s_str = str(s) + if s_str not in seen_strs: + schemas.append(s) + seen_strs.add(s_str) + + if "::" in op_name: + ns, func_name = op_name.split("::", 1) + ns_obj = getattr(torch.ops, ns, None) + if ns_obj is not None: + op_obj = getattr(ns_obj, func_name, None) + if op_obj is not None: try: - s = op_obj.default._schema - s_str = str(s) - if s_str not in seen_strs: - schemas.append(s) - seen_strs.add(s_str) + for overload_name in op_obj.overloads(): + overload = getattr(op_obj, overload_name) + s = overload._schema + s_str = str(s) + if s_str not in seen_strs: + schemas.append(s) + seen_strs.add(s_str) except Exception: - pass + try: + s = op_obj.default._schema + s_str = str(s) + if s_str not in seen_strs: + schemas.append(s) + seen_strs.add(s_str) + except Exception: + pass + + if schemas or attempt > 0: + break + if not _try_auto_import(op_name, verbose): + break if verbose: print(f"Found {len(schemas)} schemas for {op_name}:") @@ -169,12 +205,34 @@ def _search_schemas(op_name: str, verbose: bool = False): class EventReplayer: + _custom_init_registry: List[CustomInit] = [ + PagedAttentionInit(), + MoeRoutingInit(), + ] + + @classmethod + def register_custom_init(cls, init: CustomInit): + """Add a custom initializer to the registry.""" + cls._custom_init_registry.append(init) + + @classmethod + def register_namespace(cls, namespace: str, modules: List[str]): + """Register a namespace-to-module mapping for auto-import.""" + _NAMESPACE_IMPORTS[namespace] = modules + + @classmethod + def list_custom_inits(cls) -> List[Tuple[str, List[str]]]: + """List all registered custom initializers and their op patterns.""" + return [(type(i).__name__, i.op_patterns) for i in cls._custom_init_registry] + def __init__( self, event: Dict[str, Any], device: str = "cuda", lazy: bool = False, verbose: bool = False, + auto_init: bool = True, + init_kwargs: Optional[Dict[str, Any]] = None, ): """ Initialize the EventReplayer with the event data and device type. @@ -182,12 +240,21 @@ def __init__( Args: event (Dict[str, Any]): From the pytorch profile json data['traceEvents'] device (str): The device type ('cuda' or 'cpu'). + lazy (bool): If True, defer tensor creation until replay(). verbose (bool): Flag to enable verbose output. + auto_init (bool): If True, automatically apply custom initializers + for ops that need realistic tensor content (e.g., paged attention + block tables, MoE routing tensors). + init_kwargs (Dict[str, Any]): Parameters passed to custom initializers + (e.g., {"moe_distribution": "zipf", "moe_zipf_s": 1.5}). """ self.event = event self.device = device self.lazy = lazy self.verbose = verbose + self._auto_init = auto_init + self._init_kwargs = init_kwargs or {} + self._inits_applied = False self._func = None self._setup() @@ -199,7 +266,7 @@ def _setup(self): print(f"Preparing {self.event['name']} event for replay") self._func, self._func_source, self._resolved_name = _resolve_op_func( - self.event["name"] + self.event["name"], verbose=self.verbose ) if self.verbose: print(f"Resolved op via {self._func_source}") @@ -246,8 +313,25 @@ def replay(self): else: args, kwargs = self.args, self.kwargs + if not self._inits_applied and self._auto_init: + self._apply_custom_inits() + self._func(*args, **kwargs) + def _apply_custom_inits(self): + """Run all applicable custom initializers on this replayer's tensors.""" + for custom_init in self._custom_init_registry: + if custom_init.applies_to(self): + try: + summary = custom_init.initialize(self, **self._init_kwargs) + if summary: + print(summary) + except Exception as e: + warnings.warn( + f"[custom init] {type(custom_init).__name__} failed: {e}" + ) + self._inits_applied = True + @staticmethod def _search_schema( event: Dict[str, Any], @@ -297,7 +381,6 @@ def _is_schema_match( print(f"\tProfiled type: {profiled_type}") is_match = True - # Optional types: schema ends with '?' => profiled type can be blank if schema_type.endswith("?"): schema_type = schema_type[:-1] if profiled_type == "": @@ -312,7 +395,7 @@ def _is_schema_match( is_match = False elif schema_type == "bool": profiled_value = event["args"]["Concrete Inputs"][idx] - if profiled_value.lower() not in ["true", "false"]: + if profiled_value.lower() not in ("true", "false"): is_match = False elif schema_type in ("int", "SymInt"): if profiled_type != "Scalar": @@ -383,27 +466,17 @@ def _is_tensor_schema_type(schema_type: str) -> bool: """Check if a schema type string represents a Tensor argument.""" if schema_type in ("Tensor", "Tensor?"): return True - # Handles annotated variants like Tensor(a!), Tensor(a), Tensor(b!) if schema_type.startswith("Tensor("): return True return False @staticmethod def _should_skip_tensor_init(evt_name: str, arg_name: str, arg_idx: int) -> bool: - """ - Determine whether a tensor argument is an output-only buffer that - does not need random initialization. - - Generalizes the old aten::fill_ / aten::copy_ special-cases to - any in-place or out-of-place output tensor. - """ - # In-place ops (name ends with '_'): first tensor is the mutated output + """Determine whether a tensor argument is an output-only buffer.""" if evt_name.endswith("_") and arg_idx == 0: return True - # Explicit 'out' arguments in .out variants if arg_name == "out": return True - # aten::copy_ destination if evt_name == "aten::copy_" and arg_name != "src": return True return False @@ -412,9 +485,7 @@ def _should_skip_tensor_init(evt_name: str, arg_name: str, arg_idx: int) -> bool def _get_event_replay_IR( event: Dict[str, Any], schema: "torch._C.FunctionSchema", verbose: bool = False ) -> Dict[str, Any]: - """ - Get the event replay IR from the event and schema. - """ + """Get the event replay IR from the event and schema.""" evt_name = event["name"] op_name, pos_args_schema, kwargs_schema, return_type = ( EventReplayer.parse_schema_string(schema) @@ -457,7 +528,6 @@ def _get_event_replay_IR( if EventReplayer._should_skip_tensor_init(evt_name, arg_name, idx): init = None profiled_dtype = event["args"]["Input type"][idx] - # Non-floating-point tensors cannot use 'normal' init if profiled_dtype in ("long", "long int", "int", "bool", "unsigned char"): init = "zeros" if init == "normal" else init value = TensorCfg( @@ -534,17 +604,7 @@ def _get_event_replay_IR_schemaless( verbose: bool = False, resolved_name: Optional[str] = None, ) -> Dict[str, Any]: - """ - Build a replay IR without a schema by inferring types directly from the - profiled data. All arguments are treated as positional. - - Heuristics: - - If Input type is a known tensor dtype -> TensorCfg - - If Input type is 'Scalar' and Concrete Inputs looks like int -> int - - If Input type is 'Scalar' and Concrete Inputs looks like float -> float - - If Input type is 'Scalar' and Concrete Inputs is true/false -> bool - - If Input type is '' and Concrete Inputs is '' -> check _STR_ARG_DEFAULTS - """ + """Build a replay IR without a schema by inferring types from profile data.""" evt_name = event["name"] schema_arg_names = EventReplayer._get_schema_arg_names( resolved_name or evt_name @@ -592,7 +652,6 @@ def _get_event_replay_IR_schemaless( value = concrete arg_type = "str" elif profiled_type == "" and concrete == "": - # Likely a dropped str arg — check known defaults hint_name = ( schema_arg_names[idx] if idx < len(schema_arg_names) else None ) @@ -637,9 +696,7 @@ def _get_event_replay_IR_schemaless( def _get_args_kwargs( event_replay_IR: Dict[str, Any], device: str = "cuda" ) -> tuple[List["torch.Tensor"], Dict[str, Any]]: - """ - Get the arguments and keyword arguments from the event replay IR. - """ + """Get the arguments and keyword arguments from the event replay IR.""" pos_args = [] for arg in event_replay_IR["list_pos_args"]: value = arg["value"] @@ -670,7 +727,6 @@ def parse_schema_string( kwarg_part = parts[1].lstrip(",").strip() if len(parts) > 1 else "" def _parse_arg(raw_arg: str) -> Tuple[str, str, Optional[str], bool]: - # Match type (may contain spaces, e.g. "Tensor($0! -> )") then name[=default]. # Greedy (.+) consumes everything up to the last whitespace before # a valid identifier, so "Tensor($0! -> ) key_cache" parses correctly. m = re.match( @@ -705,10 +761,6 @@ def _parse_arg(raw_arg: str) -> Tuple[str, str, Optional[str], bool]: def get_repro_info(self) -> Dict[str, Any]: """ Extracts the minimal, serializable information needed to reproduce the event call. - - Returns: - Dict[str, Any]: A dictionary containing the operator name and the replay IR. - Suitable for JSON serialization using the custom encoder. """ dict_repro_info = {} dict_repro_info["op_name"] = self.event["name"] diff --git a/TraceLens/EventReplay/utils.py b/TraceLens/EventReplay/utils.py index a4d92ebca..9ea2f1764 100644 --- a/TraceLens/EventReplay/utils.py +++ b/TraceLens/EventReplay/utils.py @@ -124,32 +124,74 @@ def summarize_tensor(tensor: "torch.Tensor") -> str: return f"Tensor(shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}, strides={tensor.stride()})" -def benchmark_func(func, device, warmup=50, avg_steps=100): - """ - Benchmark a function with warmup and average steps. - Disclaimer: This method would be inaccurate for very short ops. +_L2_FLUSH_BUFFER = None +_L2_FLUSH_SIZE = 256 * 1024 * 1024 # 256 MB -- larger than any GPU's L2 + + +def _flush_l2(device: str): + """Force-evict GPU L2 cache by reading a large buffer.""" + global _L2_FLUSH_BUFFER + torch = _get_torch_or_raise() + if _L2_FLUSH_BUFFER is None or str(_L2_FLUSH_BUFFER.device) != device: + _L2_FLUSH_BUFFER = torch.empty( + _L2_FLUSH_SIZE // 4, dtype=torch.float32, device=device + ) + _L2_FLUSH_BUFFER.sum() + + +def benchmark_func( + func, + device, + warmup=50, + avg_steps=100, + flush_l2=False, +): + """Benchmark a function with warmup and per-iteration CUDA event timing. + Args: - func (callable): The function to benchmark. - warmup (int): Number of warmup iterations. - avg_steps (int): Number of iterations to average over. + func: Callable to benchmark. + device: CUDA device string. + warmup: Number of warmup iterations. + avg_steps: Number of measured iterations. + flush_l2: If True, flush the GPU L2 cache before each measured iteration + to simulate cold-cache conditions (more representative of real + inference where other kernels pollute L2 between invocations). + Returns: - float: Average time taken per iteration in microseconds. + dict with keys: median_us, mean_us, std_us, min_us, max_us, + all_us (list of per-iteration timings in microseconds). """ torch = _get_torch_or_raise() - # Warmup phase + for _ in range(warmup): func() - - # Benchmarking phase torch.cuda.synchronize(device) - start_time = time.time() + + timings_ms: List[float] = [] for _ in range(avg_steps): + if flush_l2: + _flush_l2(device) + torch.cuda.synchronize(device) + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() func() - torch.cuda.synchronize(device) - end_time = time.time() - - elapsed_time = end_time - start_time - avg_time_sec = elapsed_time / avg_steps - avg_time_us = avg_time_sec * 1e6 - - return avg_time_us + end_evt.record() + torch.cuda.synchronize(device) + timings_ms.append(start_evt.elapsed_time(end_evt)) + + timings_us = [t * 1000.0 for t in timings_ms] + sorted_us = sorted(timings_us) + n = len(sorted_us) + median = (sorted_us[n // 2] + sorted_us[(n - 1) // 2]) / 2.0 + mean = sum(timings_us) / n + variance = sum((t - mean) ** 2 for t in timings_us) / n + std = variance ** 0.5 + return { + "median_us": median, + "mean_us": mean, + "std_us": std, + "min_us": sorted_us[0], + "max_us": sorted_us[-1], + "all_us": timings_us, + } diff --git a/docs/EventReplay.md b/docs/EventReplay.md index bf6f4fb79..1602e2507 100644 --- a/docs/EventReplay.md +++ b/docs/EventReplay.md @@ -6,49 +6,167 @@ See LICENSE for license information. # Event Replay -Optimizing GPU performance in deep learning requires isolating and benchmarking individual operations to identify bottlenecks. However, reproducing operations directly from complex model code or large profiles can be cumbersome. - -Event Replay is a Python-based tool within TraceLens that extracts and replays almost arbitrary PyTorch operations using minimal, portable Intermediate Representation (IR). It enables users to easily reproduce, analyze, and benchmark specific operators independently from the original model execution, streamlining performance optimization workflows. - ---- - -## Key Features - -- **Generic Operator Replay**: Reconstructs and benchmarks any PyTorch operator from profile data, including convolutions, GEMMs, reductions, element-wise operations, and more. -- **Minimalistic IR**: Extracts essential operator attributes (tensor shapes, strides, dtypes, and other arguments) into a lightweight, portable JSON-based IR. -- **Portable Artifacts**: Enables sharing standalone artifacts (JSON IR and scripts) with teammates or upstream repositories without requiring access to the model or TraceLens. +Optimizing GPU performance in deep learning requires isolating and benchmarking +individual operations to identify bottlenecks. However, reproducing operations +directly from complex model code or large profiles can be cumbersome — the +profiler captures tensor dimensions and types but strips argument names, semantic +context, and the relationship between arguments. + +Event Replay is a Python-based tool within TraceLens that extracts and replays +almost arbitrary PyTorch operations using minimal, portable Intermediate +Representation (IR). It translates the opaque profiler output into human-readable +JSON — named arguments, tensor shapes, dtypes, strides, and scalar values — +making profiler traces interpretable, shareable, and replayable on any machine +with the right op libraries installed. + +**Contents:** +[Quick Start](#quick-start) | +[Batch Replay](#batch-replay) | +[Architecture](#architecture) | +[Custom Initializers](#custom-initializers) | +[Auto-Import](#auto-import-for-custom-ops) | +[Batch Context](#batch-context-vllm-traces) | +[Validation](#validation) | +[Limitations](#known-limitations) | +[Use Cases](#use-cases) --- - ## Quick Start -### Example: Replay a Single Event +### Replay a Single Event ```python -from TraceLens import TreePerfAnalyzer, EventReplayer +from TraceLens import TreePerfAnalyzer +from TraceLens.EventReplay import EventReplayer -# Load profile and get event perf_analyzer = TreePerfAnalyzer.from_file('/path/to/profile.json') -uid = 12345 # Replace with actual UID of interest +uid = 12345 event = perf_analyzer.tree.get_UID2event(uid) -# Initialize and replay replayer = EventReplayer(event, device='cuda') replayer.replay() ``` +### Inspect the IR (without replaying) + +Even without a GPU, the extracted IR is valuable for understanding what a +profiled op actually does. The profiler's native format stores arguments as +unlabeled dimension lists with no argument names or semantic context: + +```json +{ + "cat": "cpu_op", "name": "aten::mm", + "args": { + "Input Dims": [[20, 2048], [2048, 11264]], + "Input type": ["BFloat16", "BFloat16"] + } +} +``` + +What are these two tensors? Which is the activation, which is the weight? Is +`mat2` transposed? You can't tell from `Input Dims` alone. Event Replay resolves +the op's registered schema and produces named, typed JSON: + +```python +replayer = EventReplayer(event, lazy=True) +ir = replayer.get_repro_info() +``` + +**After — Event Replay IR for `aten::mm`:** + +```json +{ + "op_name": "aten::mm", + "replay_ir": { + "list_pos_args": [ + { + "arg_name": "self", + "arg_type": "Tensor", + "value": { "shape": [20, 2048], "dtype": "c10::BFloat16", + "strides": [2048, 1], "init": "normal" } + }, + { + "arg_name": "mat2", + "arg_type": "Tensor", + "value": { "shape": [2048, 11264], "dtype": "c10::BFloat16", + "strides": [1, 2048], "init": "normal" } + } + ] + } +} +``` + +Now you can immediately read: BF16 GEMM, M=20, K=2048, N=11264, `mat2` +is column-major (stride pattern `[1, K]`). + +The contrast is sharper for complex ops. Here's what the profiler gives you +for a MoE fused expert call: + +**Raw profiler — `aiter::ck_moe_stage1`:** + +```json +{ + "cat": "cpu_op", "name": "aiter::ck_moe_stage1", + "args": { + "Input Dims": [[2, 2048], [60, 2816, 2048], [60, 2048, 1408], + [1924], [61], [2], [2, 4, 1408], [], [], [], [], [], [], [], [], [], [], []], + "Input type": ["BFloat16", "BFloat16", "BFloat16", + "Int", "Int", "Int", "BFloat16", + "Scalar", "Scalar", "Scalar", "Scalar", "Scalar", "Scalar", + "Scalar", "Scalar", "Scalar", "Scalar", "Scalar"] + } +} +``` + +18 arguments, most labeled just "Scalar" — which one is `topk`? Which is +`block_m`? What does `[1924]` represent? Uninterpretable without reading the +source code. + +**After — Event Replay IR:** + +```json +{ + "op_name": "aiter::ck_moe_stage1", + "replay_ir": { + "list_pos_args": [ + { "arg_name": "hidden_states", "arg_type": "Tensor", + "value": { "shape": [2, 2048], "dtype": "c10::BFloat16" } }, + { "arg_name": "w1", "arg_type": "Tensor", + "value": { "shape": [60, 2816, 2048], "dtype": "c10::BFloat16" } }, + { "arg_name": "w2", "arg_type": "Tensor", + "value": { "shape": [60, 2048, 1408], "dtype": "c10::BFloat16" } }, + { "arg_name": "sorted_token_ids", "arg_type": "Tensor", + "value": { "shape": [1924], "dtype": "int", "init": "zeros" } }, + { "arg_name": "sorted_expert_ids", "arg_type": "Tensor", + "value": { "shape": [61], "dtype": "int", "init": "zeros" } }, + { "arg_name": "num_valid_ids", "arg_type": "Tensor", + "value": { "shape": [2], "dtype": "int", "init": "zeros" } }, + { "arg_name": "out", "arg_type": "Tensor", + "value": { "shape": [2, 4, 1408], "dtype": "c10::BFloat16", "init": null } }, + { "arg_name": "topk", "arg_type": "SymInt", "value": 4 }, + { "arg_name": "block_m", "arg_type": "SymInt?", "value": 32 }, + { "arg_name": "use_non_temporal_load", "arg_type": "bool", "value": true } + ] + } +} +``` + +Now you can read: 2 tokens routed to top-4 of 60 experts, gate hidden dim 2048, +up-projection to 2816, down-projection through 1408, `[1924]` is +`sorted_token_ids` (the routing table), block tile size 32, NTL enabled. + --- -## Batch Replay and Benchmark +## Batch Replay -### Extract Operator IR from TraceLens Profiles +### Extract IR for Multiple Events ```python import json -# Extract replay IR for events of interest -repro_data = [EventReplayer(event, lazy=True).get_repro_info() for event in events_of_interest] +repro_data = [EventReplayer(event, lazy=True).get_repro_info() + for event in events_of_interest] with open('event_replay_ir.json', 'w') as f: json.dump(repro_data, f, indent=4) @@ -77,47 +195,21 @@ python batched_replay.py event_replay_ir.json Average time taken: 100.38 microseconds Successfully executed aten::convolution. Result: Tensor(shape=torch.Size([20, 256, 14, 14]), dtype=torch.bfloat16, device=cuda:0) - -[8/11] Replaying: aten::convolution - Reconstructing arguments for 'aten::convolution'... - Positional Args: - input Tensor: {'shape': [20, 256, 14, 14], 'dtype': 'c10::BFloat16', 'strides': [50176, 196, 14, 1]} - weight Tensor: {'shape': [512, 256, 3, 3], 'dtype': 'c10::BFloat16', 'strides': [2304, 9, 3, 1]} - bias Tensor?: None - stride SymInt[]: [2, 2] - padding SymInt[]: [1, 1] - dilation SymInt[]: [1, 1] - transposed bool: False - output_padding SymInt[]: [0, 0] - groups SymInt: 1 - Keyword Args: - Average time taken: 92.83 microseconds - Successfully executed aten::convolution. - Result: Tensor(shape=torch.Size([20, 512, 7, 7]), dtype=torch.bfloat16, device=cuda:0) ... --- Replay Summary --- Total operations in file: 11 Attempted replays: 11 Successful replays: 11 Errors encountered: 0 - ``` -------------------- -### Creating Standalone Replay Artifacts -You can optionally package the extracted replay IR and scripts into a standalone zip file for easy sharing and reproduction, independent of the original model code or TraceLens repository. - -Artifacts included: -- `event_replay_ir.json`: Serialized operator replay instructions. -- `utils.py`: Tensor creation and helper utilities. -- `batched_replay.py`: Script to batch replay and benchmark operations. -- `batched_replay_readme.md`: Instructions for running the replay. +### Creating Standalone Replay Artifacts -Example packaging code: +Package the IR and scripts into a standalone zip for sharing and reproduction, +independent of the original model code or TraceLens: ```python -import zipfile -import os +import zipfile, os from TraceLens.EventReplay import utils as tl_utils from TraceLens.EventReplay import batched_replay @@ -128,25 +220,282 @@ files = [ batched_replay.__file__.replace('batched_replay.py', 'batched_replay_readme.md') ] -zip_file_path = '/path/to/replay_code.zip' -with zipfile.ZipFile(zip_file_path, 'w') as zipf: +with zipfile.ZipFile('/path/to/replay_code.zip', 'w') as zipf: for file in files: zipf.write(file, arcname=os.path.basename(file)) +``` + +--- + +## Architecture + +Event Replay operates in two distinct phases: + +### Phase 1: IR Extraction (deterministic) + +When an `EventReplayer` is constructed from a profiler event, it resolves the +op's registered schema and extracts a complete Intermediate Representation: + +- **Op name** — the fully qualified operator name (e.g., `aten::mm`, `_rocm_C::paged_attention`) +- **Argument metadata** — for each positional and keyword argument: + - Tensors: shape, dtype, strides, initialization hint + - Scalars: concrete value (int, float, bool, str) + - Lists: element values + - Optionals: `null` when not provided + +This phase is purely mechanical: the same profiler event always produces the same +IR. The output is a portable JSON dictionary. + +**Prerequisite — ops must be registered with the PyTorch dispatcher.** If an op +is called as a plain Python function (e.g., a Triton kernel launched directly), +there is no schema for the profiler to capture and no IR can be extracted. The +op must go through `torch.ops`, `torch.library`, or the JIT registry. This is +why aiter's CK-based ops (`aiter::ck_moe_stage1`) produce full IR while direct +Triton kernel calls do not. See +[Shape Metadata Guide](conceptual/shape_metadata_guide.md) for details on +registering ops that lack schemas. + +### Phase 2: Init & Replay (requires judgment) + +Given an IR, Event Replay: + +1. **Allocates tensors** matching the recorded shapes, dtypes, and strides +2. **Initializes values** — `randn` for floating-point tensors, `zeros` for + integer/bool tensors, `None` for optional arguments +3. **Resolves the op** to a callable function (via JIT registry, `torch.ops`, + or direct module import) +4. **Calls the op** and optionally benchmarks it + +The default initialization works well for compute-bound ops like GEMMs and +convolutions, where kernel performance is independent of input values. However, +**control and index tensors require realistic values** — zeroed-out metadata +produces behavior that is not representative of the true workload: + +| Op family | Affected tensors | Effect of zeros | +|-----------|-----------------|-----------------| +| Paged Attention | `block_tables`, `seq_lens`, `query_start_loc` | Kernel sees 0 context length — does no real work | +| MoE Routing | `sorted_token_ids`, `sorted_expert_ids`, `num_valid_ids` | Kernel sees 0 valid tokens — skips all computation | -print(f"Created zip file: {zip_file_path}") +A true reproduction would require the exact tensor values from the original +execution, but the profiler doesn't capture tensor contents — only shapes and +dtypes. **Custom Initializers** bridge this gap by constructing plausible values +from shapes and metadata already in the IR, without additional instrumentation. + +--- + +## Custom Initializers + +Custom initializers fill metadata tensors with realistic values before replay. +They are applied automatically when `auto_init=True` (the default). + +### Built-in Initializers + +These ship with TraceLens and require no setup — they activate automatically +when the op name matches: + +**`PagedAttentionInit`** — matches `paged_attention` + +Initializes the KV cache metadata so the attention kernel does real work: +- `block_tables` — random permutation of the physical block pool (simulates + realistic scattered memory allocation) +- `seq_lens` — all sequences set to `max_seq_len` +- `query_start_loc` — CSR indptr encoding per-sequence query token counts. + When batch context is available (see [Batch Context](#batch-context-vllm-traces)), + uses the exact prefill/decode split; otherwise falls back to heuristics + +**`MoeRoutingInit`** — matches `ck_moe_stage1`, `ck_moe_stage2` + +Constructs a complete token-to-expert routing table: +- `sorted_token_ids` — padded to `block_m` boundaries per expert +- `sorted_expert_ids` — block-level expert assignment +- `num_valid_ids` — total valid (non-padding) token slots + +Supports configurable token distribution via `init_kwargs`: + +```python +# Default: uniform random assignment across experts +replayer = EventReplayer(event, device='cuda') + +# Zipf: skewed distribution (few experts get most tokens, closer to real routing) +replayer = EventReplayer(event, device='cuda', + init_kwargs={"moe_distribution": "zipf", "moe_zipf_s": 1.5}) ``` + +### Disabling Built-in Inits + +```python +replayer = EventReplayer(event, device='cuda', auto_init=False) +replayer.replay() +``` + +### Adding Your Own Initializer + +For ops not covered by the built-ins, you can write and register your own. + +When `replay()` runs with `auto_init=True`, it iterates over registered +initializers and calls `initialize()` on the first one whose `op_patterns` +match the op name. The initializer mutates `replayer.args` / `replayer.kwargs` +**in-place** before the op is called. + +The initializer can access: + +- `replayer.args` — list of allocated tensors/scalars (in schema order) +- `replayer.kwargs` — dict of keyword arguments +- `replayer.event` — the raw profiler event dict +- `replayer.event_replay_IR` — the extracted IR with named argument metadata + +Arguments should be looked up **by name** from the IR rather than hardcoded +positions, making initializers robust to schema changes across library versions: + +```python +ir = replayer.event_replay_IR +arg_names = [a["arg_name"] for a in ir["list_pos_args"]] + +def _by_name(name, fallback_pos): + if name in arg_names: + return replayer.args[arg_names.index(name)] + return replayer.args[fallback_pos] +``` + +**Real example — `aten::index_add_`:** This op was found as a bottleneck in +gsplat (Gaussian Splatting) on MI325X. The `index` tensor (5M elements) must +contain valid row indices into `self`; the default zero-init makes every source +row accumulate into row 0, which is not representative of the real scatter +pattern: + +```python +from TraceLens.EventReplay import EventReplayer, CustomInit + +class IndexAddInit(CustomInit): + op_patterns = ["index_add"] + + def initialize(self, replayer, **kwargs): + import torch + + ir = replayer.event_replay_IR + arg_names = [a["arg_name"] for a in ir["list_pos_args"]] + + self_tensor = replayer.args[arg_names.index("self")] + dim = replayer.args[arg_names.index("dim")] + index = replayer.args[arg_names.index("index")] + + dim_size = self_tensor.shape[dim] + index.copy_(torch.randint(0, dim_size, index.shape, + device=index.device)) + + return (f"[custom init] index_add — index randint(0, {dim_size}), " + f"shape={list(index.shape)}") + +# Register — all future replays of index_add ops will use this +EventReplayer.register_custom_init(IndexAddInit()) +``` + +After `register_custom_init`, every subsequent `EventReplayer` will +automatically apply `IndexAddInit` when replaying any op whose name contains +`"index_add"`. + +To see what's currently registered: + +```python +EventReplayer.list_custom_inits() +``` + --- -## Use Cases +## Auto-Import for Custom Ops -- **Performance Debugging**: Quickly isolate and reproduce performance issues from large models. -- **Regression Testing**: Automate benchmarks to detect performance regressions at the operator level. -- **Kernel Development**: Extract minimal reproducers for GPU kernel optimization and debugging. -- **Numerical Validation**: Evaluate numerical correctness and stability of isolated operations across hardware. -- **Hardware Counter Profiling**: Use with hardware counters to analyze performance bottlenecks in specific operations. +When EventReplayer encounters an op from a non-`aten` namespace (e.g., +`_rocm_C::paged_attention`, `aiter::ck_moe_stage1`), it automatically attempts +to import the library that registers the op's schema. The import is conditional +— it only fires if the namespace is recognized and hasn't been attempted yet. + +Built-in namespace mappings: + +| Namespace | Imported modules | +|-----------|-----------------| +| `aiter` | `aiter` | +| `_rocm_C` | `vllm._rocm_C` | +| `_C` | `vllm._C` | +| `vllm` | `vllm._C`, `vllm._rocm_C` | + +Register additional namespaces: + +```python +EventReplayer.register_namespace("my_lib", ["my_lib.ops"]) +``` + +--- + +## Batch Context (vLLM traces) + +vLLM annotates each forward step with `user_annotation` events that encode +the exact prefill/decode batch composition (e.g., +`execute_context_1(21)_generation_5(5)`). `extract_batch_context` parses these +annotations and attaches the breakdown to each paged attention event, so +`PagedAttentionInit` can construct an accurate `query_start_loc`. + +```python +from TraceLens.EventReplay import extract_batch_context + +num_annotated = extract_batch_context(perf_analyzer) +print(f"Annotated {num_annotated} paged_attention events with batch context") +``` + +Without batch context, the initializer falls back to a heuristic (uniform +token distribution for prefill, 1 token/seq for decode). + +--- + +## Validation + +Tested on Qwen1.5-MoE-A2.7B (aiter backend, MI300X) — 30 unique op configs: + +| Metric | Result | +|--------|--------| +| Kernel name match | 27/30 (3 `aten::copy_` mismatches — runtime DMA path selection, expected) | +| GPU busy time delta (median, warm cache) | -17% (replay faster — isolation effect) | +| GPU busy time delta range | -32% (large mem-bound GEMMs) to +5% (latency-bound ops) | + +Also validated on Qwen2.5-3B (dense, no MoE): 26/30 match, similar timing profile. + +Replay is systematically faster because isolated single-op execution has no +inter-op cache contention. This is a fundamental property of micro-benchmarks, +not a bug. + +```python +from TraceLens.EventReplay import benchmark_func + +metrics = benchmark_func(replayer.replay, warmup=5, repeat=20) +print(f"Median: {metrics['median_ms']:.3f} ms") +``` --- -## Notes +## Known Limitations + +- **Unregistered ops are invisible.** Triton kernels called directly from Python + (e.g., aiter's Triton attention path) have no schema in the profiler. The fix + is wrapping them in `torch.library.custom_op` — a one-time registration effort + in the upstream library. See + [Shape Metadata Guide](conceptual/shape_metadata_guide.md). -- Event Replay uses randomized data based on extracted tensor shapes; thus, replay timings approximate real-world performance. +- **Single-op isolation vs. real workload.** Replay runs each op in isolation + with no surrounding memory traffic. Timings are a lower bound on in-model + performance. Sequence replay (ops in trace order to reproduce natural cache + pollution) is a planned enhancement. + +- **Data-dependent kernels.** Custom initializers provide plausible but not exact + values from the original execution (the profiler doesn't capture tensor + contents). For most ops this doesn't matter; for ops with data-dependent + control flow (e.g., sparse attention with variable sequence lengths), timing + may vary. + +--- + +## Use Cases + +- **Trace Interpretation**: Translate opaque profiler arguments into named, typed JSON for understanding what each op actually computes. +- **Performance Debugging**: Isolate and reproduce performance issues from large models without running the model. +- **Regression Testing**: Automate benchmarks to detect performance regressions at the operator level. +- **Kernel Development**: Extract minimal reproducers for GPU kernel optimization and debugging. +- **Portable Sharing**: Package IR + replay scripts as standalone zip artifacts for teammates or upstream repos. From 95406f3adb49bc24b56ada6b44a1159c2cbd749e Mon Sep 17 00:00:00 2001 From: Jassani Date: Tue, 28 Apr 2026 15:38:57 -0400 Subject: [PATCH 4/4] Fix bugs, add tests, and improve EventReplay docs Bug fixes: - Fix lazy+auto_init crash: replay() now sets self.args in lazy mode so custom initializers can access them (BUG-1) - Fix get_repro_info() shallow copy corruption: no longer mutates event_replay_IR on repeated calls (BUG-2) - Fix batched_replay.py: handle benchmark_func dict return type, implement --op-filter and --op-limit flags (BUG-3) - replay() now returns the op result instead of None (CLAIM-4) - First-match-wins for custom initializers (CLAIM-1) - Exact name matching for op_patterns (no more substring matching) Tests: - Add CPU-only unit tests (test_event_replay.py, 11 tests) - Add GPU integration tests (test_event_replay_gpu.py) with kernel name validation Docs (EventReplay.md): - Fix benchmark_func example (wrong params and key names) - Remove broken Shape Metadata Guide links - Rewrite custom initializer section as step-by-step guide - Rewrite iteration annotations section with full explanation - Add batch replay CLI flag examples - Update all op_patterns to fully-qualified names --- TraceLens/EventReplay/batched_replay.py | 21 +- TraceLens/EventReplay/custom_inits.py | 7 +- TraceLens/EventReplay/event_replay.py | 45 +-- TraceLens/EventReplay/test_event_replay.py | 204 ++++++++++++++ .../EventReplay/test_event_replay_gpu.py | 261 ++++++++++++++++++ docs/EventReplay.md | 219 +++++++++------ 6 files changed, 641 insertions(+), 116 deletions(-) create mode 100644 TraceLens/EventReplay/test_event_replay.py create mode 100644 TraceLens/EventReplay/test_event_replay_gpu.py diff --git a/TraceLens/EventReplay/batched_replay.py b/TraceLens/EventReplay/batched_replay.py index d948e6d45..97008cb75 100644 --- a/TraceLens/EventReplay/batched_replay.py +++ b/TraceLens/EventReplay/batched_replay.py @@ -99,11 +99,17 @@ def _get_args_kwargs_from_ir( replayed_count = 0 errors = 0 - for i, repro_info in enumerate(repro_data_list): + ops_to_replay = repro_data_list + if args.op_filter: + ops_to_replay = [r for r in ops_to_replay if args.op_filter in r["op_name"]] + if args.op_limit: + ops_to_replay = ops_to_replay[: args.op_limit] + + for i, repro_info in enumerate(ops_to_replay): op_name = repro_info["op_name"] replay_ir = repro_info["replay_ir"] - print(f"\n[{replayed_count + 1}/{len(repro_data_list)}] Replaying: {op_name}") + print(f"\n[{replayed_count + 1}/{len(ops_to_replay)}] Replaying: {op_name}") # Get the PyTorch operation function try: @@ -151,15 +157,16 @@ def _get_args_kwargs_from_ir( errors += 1 continue # --- Benchmark the function --- - mean_time_us = benchmark_func( + metrics = benchmark_func( lambda: func(*pos_args, **kwargs), args.device, warmup=50, avg_steps=100 ) - print(f" Average time taken: {mean_time_us:.2f} microseconds") + mean_time_us = metrics["mean_us"] + print(f" Average time taken: {mean_time_us:.2f} us (median: {metrics['median_us']:.2f} us)") if "count" in repro_info: count_workload = repro_info["count"] total_time_us = mean_time_us * count_workload print(f" Count in workload: {count_workload}") - print(f" Est time in workload: {total_time_us:.2f} microseconds") + print(f" Est time in workload: {total_time_us:.2f} us") # --- Optionally sync again --- if args.device == "cuda": torch.cuda.synchronize() @@ -190,7 +197,9 @@ def _get_args_kwargs_from_ir( print("\n--- Replay Summary ---") print(f"Total operations in file: {len(repro_data_list)}") if args.op_filter: - print(f"Filter applied: '{args.op_filter}'") + print(f"Filter applied: '{args.op_filter}' ({len(ops_to_replay)} matched)") + if args.op_limit: + print(f"Limit applied: {args.op_limit}") print(f"Attempted replays: {replayed_count}") print(f"Successful replays: {replayed_count - errors}") print(f"Errors encountered: {errors}") diff --git a/TraceLens/EventReplay/custom_inits.py b/TraceLens/EventReplay/custom_inits.py index 9ded92540..bce9960ef 100644 --- a/TraceLens/EventReplay/custom_inits.py +++ b/TraceLens/EventReplay/custom_inits.py @@ -24,6 +24,7 @@ from __future__ import annotations import re +import warnings from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -116,7 +117,7 @@ class CustomInit(ABC): def applies_to(self, replayer: Any) -> bool: op_name = replayer.event.get("name", "") - return any(pat in op_name for pat in self.op_patterns) + return op_name in self.op_patterns @abstractmethod def initialize(self, replayer: Any, **kwargs) -> Optional[str]: @@ -138,7 +139,7 @@ class PagedAttentionInit(CustomInit): - Block table entries drawn from a random permutation of the pool. """ - op_patterns = ["paged_attention"] + op_patterns = ["_rocm_C::paged_attention"] def initialize(self, replayer: Any, **kwargs) -> Optional[str]: try: @@ -277,7 +278,7 @@ class MoeRoutingInit(CustomInit): [11] block_m (scalar) """ - op_patterns = ["ck_moe_stage1", "ck_moe_stage2"] + op_patterns = ["aiter::ck_moe_stage1", "aiter::ck_moe_stage2"] def initialize(self, replayer: Any, **kwargs) -> Optional[str]: try: diff --git a/TraceLens/EventReplay/event_replay.py b/TraceLens/EventReplay/event_replay.py index bc26a13b6..715ede6f0 100644 --- a/TraceLens/EventReplay/event_replay.py +++ b/TraceLens/EventReplay/event_replay.py @@ -305,21 +305,22 @@ def _setup(self): def replay(self): """ Replay the event using the matched schema and event replay IR. + + Returns: + The result of the PyTorch operation. """ if self.lazy: - args, kwargs = EventReplayer._get_args_kwargs( + self.args, self.kwargs = EventReplayer._get_args_kwargs( self.event_replay_IR, device=self.device ) - else: - args, kwargs = self.args, self.kwargs if not self._inits_applied and self._auto_init: self._apply_custom_inits() - self._func(*args, **kwargs) + return self._func(*self.args, **self.kwargs) def _apply_custom_inits(self): - """Run all applicable custom initializers on this replayer's tensors.""" + """Apply the first matching custom initializer to this replayer's tensors.""" for custom_init in self._custom_init_registry: if custom_init.applies_to(self): try: @@ -330,6 +331,7 @@ def _apply_custom_inits(self): warnings.warn( f"[custom init] {type(custom_init).__name__} failed: {e}" ) + break self._inits_applied = True @staticmethod @@ -761,22 +763,21 @@ def _parse_arg(raw_arg: str) -> Tuple[str, str, Optional[str], bool]: def get_repro_info(self) -> Dict[str, Any]: """ Extracts the minimal, serializable information needed to reproduce the event call. + + Safe to call multiple times — does not mutate self.event_replay_IR. """ - dict_repro_info = {} - dict_repro_info["op_name"] = self.event["name"] - list_pos_args, list_kwargs = ( - self.event_replay_IR["list_pos_args"], - self.event_replay_IR["list_kwargs"], - ) - list_pos_args_copy, list_kwargs_copy = list_pos_args.copy(), list_kwargs.copy() - for idx, val in enumerate(list_pos_args_copy): - if isinstance(val["value"], TensorCfg): - list_pos_args_copy[idx]["value"] = val["value"].__dict__ - for idx, val in enumerate(list_kwargs_copy): - if isinstance(val["value"], TensorCfg): - list_kwargs_copy[idx]["value"] = val["value"].__dict__ - dict_repro_info["replay_ir"] = { - "list_pos_args": list_pos_args_copy, - "list_kwargs": list_kwargs_copy, + def _serialize_arg(arg: Dict[str, Any]) -> Dict[str, Any]: + val = arg["value"] + return { + **arg, + "value": val.__dict__.copy() if isinstance(val, TensorCfg) else val, + } + + ir = self.event_replay_IR + return { + "op_name": self.event["name"], + "replay_ir": { + "list_pos_args": [_serialize_arg(a) for a in ir["list_pos_args"]], + "list_kwargs": [_serialize_arg(a) for a in ir["list_kwargs"]], + }, } - return dict_repro_info diff --git a/TraceLens/EventReplay/test_event_replay.py b/TraceLens/EventReplay/test_event_replay.py new file mode 100644 index 000000000..7948caa0f --- /dev/null +++ b/TraceLens/EventReplay/test_event_replay.py @@ -0,0 +1,204 @@ +############################################################################### +# Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +Tests for EventReplay core functionality. + +All tests use CPU-only ops (aten::mm) so they run without a GPU. +Run from the repo root: + python -m pytest TraceLens/EventReplay/test_event_replay.py -v +""" + +import sys +import os +import pytest +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from TraceLens.EventReplay.event_replay import EventReplayer # noqa: E402 +from TraceLens.EventReplay.custom_inits import CustomInit # noqa: E402 +from TraceLens.EventReplay.utils import TensorCfg # noqa: E402 + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _make_mm_event(M=4, K=8, N=16): + """Minimal profiler event dict for aten::mm (M x K) @ (K x N).""" + return { + "name": "aten::mm", + "args": { + "Input Dims": [[M, K], [K, N]], + "Input type": ["float", "float"], + "Input Strides": [[K, 1], [N, 1]], + "Concrete Inputs": ["", ""], + }, + } + + +@pytest.fixture(autouse=True) +def _isolate_registry(): + """Save and restore the global custom-init registry around every test.""" + saved = EventReplayer._custom_init_registry[:] + yield + EventReplayer._custom_init_registry = saved + + +# --------------------------------------------------------------------------- +# BUG-1: lazy=True + auto_init=True must not crash +# --------------------------------------------------------------------------- + +class TestLazyAutoInit: + def test_lazy_replay_sets_self_args(self): + """replay() in lazy mode must populate self.args.""" + replayer = EventReplayer(_make_mm_event(), device="cpu", lazy=True, auto_init=False) + assert not hasattr(replayer, "args") + replayer.replay() + assert hasattr(replayer, "args") + assert isinstance(replayer.args, list) + + def test_lazy_with_custom_init_no_crash(self): + """A custom init that reads replayer.args must work in lazy mode.""" + accessed = {} + + class ProbeInit(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + accessed["args"] = replayer.args + accessed["kwargs"] = replayer.kwargs + return "[probe] ok" + + EventReplayer.register_custom_init(ProbeInit()) + replayer = EventReplayer(_make_mm_event(), device="cpu", lazy=True, auto_init=True) + replayer.replay() + assert "args" in accessed + assert len(accessed["args"]) == 2 # self, mat2 + + +# --------------------------------------------------------------------------- +# BUG-2: get_repro_info() must not corrupt event_replay_IR +# --------------------------------------------------------------------------- + +class TestGetReproInfo: + def test_idempotent(self): + """Calling get_repro_info() twice must produce identical output.""" + replayer = EventReplayer(_make_mm_event(), device="cpu", lazy=True) + assert replayer.get_repro_info() == replayer.get_repro_info() + + def test_does_not_mutate_ir(self): + """TensorCfg objects in the IR must survive get_repro_info().""" + replayer = EventReplayer(_make_mm_event(), device="cpu", lazy=True) + replayer.get_repro_info() + for arg in replayer.event_replay_IR["list_pos_args"]: + if arg["arg_type"].startswith("Tensor"): + assert isinstance(arg["value"], TensorCfg), ( + f"arg '{arg['arg_name']}' is {type(arg['value'])}, expected TensorCfg" + ) + + def test_replay_works_after_get_repro_info(self): + """replay() must succeed after get_repro_info() (IR still intact).""" + replayer = EventReplayer(_make_mm_event(), device="cpu", lazy=True) + replayer.get_repro_info() + result = replayer.replay() + assert isinstance(result, torch.Tensor) + + +# --------------------------------------------------------------------------- +# CLAIM-1: first-match-wins (only one init runs) +# --------------------------------------------------------------------------- + +class TestFirstMatchWins: + def test_only_first_matching_init_runs(self): + """When two inits match, only the first registered one executes.""" + log = [] + + class InitA(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + log.append("A") + + class InitB(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + log.append("B") + + EventReplayer._custom_init_registry = [InitA(), InitB()] + EventReplayer(_make_mm_event(), device="cpu", auto_init=True).replay() + assert log == ["A"] + + +# --------------------------------------------------------------------------- +# CLAIM-4: replay() returns the op result +# --------------------------------------------------------------------------- + +class TestReplayReturn: + def test_returns_tensor(self): + """replay() of aten::mm must return a correctly-shaped Tensor.""" + replayer = EventReplayer(_make_mm_event(M=4, K=8, N=16), device="cpu") + result = replayer.replay() + assert isinstance(result, torch.Tensor) + assert result.shape == (4, 16) + + def test_returns_tensor_lazy(self): + """Lazy replay must also return the result.""" + result = EventReplayer(_make_mm_event(), device="cpu", lazy=True).replay() + assert isinstance(result, torch.Tensor) + + +# --------------------------------------------------------------------------- +# Exact name matching for op_patterns +# --------------------------------------------------------------------------- + +class TestExactNameMatching: + def test_exact_match_hits(self): + """op_patterns=["aten::mm"] matches event name "aten::mm".""" + matched = [] + + class ExactInit(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + matched.append(True) + + EventReplayer._custom_init_registry = [ExactInit()] + EventReplayer(_make_mm_event(), device="cpu", auto_init=True).replay() + assert matched == [True] + + def test_substring_does_not_match(self): + """op_patterns=["mm"] must NOT match "aten::mm" (exact only).""" + matched = [] + + class SubstringInit(CustomInit): + op_patterns = ["mm"] + def initialize(self, replayer, **kwargs): + matched.append(True) + + EventReplayer._custom_init_registry = [SubstringInit()] + EventReplayer(_make_mm_event(), device="cpu", auto_init=True).replay() + assert matched == [] + + +# --------------------------------------------------------------------------- +# auto_init=False skips all inits +# --------------------------------------------------------------------------- + +class TestAutoInitDisabled: + def test_no_init_runs_when_disabled(self): + log = [] + + class AlwaysInit(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + log.append("ran") + + EventReplayer._custom_init_registry = [AlwaysInit()] + EventReplayer(_make_mm_event(), device="cpu", auto_init=False).replay() + assert log == [] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/TraceLens/EventReplay/test_event_replay_gpu.py b/TraceLens/EventReplay/test_event_replay_gpu.py new file mode 100644 index 000000000..9998ec956 --- /dev/null +++ b/TraceLens/EventReplay/test_event_replay_gpu.py @@ -0,0 +1,261 @@ +############################################################################### +# Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +""" +GPU integration tests for EventReplay. + +Profiles real ops, replays from the captured trace, and validates: + 1. Kernel name match between original and replayed execution + 2. BUG-1: lazy=True + auto_init=True works on GPU + 3. BUG-2: get_repro_info() is idempotent (doesn't corrupt IR) + 4. CLAIM-4: replay() returns a tensor + 5. CLAIM-1: first-match-wins with real ops + +Requires a GPU (MI300X / MI210 / etc). Run from the repo root: + python TraceLens/EventReplay/test_event_replay_gpu.py +""" + +import sys, os, json, time +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +import torch +from torch.profiler import profile, ProfilerActivity + +from TraceLens.EventReplay.event_replay import EventReplayer +from TraceLens.EventReplay.custom_inits import CustomInit +from TraceLens.EventReplay.utils import TensorCfg + +assert torch.cuda.is_available(), "GPU required for this test" + +DEVICE = "cuda" +TRACE_FILE = "/tmp/test_event_replay_gpu_trace.json" +REPLAY_TRACE = "/tmp/test_event_replay_gpu_replay.json" + +# --------------------------------------------------------------------------- +# Step 1: Profile a set of real ops +# --------------------------------------------------------------------------- + +print("=" * 80) +print("Step 1: Profiling real ops") +print("=" * 80) + +M, K, N = 256, 1024, 512 +mm_a = torch.randn(M, K, dtype=torch.bfloat16, device=DEVICE) +mm_b = torch.randn(K, N, dtype=torch.bfloat16, device=DEVICE) +add_a = torch.randn(M, N, dtype=torch.bfloat16, device=DEVICE) +add_b = torch.randn(M, N, dtype=torch.bfloat16, device=DEVICE) +bmm_a = torch.randn(4, M, K, dtype=torch.bfloat16, device=DEVICE) +bmm_b = torch.randn(4, K, N, dtype=torch.bfloat16, device=DEVICE) + +def run_ops(): + torch.mm(mm_a, mm_b) + torch.add(add_a, add_b) + torch.bmm(bmm_a, bmm_b) + torch.mul(add_a, add_b) + torch.sigmoid(add_a) + +for _ in range(10): + run_ops() +torch.cuda.synchronize() + +def trace_handler(p): + p.export_chrome_trace(TRACE_FILE) + +wait, warmup, active = 3, 3, 5 +with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), + record_shapes=True, + on_trace_ready=trace_handler, +) as p: + for _ in range(wait + warmup + active): + run_ops() + p.step() + +print(f"Trace saved to {TRACE_FILE}") + +# --------------------------------------------------------------------------- +# Step 2: Load trace and find events +# --------------------------------------------------------------------------- + +print(f"\n{'=' * 80}") +print("Step 2: Loading trace") +print("=" * 80) + +with open(TRACE_FILE) as f: + trace_data = json.load(f) + +all_events = trace_data.get("traceEvents", []) + +OPS_TO_TEST = ["aten::mm", "aten::add", "aten::bmm", "aten::mul", "aten::sigmoid"] + +def find_event(events, op_name): + """Find a cpu_op event with the right name and shape data.""" + candidates = [ + e for e in events + if e.get("cat") == "cpu_op" + and e.get("name") == op_name + and "args" in e + and "Input Dims" in e.get("args", {}) + ] + if candidates: + return candidates[len(candidates) // 2] + return None + +results = [] +errors = [] + +# --------------------------------------------------------------------------- +# Step 3: Replay each op and validate +# --------------------------------------------------------------------------- + +print(f"\n{'=' * 80}") +print("Step 3: Replay and validate") +print("=" * 80) + +print(f"\n{'Op':<30} {'Kernel Match':<15} {'Return':<10} {'Lazy':<10} {'ReproInfo':<12} {'Status'}") +print("-" * 100) + +for op_name in OPS_TO_TEST: + evt = find_event(all_events, op_name) + if evt is None: + print(f"{op_name:<30} {'SKIP':<15} {'---':<10} {'---':<10} {'---':<12} not in trace") + continue + + status = [] + + # --- Test: basic replay returns a result (CLAIM-4) --- + try: + replayer = EventReplayer(evt, device=DEVICE, auto_init=False) + result = replayer.replay() + returns_ok = isinstance(result, torch.Tensor) + except Exception as e: + returns_ok = False + status.append(f"replay error: {e}") + + # --- Test: lazy mode works (BUG-1) --- + try: + lazy_replayer = EventReplayer(evt, device=DEVICE, lazy=True, auto_init=False) + lazy_result = lazy_replayer.replay() + lazy_ok = isinstance(lazy_result, torch.Tensor) + assert hasattr(lazy_replayer, "args"), "self.args not set after lazy replay" + except Exception as e: + lazy_ok = False + status.append(f"lazy error: {e}") + + # --- Test: get_repro_info idempotent (BUG-2) --- + try: + repro_replayer = EventReplayer(evt, device=DEVICE, lazy=True) + info1 = repro_replayer.get_repro_info() + info2 = repro_replayer.get_repro_info() + repro_ok = (info1 == info2) + for arg in repro_replayer.event_replay_IR["list_pos_args"]: + if arg["arg_type"].startswith("Tensor"): + assert isinstance(arg["value"], TensorCfg), "IR corrupted after get_repro_info" + repro_replayer.replay() + except Exception as e: + repro_ok = False + status.append(f"repro error: {e}") + + # --- Test: kernel name match --- + kernel_match = "N/A" + try: + replay_replayer = EventReplayer(evt, device=DEVICE, auto_init=False) + for _ in range(5): + replay_replayer.replay() + torch.cuda.synchronize() + + def th(p): + p.export_chrome_trace(REPLAY_TRACE) + + w, wu, a = 2, 2, 3 + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=w, warmup=wu, active=a, repeat=1), + record_shapes=True, + on_trace_ready=th, + ) as p: + for _ in range(w + wu + a): + replay_replayer.replay() + p.step() + + with open(REPLAY_TRACE) as f: + replay_trace = json.load(f) + + orig_gpu = set() + for e in all_events: + if e.get("cat") == "kernel" and e.get("name", ""): + orig_gpu.add(e["name"]) + + replay_gpu = set() + for e in replay_trace.get("traceEvents", []): + if e.get("cat") == "kernel" and e.get("name", ""): + replay_gpu.add(e["name"]) + + kernel_match = "MATCH" if replay_gpu.issubset(orig_gpu) else "MISMATCH" + except Exception as e: + kernel_match = "ERROR" + status.append(f"kernel error: {e}") + + ok = returns_ok and lazy_ok and repro_ok and kernel_match in ("MATCH", "N/A") + tag = "PASS" if ok else "FAIL" + detail = "; ".join(status) if status else "" + + print(f"{op_name:<30} {kernel_match:<15} {'OK' if returns_ok else 'FAIL':<10} {'OK' if lazy_ok else 'FAIL':<10} {'OK' if repro_ok else 'FAIL':<12} {tag} {detail}") + results.append({"op": op_name, "ok": ok, "kernel": kernel_match, + "returns": returns_ok, "lazy": lazy_ok, "repro": repro_ok}) + +# --------------------------------------------------------------------------- +# Step 4: First-match-wins test (CLAIM-1) on GPU +# --------------------------------------------------------------------------- + +print(f"\n{'=' * 80}") +print("Step 4: First-match-wins (CLAIM-1)") +print("=" * 80) + +log = [] + +class InitA(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + log.append("A") + +class InitB(CustomInit): + op_patterns = ["aten::mm"] + def initialize(self, replayer, **kwargs): + log.append("B") + +saved_registry = EventReplayer._custom_init_registry[:] +try: + EventReplayer._custom_init_registry = [InitA(), InitB()] + mm_evt = find_event(all_events, "aten::mm") + if mm_evt: + r = EventReplayer(mm_evt, device=DEVICE, auto_init=True) + r.replay() + first_match_ok = (log == ["A"]) + print(f" First-match-wins: {'PASS' if first_match_ok else 'FAIL'} (log={log})") + else: + first_match_ok = True + print(" SKIP: aten::mm not in trace") +finally: + EventReplayer._custom_init_registry = saved_registry + +# --------------------------------------------------------------------------- +# Summary +# --------------------------------------------------------------------------- + +print(f"\n{'=' * 80}") +print("Summary") +print("=" * 80) + +total = len(results) +passed = sum(1 for r in results if r["ok"]) +print(f"Op tests: {passed}/{total} passed") +print(f"First-match-wins: {'PASS' if first_match_ok else 'FAIL'}") + +all_pass = passed == total and first_match_ok +print(f"\nOverall: {'ALL PASSED' if all_pass else 'FAILURES DETECTED'}") +sys.exit(0 if all_pass else 1) diff --git a/docs/EventReplay.md b/docs/EventReplay.md index 1602e2507..6e6898465 100644 --- a/docs/EventReplay.md +++ b/docs/EventReplay.md @@ -25,8 +25,7 @@ with the right op libraries installed. [Architecture](#architecture) | [Custom Initializers](#custom-initializers) | [Auto-Import](#auto-import-for-custom-ops) | -[Batch Context](#batch-context-vllm-traces) | -[Validation](#validation) | +[Iteration Annotations](#iteration-annotations-vllm-traces) | [Limitations](#known-limitations) | [Use Cases](#use-cases) @@ -173,10 +172,13 @@ with open('event_replay_ir.json', 'w') as f: ``` ```bash -python batched_replay.py event_replay_ir.json +python batched_replay.py event_replay_ir.json # default (timing only) +python batched_replay.py -v event_replay_ir.json # verbose (shows args) +python batched_replay.py --op-filter aten::mm event_replay_ir.json # filter by name +python batched_replay.py --op-limit 5 event_replay_ir.json # first 5 ops ``` -#### Example Output +#### Example Output (`-v`) ``` [7/11] Replaying: aten::convolution @@ -192,7 +194,7 @@ python batched_replay.py event_replay_ir.json output_padding SymInt[]: [0, 0] groups SymInt: 1 Keyword Args: - Average time taken: 100.38 microseconds + Average time taken: 100.38 us (median: 98.21 us) Successfully executed aten::convolution. Result: Tensor(shape=torch.Size([20, 256, 14, 14]), dtype=torch.bfloat16, device=cuda:0) ... @@ -233,8 +235,28 @@ Event Replay operates in two distinct phases: ### Phase 1: IR Extraction (deterministic) -When an `EventReplayer` is constructed from a profiler event, it resolves the -op's registered schema and extracts a complete Intermediate Representation: +The profiler captures tensor dimensions and types but not argument names: + +``` +Input Dims: [[2, 2048], [60, 2816, 2048], [60, 2048, 1408], [1924], [61], [2], ...] +Input type: [BFloat16, BFloat16, BFloat16, Int, Int, Int, ...] +``` + +EventReplayer looks up the op's **registered schema** from the PyTorch dispatcher +(via `torch._C._jit_get_all_schemas()` or `torch.ops`). For example, querying +the registry for `aiter::ck_moe_stage1` returns: + +``` +aiter::ck_moe_stage1(Tensor(a0!) hidden_states, Tensor(a1!) w1, Tensor(a2!) w2, + Tensor(a3!) sorted_token_ids, Tensor(a4!) sorted_expert_ids, + Tensor(a5!) num_valid_ids, Tensor(a6!) out, SymInt topk, + str? kernelName="", Tensor(a9!)? w1_scale=None, + Tensor(a10!)? a1_scale=None, SymInt? block_m=None, ...) -> () +``` + +This schema provides argument names, types, and defaults. EventReplayer zips the +schema with the profiler's `Input Dims` / `Input type` arrays to produce the +named, typed IR: - **Op name** — the fully qualified operator name (e.g., `aten::mm`, `_rocm_C::paged_attention`) - **Argument metadata** — for each positional and keyword argument: @@ -248,12 +270,11 @@ IR. The output is a portable JSON dictionary. **Prerequisite — ops must be registered with the PyTorch dispatcher.** If an op is called as a plain Python function (e.g., a Triton kernel launched directly), -there is no schema for the profiler to capture and no IR can be extracted. The -op must go through `torch.ops`, `torch.library`, or the JIT registry. This is -why aiter's CK-based ops (`aiter::ck_moe_stage1`) produce full IR while direct -Triton kernel calls do not. See -[Shape Metadata Guide](conceptual/shape_metadata_guide.md) for details on -registering ops that lack schemas. +there is no schema to query and no IR can be extracted. The op must go through +`torch.ops`, `torch.library`, or the JIT registry. This is why aiter's CK-based +ops (`aiter::ck_moe_stage1`) produce full IR while direct Triton kernel calls do +not. The fix is wrapping such kernels in `torch.library.custom_op` so the +dispatcher has a schema to query. ### Phase 2: Init & Replay (requires judgment) @@ -293,17 +314,17 @@ They are applied automatically when `auto_init=True` (the default). These ship with TraceLens and require no setup — they activate automatically when the op name matches: -**`PagedAttentionInit`** — matches `paged_attention` +**`PagedAttentionInit`** — matches `_rocm_C::paged_attention` Initializes the KV cache metadata so the attention kernel does real work: - `block_tables` — random permutation of the physical block pool (simulates realistic scattered memory allocation) - `seq_lens` — all sequences set to `max_seq_len` - `query_start_loc` — CSR indptr encoding per-sequence query token counts. - When batch context is available (see [Batch Context](#batch-context-vllm-traces)), + When iteration annotations are available (see [Iteration Annotations](#iteration-annotations-vllm-traces)), uses the exact prefill/decode split; otherwise falls back to heuristics -**`MoeRoutingInit`** — matches `ck_moe_stage1`, `ck_moe_stage2` +**`MoeRoutingInit`** — matches `aiter::ck_moe_stage1`, `aiter::ck_moe_stage2` Constructs a complete token-to-expert routing table: - `sorted_token_ids` — padded to `block_m` boundaries per expert @@ -321,54 +342,38 @@ replayer = EventReplayer(event, device='cuda', init_kwargs={"moe_distribution": "zipf", "moe_zipf_s": 1.5}) ``` -### Disabling Built-in Inits +### Writing Your Own Initializer -```python -replayer = EventReplayer(event, device='cuda', auto_init=False) -replayer.replay() -``` +If you're replaying an op that needs realistic tensor content but isn't +covered by the built-ins, you can write your own custom initializer in +three steps. For real-world examples, see `PagedAttentionInit` and +`MoeRoutingInit` in `TraceLens/EventReplay/custom_inits.py`. -### Adding Your Own Initializer +**Step 1 — Subclass `CustomInit`.** Set `op_patterns` to the exact op name(s) +you want to target. This is an exact match against the profiler event name +(e.g., `"aten::index_add_"`, not `"index_add"`): -For ops not covered by the built-ins, you can write and register your own. +```python +from TraceLens.EventReplay import EventReplayer, CustomInit -When `replay()` runs with `auto_init=True`, it iterates over registered -initializers and calls `initialize()` on the first one whose `op_patterns` -match the op name. The initializer mutates `replayer.args` / `replayer.kwargs` -**in-place** before the op is called. +class IndexAddInit(CustomInit): + op_patterns = ["aten::index_add_"] +``` -The initializer can access: +**Step 2 — Implement `initialize()`.** This method receives the `replayer` +object and mutates its tensors **in-place** before the op executes. You have +access to: - `replayer.args` — list of allocated tensors/scalars (in schema order) - `replayer.kwargs` — dict of keyword arguments - `replayer.event` — the raw profiler event dict - `replayer.event_replay_IR` — the extracted IR with named argument metadata -Arguments should be looked up **by name** from the IR rather than hardcoded -positions, making initializers robust to schema changes across library versions: - -```python -ir = replayer.event_replay_IR -arg_names = [a["arg_name"] for a in ir["list_pos_args"]] - -def _by_name(name, fallback_pos): - if name in arg_names: - return replayer.args[arg_names.index(name)] - return replayer.args[fallback_pos] -``` - -**Real example — `aten::index_add_`:** This op was found as a bottleneck in -gsplat (Gaussian Splatting) on MI325X. The `index` tensor (5M elements) must -contain valid row indices into `self`; the default zero-init makes every source -row accumulate into row 0, which is not representative of the real scatter -pattern: +Look up arguments **by name** from the IR rather than hardcoding positional +indices — this keeps your initializer robust to schema changes across library +versions: ```python -from TraceLens.EventReplay import EventReplayer, CustomInit - -class IndexAddInit(CustomInit): - op_patterns = ["index_add"] - def initialize(self, replayer, **kwargs): import torch @@ -385,14 +390,24 @@ class IndexAddInit(CustomInit): return (f"[custom init] index_add — index randint(0, {dim_size}), " f"shape={list(index.shape)}") +``` + +In this example, `aten::index_add_` accumulates source rows into `self` at +positions given by `index`. The default zero-init makes every row land on +row 0 — not representative of the real scatter pattern. The initializer +fills `index` with random valid indices so the kernel exercises realistic +memory access. -# Register — all future replays of index_add ops will use this +**Step 3 — Register it.** Once registered, the initializer fires automatically +on every future replay of matching ops: + +```python EventReplayer.register_custom_init(IndexAddInit()) ``` -After `register_custom_init`, every subsequent `EventReplayer` will -automatically apply `IndexAddInit` when replaying any op whose name contains -`"index_add"`. +When `replay()` runs with `auto_init=True`, it iterates over registered +initializers and applies the **first match** (built-ins are checked first, +then user-registered ones in order). To see what's currently registered: @@ -426,49 +441,84 @@ EventReplayer.register_namespace("my_lib", ["my_lib.ops"]) --- -## Batch Context (vLLM traces) +## Iteration Annotations (vLLM traces) -vLLM annotates each forward step with `user_annotation` events that encode -the exact prefill/decode batch composition (e.g., -`execute_context_1(21)_generation_5(5)`). `extract_batch_context` parses these -annotations and attaches the breakdown to each paged attention event, so -`PagedAttentionInit` can construct an accurate `query_start_loc`. +### The problem -```python -from TraceLens.EventReplay import extract_batch_context +Paged attention's `query_start_loc` is a CSR indptr array that encodes how +many query tokens each sequence contributes. In a mixed batch (common in +vLLM's continuous batching), some sequences are **prefill** (many query tokens) +and others are **decode** (1 token each). The profiler captures the tensor shape +but not the per-sequence breakdown, so without additional information the +custom initializer has to guess — and guessing wrong changes the compute +pattern significantly (prefill is quadratic in sequence length, decode is +linear). + +### How vLLM exposes this + +vLLM emits a `user_annotation` event for each `execute_model` iteration +with the exact prefill/decode composition encoded in the name: -num_annotated = extract_batch_context(perf_analyzer) -print(f"Annotated {num_annotated} paged_attention events with batch context") +``` +execute_context_2(18)_generation_5(5) ``` -Without batch context, the initializer falls back to a heuristic (uniform -token distribution for prefill, 1 token/seq for decode). +This is an **iteration annotation** — it describes one forward pass. Here it +means: **2 prefill sequences** with **18 total query tokens**, and **5 decode +sequences** with **5 tokens** (1 each). ---- +### Extracting batch context from iteration annotations -## Validation +`extract_batch_context` parses these iteration annotations by timestamp and +attaches a `batch_context` dict to each paged attention event that falls +within the annotation's time range: -Tested on Qwen1.5-MoE-A2.7B (aiter backend, MI300X) — 30 unique op configs: +```python +from TraceLens import TreePerfAnalyzer +from TraceLens.EventReplay import EventReplayer, extract_batch_context + +analyzer = TreePerfAnalyzer.from_file("vllm_trace.json") -| Metric | Result | -|--------|--------| -| Kernel name match | 27/30 (3 `aten::copy_` mismatches — runtime DMA path selection, expected) | -| GPU busy time delta (median, warm cache) | -17% (replay faster — isolation effect) | -| GPU busy time delta range | -32% (large mem-bound GEMMs) to +5% (latency-bound ops) | +# Annotate paged attention events with prefill/decode split +num_annotated = extract_batch_context(analyzer) +print(f"Annotated {num_annotated} paged_attention events") -Also validated on Qwen2.5-3B (dense, no MoE): 26/30 match, similar timing profile. +# Now replay — PagedAttentionInit reads event["batch_context"] automatically +event = analyzer.tree.get_UID2event(some_uid) +replayer = EventReplayer(event, device='cuda') +replayer.replay() # query_start_loc reflects the real prefill/decode split +``` -Replay is systematically faster because isolated single-op execution has no -inter-op cache contention. This is a fundamental property of micro-benchmarks, -not a bug. +After `extract_batch_context`, each annotated event carries: ```python -from TraceLens.EventReplay import benchmark_func - -metrics = benchmark_func(replayer.replay, warmup=5, repeat=20) -print(f"Median: {metrics['median_ms']:.3f} ms") +event["batch_context"] = { + "n_prefill": 2, # number of prefill sequences + "prefill_tokens": 18, # total query tokens across prefill sequences + "n_decode": 5, # number of decode sequences + "decode_tokens": 5, # total query tokens across decode sequences (1 each) +} ``` +`PagedAttentionInit` uses this to build `query_start_loc` accurately: +prefill sequences get `prefill_tokens / n_prefill` tokens each, decode +sequences get 1 token each. + +### Without iteration annotations + +`PagedAttentionInit` still runs — it always initializes `block_tables` +(random permutation of the physical block pool) and `seq_lens` (set to +`max_seq_len`). The only difference is how `query_start_loc` is built. +Without annotations, it falls back to heuristics: + +- `query_tokens == num_seqs` → assumes pure decode (1 token/seq) +- `query_tokens > num_seqs` → assumes pure prefill (tokens distributed + uniformly) + +This is a reasonable approximation for homogeneous batches but inaccurate +for mixed prefill+decode batches, where the iteration annotation provides +the exact split. + --- ## Known Limitations @@ -476,8 +526,7 @@ print(f"Median: {metrics['median_ms']:.3f} ms") - **Unregistered ops are invisible.** Triton kernels called directly from Python (e.g., aiter's Triton attention path) have no schema in the profiler. The fix is wrapping them in `torch.library.custom_op` — a one-time registration effort - in the upstream library. See - [Shape Metadata Guide](conceptual/shape_metadata_guide.md). + in the upstream library. - **Single-op isolation vs. real workload.** Replay runs each op in isolation with no surrounding memory traffic. Timings are a lower bound on in-model