11#include < cuda_runtime.h>
22#include < cuda_profiler_api.h>
3- #include < nccl.h>
43#include < nvtx3/nvToolsExt.h>
54#include < nvml.h>
5+ #include < nccl.h>
66
77#include " models.cu"
88#include " muon.cu"
@@ -166,32 +166,13 @@ void register_train_buffers(TrainGraph& bufs, Allocator* alloc, int S, int H, in
166166 alloc_register (alloc, &bufs.mb_newvalue );
167167}
168168
169- // Minimal CUDA graph wrapper using raw APIs (no torch dependency)
170- struct RawCudaGraph {
171- cudaGraph_t graph = nullptr ;
172- cudaGraphExec_t exec = nullptr ;
173-
174- void capture_begin (cudaStream_t stream) {
175- cudaStreamBeginCapture (stream, cudaStreamCaptureModeGlobal);
176- }
177- void capture_end (cudaStream_t stream) {
178- cudaStreamEndCapture (stream, &graph);
179- cudaGraphInstantiate (&exec, graph, 0 );
180- }
181- void replay (cudaStream_t stream) {
182- cudaGraphLaunch (exec, stream);
183- }
184- void reset () {
185- if (exec) {
186- cudaGraphExecDestroy (exec);
187- exec = nullptr ;
188- }
189- if (graph) {
190- cudaGraphDestroy (graph);
191- graph = nullptr ;
192- }
193- }
194- };
169+ // CUDA graph helpers
170+ inline void cudagraph_capture_end (cudaGraphExec_t* exec, cudaStream_t stream) {
171+ cudaGraph_t graph;
172+ cudaStreamEndCapture (stream, &graph);
173+ cudaGraphInstantiate (exec, graph, 0 );
174+ cudaGraphDestroy (graph);
175+ }
195176
196177// Slice: select dim0 index t, then narrow dim0 from start for count.
197178// 3D (H, S, F) -> (count, F); 2D (H, S) -> (count,)
@@ -335,8 +316,8 @@ typedef struct {
335316 EnvBuf env;
336317 TrainGraph train_buf;
337318 FloatTensor advantages_puf; // Pre-allocated for train_impl (S, H) f32
338- RawCudaGraph * fused_rollout_cudagraphs; // [horizon][num_buffers]
339- RawCudaGraph train_cudagraph;
319+ cudaGraphExec_t * fused_rollout_cudagraphs; // [horizon][num_buffers]
320+ cudaGraphExec_t train_cudagraph;
340321 cudaStream_t* streams; // per-buffer raw CUDA streams
341322 cudaStream_t default_stream; // main-thread stream (captured once at init)
342323 IntTensor act_sizes_puf; // CUDA int32 tensor of action head sizes
@@ -562,7 +543,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
562543
563544 cudaStream_t current_stream = tl_stream;
564545 if (pufferl->rollout_captured ) {
565- pufferl->fused_rollout_cudagraphs [graph]. replay ( current_stream);
546+ cudaGraphLaunch ( pufferl->fused_rollout_cudagraphs [graph], current_stream);
566547 profile_end (hypers.profile );
567548 return ;
568549 }
@@ -572,7 +553,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
572553 if (capturing) {
573554 cudaStreamCreate (&cap_stream_raw);
574555 current_stream = cap_stream_raw;
575- pufferl-> fused_rollout_cudagraphs [graph]. capture_begin (cap_stream_raw);
556+ cudaStreamBeginCapture (cap_stream_raw, cudaStreamCaptureModeGlobal );
576557 }
577558
578559 RolloutBuf& rollouts = pufferl->rollouts ;
@@ -628,7 +609,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
628609 act_slice.data , numel (act_slice.shape ) * sizeof (double ), cudaMemcpyDeviceToDevice, stream);
629610
630611 if (capturing) {
631- pufferl->fused_rollout_cudagraphs [graph]. capture_end ( cap_stream_raw);
612+ cudagraph_capture_end (& pufferl->fused_rollout_cudagraphs [graph], cap_stream_raw);
632613 cudaStreamSynchronize (cap_stream_raw);
633614 cudaDeviceSynchronize ();
634615 cudaStreamDestroy (cap_stream_raw);
@@ -1460,13 +1441,13 @@ void train_impl(PuffeRL& pufferl) {
14601441
14611442 cudaEventRecord (pufferl.profile .events [3 ]); // end misc / start forward
14621443 if (pufferl.train_captured ) {
1463- pufferl.train_cudagraph . replay ( train_stream);
1444+ cudaGraphLaunch ( pufferl.train_cudagraph , train_stream);
14641445 } else {
14651446 bool capturing = pufferl.train_warmup == hypers.cudagraphs ;
14661447 cudaStream_t cap_stream_raw = train_stream;
14671448 if (capturing) {
14681449 cudaStreamCreate (&cap_stream_raw);
1469- pufferl. train_cudagraph . capture_begin (cap_stream_raw);
1450+ cudaStreamBeginCapture (cap_stream_raw, cudaStreamCaptureModeGlobal );
14701451 }
14711452
14721453 cudaStream_t stream = cap_stream_raw;
@@ -1499,9 +1480,8 @@ void train_impl(PuffeRL& pufferl) {
14991480 cast_kernel<<<grid_size(n), BLOCK_SIZE, 0 , stream>>> (
15001481 pufferl.param_puf .data , pufferl.master_weights .data , n);
15011482 }
1502-
15031483 if (capturing) {
1504- pufferl.train_cudagraph . capture_end ( cap_stream_raw);
1484+ cudagraph_capture_end (& pufferl.train_cudagraph , cap_stream_raw);
15051485 cudaStreamSynchronize (cap_stream_raw);
15061486 cudaDeviceSynchronize ();
15071487 cudaStreamDestroy (cap_stream_raw);
@@ -1636,6 +1616,7 @@ std::unique_ptr<PuffeRL> create_pufferl_impl(HypersT& hypers,
16361616 .reg_rollout = encoder_reg_rollout,
16371617 .create_weights = encoder_create_weights,
16381618 .free_weights = encoder_free_weights,
1619+ .free_activations = encoder_free_activations,
16391620 .in_dim = input_size, .out_dim = hidden_size,
16401621 };
16411622 create_custom_encoder (env_name, &encoder);
@@ -1648,6 +1629,7 @@ std::unique_ptr<PuffeRL> create_pufferl_impl(HypersT& hypers,
16481629 .reg_rollout = decoder_reg_rollout,
16491630 .create_weights = decoder_create_weights,
16501631 .free_weights = decoder_free_weights,
1632+ .free_activations = decoder_free_activations,
16511633 .hidden_dim = hidden_size, .output_dim = decoder_output_size, .continuous = is_continuous,
16521634 };
16531635 Network network = {
@@ -1660,6 +1642,7 @@ std::unique_ptr<PuffeRL> create_pufferl_impl(HypersT& hypers,
16601642 .reg_rollout = mingru_reg_rollout,
16611643 .create_weights = mingru_create_weights,
16621644 .free_weights = mingru_free_weights,
1645+ .free_activations = mingru_free_activations,
16631646 .hidden = hidden_size, .num_layers = num_layers, .horizon = hypers.horizon ,
16641647 };
16651648 pufferl->policy = Policy{
@@ -1744,7 +1727,7 @@ std::unique_ptr<PuffeRL> create_pufferl_impl(HypersT& hypers,
17441727 muon_post_create (&pufferl->muon );
17451728
17461729 if (hypers.cudagraphs >= 0 ) {
1747- pufferl->fused_rollout_cudagraphs = (RawCudaGraph *)calloc (horizon*num_buffers, sizeof (RawCudaGraph ));
1730+ pufferl->fused_rollout_cudagraphs = (cudaGraphExec_t *)calloc (horizon*num_buffers, sizeof (cudaGraphExec_t ));
17481731 pufferl->train_warmup = 0 ;
17491732
17501733 // Snapshot weights + optimizer state before init-time capture
@@ -1831,15 +1814,15 @@ void close_impl(PuffeRL& pufferl) {
18311814 cudaProfilerStop ();
18321815 }
18331816
1834- pufferl.train_cudagraph . reset ( );
1817+ cudaGraphExecDestroy ( pufferl.train_cudagraph );
18351818 for (int i = 0 ; i < pufferl.hypers .horizon * pufferl.hypers .num_buffers ; i++) {
1836- pufferl.fused_rollout_cudagraphs [i]. reset ( );
1819+ cudaGraphExecDestroy ( pufferl.fused_rollout_cudagraphs [i]);
18371820 }
18381821
18391822 policy_weights_free (&pufferl.policy , &pufferl.weights );
1840- policy_activations_free (pufferl.train_activations );
1823+ policy_activations_free (&pufferl. policy , pufferl.train_activations );
18411824 for (int buf = 0 ; buf < pufferl.hypers .num_buffers ; buf++) {
1842- policy_activations_free (pufferl.buffer_activations [buf]);
1825+ policy_activations_free (&pufferl. policy , pufferl.buffer_activations [buf]);
18431826 }
18441827
18451828 if (USE_BF16) {
0 commit comments