@@ -38,7 +38,8 @@ __global__ void BinaryForwardKernel(T *output, Func fn, size_t num_elements_a, s
3838// launch the given kernel function with the given output and inputs
3939template <size_t BLOCK_SIZE, typename T, typename Kernel, typename ... Inputs>
4040void LaunchKernel (Kernel &&kernel, const std::shared_ptr<Tensor> &output, const Inputs &...inputs) {
41- auto extract_ptrs = [](const auto &...ts ) { return std::make_tuple (static_cast <T *>(ts->DataPtr ())...); };
41+ auto extract_ptrs
42+ = [](const auto &...ts ) { return std::make_tuple (static_cast <T *>(ts ? ts->DataPtr () : nullptr )...); };
4243 auto input_ptrs = extract_ptrs (inputs...);
4344
4445 cudaDeviceProp prop;
@@ -135,7 +136,6 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr<Tensor> &out
135136 T *output_a_ptr = static_cast <T *>(output_a->DataPtr ());
136137 T *output_b_ptr = static_cast <T *>(output_b->DataPtr ());
137138 const T *grad_output_ptr = static_cast <const T *>(grad_output->DataPtr ());
138-
139139 LaunchKernel<BLOCK_SIZE, T>(
140140 [=](dim3 grid, dim3 block, size_t offset, auto ... ptrs) {
141141 BinaryBackwardKernel<<<grid, block>>> (output_a_ptr, output_b_ptr, fun_a, fun_b, a_num_elements,
@@ -201,7 +201,6 @@ std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
201201BinaryBackward (const std::shared_ptr<Tensor> &grad_output, const std::shared_ptr<Tensor> &a,
202202 const std::shared_ptr<Tensor> &b, const std::vector<int64_t > &a_dims, const std::vector<int64_t > &b_dims,
203203 FuncA fn_a, FuncB fn_b) {
204-
205204 const auto a_num_elements = std::accumulate (a_dims.begin (), a_dims.end (), 1 , std::multiplies<int64_t >());
206205 const auto b_num_elements = std::accumulate (b_dims.begin (), b_dims.end (), 1 , std::multiplies<int64_t >());
207206
@@ -212,14 +211,15 @@ BinaryBackward(const std::shared_ptr<Tensor> &grad_output, const std::shared_ptr
212211 if (b) {
213212 CHECK (b_num_elements == b->NumElements ());
214213 }
215-
216214 auto dtype = grad_output->Dtype ();
217- auto device = a->GetDevice ();
218- // Currently a and b should have the same data type
219- CHECK (dtype == b->Dtype ());
220- auto grad_a = std::make_shared<Tensor>(a->Dims (), dtype, device);
221- auto grad_b = std::make_shared<Tensor>(b->Dims (), dtype, device);
215+ auto device = grad_output->GetDevice ();
222216
217+ // Currently a and b should have the same data type
218+ if (a && b) {
219+ CHECK (a->Dtype () == b->Dtype ());
220+ }
221+ auto grad_a = std::make_shared<Tensor>(a_dims, dtype, device);
222+ auto grad_b = std::make_shared<Tensor>(b_dims, dtype, device);
223223 switch (dtype) {
224224 case DataType::kFLOAT32 :
225225 LaunchBackward<256 , float >(fn_a, fn_b, grad_a, grad_b, a_num_elements, b_num_elements, grad_output, a, b);
0 commit comments