Skip to content

Commit b940322

Browse files
committed
nmmo fixes - maybe will train?
1 parent 27b45db commit b940322

4 files changed

Lines changed: 291 additions & 556 deletions

File tree

pufferlib/src/models.cu

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,22 +488,32 @@ static void encoder_free_activations(void* activations) {
488488
#include "ocean.cu"
489489

490490
struct DecoderWeights {
491-
PrecisionTensor weight, logstd;
491+
PrecisionTensor weight, bias, logstd;
492492
int hidden_dim, output_dim;
493493
bool continuous;
494494
};
495495

496496
struct DecoderActivations {
497-
PrecisionTensor out, grad_out, saved_input, grad_input, wgrad_scratch, logstd_scratch;
497+
PrecisionTensor out, grad_out, saved_input, grad_input, wgrad_scratch, bgrad_scratch, logstd_scratch;
498498
};
499499

500+
__global__ void bias_add_kernel(precision_t* __restrict__ data,
501+
const precision_t* __restrict__ bias, int total, int dim) {
502+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
503+
if (idx >= total) return;
504+
data[idx] = from_float(to_float(data[idx]) + to_float(bias[idx % dim]));
505+
}
506+
500507
static PrecisionTensor decoder_forward(void* w, void* activations, PrecisionTensor input, cudaStream_t stream) {
501508
DecoderWeights* dw = (DecoderWeights*)w;
502509
DecoderActivations* a = (DecoderActivations*)activations;
503510
if (a->saved_input.data) {
504511
puf_copy(&a->saved_input, &input, stream);
505512
}
506513
puf_mm(&input, &dw->weight, &a->out, stream);
514+
int B = input.shape[0], od1 = dw->output_dim + 1;
515+
bias_add_kernel<<<grid_size(B * od1), BLOCK_SIZE, 0, stream>>>(
516+
a->out.data, dw->bias.data, B * od1, od1);
507517
return a->out;
508518
}
509519

@@ -514,12 +524,15 @@ static void decoder_init_weights(void* w, ulong* seed, cudaStream_t stream) {
514524
.shape = {dw->output_dim + 1, dw->hidden_dim},
515525
};
516526
puf_kaiming_init(&wt, 0.01f, (*seed)++, stream);
527+
cudaMemsetAsync(dw->bias.data, 0, numel(dw->bias.shape) * sizeof(precision_t), stream);
517528
}
518529

519530
static void decoder_reg_params(void* w, Allocator* alloc) {
520531
DecoderWeights* dw = (DecoderWeights*)w;
521532
dw->weight = {.shape = {dw->output_dim + 1, dw->hidden_dim}};
533+
dw->bias = {.shape = {dw->output_dim + 1}};
522534
alloc_register(alloc,&dw->weight);
535+
alloc_register(alloc,&dw->bias);
523536
if (dw->continuous) {
524537
dw->logstd = {.shape = {1, dw->output_dim}};
525538
alloc_register(alloc,&dw->logstd);
@@ -536,13 +549,15 @@ static void decoder_reg_train(void* w, void* activations, Allocator* acts, Alloc
536549
.saved_input = {.shape = {B_TT, dw->hidden_dim}},
537550
.grad_input = {.shape = {B_TT, dw->hidden_dim}},
538551
.wgrad_scratch = {.shape = {od1, dw->hidden_dim}},
552+
.bgrad_scratch = {.shape = {od1}},
539553
.logstd_scratch = {.shape = {1, dw->output_dim}},
540554
};
541555
alloc_register(acts,&a->out);
542556
alloc_register(acts,&a->saved_input);
543557
alloc_register(acts,&a->grad_out);
544558
alloc_register(acts,&a->grad_input);
545559
alloc_register(grads,&a->wgrad_scratch);
560+
alloc_register(grads,&a->bgrad_scratch);
546561
if (dw->continuous) alloc_register(grads,&a->logstd_scratch);
547562
}
548563

@@ -577,6 +592,8 @@ static PrecisionTensor decoder_backward(void* w, void* activations,
577592
assemble_decoder_grad_kernel<<<grid_size(B_TT * od1), BLOCK_SIZE, 0, stream>>>(
578593
a->grad_out.data, grad_logits.data, grad_value.data, B_TT, od, od1);
579594
puf_mm_tn(&a->grad_out, &a->saved_input, &a->wgrad_scratch, stream);
595+
n3_bias_grad_kernel<<<od1, 256, 0, stream>>>(
596+
a->bgrad_scratch.data, a->grad_out.data, B_TT, od1);
580597
if (dw->continuous && grad_logstd.data != nullptr) {
581598
sum_rows_to_precision_kernel<<<grid_size(dw->output_dim), BLOCK_SIZE, 0, stream>>>(
582599
a->logstd_scratch.data, grad_logstd.data, B_TT, dw->output_dim);

pufferlib/src/ocean.cu

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,32 @@ __global__ void n3_concat_backward_conv_kernel(
140140
conv_grad[b * N3_CONV_FLAT + c] = concat_grad[b * N3_CONCAT + c];
141141
}
142142

143+
// Embedding backward: scatter-add grad from concat_grad's player_embed region
144+
// into embed_wgrad (float accumulation buffer).
145+
// Each (b, f) looked up row obs[b, MAP_SIZE+f] from the table.
146+
__global__ void n3_embedding_backward_kernel(
147+
float* __restrict__ embed_wgrad_f,
148+
const precision_t* __restrict__ concat_grad,
149+
const precision_t* __restrict__ obs,
150+
int B, int obs_size) {
151+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
152+
if (idx >= B * N3_PLAYER * N3_EMBED_DIM) return;
153+
int b = idx / (N3_PLAYER * N3_EMBED_DIM);
154+
int rem = idx % (N3_PLAYER * N3_EMBED_DIM);
155+
int f = rem / N3_EMBED_DIM;
156+
int d = rem % N3_EMBED_DIM;
157+
int val = (int)to_float(obs[b * obs_size + N3_MAP_SIZE + f]);
158+
float g = to_float(concat_grad[b * N3_CONCAT + N3_CONV_FLAT + f * N3_EMBED_DIM + d]);
159+
atomicAdd(&embed_wgrad_f[val * N3_EMBED_DIM + d], g);
160+
}
161+
162+
// Cast float buffer to precision_t
163+
__global__ void n3_float_to_precision_kernel(
164+
precision_t* __restrict__ dst, const float* __restrict__ src, int n) {
165+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
166+
if (idx < n) dst[idx] = from_float(src[idx]);
167+
}
168+
143169
// ---- NMMO3 encoder structs ----
144170

145171
struct NMMO3EncoderWeights {
@@ -152,6 +178,7 @@ struct NMMO3EncoderActivations {
152178
ConvActivations conv1, conv2;
153179
PrecisionTensor multihot, embed_out, concat, out, saved_obs;
154180
PrecisionTensor embed_wgrad, proj_wgrad, proj_bgrad;
181+
FloatTensor embed_wgrad_f; // float accumulation buffer for scatter-add
155182
};
156183

157184
static NMMO3EncoderWeights* nmmo3_encoder_create(int obs_size, int hidden) {
@@ -219,7 +246,13 @@ static void nmmo3_encoder_backward(void* w, void* activations, PrecisionTensor g
219246
B, ew->conv1.OC, ew->conv1.OH * ew->conv1.OW);
220247
conv_backward(&ew->conv1, &a->conv1, NULL, stream);
221248

222-
cudaMemsetAsync(a->embed_wgrad.data, 0, numel(a->embed_wgrad.shape) * sizeof(precision_t), stream);
249+
// Embedding backward: scatter-add from concat gradient into float buffer, then cast
250+
int embed_n = N3_EMBED_VOCAB * N3_EMBED_DIM;
251+
cudaMemsetAsync(a->embed_wgrad_f.data, 0, embed_n * sizeof(float), stream);
252+
n3_embedding_backward_kernel<<<grid_size(B * N3_PLAYER * N3_EMBED_DIM), BLOCK_SIZE, 0, stream>>>(
253+
a->embed_wgrad_f.data, grad_concat.data, a->saved_obs.data, B, ew->obs_size);
254+
n3_float_to_precision_kernel<<<grid_size(embed_n), BLOCK_SIZE, 0, stream>>>(
255+
a->embed_wgrad.data, a->embed_wgrad_f.data, embed_n);
223256
}
224257

225258
static void nmmo3_encoder_init_weights(void* w, uint64_t* seed, cudaStream_t stream) {
@@ -261,9 +294,11 @@ static void nmmo3_encoder_reg_train(void* w, void* activations, Allocator* acts,
261294
alloc_register(acts,&a->embed_out); alloc_register(acts,&a->concat);
262295
alloc_register(acts,&a->out); alloc_register(acts,&a->saved_obs);
263296
a->embed_wgrad = {.shape = {N3_EMBED_VOCAB, N3_EMBED_DIM}};
297+
a->embed_wgrad_f = {.shape = {N3_EMBED_VOCAB, N3_EMBED_DIM}};
264298
a->proj_wgrad = {.shape = {ew->hidden, N3_CONCAT}};
265299
a->proj_bgrad = {.shape = {ew->hidden}};
266300
alloc_register(grads,&a->embed_wgrad);
301+
alloc_register(acts,&a->embed_wgrad_f);
267302
alloc_register(grads,&a->proj_wgrad); alloc_register(grads,&a->proj_bgrad);
268303
}
269304

0 commit comments

Comments
 (0)