Skip to content

Commit 5c8b12b

Browse files
chen2021673claude
andcommitted
fix(linear): apply compute_dtype cast only to saved tensors that are needed
Previously, saved_tensors_ was set twice: first with cast tensors for both input and weight, then immediately overwritten with the needs_input_grad-conditional version without casting. This meant saved tensors were never cast to compute_dtype, causing dtype mismatches in backward. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8f73c80 commit 5c8b12b

2 files changed

Lines changed: 8 additions & 7 deletions

File tree

infini_train/src/autograd/linear.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ void Linear::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tens
3131
// FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be
3232
// determined by autocast, not derived from output_tensors[0]->Dtype().
3333
auto compute_dtype = output_tensors[0]->Dtype();
34-
saved_tensors_ = {
35-
input->Dtype() == compute_dtype ? input : std::make_shared<Tensor>(input->To(compute_dtype)),
36-
weight->Dtype() == compute_dtype ? weight : std::make_shared<Tensor>(weight->To(compute_dtype)),
37-
};
3834
bool need_input = needs_input_grad_.size() > 0 && needs_input_grad_[0];
3935
bool need_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1];
4036

37+
auto cast = [&](const std::shared_ptr<Tensor> &t) {
38+
return t->Dtype() == compute_dtype ? t : std::make_shared<Tensor>(t->To(compute_dtype));
39+
};
40+
4141
// grad_input needs weight, grad_weight needs input
42-
saved_tensors_ = {need_weight ? input : nullptr, need_input ? weight : nullptr};
42+
saved_tensors_ = {need_weight ? cast(input) : nullptr, need_input ? cast(weight) : nullptr};
4343

4444
transpose_ = true;
4545
bias_ = input_tensors.size() == 3;
@@ -61,6 +61,7 @@ std::vector<std::shared_ptr<Tensor>> Linear::Backward(const std::vector<std::sha
6161
.bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]};
6262

6363
auto device = grad_output->GetDevice().type();
64+
// TODO: skip autograd graph construction entirely when no input requires grad
6465
auto [grad_input, grad_weight, grad_bias]
6566
= Dispatcher::Instance()
6667
.Call<std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(

infini_train/src/kernels/cuda/linear.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
335335
auto dtype = grad_output->Dtype();
336336

337337
// For type promotion, use available tensors
338-
DataType input_dtype = input ? input->Dtype() : dtype;
339-
DataType weight_dtype = weight ? weight->Dtype() : dtype;
338+
DataType input_dtype = input ? input->Dtype() : (weight ? weight->Dtype() : dtype);
339+
DataType weight_dtype = weight ? weight->Dtype() : (input ? input->Dtype() : dtype);
340340
// Compute dtype determined by saved tensors (forward compute dtype), not grad_output
341341
DataType compute_dtype = DispatchFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
342342
{input_dtype, weight_dtype}, [=]<typename Tin, typename Tw>() { return DataTypeMap_v<WidestType_t<Tin, Tw>>; },

0 commit comments

Comments
 (0)