@@ -494,14 +494,16 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
494494
495495 cudaStream_t current_stream = tl_stream;
496496 if (pufferl->rollout_captured ) {
497- cudaGraphLaunch (pufferl->fused_rollout_cudagraphs [graph], current_stream);
497+ assert (cudaGraphLaunch (pufferl->fused_rollout_cudagraphs [graph], current_stream) == cudaSuccess
498+ && " cudaGraphLaunch failed" );
498499 profile_end (hypers.profile );
499500 return ;
500501 }
501502
502503 bool capturing = pufferl->epoch == hypers.cudagraphs ;
503504 if (capturing) {
504- cudaStreamBeginCapture (current_stream, cudaStreamCaptureModeGlobal);
505+ assert (cudaStreamBeginCapture (current_stream, cudaStreamCaptureModeGlobal) == cudaSuccess
506+ && " cudaStreamBeginCapture failed" );
505507 }
506508
507509 RolloutBuf& rollouts = pufferl->rollouts ;
@@ -552,9 +554,11 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
552554
553555 if (capturing) {
554556 cudaGraph_t _graph;
555- cudaStreamEndCapture (current_stream, &_graph);
556- cudaGraphInstantiate (&pufferl->fused_rollout_cudagraphs [graph], _graph, 0 );
557- cudaGraphDestroy (_graph);
557+ assert (cudaStreamEndCapture (current_stream, &_graph) == cudaSuccess
558+ && " cudaStreamEndCapture failed" );
559+ assert (cudaGraphInstantiate (&pufferl->fused_rollout_cudagraphs [graph], _graph, 0 ) == cudaSuccess
560+ && " cudaGraphInstantiate failed" );
561+ assert (cudaGraphDestroy (_graph) == cudaSuccess && " cudaGraphDestroy failed" );
558562 cudaDeviceSynchronize ();
559563 }
560564 profile_end (hypers.profile );
@@ -1008,41 +1012,43 @@ __global__ void compute_prio_imp_weights(
10081012 }
10091013}
10101014
1011- // Multinomial with replacement (uses cuRAND)
1012- __global__ void multinomial_sample (
1013- int * __restrict__ out_idx, const float * __restrict__ probs,
1014- float * __restrict__ cdf, int B, int num_samples,
1015- uint64_t seed, int64_t * __restrict__ offset_ptr) {
1016- int tid = threadIdx .x ;
1017- if (tid == 0 ) {
1015+ __global__ void build_cdf (
1016+ float * __restrict__ cdf, const float * __restrict__ probs, int B) {
1017+ if (blockIdx .x == 0 && threadIdx .x == 0 ) {
10181018 float cum = 0 .0f ;
10191019 for (int i = 0 ; i < B; i++) {
10201020 cum += probs[i];
10211021 cdf[i] = cum;
10221022 }
10231023 }
1024- __syncthreads ();
1025- if (tid < num_samples) {
1026- uint64_t base_off = *offset_ptr;
1027- curandStatePhilox4_32_10_t rng_state;
1028- curand_init (seed, base_off + tid, 0 , &rng_state);
1029- float u = curand_uniform (&rng_state);
1030- int lo = 0 , hi = B - 1 ;
1031- while (lo < hi) {
1032- int mid = (lo + hi) / 2 ;
1033- if (cdf[mid] < u) {
1034- lo = mid + 1 ;
1035- } else {
1036- hi = mid;
1037- }
1038- }
1039- out_idx[tid] = lo;
1040- }
1041- if (tid == 0 ) {
1042- atomicAdd ((unsigned long long *)offset_ptr, (unsigned long long )num_samples);
1024+ }
1025+
1026+ __global__ void advance_rng_offset (int64_t * __restrict__ offset_ptr, int64_t delta) {
1027+ if (blockIdx .x == 0 && threadIdx .x == 0 ) {
1028+ *offset_ptr += delta;
10431029 }
10441030}
10451031
1032+ // Multinomial with replacement (uses cuRAND)
1033+ __global__ void multinomial_sample (int * __restrict__ out_idx, const float * __restrict__ cdf,
1034+ int B, int num_samples, uint64_t seed, const int64_t * __restrict__ offset_ptr) {
1035+ int tid = blockIdx .x * blockDim .x + threadIdx .x ;
1036+ if (tid >= num_samples) return ;
1037+
1038+ uint64_t base_off = (uint64_t )(*offset_ptr);
1039+ curandStatePhilox4_32_10_t rng_state;
1040+ curand_init (seed, base_off + tid, 0 , &rng_state);
1041+ float u = curand_uniform (&rng_state);
1042+
1043+ int lo = 0 , hi = B - 1 ;
1044+ while (lo < hi) {
1045+ int mid = (lo + hi) / 2 ;
1046+ if (cdf[mid] < u) lo = mid + 1 ;
1047+ else hi = mid;
1048+ }
1049+ out_idx[tid] = lo;
1050+ }
1051+
10461052// Prioritize high absolute advantage trajectories
10471053// This is a form of implicit curriculum learning
10481054// It is a major improvement in some complex environments
@@ -1056,10 +1062,14 @@ void prio_replay_cuda(PrecisionTensor& advantages, float prio_alpha,
10561062 advantages.data , bufs.prio_probs .data , prio_alpha, T);
10571063 compute_prio_normalize<<<1 , PRIO_BLOCK_SIZE, 0 , stream>>> (
10581064 bufs.prio_probs .data , B);
1059- int block = fmaxf (((minibatch_segments + 31 ) / 32 ) * 32 , 32 );
1060- multinomial_sample<<<1 , block, 0 , stream>>> (
1061- bufs.idx .data , bufs.prio_probs .data ,
1062- bufs.cdf .data , B, minibatch_segments, seed, offset_ptr);
1065+ // int block = fmaxf(((minibatch_segments + 31) / 32) * 32, 32);
1066+ build_cdf<<<1 , 1 , 0 , stream>>> (bufs.cdf .data , bufs.prio_probs .data , B);
1067+ int threads = 256 ;
1068+ int blocks = (minibatch_segments + threads - 1 ) / threads;
1069+ multinomial_sample<<<blocks, threads, 0 , stream>>> (
1070+ bufs.idx .data , bufs.cdf .data , B, minibatch_segments, seed, offset_ptr);
1071+ advance_rng_offset<<<1 , 1 , 0 , stream>>> (offset_ptr, (int64_t )minibatch_segments);
1072+
10631073 int p3_blocks = (minibatch_segments + PRIO_BLOCK_SIZE - 1 ) / PRIO_BLOCK_SIZE;
10641074 compute_prio_imp_weights<<<p3_blocks, PRIO_BLOCK_SIZE, 0 , stream>>> (
10651075 bufs.idx .data , bufs.prio_probs .data ,
@@ -1368,7 +1378,8 @@ void train_impl(PuffeRL& pufferl) {
13681378 } else {
13691379 bool capturing = pufferl.train_warmup == hypers.cudagraphs ;
13701380 if (capturing) {
1371- cudaStreamBeginCapture (train_stream, cudaStreamCaptureModeGlobal);
1381+ assert (cudaStreamBeginCapture (train_stream, cudaStreamCaptureModeGlobal) == cudaSuccess
1382+ && " cudaStreamBeginCapture failed" );
13721383 }
13731384
13741385 cudaStream_t stream = train_stream;
@@ -1400,9 +1411,11 @@ void train_impl(PuffeRL& pufferl) {
14001411 }
14011412 if (capturing) {
14021413 cudaGraph_t _graph;
1403- cudaStreamEndCapture (train_stream, &_graph);
1404- cudaGraphInstantiate (&pufferl.train_cudagraph , _graph, 0 );
1405- cudaGraphDestroy (_graph);
1414+ assert (cudaStreamEndCapture (train_stream, &_graph) == cudaSuccess
1415+ && " cudaStreamEndCapture failed" );
1416+ assert (cudaGraphInstantiate (&pufferl.train_cudagraph , _graph, 0 ) == cudaSuccess
1417+ && " cudaGraphInstantiate failed" );
1418+ assert (cudaGraphDestroy (_graph) == cudaSuccess && " cudaGraphDestroy failed" );
14061419 cudaDeviceSynchronize ();
14071420 pufferl.train_captured = true ;
14081421 }
0 commit comments