-
Notifications
You must be signed in to change notification settings - Fork 444
Expand file tree
/
Copy pathkernels.cu
More file actions
1857 lines (1605 loc) · 61.8 KB
/
kernels.cu
File metadata and controls
1857 lines (1605 loc) · 61.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#ifndef PUFFERLIB_KERNELS_CU
#define PUFFERLIB_KERNELS_CU
/* Kernels must launch on the current torch stream to be traced by cudagraphs.
* Launch functions take cudaStream_t as parameter - callers (modules.cu) should
* pass at::cuda::getCurrentCUDAStream() when using with torch.
*/
#include <cuda_runtime.h>
#include "ops.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <curand_kernel.h>
#include <c10/util/BFloat16.h>
#include <cstdio>
#include <cstdint>
#define SEQ_SIZE 256
#define BLOCK_SIZE 256
inline int grid_size(int N) {
return (N + BLOCK_SIZE - 1) / BLOCK_SIZE;
}
inline int seq_size(int N) {
return (N + SEQ_SIZE - 1) / SEQ_SIZE;
}
// If you can get this to work, go ahead. I tried.
// NVCC won't parse templated types in kernel launches
/*
template <template <class> class KernelFn, typename... Args>
void dispatch_and_launch(const at::Tensor& example_tensor, Args... args) {
const int64_t N = example_tensor.numel();
const int64_t block = LAUNCH_BLOCK_SIZE;
const int64_t grid = (N + block - 1) / block;
auto stream = at::cuda::getCurrentCUDAStream();
at::cuda::CUDAGuard device_guard(example_tensor.device());
at::ScalarType dtype = example_tensor.scalar_type();
if (dtype == at::ScalarType::Float) {
KernelFn<float><<<grid, block, 0, stream>>>(args..., N);
} else if (dtype == at::ScalarType::Half) {
KernelFn<__half><<<grid, block, 0, stream>>>(args..., N);
} else if (dtype == at::ScalarType::BFloat16) {
KernelFn<__nv_bfloat16><<<grid, block, 0, stream>>>(args..., N);
} else {
AT_ERROR("Unsupported dtype: ", dtype);
}
}
*/
template<typename T>
__global__ void rmsnorm_forward_kernel(
T* __restrict__ out,
float* __restrict__ inv_norm_buf,
const T* __restrict__ x,
const T* __restrict__ weight,
double eps,
int T_total,
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * T_total) return;
int b = idx / T_total;
int t = idx % T_total;
int base = b*T_total*H + t*H;
float sum_sq = 0.0f;
for (int h = 0; h < H; h++) {
int curr = base + h;
float x_val = float(x[curr]);
sum_sq += x_val * x_val;
}
float rms = sqrtf(sum_sq/H + eps);
float inv_rms = 1.0f / rms;
inv_norm_buf[idx] = inv_rms;
for (int h = 0; h < H; h++) {
int curr = base + h;
out[curr] = T(weight[h] * x[curr] * inv_rms);
}
}
template<typename T>
__global__ void rmsnorm_backward_kernel(
T* __restrict__ grad_x,
T* __restrict__ grad_weight,
const T* __restrict__ grad_out,
const float* __restrict__ inv_norm_buf,
const T* __restrict__ x_buf,
const T* __restrict__ weight,
double eps,
int T_total,
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= T_total*H*B) return;
int base = idx % H;
int norm_idx = idx / H;
float inv_rms = inv_norm_buf[norm_idx];
float inv_rms_3 = inv_rms * inv_rms * inv_rms;
grad_x[idx] = weight[base] * grad_out[idx] * inv_rms;
grad_weight[idx] = grad_out[idx] * inv_rms;
float wg_x = 0.0f;
for (int h=0; h<H; h++) {
float x = x_buf[base + h];
float w = weight[h];
float g = grad_out[base + h];
wg_x += w*g*x;
}
float x = x_buf[idx];
grad_x[idx] -= x*wg_x*inv_rms_3/float(H);
}
/*
template<typename T>
__global__ void rmsnorm_backward_kernel(
T* grad_x,
T* grad_weight,
const T* grad_out,
const float* inv_norm_buf,
const T* x,
const T* weight,
double eps,
int T_total,
int H,
int B
) {
int total_elements = B * T_total * H;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total_elements) return;
int h = idx % H;
int vec_idx = idx / H; // index of the vector (b,t)
int offset = vec_idx * H;
float inv_rms = inv_norm_buf[vec_idx];
float inv_rms3 = inv_rms * inv_rms * inv_rms;
// ∂L/∂γ_h += grad_out * (x / rms)
float gw = grad_out[idx] * (float)x[idx] * inv_rms;
atomicAdd((float*)&grad_weight[h], gw);
// Compute reduction: sum_h weight[h] * grad_out[h] * x[h]
float sum = 0.0f;
for (int i = 0; i < H; ++i) {
sum += (float)weight[i] * (float)grad_out[offset + i] * (float)x[offset + i];
}
float reduction = sum * inv_rms; // = σ γ g hat_x
float dx = (float)weight[h] * (float)grad_out[idx] * inv_rms
- (float)x[idx] * reduction * inv_rms3 / H;
grad_x[idx] = T(dx);
}
*/
template<typename T>
void launch_rmsnorm_forward(
T* __restrict__ out,
float* __restrict__ inv_norm_buf,
const T* __restrict__ x,
const T* __restrict__ weight,
double eps,
int T_total,
int H,
int B,
cudaStream_t stream
) {
int total = B * T_total;
int grid = grid_size(total);
rmsnorm_forward_kernel<T><<<grid, BLOCK_SIZE, 0, stream>>>(
out,
inv_norm_buf,
x,
weight,
eps,
T_total,
H,
B
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error in forward: %s\n", cudaGetErrorString(err));
}
}
template<typename T>
void launch_rmsnorm_backward(
T* __restrict__ grad_x,
T* __restrict__ grad_weight,
const T* __restrict__ grad_out,
const float* __restrict__ inv_norm_buf,
const T* __restrict__ x_buf,
const T* __restrict__ weight,
double eps,
int T_total,
int H,
int B,
cudaStream_t stream
) {
// The backward is fully parallel
// since the inv norm is cached
int total = B * T_total * H;
int grid = grid_size(total);
rmsnorm_backward_kernel<T><<<grid, BLOCK_SIZE, 0, stream>>>(
grad_x,
grad_weight,
grad_out,
inv_norm_buf,
x_buf,
weight,
eps,
T_total,
H,
B
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error in backward: %s\n", cudaGetErrorString(err));
}
}
// Fused kernel: chunk + mingru_gate + sigmoid(proj) * out
// combined is (B, 1, 3*H) containing [hidden, gate, proj] concatenated on last dim
// state is (B, 1, H)
// out is (B, 1, H) = sigmoid(proj) * mingru_out (final output)
// next_state is (B, 1, H) = mingru_out (recurrent state, without proj)
template<typename T>
__global__ void mingru_gate_inference_kernel(
T* out,
T* next_state,
const T* combined, // (B, 1, 3*H) = [hidden, gate, proj]
const T* state_in, // (B, 1, H)
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int N = B * H;
if (idx >= N) return;
int b = idx / H;
int h = idx % H;
// Read from combined: layout is [hidden(H), gate(H), proj(H)] for each batch
int combined_base = b * 3 * H;
float hidden = float(combined[combined_base + h]);
float gate = float(combined[combined_base + H + h]);
float proj = float(combined[combined_base + 2 * H + h]);
float state = float(state_in[idx]);
// mingru_gate computation
float gate_sigmoid = sigmoid(gate);
float hidden_tilde = tilde_relu_fwd(hidden);
float mingru_out = lerp(state, hidden_tilde, gate_sigmoid);
// next_state is mingru_out (for recurrence)
next_state[idx] = T(mingru_out);
// out is sigmoid(proj) * mingru_out (final output)
float proj_sigmoid = sigmoid(proj);
out[idx] = T(proj_sigmoid * mingru_out);
}
template<typename T>
void launch_mingru_gate_inference(
T* out,
T* next_state,
const T* combined,
const T* state_in,
int H,
int B,
cudaStream_t stream
) {
int N = B * H;
int grid = grid_size(N);
mingru_gate_inference_kernel<T><<<grid, BLOCK_SIZE, 0, stream>>>(
out,
next_state,
combined,
state_in,
H,
B
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error: %s\n", cudaGetErrorString(err));
}
}
template<typename T>
__global__ void log_coeffs_and_values_kernel(
T* log_coeffs,
T* log_values,
const T* gate,
const T* hidden,
int N
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) return;
float g = float(gate[idx]);
float h = float(hidden[idx]);
log_coeffs[idx] = -softplus_fwd(g);
float log_z = -softplus_fwd(-g);
float log_tilde_h;
if (h >= 0.0f) {
float relu_h = relu(h);
log_tilde_h = logf(relu_h + 0.5f);
} else {
log_tilde_h = -softplus_fwd(-h);
}
log_values[idx] = log_z + log_tilde_h;
}
template<typename T>
__global__ void log_coeffs_and_values_backward_kernel(
T* grad_gate,
T* grad_hidden,
const T* grad_log_coeffs,
const T* grad_log_values,
const T* gate,
const T* hidden,
int N
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) return;
float g = float(gate[idx]);
float h = float(hidden[idx]);
float grad_lc = float(grad_log_coeffs[idx]);
float grad_lv = float(grad_log_values[idx]);
float grad_g_from_lc = -softplus_bwd(grad_lc, g);
float grad_g_from_lz = -softplus_bwd(-grad_lv, -g);
float grad_g_total = grad_g_from_lc + grad_g_from_lz;
grad_gate[idx] = T(grad_g_total);
float log_tilde_h;
float grad_h_from_lt;
if (h >= 0.0f) {
float relu_h = relu(h);
log_tilde_h = logf(relu_h + 0.5f);
float inner_grad = 1.0f / (relu_h + 0.5f);
grad_h_from_lt = relu_backward(h, inner_grad * grad_lv);
} else {
log_tilde_h = -softplus_fwd(-h);
grad_h_from_lt = -softplus_bwd(-grad_lv, -h);
}
grad_hidden[idx] = T(grad_h_from_lt);
}
template<typename T>
void launch_log_coeffs_and_values(
T* log_coeffs,
T* log_values,
const T* gate,
const T* hidden,
int N,
cudaStream_t stream
) {
int grid = grid_size(N);
log_coeffs_and_values_kernel<T><<<grid, BLOCK_SIZE, 0, stream>>>(
log_coeffs,
log_values,
gate,
hidden,
N
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error: %s\n", cudaGetErrorString(err));
}
}
template<typename T>
void launch_log_coeffs_and_values_backward(
T* grad_gate,
T* grad_hidden,
const T* grad_log_coeffs,
const T* grad_log_values,
const T* gate,
const T* hidden,
int N,
cudaStream_t stream
) {
int grid = grid_size(N);
log_coeffs_and_values_backward_kernel<T><<<grid, BLOCK_SIZE, 0, stream>>>(
grad_gate,
grad_hidden,
grad_log_coeffs,
grad_log_values,
gate,
hidden,
N
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error: %s\n", cudaGetErrorString(err));
}
}
__device__ __forceinline__ double logcumsumexp_forward(double x, double acc) {
if (acc == -INFINITY) {
return x;
} else {
double min_val = fmin(acc, x);
double max_val = fmax(acc, x);
return max_val + log1pf(expf(min_val - max_val));
}
}
__device__ __forceinline__ double logcumsumexp_backward(double x, double* acc, double grad, double s, double* s_nxt) {
*acc = grad + *acc * exp(s - *s_nxt);
*s_nxt = s;
return *acc * exp(x - s);
}
// float32 + branch free
__device__ __forceinline__ float logcumsumexp_forward_opt(float x, float acc) {
float min_val = fminf(acc, x);
float max_val = fmaxf(acc, x);
return max_val + log1pf(__expf(min_val - max_val));
}
__device__ __forceinline__ float logcumsumexp_backward_opt(float x, float* acc, float grad, float s, float* s_nxt) {
*acc = fmaf(*acc, __expf(s - *s_nxt), grad);
*s_nxt = s;
return *acc * __expf(x - s);
}
// Fully fused forward: chunk + log_coeffs_and_values + scan + sigmoid(proj)*out
// Takes combined (B, T, 3*H) = [hidden, gate, proj] and outputs gated result
template<typename T>
__global__ void fused_scan_forward_kernel(
T* __restrict__ out, // (B, T, H) - sigmoid(proj) * scan_result
T* __restrict__ next_state, // (B, 1, H) - raw scan_result at T (for recurrence)
float* __restrict__ a_star_buf, // (B, T+1, H) - for backward
float* __restrict__ s_buf, // (B, T+1, H) - for backward
float* __restrict__ log_values_buf, // (B, T+1, H) - cached log_values for backward
const T* __restrict__ combined, // (B, T, 3*H) = [hidden(H), gate(H), proj(H)]
const T* __restrict__ state, // (B, 1, H)
int T_seq, // sequence length (T)
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * H) return;
int b = idx / H;
int h = idx % H;
int T_out = T_seq + 1;
int buf_base = b * T_out * H + h; // base for a_star/s/log_values buffers (T+1 timesteps)
int out_base = b * T_seq * H + h; // base for output (T timesteps)
int state_idx = b * H + h; // state is (B, 1, H) -> flatten to (B, H)
float a_star = 0.0f;
float s = -INFINITY; // logcumsumexp accumulator
// Handle t=0 outside the loop: use log(state), coeff = 0
float log_value_0 = logf(float(state[state_idx]));
log_values_buf[buf_base] = log_value_0;
s = log_value_0; // z = log_value - a_star = log_value - 0 = log_value
a_star_buf[buf_base] = a_star;
s_buf[buf_base] = s;
// Loop t=1..T_seq (no branches needed)
float scan_result = 0.0f;
for (int t = 1; t < T_out; t++) {
int buf_curr = buf_base + t * H;
int combined_base = b * T_seq * 3 * H + (t - 1) * 3 * H;
float hidden_val = float(combined[combined_base + h]);
float gate_val = float(combined[combined_base + H + h]);
float proj_val = float(combined[combined_base + 2 * H + h]);
float log_coeff_val, log_value_val;
log_coeffs_and_values_fwd(gate_val, hidden_val, &log_coeff_val, &log_value_val);
// Cache log_value for backward (avoid recomputation)
log_values_buf[buf_curr] = log_value_val;
// a_star[t] = sum_{i=0}^t log_coeffs[i]
a_star += log_coeff_val;
float z = log_value_val - a_star;
if (s == -INFINITY) {
s = z;
} else {
float min_val = fminf(s, z);
float max_val = fmaxf(s, z);
s = max_val + log1pf(expf(min_val - max_val));
}
scan_result = expf(a_star + s);
// sigmoid(proj) * out
int out_curr = out_base + (t - 1) * H;
float proj_sigmoid = sigmoid(proj_val);
out[out_curr] = T(proj_sigmoid * scan_result);
a_star_buf[buf_curr] = a_star;
s_buf[buf_curr] = s;
}
// Write timestep T to next_state (raw scan_result, no proj, for recurrence)
next_state[state_idx] = T(scan_result);
}
// Fully fused backward: chains through sigmoid(proj)*out and log_coeffs_and_values
// Takes combined (B, T, 3*H), outputs grad_combined (B, T, 3*H) = [grad_hidden, grad_gate, grad_proj]
template<typename T>
__global__ void fused_scan_backward_kernel(
T* __restrict__ grad_combined, // (B, T, 3*H) = [grad_hidden, grad_gate, grad_proj]
T* __restrict__ grad_state, // (B, 1, H)
const T* __restrict__ grad_out, // (B, T, H) - gradient of sigmoid(proj)*scan_result
const T* __restrict__ grad_next_state, // (B, 1, H) - gradient of raw scan_result at T
const T* __restrict__ combined, // (B, T, 3*H) = [hidden, gate, proj]
const T* __restrict__ state, // (B, 1, H)
const float* __restrict__ a_star_buf, // (B, T+1, H)
const float* __restrict__ s_buf, // (B, T+1, H)
const float* __restrict__ log_values_buf, // (B, T+1, H) - cached from forward
int T_seq, // sequence length (T)
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * H) return;
int b = idx / H;
int h = idx % H;
int T_out = T_seq + 1;
int buf_base = b * T_out * H + h; // base for a_star/s/log_values buffers (T+1 timesteps)
int out_base = b * T_seq * H + h; // base for grad_out (T timesteps)
int state_idx = b * H + h; // state is (B, 1, H) -> flatten to (B, H)
float acc = 0.0;
float s_val_next = 0.0;
float carry_grad_a = 0.0;
for (int t = T_out - 1; t >= 0; --t) {
int base_adr = b*T_seq*3*H + (t-1)*3*H;
int hidden_adr = base_adr + h;
int gate_adr = base_adr + H + h;
int proj_adr = base_adr + 2*H + h;
int buf_curr = buf_base + t * H;
float a_star = a_star_buf[buf_curr];
float s = s_buf[buf_curr];
float scan_result = expf(a_star + s); // reconstruct scan result
// Read cached log_value from forward pass (no recomputation needed)
float log_value_val = log_values_buf[buf_curr];
// Read from combined for t >= 1 (still need gate/hidden for backward, proj for output gate)
float gate_val = 0.0f, hidden_val = 0.0f, proj_val = 0.0f;
int combined_base = 0;
if (t >= 1) {
hidden_val = float(combined[hidden_adr]);
gate_val = float(combined[gate_adr]);
proj_val = float(combined[proj_adr]);
}
float z = log_value_val - a_star;
// Get gradient for this timestep
// For t >= 1: grad_out is gradient of (sigmoid(proj) * scan_result)
// For t = T: also add grad_next_state (gradient of raw scan_result)
float grad_gated_out = 0.0f;
float grad_scan_from_next = 0.0f;
if (t >= 1) {
int grad_out_idx = out_base + (t - 1) * H;
grad_gated_out = float(grad_out[grad_out_idx]);
}
if (t == T_seq) {
grad_scan_from_next = float(grad_next_state[state_idx]);
}
// Chain through sigmoid(proj) * scan_result
// out = sigmoid(proj) * scan_result
// d_out/d_scan_result = sigmoid(proj)
// d_out/d_proj = scan_result * sigmoid(proj) * (1 - sigmoid(proj))
float grad_scan_result = grad_scan_from_next;
float grad_proj = 0.0f;
if (t >= 1) {
float proj_sigmoid = sigmoid(proj_val);
grad_scan_result += grad_gated_out * proj_sigmoid;
// sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
grad_proj = grad_gated_out * scan_result * proj_sigmoid * (1.0f - proj_sigmoid);
}
// Now chain grad_scan_result through the scan backward
float grad_log_h = grad_scan_result * scan_result;
float grad_s = grad_log_h;
if (t == T_out - 1) {
acc = grad_s;
} else {
acc = grad_s + acc * expf(s - s_val_next);
}
float grad_z = acc * expf(z - s);
s_val_next = s;
float grad_a = grad_log_h + carry_grad_a - grad_z;
carry_grad_a = grad_a;
if (t == 0) {
// grad_state = grad_z * d(log(state))/d(state) = grad_z / state
grad_state[state_idx] = T(grad_z / float(state[state_idx]));
} else {
// Chain through log_coeffs_and_values backward to get grad_gate, grad_hidden
float grad_g, grad_h;
log_coeffs_and_values_bwd(grad_a, grad_z, gate_val, hidden_val, &grad_g, &grad_h);
// Write to grad_combined: [grad_hidden, grad_gate, grad_proj]
grad_combined[gate_adr] = T(grad_g);
grad_combined[hidden_adr] = T(grad_h);
grad_combined[proj_adr] = T(grad_proj);
}
}
}
/*
template<typename T>
__global__ void fused_scan_backward_kernel(
T* __restrict__ grad_log_coeffs,
T* __restrict__ grad_log_values,
const T* __restrict__ grad_out,
const T* __restrict__ out_buf,
const double* __restrict__ a_star_buf,
const double* __restrict__ s_buf,
const T* __restrict__ log_values,
int T_total,
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * H) return;
int b = idx / H;
int h = idx % H;
int base = b * T_total * H + h;
double carry_grad_a = 0.0;
double carry_grad_s = 0.0;
for (int t = T_total - 1; t >= 0; --t) {
int curr = base + t * H;
double a_star = a_star_buf[curr];
double s = s_buf[curr];
double z = double(log_values[curr]) - a_star;
double grad_log_h = double(grad_out[curr]) * double(out_buf[curr]); // out_buf[t] = exp(log_h[t])
double grad_s = grad_log_h + carry_grad_s;
double s_prev = -INFINITY;
if (t > 0) {
s_prev = s_buf[base + (t - 1) * H];
}
double max_val = fmax(s_prev, z);
double exp_prev = 0.0;
if (s_prev != -INFINITY) {
exp_prev = exp(s_prev - max_val);
}
double exp_z = 0.0;
if (z != -INFINITY) {
exp_z = exp(z - max_val);
}
double denom = exp_prev + exp_z;
double frac_prev = 0.0;
double frac_z = 0.0;
if (denom != 0.0) {
frac_prev = exp_prev / denom;
frac_z = exp_z / denom;
}
// grad_z = (grad_log_h + carry_grad_s) * exp(z - max_val) / (exp(s_prev - max_val) + exp(z - max_val))
// grad_z = (grad_log_h + exp(s - exp_nxt)) * exp(z - s)
double d_Z = frac_z * grad_s;
double d_A = grad_log_h + carry_grad_a - d_Z;
grad_log_values[curr] = T(d_Z);
grad_log_coeffs[curr] = T(d_A);
carry_grad_a = d_A;
carry_grad_s = frac_prev * grad_s;
}
}
*/
/*
template<typename T>
__global__ void fused_scan_backward_kernel(
T* __restrict__ grad_log_coeffs,
T* __restrict__ grad_log_values,
const T* __restrict__ grad_out,
const T* __restrict__ log_coeffs,
const T* __restrict__ log_values,
const T* __restrict__ out,
const double* __restrict__ a_star_buf,
const double* __restrict__ s_buf,
int T_total,
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * H) return;
int b = idx / H;
int h = idx % H;
int base = b * T_total * H + h;
double grad_a_star[1025] = {0}; // Assuming T_total <= 1024
double W = 0.0; // Accumulates sum_{i=t}^{T-1} [grad_log_h[i] * exp(-s[i])]
for (int t = T_total - 1; t >= 0; t--) {
int curr = base + t * H;
double a_star = a_star_buf[curr];
double s_val = s_buf[curr];
double z_val = double(log_values[curr]) - a_star;
// Compute dL/d(log_h[t]) = dL/d(out[t]) * d(out[t])/d(log_h[t])
double grad_log_h = double(grad_out[curr]) * double(out[curr]);
// Update W: W[t] = grad_log_h[t] * exp(-s_val) + W[t+1]
W = grad_log_h * exp(-s_val) + W;
// Compute dL/d(z[t]) = exp(z_val) * W[t]
double grad_z = exp(z_val) * W;
// dL/d(log_values[t]) = dL/d(z[t]) * dz[t]/d(log_values[t]) = grad_z
grad_log_values[curr] = T(grad_z);
// dL/da_star[t] = dL/d(log_h[t]) - dL/d(z[t]) (due to chain rule)
grad_a_star[t] = grad_log_h - grad_z;
}
// Compute dL/d(log_coeffs) via cumulative sum of dL/da_star
double accum = 0.0;
for (int t = T_total - 1; t >= 0; t--) {
accum += grad_a_star[t];
grad_log_coeffs[base + t * H] = T(accum);
}
}
*/
/*
template<typename T>
__global__ void fused_scan_backward_kernel(
T* __restrict__ grad_log_coeffs,
T* __restrict__ grad_log_values,
const T* __restrict__ grad_out,
const T* __restrict__ log_coeffs,
const T* __restrict__ log_values,
const T* __restrict__ out,
const float* __restrict__ a_star_buf,
const float* __restrict__ s_buf,
int T_total,
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * H) return;
int b = idx / H;
int h = idx % H;
int base = b * T_total * H + h;
float grad_a_star[1025] = {0};
float G = 0.0f; // G[t] = sum_{i=t}^{T-1} grad_s[i]
for (int t = T_total - 1; t >= 0; t--) {
int curr = base + t * H;
float a_star = a_star_buf[curr];
float s_val = s_buf[curr];
float z = float(log_values[curr]) - a_star;
// grad_log_h[t] = grad_out[t] * out[t]
float grad_log_h = float(grad_out[curr]) * float(out[curr]);
// G = sum of grad_s from t to end (grad_s[t] = grad_log_h[t])
G += grad_log_h;
// grad_z[t] = exp(z - s_val) * G
float prob = expf(z - s_val);
float grad_z = prob * G;
// grad_log_values[t] = grad_z
grad_log_values[curr] = T(grad_z);
// grad_a_star[t] gets:
// - +grad_log_h (from log_h = a_star + s)
// - -grad_z (from z = log_values - a_star)
grad_a_star[t] = grad_log_h - grad_z;
}
// grad_log_coeffs[t] = sum_{i=t}^{T-1} grad_a_star[i]
float accum = 0.0f;
for (int t = T_total - 1; t >= 0; t--) {
accum += grad_a_star[t];
grad_log_coeffs[base + t * H] = T(accum);
}
}
template<typename T>
__global__ void fused_scan_backward_kernel(
T* __restrict__ grad_log_coeffs,
T* __restrict__ grad_log_values,
const T* __restrict__ grad_out,
const T* __restrict__ log_coeffs,
const T* __restrict__ log_values,
const T* __restrict__ out,
const float* __restrict__ a_star_buf,
const float* __restrict__ s_buf,
int T_total,
int H,
int B
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= B * H) return;
int b = idx / H;
int h = idx % H;
int base = b * T_total * H + h;
// Recompute z[t] = log_values[t] - a_star[t]
float z[1025];
for (int t = 0; t < T_total; t++) {
int curr = base + t * H;
z[t] = float(log_values[curr]) - a_star_buf[curr];
}
// g_log_h[t] = grad_out[t] * out[t]
float g_log_h[1025];
for (int t = 0; t < T_total; t++) {
int curr = base + t * H;
g_log_h[t] = float(grad_out[curr]) * float(out[curr]);
}
// Step: Online logcumsumexp backward for g_z
float g_z[1025] = {0};
g_z[T_total - 1] = g_log_h[T_total - 1];
for (int t = T_total - 2; t >= 0; t--) {
float exp_term = expf(z[t] - s_buf[base + (t + 1) * H]);
g_z[t] = g_log_h[t] + g_z[t + 1] * exp_term;
}
// grad_log_values[t] = g_z[t]
for (int t = 0; t < T_total; t++) {
int curr = base + t * H;
grad_log_values[curr] = T(g_z[t]);
}
// g_a_star[t] = g_log_h[t] - g_z[t]
float g_a_star[1025] = {0};
for (int t = 0; t < T_total; t++) {
g_a_star[t] = g_log_h[t] - g_z[t];
}
// grad_log_coeffs[t] = reverse cumsum of g_a_star
float accum = 0.0f;
for (int t = T_total - 1; t >= 0; t--) {
accum += g_a_star[t];
grad_log_coeffs[base + t * H] = T(accum);
}
}
*/
// Fully fused forward launch: takes combined (B, T, 3*H) = [hidden, gate, proj]
template<typename T>
void launch_fused_scan_forward(
T* out,
T* next_state,
float* a_star,
float* s_vals,
float* log_values_buf, // (B, T+1, H) - cached for backward
const T* combined, // (B, T, 3*H) = [hidden, gate, proj]
const T* state,
int T_seq,
int H,
int B,
cudaStream_t stream
) {
int total = B * H;
int grid = seq_size(total);
fused_scan_forward_kernel<T><<<grid, SEQ_SIZE, 0, stream>>>(
out,
next_state,
a_star,
s_vals,
log_values_buf,
combined,
state,
T_seq,
H,
B
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error in forward: %s\n", cudaGetErrorString(err));
}
}
// Fully fused backward launch: outputs grad_combined (B, T, 3*H) = [grad_hidden, grad_gate, grad_proj]
template<typename T>
void launch_fused_scan_backward(
T* grad_combined, // (B, T, 3*H) = [grad_hidden, grad_gate, grad_proj]
T* grad_state,
const T* grad_out,
const T* grad_next_state,
const T* combined, // (B, T, 3*H) = [hidden, gate, proj]
const T* state,
const float* a_star_buf,
const float* s_buf,
const float* log_values_buf, // (B, T+1, H) - cached from forward
int T_seq,
int H,
int B,
cudaStream_t stream
) {
int total = B * H;
int grid = seq_size(total);
fused_scan_backward_kernel<T><<<grid, SEQ_SIZE, 0, stream>>>(
grad_combined,
grad_state,
grad_out,
grad_next_state,
combined,
state,
a_star_buf,
s_buf,
log_values_buf,
T_seq,
H,
B
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "CUDA kernel launch error in backward: %s\n", cudaGetErrorString(err));
}
}
/*
__device__ __forceinline__ float log_add_exp(const float a, const float b) {
if (::isnan(a) || ::isnan(b)) {
return std::numeric_limits<float>::quiet_NaN();
}
float min_val = fminf(a, b);
float max_val = fmaxf(a, b);
if (min_val != max_val || ::isfinite(min_val)) {
return max_val + log1pf(expf(min_val - max_val));
} else {
return a;
}
}
__device__ __forceinline__ float log_add_exp_backward(float x_val, float s_val) {
if (::isnan(x_val) || ::isnan(s_val)) {
return std::numeric_limits<float>::quiet_NaN();