@@ -488,22 +488,32 @@ static void encoder_free_activations(void* activations) {
488488#include " ocean.cu"
489489
490490struct DecoderWeights {
491- PrecisionTensor weight, logstd;
491+ PrecisionTensor weight, bias, logstd;
492492 int hidden_dim, output_dim;
493493 bool continuous;
494494};
495495
496496struct DecoderActivations {
497- PrecisionTensor out, grad_out, saved_input, grad_input, wgrad_scratch, logstd_scratch;
497+ PrecisionTensor out, grad_out, saved_input, grad_input, wgrad_scratch, bgrad_scratch, logstd_scratch;
498498};
499499
500+ __global__ void bias_add_kernel (precision_t * __restrict__ data,
501+ const precision_t * __restrict__ bias, int total, int dim) {
502+ int idx = blockIdx .x * blockDim .x + threadIdx .x ;
503+ if (idx >= total) return ;
504+ data[idx] = from_float (to_float (data[idx]) + to_float (bias[idx % dim]));
505+ }
506+
500507static PrecisionTensor decoder_forward (void * w, void * activations, PrecisionTensor input, cudaStream_t stream) {
501508 DecoderWeights* dw = (DecoderWeights*)w;
502509 DecoderActivations* a = (DecoderActivations*)activations;
503510 if (a->saved_input .data ) {
504511 puf_copy (&a->saved_input , &input, stream);
505512 }
506513 puf_mm (&input, &dw->weight , &a->out , stream);
514+ int B = input.shape [0 ], od1 = dw->output_dim + 1 ;
515+ bias_add_kernel<<<grid_size(B * od1), BLOCK_SIZE, 0 , stream>>> (
516+ a->out .data , dw->bias .data , B * od1, od1);
507517 return a->out ;
508518}
509519
@@ -514,12 +524,15 @@ static void decoder_init_weights(void* w, ulong* seed, cudaStream_t stream) {
514524 .shape = {dw->output_dim + 1 , dw->hidden_dim },
515525 };
516526 puf_kaiming_init (&wt, 0 .01f , (*seed)++, stream);
527+ cudaMemsetAsync (dw->bias .data , 0 , numel (dw->bias .shape ) * sizeof (precision_t ), stream);
517528}
518529
519530static void decoder_reg_params (void * w, Allocator* alloc) {
520531 DecoderWeights* dw = (DecoderWeights*)w;
521532 dw->weight = {.shape = {dw->output_dim + 1 , dw->hidden_dim }};
533+ dw->bias = {.shape = {dw->output_dim + 1 }};
522534 alloc_register (alloc,&dw->weight );
535+ alloc_register (alloc,&dw->bias );
523536 if (dw->continuous ) {
524537 dw->logstd = {.shape = {1 , dw->output_dim }};
525538 alloc_register (alloc,&dw->logstd );
@@ -536,13 +549,15 @@ static void decoder_reg_train(void* w, void* activations, Allocator* acts, Alloc
536549 .saved_input = {.shape = {B_TT, dw->hidden_dim }},
537550 .grad_input = {.shape = {B_TT, dw->hidden_dim }},
538551 .wgrad_scratch = {.shape = {od1, dw->hidden_dim }},
552+ .bgrad_scratch = {.shape = {od1}},
539553 .logstd_scratch = {.shape = {1 , dw->output_dim }},
540554 };
541555 alloc_register (acts,&a->out );
542556 alloc_register (acts,&a->saved_input );
543557 alloc_register (acts,&a->grad_out );
544558 alloc_register (acts,&a->grad_input );
545559 alloc_register (grads,&a->wgrad_scratch );
560+ alloc_register (grads,&a->bgrad_scratch );
546561 if (dw->continuous ) alloc_register (grads,&a->logstd_scratch );
547562}
548563
@@ -577,6 +592,8 @@ static PrecisionTensor decoder_backward(void* w, void* activations,
577592 assemble_decoder_grad_kernel<<<grid_size(B_TT * od1), BLOCK_SIZE, 0 , stream>>> (
578593 a->grad_out .data , grad_logits.data , grad_value.data , B_TT, od, od1);
579594 puf_mm_tn (&a->grad_out , &a->saved_input , &a->wgrad_scratch , stream);
595+ n3_bias_grad_kernel<<<od1, 256 , 0 , stream>>> (
596+ a->bgrad_scratch .data , a->grad_out .data , B_TT, od1);
580597 if (dw->continuous && grad_logstd.data != nullptr ) {
581598 sum_rows_to_precision_kernel<<<grid_size(dw->output_dim), BLOCK_SIZE, 0 , stream>>> (
582599 a->logstd_scratch .data , grad_logstd.data , B_TT, dw->output_dim );
0 commit comments