Skip to content

Commit c4b6a75

Browse files
yushangdisvekarsAlannaBurke
authored
Tutorial for DebugMode (#3697)
A tutorial for DebugMode cc @pianpwk Co-authored-by: Svetlana Karslioglu <svekars@meta.com> Co-authored-by: Alanna Burke <burkealanna@meta.com>
1 parent e156b07 commit c4b6a75

2 files changed

Lines changed: 287 additions & 0 deletions

File tree

recipes_index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,13 @@ from our full-length tutorials.
152152
:link: recipes/recipes/tensorboard_with_pytorch.html
153153
:tags: Visualization,TensorBoard
154154

155+
.. customcarditem::
156+
:header: DebugMode: Recording Dispatched Operations and Numerical Debugging
157+
:card_description: Inspect dispatched ops, tensor hashes, and module boundaries to debug eager and ``torch.compile`` runs.
158+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
159+
:link: recipes/debug_mode_tutorial.html
160+
:tags: Interpretability,Compiler
161+
155162
.. Automatic Mixed Precision
156163
157164
.. customcarditem::
@@ -354,3 +361,4 @@ from our full-length tutorials.
354361
recipes/distributed_device_mesh
355362
recipes/distributed_checkpoint_recipe
356363
recipes/distributed_async_checkpoint_recipe
364+
recipes/debug_mode_tutorial
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
DebugMode: Recording Dispatched Operations and Numerical Debugging
5+
=================================================================
6+
7+
**Authors:** Pian Pawakapan, Shangdi Yu
8+
9+
.. grid:: 2
10+
11+
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
12+
:class-card: card-prerequisites
13+
14+
* How to capture dispatched ops for eager and ``torch.compile`` runs
15+
* How to use tensor hashes and stack traces in DebugMode to pinpoint numerical divergence
16+
17+
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
18+
:class-card: card-prerequisites
19+
20+
* PyTorch 2.10 or later
21+
22+
"""
23+
24+
######################################################################
25+
# Overview
26+
# --------
27+
#
28+
# ``DebugMode`` (:class:`torch.utils._debug_mode.DebugMode`) is a
29+
# ``TorchDispatchMode`` that intercepts PyTorch runtime calls and emits a
30+
# hierarchical log of operations. It is particularly useful when you need to
31+
# understand *what* actually runs, both in eager mode and under ``torch.compile``
32+
# or when you need to pinpoint numerical divergence between two runs.
33+
#
34+
# Key capabilities:
35+
#
36+
# * **Runtime logging** – Records dispatched operations and TorchInductor compiled
37+
# Triton kernels.
38+
# * **Tensor hashing** – Attaches deterministic hashes to inputs/outputs to enable
39+
# diffing runs to locate numerical divergences.
40+
# * **Dispatch hooks** – Allows registration of custom hooks to annotate calls
41+
#
42+
# .. note::
43+
#
44+
# This recipe describes a prototype feature. Prototype features are typically
45+
# at an early stage for feedback and testing and are subject to change.
46+
#
47+
48+
######################################################################
49+
# Quick start
50+
# -----------
51+
#
52+
# The snippet below captures a small eager workload and prints the debug string:
53+
54+
from torch._inductor.decomposition import decomps_to_exclude
55+
import torch
56+
from torch.utils._debug_mode import DebugMode
57+
58+
def run_once():
59+
x = torch.randn(8, 8)
60+
y = torch.randn(8, 8)
61+
return torch.mm(torch.relu(x), y)
62+
63+
with DebugMode() as debug_mode:
64+
out = run_once()
65+
66+
print("DebugMode output:")
67+
print(debug_mode.debug_string())
68+
69+
70+
######################################################################
71+
# Getting more metadata
72+
# -----------
73+
#
74+
# For most investigations, you'll want to enable stack traces, tensor IDs, and tensor hashing.
75+
# These features provide metadata to correlate operations back to model code.
76+
#
77+
# ``DebugMode.log_tensor_hashes`` decorates the log with hashes for every call.
78+
# The ``hash_tensor`` hash function uses ``torch.hash_tensor``, which returns 0 for tensors whose
79+
# elements are all the same. The ``norm`` hash function uses ``norm`` with ``p=1``.
80+
# With both these functions, especially ``norm``, tensor closeness in numerics is related to hash closeness,
81+
# so it's rather interpretable. The default ``hash_fn`` is ``norm``.
82+
83+
with (
84+
DebugMode(
85+
# record_stack_trace is only supported for eager in pytorch 2.10
86+
record_stack_trace=True,
87+
record_ids=True,
88+
) as debug_mode,
89+
DebugMode.log_tensor_hashes(
90+
hash_fn=["norm"], # this is the default
91+
hash_inputs=True,
92+
),
93+
):
94+
result = run_once()
95+
96+
print("DebugMode output with more metadata:")
97+
print(
98+
debug_mode.debug_string(show_stack_trace=True)
99+
)
100+
101+
######################################################################
102+
# Each line follows ``op(args) -> outputs``. When ``record_ids`` is enabled,
103+
# tensors are suffixed with ``$<id>`` and DTensors are labeled ``dt``.
104+
105+
106+
######################################################################
107+
# Log Triton kernels
108+
# ------------------
109+
#
110+
# Though Triton kernels are not dispatched, DebugMode has custom logic that logs their inputs and outputs.
111+
#
112+
# Inductor-generated Triton kernels show up with a ``[triton]`` prefix.
113+
# Pre/post hash annotations report buffer hashes around each kernel call, which
114+
# is helpful when isolating incorrect kernels.
115+
def f(x):
116+
return torch.mm(torch.relu(x), x.T)
117+
118+
x = torch.randn(3, 3, device="cuda")
119+
120+
with (
121+
DebugMode(record_output=True) as debug_mode,
122+
DebugMode.log_tensor_hashes(
123+
hash_inputs=True,
124+
)
125+
):
126+
a = torch.compile(f)(x)
127+
128+
print("Triton in DebugMode logs:")
129+
print(debug_mode.debug_string())
130+
131+
######################################################################
132+
# Numerical debugging with tensor hashes
133+
# --------------------------------------
134+
#
135+
# If you have numerical divergence between modes, you can use DebugMode to find where the
136+
# numerical divergence originates.
137+
# In the example below, you can see that all tensor hashes are the same for eager mode and compiled mode.
138+
# If any hash is different, then that's where the numerical divergence is coming from.
139+
140+
def run_model(model, data, *, compile_with=None):
141+
if compile_with is not None:
142+
model = torch.compile(model, backend=compile_with)
143+
with DebugMode(record_output=True) as dm, DebugMode.log_tensor_hashes(
144+
hash_inputs=True,
145+
):
146+
dm_out = model(*data)
147+
return dm, dm_out
148+
149+
class Toy(torch.nn.Module):
150+
def forward(self, x):
151+
return torch.relu(x).mm(x.T)
152+
153+
inputs = (torch.randn(4, 4),)
154+
dm_eager, _ = run_model(Toy(), inputs)
155+
dm_compiled, _ = run_model(Toy(), inputs, compile_with="aot_eager")
156+
157+
print("Eager mode:")
158+
print(dm_eager.debug_string())
159+
print("Compiled aot_eager mode:")
160+
print(dm_compiled.debug_string())
161+
162+
###############################################################################################
163+
# Now let's look at an example where the tensor hashes are different.
164+
# I intentionally wrote a wrong decomposition that decomposes cosine to sin.
165+
# This will cause numerical divergence.
166+
167+
168+
from torch._dynamo.backends.common import aot_autograd
169+
from torch._dynamo.backends.debugging import get_nop_func
170+
171+
def wrong_decomp(x):
172+
return torch.sin(x)
173+
174+
decomp_table = {}
175+
decomp_table[torch.ops.aten.cos.default] = wrong_decomp
176+
177+
backend = aot_autograd(
178+
fw_compiler=get_nop_func(),
179+
bw_compiler=get_nop_func(),
180+
decompositions=decomp_table
181+
)
182+
183+
def f(x):
184+
y = x.relu()
185+
z = torch.cos(x)
186+
return y + z
187+
188+
x = torch.randn(3, 3)
189+
with DebugMode(record_output=True) as dm_eager, DebugMode.log_tensor_hashes(
190+
hash_inputs=True,
191+
):
192+
f(x)
193+
194+
with DebugMode(record_output=True) as dm_compiled, DebugMode.log_tensor_hashes(
195+
hash_inputs=True,
196+
):
197+
torch.compile(f, backend=backend)(x)
198+
199+
print("Eager:")
200+
print(dm_eager.debug_string(show_stack_trace=True))
201+
print()
202+
print("Compiled with wrong decomposition:")
203+
print(dm_compiled.debug_string())
204+
205+
###############################################################################################
206+
# In the eager log, we have ``aten::cos``, but in the compiled log, we have ``aten::sin``.
207+
# Moreover, the output hash is different between eager and compiled mode.
208+
# Diffing the two logs would show that the first numerical divergence shows up in the ``aten::cos`` call.
209+
210+
211+
212+
213+
######################################################################
214+
# Custom dispatch hooks
215+
# ---------------------
216+
#
217+
# Hooks allow you to annotate each call with custom metadata such as GPU memory usage. ``log_hook`` returns a mapping
218+
# that is rendered inline with the debug string.
219+
220+
MB = 1024 * 1024.0
221+
222+
def memory_hook(func, types, args, kwargs, result):
223+
mem = torch.cuda.memory_allocated() / MB if torch.cuda.is_available() else 0.0
224+
peak = torch.cuda.max_memory_allocated() / MB if torch.cuda.is_available() else 0.0
225+
torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
226+
return {"mem": f"{mem:.3f} MB", "peak": f"{peak:.3f} MB"}
227+
228+
with (
229+
DebugMode() as dm,
230+
DebugMode.dispatch_hooks(log_hook=memory_hook),
231+
):
232+
run_once()
233+
234+
print("DebugMode output with memory usage:")
235+
print(dm.debug_string())
236+
237+
######################################################################
238+
# Module boundaries
239+
# ----------------------------------
240+
#
241+
# ``record_nn_module=True`` inserts ``[nn.Mod]`` markers that show which
242+
# module executed each set of operations. As of PyTorch 2.10 it only works in eager mode,
243+
# but support for compiled modes is under development.
244+
245+
class Foo(torch.nn.Module):
246+
def __init__(self):
247+
super().__init__()
248+
self.l1 = torch.nn.Linear(4, 4)
249+
self.l2 = torch.nn.Linear(4, 4)
250+
251+
def forward(self, x):
252+
return self.l2(self.l1(x))
253+
254+
class Bar(torch.nn.Module):
255+
def __init__(self):
256+
super().__init__()
257+
self.abc = Foo()
258+
self.xyz = torch.nn.Linear(4, 4)
259+
260+
def forward(self, x):
261+
return self.xyz(self.abc(x))
262+
263+
mod = Bar()
264+
inp = torch.randn(4, 4)
265+
with DebugMode(record_nn_module=True, record_output=False) as debug_mode:
266+
_ = mod(inp)
267+
268+
print("DebugMode output with stack traces and module boundaries:")
269+
print(debug_mode.debug_string(show_stack_trace=True))
270+
271+
######################################################################
272+
# Conclusion
273+
# ----------
274+
#
275+
# In this tutorial, we saw how DebugMode gives you a lightweight, runtime-only
276+
# view of what PyTorch actually executed, whether you are running eager code or
277+
# compiled graphs. By layering tensor hashing, Triton logging, and custom
278+
# dispatch hooks you can quickly track down numerical differences. This is
279+
# especially helpful in debugging bit-wise equivalence between runs.

0 commit comments

Comments
 (0)