Skip to content

Commit 27b45db

Browse files
committed
Small refactors
1 parent 7901b82 commit 27b45db

8 files changed

Lines changed: 807 additions & 84 deletions

File tree

pufferlib/src/cudnn_conv2d.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include <cudnn.h>
99
#include <cstdio>
1010

11+
#include "kernels.cu"
12+
1113
#ifndef CHECK_CUDNN
1214
#define CHECK_CUDNN(call) do { \
1315
cudnnStatus_t e = call; \

pufferlib/src/kernels.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,14 @@ __global__ void add_kernel(float* __restrict__ dst, const precision_t* __restric
172172
}
173173
}
174174

175+
#ifndef PRECISION_FLOAT
175176
__global__ void add_kernel(precision_t* __restrict__ dst, const precision_t* __restrict__ src, int n) {
176177
int idx = blockIdx.x * blockDim.x + threadIdx.x;
177178
if (idx < n) {
178179
dst[idx] = from_float(to_float(dst[idx]) + to_float(src[idx]));
179180
}
180181
}
182+
#endif
181183

182184
#include "tensor.h"
183185

@@ -322,13 +324,15 @@ __global__ void cast_kernel(precision_t* __restrict__ dst,
322324
}
323325
}
324326

327+
#ifndef PRECISION_FLOAT
325328
__global__ void cast_kernel(float* __restrict__ dst,
326329
const precision_t* __restrict__ src, int n) {
327330
int idx = blockIdx.x * blockDim.x + threadIdx.x;
328331
if (idx < n) {
329332
dst[idx] = to_float(src[idx]);
330333
}
331334
}
335+
#endif
332336

333337
__global__ void cast_kernel(precision_t* __restrict__ dst,
334338
const unsigned char* __restrict__ src, int n) {

pufferlib/src/models.cu

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
// Uses vector for MinGRU activations
1+
// Removed vector dependency for MinGRU activations - now uses raw pointers
22

33
#ifndef PUFFERLIB_MODELS_CU
44
#define PUFFERLIB_MODELS_CU
55

66
#include <cuda_runtime.h>
7-
#include <vector>
87
#include <string>
98
#include <cstdint>
109

1110
#include <stdio.h>
1211
#include <stdlib.h>
1312

14-
using std::vector;
15-
1613
#include "kernels.cu"
1714

1815
// Shared function pointer types (same signature for encoder and decoder)
@@ -22,6 +19,7 @@ typedef void (*reg_train_fn)(void* weights, void* buf, Allocator* acts, Allocato
2219
typedef void (*reg_rollout_fn)(void* weights, void* buf, Allocator* alloc, int B);
2320
typedef void* (*create_weights_fn)(void* self);
2421
typedef void (*free_weights_fn)(void* weights);
22+
typedef void (*free_activations_fn)(void* activations);
2523
typedef PrecisionTensor (*forward_fn)(void* weights, void* activations, PrecisionTensor input, cudaStream_t stream);
2624
typedef void (*encoder_backward_fn)(void* weights, void* activations,
2725
PrecisionTensor grad, cudaStream_t stream);
@@ -43,6 +41,7 @@ struct Encoder {
4341
reg_rollout_fn reg_rollout;
4442
create_weights_fn create_weights;
4543
free_weights_fn free_weights;
44+
free_activations_fn free_activations;
4645
int in_dim, out_dim;
4746
};
4847

@@ -55,6 +54,7 @@ struct Decoder {
5554
reg_rollout_fn reg_rollout;
5655
create_weights_fn create_weights;
5756
free_weights_fn free_weights;
57+
free_activations_fn free_activations;
5858
int hidden_dim, output_dim;
5959
bool continuous;
6060
};
@@ -69,6 +69,7 @@ struct Network {
6969
reg_rollout_fn reg_rollout;
7070
create_weights_fn create_weights;
7171
free_weights_fn free_weights;
72+
free_activations_fn free_activations;
7273
int hidden, num_layers, horizon;
7374
};
7475

@@ -480,6 +481,10 @@ static void encoder_free_weights(void* weights) {
480481
free(weights);
481482
}
482483

484+
static void encoder_free_activations(void* activations) {
485+
free(activations);
486+
}
487+
483488
#include "ocean.cu"
484489

485490
struct DecoderWeights {
@@ -559,6 +564,10 @@ static void decoder_free_weights(void* weights) {
559564
free(weights);
560565
}
561566

567+
static void decoder_free_activations(void* activations) {
568+
free(activations);
569+
}
570+
562571
static PrecisionTensor decoder_backward(void* w, void* activations,
563572
FloatTensor grad_logits, FloatTensor grad_logstd, FloatTensor grad_value, cudaStream_t stream) {
564573
DecoderWeights* dw = (DecoderWeights*)w;
@@ -579,18 +588,26 @@ static PrecisionTensor decoder_backward(void* w, void* activations,
579588
struct MinGRUActivations {
580589
int num_layers;
581590
// Rollout
582-
vector<PrecisionTensor> combined; // per-layer (B_inf, 3*H)
591+
PrecisionTensor* combined; // per-layer (B_inf, 3*H) - malloc'd
583592
PrecisionTensor out; // (B_inf, H)
584593
PrecisionTensor next_state; // (B_inf, H)
585594
// Training
586-
vector<PrecisionTensor> saved_inputs; // per-layer (B, TT, H)
587-
vector<PrefixScan> scan_bufs; // per-layer scan state
588-
vector<PrecisionTensor> combined_bufs;// per-layer (B_TT, 3*H)
589-
vector<PrecisionTensor> wgrad_scratch;// per-layer (3*H, H) weight grad output
595+
PrecisionTensor* saved_inputs; // per-layer (B, TT, H) - malloc'd
596+
PrefixScan* scan_bufs; // per-layer scan state - malloc'd
597+
PrecisionTensor* combined_bufs; // per-layer (B_TT, 3*H) - malloc'd
598+
PrecisionTensor* wgrad_scratch; // per-layer (3*H, H) weight grad output - malloc'd
590599
PrecisionTensor grad_input_buf; // (B_TT, H)
591600
PrecisionTensor grad_next_state; // (B, 1, H)
592601
};
593602

603+
void mingru_activations_free(MinGRUActivations* a) {
604+
free(a->combined);
605+
free(a->saved_inputs);
606+
free(a->scan_bufs);
607+
free(a->combined_bufs);
608+
free(a->wgrad_scratch);
609+
}
610+
594611
struct MinGRUWeights {
595612
int hidden, num_layers, horizon;
596613
PrecisionTensor* weights; // [num_layers], malloc'd
@@ -625,10 +642,10 @@ static void mingru_reg_train(void* w, void* activations, Allocator* acts, Alloca
625642
MinGRUActivations* a = (MinGRUActivations*)activations;
626643
int H = m->hidden, TT = m->horizon, B = B_TT / TT;
627644
a->num_layers = m->num_layers;
628-
a->saved_inputs.resize(m->num_layers);
629-
a->scan_bufs.resize(m->num_layers);
630-
a->combined_bufs.resize(m->num_layers);
631-
a->wgrad_scratch.resize(m->num_layers);
645+
a->saved_inputs = (PrecisionTensor*)calloc(m->num_layers, sizeof(PrecisionTensor));
646+
a->scan_bufs = (PrefixScan*)calloc(m->num_layers, sizeof(PrefixScan));
647+
a->combined_bufs = (PrecisionTensor*)calloc(m->num_layers, sizeof(PrecisionTensor));
648+
a->wgrad_scratch = (PrecisionTensor*)calloc(m->num_layers, sizeof(PrecisionTensor));
632649
a->grad_input_buf = {.shape = {B_TT, H}};
633650
a->grad_next_state = {.shape = {B, 1, H}};
634651
alloc_register(acts,&a->grad_input_buf);
@@ -666,7 +683,7 @@ static void mingru_reg_rollout(void* weights, void* activations, Allocator* allo
666683
MinGRUActivations* a = (MinGRUActivations*)activations;
667684
int H = w->hidden;
668685
a->num_layers = w->num_layers;
669-
a->combined.resize(w->num_layers);
686+
a->combined = (PrecisionTensor*)calloc(w->num_layers, sizeof(PrecisionTensor));
670687
for (int i = 0; i < w->num_layers; i++) {
671688
a->combined[i] = {.shape = {B_inf, 3 * H}};
672689
alloc_register(alloc,&a->combined[i]);
@@ -684,12 +701,19 @@ static void* mingru_create_weights(void* self) {
684701
mw->weights = (PrecisionTensor*)calloc(n->num_layers, sizeof(PrecisionTensor));
685702
return mw;
686703
}
704+
687705
static void mingru_free_weights(void* weights) {
688706
MinGRUWeights* mw = (MinGRUWeights*)weights;
689707
free(mw->weights);
690708
free(mw);
691709
}
692710

711+
static void mingru_free_activations(void* activations) {
712+
MinGRUActivations* a = (MinGRUActivations*)activations;
713+
mingru_activations_free(a);
714+
free(a);
715+
}
716+
693717
static PrecisionTensor mingru_forward(void* w, PrecisionTensor x, PrecisionTensor state,
694718
void* activations, cudaStream_t stream) {
695719
MinGRUWeights* m = (MinGRUWeights*)w;
@@ -764,11 +788,10 @@ struct PolicyWeights {
764788
void* network;
765789
};
766790

767-
static void policy_activations_free(PolicyActivations& a) {
768-
free(a.encoder);
769-
free(a.decoder);
770-
((MinGRUActivations*)a.network)->~MinGRUActivations();
771-
free(a.network);
791+
static void policy_activations_free(Policy* p, PolicyActivations& a) {
792+
p->encoder.free_activations(a.encoder);
793+
p->decoder.free_activations(a.decoder);
794+
p->network.free_activations(a.network);
772795
}
773796

774797
PrecisionTensor policy_forward(Policy* p, PolicyWeights& w, PolicyActivations& activations,

pufferlib/src/ocean.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ static void* nmmo3_encoder_create_weights(void* self) {
285285
return nmmo3_encoder_create(e->in_dim, e->out_dim);
286286
}
287287
static void nmmo3_encoder_free_weights(void* weights) { free(weights); }
288+
static void nmmo3_encoder_free_activations(void* activations) { free(activations); }
288289

289290
// Override encoder vtable for known ocean environments. No-op for unknown envs.
290291
static void create_custom_encoder(const std::string& env_name, Encoder* enc) {
@@ -298,6 +299,7 @@ static void create_custom_encoder(const std::string& env_name, Encoder* enc) {
298299
.reg_rollout = nmmo3_encoder_reg_rollout,
299300
.create_weights = nmmo3_encoder_create_weights,
300301
.free_weights = nmmo3_encoder_free_weights,
302+
.free_activations = nmmo3_encoder_free_activations,
301303
.in_dim = enc->in_dim, .out_dim = enc->out_dim,
302304
};
303305
}

pufferlib/src/pufferlib.cu

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <cuda_runtime.h>
22
#include <cuda_profiler_api.h>
3-
#include <nccl.h>
43
#include <nvtx3/nvToolsExt.h>
54
#include <nvml.h>
5+
#include <nccl.h>
66

77
#include "models.cu"
88
#include "muon.cu"
@@ -166,32 +166,13 @@ void register_train_buffers(TrainGraph& bufs, Allocator* alloc, int S, int H, in
166166
alloc_register(alloc, &bufs.mb_newvalue);
167167
}
168168

169-
// Minimal CUDA graph wrapper using raw APIs (no torch dependency)
170-
struct RawCudaGraph {
171-
cudaGraph_t graph = nullptr;
172-
cudaGraphExec_t exec = nullptr;
173-
174-
void capture_begin(cudaStream_t stream) {
175-
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
176-
}
177-
void capture_end(cudaStream_t stream) {
178-
cudaStreamEndCapture(stream, &graph);
179-
cudaGraphInstantiate(&exec, graph, 0);
180-
}
181-
void replay(cudaStream_t stream) {
182-
cudaGraphLaunch(exec, stream);
183-
}
184-
void reset() {
185-
if (exec) {
186-
cudaGraphExecDestroy(exec);
187-
exec = nullptr;
188-
}
189-
if (graph) {
190-
cudaGraphDestroy(graph);
191-
graph = nullptr;
192-
}
193-
}
194-
};
169+
// CUDA graph helpers
170+
inline void cudagraph_capture_end(cudaGraphExec_t* exec, cudaStream_t stream) {
171+
cudaGraph_t graph;
172+
cudaStreamEndCapture(stream, &graph);
173+
cudaGraphInstantiate(exec, graph, 0);
174+
cudaGraphDestroy(graph);
175+
}
195176

196177
// Slice: select dim0 index t, then narrow dim0 from start for count.
197178
// 3D (H, S, F) -> (count, F); 2D (H, S) -> (count,)
@@ -335,8 +316,8 @@ typedef struct {
335316
EnvBuf env;
336317
TrainGraph train_buf;
337318
FloatTensor advantages_puf; // Pre-allocated for train_impl (S, H) f32
338-
RawCudaGraph* fused_rollout_cudagraphs; // [horizon][num_buffers]
339-
RawCudaGraph train_cudagraph;
319+
cudaGraphExec_t* fused_rollout_cudagraphs; // [horizon][num_buffers]
320+
cudaGraphExec_t train_cudagraph;
340321
cudaStream_t* streams; // per-buffer raw CUDA streams
341322
cudaStream_t default_stream; // main-thread stream (captured once at init)
342323
IntTensor act_sizes_puf; // CUDA int32 tensor of action head sizes
@@ -562,7 +543,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
562543

563544
cudaStream_t current_stream = tl_stream;
564545
if (pufferl->rollout_captured) {
565-
pufferl->fused_rollout_cudagraphs[graph].replay(current_stream);
546+
cudaGraphLaunch(pufferl->fused_rollout_cudagraphs[graph], current_stream);
566547
profile_end(hypers.profile);
567548
return;
568549
}
@@ -572,7 +553,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
572553
if (capturing) {
573554
cudaStreamCreate(&cap_stream_raw);
574555
current_stream = cap_stream_raw;
575-
pufferl->fused_rollout_cudagraphs[graph].capture_begin(cap_stream_raw);
556+
cudaStreamBeginCapture(cap_stream_raw, cudaStreamCaptureModeGlobal);
576557
}
577558

578559
RolloutBuf& rollouts = pufferl->rollouts;
@@ -628,7 +609,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
628609
act_slice.data, numel(act_slice.shape) * sizeof(double), cudaMemcpyDeviceToDevice, stream);
629610

630611
if (capturing) {
631-
pufferl->fused_rollout_cudagraphs[graph].capture_end(cap_stream_raw);
612+
cudagraph_capture_end(&pufferl->fused_rollout_cudagraphs[graph], cap_stream_raw);
632613
cudaStreamSynchronize(cap_stream_raw);
633614
cudaDeviceSynchronize();
634615
cudaStreamDestroy(cap_stream_raw);
@@ -1460,13 +1441,13 @@ void train_impl(PuffeRL& pufferl) {
14601441

14611442
cudaEventRecord(pufferl.profile.events[3]); // end misc / start forward
14621443
if (pufferl.train_captured) {
1463-
pufferl.train_cudagraph.replay(train_stream);
1444+
cudaGraphLaunch(pufferl.train_cudagraph, train_stream);
14641445
} else {
14651446
bool capturing = pufferl.train_warmup == hypers.cudagraphs;
14661447
cudaStream_t cap_stream_raw = train_stream;
14671448
if (capturing) {
14681449
cudaStreamCreate(&cap_stream_raw);
1469-
pufferl.train_cudagraph.capture_begin(cap_stream_raw);
1450+
cudaStreamBeginCapture(cap_stream_raw, cudaStreamCaptureModeGlobal);
14701451
}
14711452

14721453
cudaStream_t stream = cap_stream_raw;
@@ -1499,9 +1480,8 @@ void train_impl(PuffeRL& pufferl) {
14991480
cast_kernel<<<grid_size(n), BLOCK_SIZE, 0, stream>>>(
15001481
pufferl.param_puf.data, pufferl.master_weights.data, n);
15011482
}
1502-
15031483
if (capturing) {
1504-
pufferl.train_cudagraph.capture_end(cap_stream_raw);
1484+
cudagraph_capture_end(&pufferl.train_cudagraph, cap_stream_raw);
15051485
cudaStreamSynchronize(cap_stream_raw);
15061486
cudaDeviceSynchronize();
15071487
cudaStreamDestroy(cap_stream_raw);
@@ -1636,6 +1616,7 @@ std::unique_ptr<PuffeRL> create_pufferl_impl(HypersT& hypers,
16361616
.reg_rollout = encoder_reg_rollout,
16371617
.create_weights = encoder_create_weights,
16381618
.free_weights = encoder_free_weights,
1619+
.free_activations = encoder_free_activations,
16391620
.in_dim = input_size, .out_dim = hidden_size,
16401621
};
16411622
create_custom_encoder(env_name, &encoder);
@@ -1648,6 +1629,7 @@ std::unique_ptr<PuffeRL> create_pufferl_impl(HypersT& hypers,
16481629
.reg_rollout = decoder_reg_rollout,
16491630
.create_weights = decoder_create_weights,
16501631
.free_weights = decoder_free_weights,
1632+
.free_activations = decoder_free_activations,
16511633
.hidden_dim = hidden_size, .output_dim = decoder_output_size, .continuous = is_continuous,
16521634
};
16531635
Network network = {
@@ -1660,6 +1642,7 @@ std::unique_ptr<PuffeRL> create_pufferl_impl(HypersT& hypers,
16601642
.reg_rollout = mingru_reg_rollout,
16611643
.create_weights = mingru_create_weights,
16621644
.free_weights = mingru_free_weights,
1645+
.free_activations = mingru_free_activations,
16631646
.hidden = hidden_size, .num_layers = num_layers, .horizon = hypers.horizon,
16641647
};
16651648
pufferl->policy = Policy{
@@ -1744,7 +1727,7 @@ std::unique_ptr<PuffeRL> create_pufferl_impl(HypersT& hypers,
17441727
muon_post_create(&pufferl->muon);
17451728

17461729
if (hypers.cudagraphs >= 0) {
1747-
pufferl->fused_rollout_cudagraphs = (RawCudaGraph*)calloc(horizon*num_buffers, sizeof(RawCudaGraph));
1730+
pufferl->fused_rollout_cudagraphs = (cudaGraphExec_t*)calloc(horizon*num_buffers, sizeof(cudaGraphExec_t));
17481731
pufferl->train_warmup = 0;
17491732

17501733
// Snapshot weights + optimizer state before init-time capture
@@ -1831,15 +1814,15 @@ void close_impl(PuffeRL& pufferl) {
18311814
cudaProfilerStop();
18321815
}
18331816

1834-
pufferl.train_cudagraph.reset();
1817+
cudaGraphExecDestroy(pufferl.train_cudagraph);
18351818
for (int i = 0; i < pufferl.hypers.horizon * pufferl.hypers.num_buffers; i++) {
1836-
pufferl.fused_rollout_cudagraphs[i].reset();
1819+
cudaGraphExecDestroy(pufferl.fused_rollout_cudagraphs[i]);
18371820
}
18381821

18391822
policy_weights_free(&pufferl.policy, &pufferl.weights);
1840-
policy_activations_free(pufferl.train_activations);
1823+
policy_activations_free(&pufferl.policy, pufferl.train_activations);
18411824
for (int buf = 0; buf < pufferl.hypers.num_buffers; buf++) {
1842-
policy_activations_free(pufferl.buffer_activations[buf]);
1825+
policy_activations_free(&pufferl.policy, pufferl.buffer_activations[buf]);
18431826
}
18441827

18451828
if (USE_BF16) {

0 commit comments

Comments
 (0)