@@ -187,7 +187,10 @@ __global__ void reluKernel(const ValueType* input, ValueType* output, std::size_
187187
188188__global__ void reluDerivativeKernel (const ValueType* input, ValueType* output, std::size_t count) {
189189 std::size_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
190- if (idx < count) output[idx] = input[idx] > 0 .0f ? 1 .0f : 0 .0f ;
190+ if (idx < count) {
191+ ValueType derivative = (input[idx] > 0 .0f ) ? 1 .0f : 0 .0f ;
192+ output[idx] *= derivative; // FIX: Changed = to *=
193+ }
191194}
192195
193196void relu (const ValueType* input, ValueType* output, std::size_t count) {
@@ -219,7 +222,8 @@ __global__ void sigmoidDerivativeKernel(const ValueType* input, ValueType* outpu
219222 if (idx < count) {
220223 ValueType x = input[idx];
221224 ValueType s = 1 .0f / (1 .0f + expf (-x));
222- output[idx] = s * (1 .0f - s);
225+ ValueType derivative = s * (1 .0f - s);
226+ output[idx] *= derivative;
223227 }
224228}
225229
@@ -248,7 +252,8 @@ __global__ void tanhDerivativeKernel(const ValueType* input, ValueType* output,
248252 std::size_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
249253 if (idx < count) {
250254 ValueType t = tanhf (input[idx]);
251- output[idx] = 1 .0f - t * t;
255+ ValueType derivative = 1 .0f - t * t;
256+ output[idx] *= derivative;
252257 }
253258}
254259
@@ -275,7 +280,10 @@ __global__ void leakyReluKernel(const ValueType* input, ValueType* output, std::
275280
276281__global__ void leakyReluDerivativeKernel (const ValueType* input, ValueType* output, std::size_t count, ValueType alpha) {
277282 std::size_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
278- if (idx < count) output[idx] = (input[idx] > 0 .0f ) ? 1 .0f : alpha;
283+ if (idx < count) {
284+ ValueType derivative = (input[idx] > 0 .0f ) ? 1 .0f : alpha;
285+ output[idx] *= derivative; // FIX: Changed = to *=
286+ }
279287}
280288
281289void leaky_relu (const ValueType* input, ValueType* output, std::size_t count, ValueType alpha) {
@@ -384,7 +392,7 @@ __global__ void outerKernel(const ValueType* a, const ValueType* b, ValueType* r
384392 if (idx < total) {
385393 size_t i = idx / n;
386394 size_t j = idx % n;
387- result[i * n + j] = a[i] * b[j];
395+ result[i * n + j] + = a[i] * b[j];
388396 }
389397}
390398
0 commit comments