Commit 5c8b12b
fix(linear): apply compute_dtype cast only to saved tensors that are needed
Previously, saved_tensors_ was set twice: first with cast tensors for
both input and weight, then immediately overwritten with the
needs_input_grad-conditional version without casting. This meant saved
tensors were never cast to compute_dtype, causing dtype mismatches in
backward.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>1 parent 8f73c80 commit 5c8b12b
2 files changed
Lines changed: 8 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | 34 | | |
39 | 35 | | |
40 | 36 | | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
41 | 41 | | |
42 | | - | |
| 42 | + | |
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
| |||
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
| 64 | + | |
64 | 65 | | |
65 | 66 | | |
66 | 67 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
335 | 335 | | |
336 | 336 | | |
337 | 337 | | |
338 | | - | |
339 | | - | |
| 338 | + | |
| 339 | + | |
340 | 340 | | |
341 | 341 | | |
342 | 342 | | |
| |||
0 commit comments