|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +DebugMode: Recording Dispatched Operations and Numerical Debugging |
| 5 | +================================================================= |
| 6 | +
|
| 7 | +**Authors:** Pian Pawakapan, Shangdi Yu |
| 8 | +
|
| 9 | +.. grid:: 2 |
| 10 | +
|
| 11 | + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn |
| 12 | + :class-card: card-prerequisites |
| 13 | +
|
| 14 | + * How to capture dispatched ops for eager and ``torch.compile`` runs |
| 15 | + * How to use tensor hashes and stack traces in DebugMode to pinpoint numerical divergence |
| 16 | +
|
| 17 | + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
| 18 | + :class-card: card-prerequisites |
| 19 | +
|
| 20 | + * PyTorch 2.10 or later |
| 21 | +
|
| 22 | +""" |
| 23 | + |
| 24 | +###################################################################### |
| 25 | +# Overview |
| 26 | +# -------- |
| 27 | +# |
| 28 | +# ``DebugMode`` (:class:`torch.utils._debug_mode.DebugMode`) is a |
| 29 | +# ``TorchDispatchMode`` that intercepts PyTorch runtime calls and emits a |
| 30 | +# hierarchical log of operations. It is particularly useful when you need to |
| 31 | +# understand *what* actually runs, both in eager mode and under ``torch.compile`` |
| 32 | +# or when you need to pinpoint numerical divergence between two runs. |
| 33 | +# |
| 34 | +# Key capabilities: |
| 35 | +# |
| 36 | +# * **Runtime logging** – Records dispatched operations and TorchInductor compiled |
| 37 | +# Triton kernels. |
| 38 | +# * **Tensor hashing** – Attaches deterministic hashes to inputs/outputs to enable |
| 39 | +# diffing runs to locate numerical divergences. |
| 40 | +# * **Dispatch hooks** – Allows registration of custom hooks to annotate calls |
| 41 | +# |
| 42 | +# .. note:: |
| 43 | +# |
| 44 | +# This recipe describes a prototype feature. Prototype features are typically |
| 45 | +# at an early stage for feedback and testing and are subject to change. |
| 46 | +# |
| 47 | + |
| 48 | +###################################################################### |
| 49 | +# Quick start |
| 50 | +# ----------- |
| 51 | +# |
| 52 | +# The snippet below captures a small eager workload and prints the debug string: |
| 53 | + |
| 54 | +from torch._inductor.decomposition import decomps_to_exclude |
| 55 | +import torch |
| 56 | +from torch.utils._debug_mode import DebugMode |
| 57 | + |
| 58 | +def run_once(): |
| 59 | + x = torch.randn(8, 8) |
| 60 | + y = torch.randn(8, 8) |
| 61 | + return torch.mm(torch.relu(x), y) |
| 62 | + |
| 63 | +with DebugMode() as debug_mode: |
| 64 | + out = run_once() |
| 65 | + |
| 66 | +print("DebugMode output:") |
| 67 | +print(debug_mode.debug_string()) |
| 68 | + |
| 69 | + |
| 70 | +###################################################################### |
| 71 | +# Getting more metadata |
| 72 | +# ----------- |
| 73 | +# |
| 74 | +# For most investigations, you'll want to enable stack traces, tensor IDs, and tensor hashing. |
| 75 | +# These features provide metadata to correlate operations back to model code. |
| 76 | +# |
| 77 | +# ``DebugMode.log_tensor_hashes`` decorates the log with hashes for every call. |
| 78 | +# The ``hash_tensor`` hash function uses ``torch.hash_tensor``, which returns 0 for tensors whose |
| 79 | +# elements are all the same. The ``norm`` hash function uses ``norm`` with ``p=1``. |
| 80 | +# With both these functions, especially ``norm``, tensor closeness in numerics is related to hash closeness, |
| 81 | +# so it's rather interpretable. The default ``hash_fn`` is ``norm``. |
| 82 | + |
| 83 | +with ( |
| 84 | + DebugMode( |
| 85 | + # record_stack_trace is only supported for eager in pytorch 2.10 |
| 86 | + record_stack_trace=True, |
| 87 | + record_ids=True, |
| 88 | + ) as debug_mode, |
| 89 | + DebugMode.log_tensor_hashes( |
| 90 | + hash_fn=["norm"], # this is the default |
| 91 | + hash_inputs=True, |
| 92 | + ), |
| 93 | +): |
| 94 | + result = run_once() |
| 95 | + |
| 96 | +print("DebugMode output with more metadata:") |
| 97 | +print( |
| 98 | + debug_mode.debug_string(show_stack_trace=True) |
| 99 | +) |
| 100 | + |
| 101 | +###################################################################### |
| 102 | +# Each line follows ``op(args) -> outputs``. When ``record_ids`` is enabled, |
| 103 | +# tensors are suffixed with ``$<id>`` and DTensors are labeled ``dt``. |
| 104 | + |
| 105 | + |
| 106 | +###################################################################### |
| 107 | +# Log Triton kernels |
| 108 | +# ------------------ |
| 109 | +# |
| 110 | +# Though Triton kernels are not dispatched, DebugMode has custom logic that logs their inputs and outputs. |
| 111 | +# |
| 112 | +# Inductor-generated Triton kernels show up with a ``[triton]`` prefix. |
| 113 | +# Pre/post hash annotations report buffer hashes around each kernel call, which |
| 114 | +# is helpful when isolating incorrect kernels. |
| 115 | +def f(x): |
| 116 | + return torch.mm(torch.relu(x), x.T) |
| 117 | + |
| 118 | +x = torch.randn(3, 3, device="cuda") |
| 119 | + |
| 120 | +with ( |
| 121 | + DebugMode(record_output=True) as debug_mode, |
| 122 | + DebugMode.log_tensor_hashes( |
| 123 | + hash_inputs=True, |
| 124 | + ) |
| 125 | +): |
| 126 | + a = torch.compile(f)(x) |
| 127 | + |
| 128 | +print("Triton in DebugMode logs:") |
| 129 | +print(debug_mode.debug_string()) |
| 130 | + |
| 131 | +###################################################################### |
| 132 | +# Numerical debugging with tensor hashes |
| 133 | +# -------------------------------------- |
| 134 | +# |
| 135 | +# If you have numerical divergence between modes, you can use DebugMode to find where the |
| 136 | +# numerical divergence originates. |
| 137 | +# In the example below, you can see that all tensor hashes are the same for eager mode and compiled mode. |
| 138 | +# If any hash is different, then that's where the numerical divergence is coming from. |
| 139 | + |
| 140 | +def run_model(model, data, *, compile_with=None): |
| 141 | + if compile_with is not None: |
| 142 | + model = torch.compile(model, backend=compile_with) |
| 143 | + with DebugMode(record_output=True) as dm, DebugMode.log_tensor_hashes( |
| 144 | + hash_inputs=True, |
| 145 | + ): |
| 146 | + dm_out = model(*data) |
| 147 | + return dm, dm_out |
| 148 | + |
| 149 | +class Toy(torch.nn.Module): |
| 150 | + def forward(self, x): |
| 151 | + return torch.relu(x).mm(x.T) |
| 152 | + |
| 153 | +inputs = (torch.randn(4, 4),) |
| 154 | +dm_eager, _ = run_model(Toy(), inputs) |
| 155 | +dm_compiled, _ = run_model(Toy(), inputs, compile_with="aot_eager") |
| 156 | + |
| 157 | +print("Eager mode:") |
| 158 | +print(dm_eager.debug_string()) |
| 159 | +print("Compiled aot_eager mode:") |
| 160 | +print(dm_compiled.debug_string()) |
| 161 | + |
| 162 | +############################################################################################### |
| 163 | +# Now let's look at an example where the tensor hashes are different. |
| 164 | +# I intentionally wrote a wrong decomposition that decomposes cosine to sin. |
| 165 | +# This will cause numerical divergence. |
| 166 | + |
| 167 | + |
| 168 | +from torch._dynamo.backends.common import aot_autograd |
| 169 | +from torch._dynamo.backends.debugging import get_nop_func |
| 170 | + |
| 171 | +def wrong_decomp(x): |
| 172 | + return torch.sin(x) |
| 173 | + |
| 174 | +decomp_table = {} |
| 175 | +decomp_table[torch.ops.aten.cos.default] = wrong_decomp |
| 176 | + |
| 177 | +backend = aot_autograd( |
| 178 | + fw_compiler=get_nop_func(), |
| 179 | + bw_compiler=get_nop_func(), |
| 180 | + decompositions=decomp_table |
| 181 | +) |
| 182 | + |
| 183 | +def f(x): |
| 184 | + y = x.relu() |
| 185 | + z = torch.cos(x) |
| 186 | + return y + z |
| 187 | + |
| 188 | +x = torch.randn(3, 3) |
| 189 | +with DebugMode(record_output=True) as dm_eager, DebugMode.log_tensor_hashes( |
| 190 | + hash_inputs=True, |
| 191 | +): |
| 192 | + f(x) |
| 193 | + |
| 194 | +with DebugMode(record_output=True) as dm_compiled, DebugMode.log_tensor_hashes( |
| 195 | + hash_inputs=True, |
| 196 | +): |
| 197 | + torch.compile(f, backend=backend)(x) |
| 198 | + |
| 199 | +print("Eager:") |
| 200 | +print(dm_eager.debug_string(show_stack_trace=True)) |
| 201 | +print() |
| 202 | +print("Compiled with wrong decomposition:") |
| 203 | +print(dm_compiled.debug_string()) |
| 204 | + |
| 205 | +############################################################################################### |
| 206 | +# In the eager log, we have ``aten::cos``, but in the compiled log, we have ``aten::sin``. |
| 207 | +# Moreover, the output hash is different between eager and compiled mode. |
| 208 | +# Diffing the two logs would show that the first numerical divergence shows up in the ``aten::cos`` call. |
| 209 | + |
| 210 | + |
| 211 | + |
| 212 | + |
| 213 | +###################################################################### |
| 214 | +# Custom dispatch hooks |
| 215 | +# --------------------- |
| 216 | +# |
| 217 | +# Hooks allow you to annotate each call with custom metadata such as GPU memory usage. ``log_hook`` returns a mapping |
| 218 | +# that is rendered inline with the debug string. |
| 219 | + |
| 220 | +MB = 1024 * 1024.0 |
| 221 | + |
| 222 | +def memory_hook(func, types, args, kwargs, result): |
| 223 | + mem = torch.cuda.memory_allocated() / MB if torch.cuda.is_available() else 0.0 |
| 224 | + peak = torch.cuda.max_memory_allocated() / MB if torch.cuda.is_available() else 0.0 |
| 225 | + torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None |
| 226 | + return {"mem": f"{mem:.3f} MB", "peak": f"{peak:.3f} MB"} |
| 227 | + |
| 228 | +with ( |
| 229 | + DebugMode() as dm, |
| 230 | + DebugMode.dispatch_hooks(log_hook=memory_hook), |
| 231 | +): |
| 232 | + run_once() |
| 233 | + |
| 234 | +print("DebugMode output with memory usage:") |
| 235 | +print(dm.debug_string()) |
| 236 | + |
| 237 | +###################################################################### |
| 238 | +# Module boundaries |
| 239 | +# ---------------------------------- |
| 240 | +# |
| 241 | +# ``record_nn_module=True`` inserts ``[nn.Mod]`` markers that show which |
| 242 | +# module executed each set of operations. As of PyTorch 2.10 it only works in eager mode, |
| 243 | +# but support for compiled modes is under development. |
| 244 | + |
| 245 | +class Foo(torch.nn.Module): |
| 246 | + def __init__(self): |
| 247 | + super().__init__() |
| 248 | + self.l1 = torch.nn.Linear(4, 4) |
| 249 | + self.l2 = torch.nn.Linear(4, 4) |
| 250 | + |
| 251 | + def forward(self, x): |
| 252 | + return self.l2(self.l1(x)) |
| 253 | + |
| 254 | +class Bar(torch.nn.Module): |
| 255 | + def __init__(self): |
| 256 | + super().__init__() |
| 257 | + self.abc = Foo() |
| 258 | + self.xyz = torch.nn.Linear(4, 4) |
| 259 | + |
| 260 | + def forward(self, x): |
| 261 | + return self.xyz(self.abc(x)) |
| 262 | + |
| 263 | +mod = Bar() |
| 264 | +inp = torch.randn(4, 4) |
| 265 | +with DebugMode(record_nn_module=True, record_output=False) as debug_mode: |
| 266 | + _ = mod(inp) |
| 267 | + |
| 268 | +print("DebugMode output with stack traces and module boundaries:") |
| 269 | +print(debug_mode.debug_string(show_stack_trace=True)) |
| 270 | + |
| 271 | +###################################################################### |
| 272 | +# Conclusion |
| 273 | +# ---------- |
| 274 | +# |
| 275 | +# In this tutorial, we saw how DebugMode gives you a lightweight, runtime-only |
| 276 | +# view of what PyTorch actually executed, whether you are running eager code or |
| 277 | +# compiled graphs. By layering tensor hashing, Triton logging, and custom |
| 278 | +# dispatch hooks you can quickly track down numerical differences. This is |
| 279 | +# especially helpful in debugging bit-wise equivalence between runs. |
0 commit comments