Skip to content

Commit abeb03c

Browse files
committed
Major alignment bug fix on copy_bytes
1 parent 919609b commit abeb03c

3 files changed

Lines changed: 65 additions & 39 deletions

File tree

src/kernels.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,18 @@ __device__ __forceinline__ float logaddexp(float a, float b) {
153153
return (diff < -88.0f) ? m : m + log1pf(__expf(diff));
154154
}
155155

156+
//TODO: Speed up. The previous version was misaligned.
157+
__device__ __forceinline__ void copy_bytes(
158+
const char* __restrict__ src, char* __restrict__ dst,
159+
int src_row, int dst_row, int row_bytes) {
160+
const char* s = src + (int64_t)src_row * row_bytes;
161+
char* d = dst + (int64_t)dst_row * row_bytes;
162+
for (int i = threadIdx.x; i < row_bytes; i += blockDim.x) {
163+
d[i] = s[i];
164+
}
165+
}
166+
167+
/*
156168
__device__ __forceinline__ void copy_bytes(const char* __restrict__ src,
157169
char* __restrict__ dst, int src_row, int dst_row, int row_bytes) {
158170
const int* soffset = (const int*)(src + (int64_t)src_row * row_bytes);
@@ -161,6 +173,7 @@ __device__ __forceinline__ void copy_bytes(const char* __restrict__ src,
161173
doffset[i] = soffset[i];
162174
}
163175
}
176+
*/
164177

165178
// Transpose dims 0,1: [A, B, C] -> [B, A, C]. For 2D, pass C=1.
166179
__global__ void transpose_102(precision_t* __restrict__ dst,

src/pufferlib.cu

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -494,14 +494,16 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
494494

495495
cudaStream_t current_stream = tl_stream;
496496
if (pufferl->rollout_captured) {
497-
cudaGraphLaunch(pufferl->fused_rollout_cudagraphs[graph], current_stream);
497+
assert(cudaGraphLaunch(pufferl->fused_rollout_cudagraphs[graph], current_stream) == cudaSuccess
498+
&& "cudaGraphLaunch failed");
498499
profile_end(hypers.profile);
499500
return;
500501
}
501502

502503
bool capturing = pufferl->epoch == hypers.cudagraphs;
503504
if (capturing) {
504-
cudaStreamBeginCapture(current_stream, cudaStreamCaptureModeGlobal);
505+
assert(cudaStreamBeginCapture(current_stream, cudaStreamCaptureModeGlobal) == cudaSuccess
506+
&& "cudaStreamBeginCapture failed");
505507
}
506508

507509
RolloutBuf& rollouts = pufferl->rollouts;
@@ -552,9 +554,11 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
552554

553555
if (capturing) {
554556
cudaGraph_t _graph;
555-
cudaStreamEndCapture(current_stream, &_graph);
556-
cudaGraphInstantiate(&pufferl->fused_rollout_cudagraphs[graph], _graph, 0);
557-
cudaGraphDestroy(_graph);
557+
assert(cudaStreamEndCapture(current_stream, &_graph) == cudaSuccess
558+
&& "cudaStreamEndCapture failed");
559+
assert(cudaGraphInstantiate(&pufferl->fused_rollout_cudagraphs[graph], _graph, 0) == cudaSuccess
560+
&& "cudaGraphInstantiate failed");
561+
assert(cudaGraphDestroy(_graph) == cudaSuccess && "cudaGraphDestroy failed");
558562
cudaDeviceSynchronize();
559563
}
560564
profile_end(hypers.profile);
@@ -1008,41 +1012,43 @@ __global__ void compute_prio_imp_weights(
10081012
}
10091013
}
10101014

1011-
// Multinomial with replacement (uses cuRAND)
1012-
__global__ void multinomial_sample(
1013-
int* __restrict__ out_idx, const float* __restrict__ probs,
1014-
float* __restrict__ cdf, int B, int num_samples,
1015-
uint64_t seed, int64_t* __restrict__ offset_ptr) {
1016-
int tid = threadIdx.x;
1017-
if (tid == 0) {
1015+
__global__ void build_cdf(
1016+
float* __restrict__ cdf, const float* __restrict__ probs, int B) {
1017+
if (blockIdx.x == 0 && threadIdx.x == 0) {
10181018
float cum = 0.0f;
10191019
for (int i = 0; i < B; i++) {
10201020
cum += probs[i];
10211021
cdf[i] = cum;
10221022
}
10231023
}
1024-
__syncthreads();
1025-
if (tid < num_samples) {
1026-
uint64_t base_off = *offset_ptr;
1027-
curandStatePhilox4_32_10_t rng_state;
1028-
curand_init(seed, base_off + tid, 0, &rng_state);
1029-
float u = curand_uniform(&rng_state);
1030-
int lo = 0, hi = B - 1;
1031-
while (lo < hi) {
1032-
int mid = (lo + hi) / 2;
1033-
if (cdf[mid] < u) {
1034-
lo = mid + 1;
1035-
} else {
1036-
hi = mid;
1037-
}
1038-
}
1039-
out_idx[tid] = lo;
1040-
}
1041-
if (tid == 0) {
1042-
atomicAdd((unsigned long long*)offset_ptr, (unsigned long long)num_samples);
1024+
}
1025+
1026+
__global__ void advance_rng_offset(int64_t* __restrict__ offset_ptr, int64_t delta) {
1027+
if (blockIdx.x == 0 && threadIdx.x == 0) {
1028+
*offset_ptr += delta;
10431029
}
10441030
}
10451031

1032+
// Multinomial with replacement (uses cuRAND)
1033+
__global__ void multinomial_sample(int* __restrict__ out_idx, const float* __restrict__ cdf,
1034+
int B, int num_samples, uint64_t seed, const int64_t* __restrict__ offset_ptr) {
1035+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
1036+
if (tid >= num_samples) return;
1037+
1038+
uint64_t base_off = (uint64_t)(*offset_ptr);
1039+
curandStatePhilox4_32_10_t rng_state;
1040+
curand_init(seed, base_off + tid, 0, &rng_state);
1041+
float u = curand_uniform(&rng_state);
1042+
1043+
int lo = 0, hi = B - 1;
1044+
while (lo < hi) {
1045+
int mid = (lo + hi) / 2;
1046+
if (cdf[mid] < u) lo = mid + 1;
1047+
else hi = mid;
1048+
}
1049+
out_idx[tid] = lo;
1050+
}
1051+
10461052
// Prioritize high absolute advantage trajectories
10471053
// This is a form of implicit curriculum learning
10481054
// It is a major improvement in some complex environments
@@ -1056,10 +1062,14 @@ void prio_replay_cuda(PrecisionTensor& advantages, float prio_alpha,
10561062
advantages.data, bufs.prio_probs.data, prio_alpha, T);
10571063
compute_prio_normalize<<<1, PRIO_BLOCK_SIZE, 0, stream>>>(
10581064
bufs.prio_probs.data, B);
1059-
int block = fmaxf(((minibatch_segments + 31) / 32) * 32, 32);
1060-
multinomial_sample<<<1, block, 0, stream>>>(
1061-
bufs.idx.data, bufs.prio_probs.data,
1062-
bufs.cdf.data, B, minibatch_segments, seed, offset_ptr);
1065+
//int block = fmaxf(((minibatch_segments + 31) / 32) * 32, 32);
1066+
build_cdf<<<1, 1, 0, stream>>>(bufs.cdf.data, bufs.prio_probs.data, B);
1067+
int threads = 256;
1068+
int blocks = (minibatch_segments + threads - 1) / threads;
1069+
multinomial_sample<<<blocks, threads, 0, stream>>>(
1070+
bufs.idx.data, bufs.cdf.data, B, minibatch_segments, seed, offset_ptr);
1071+
advance_rng_offset<<<1, 1, 0, stream>>>(offset_ptr, (int64_t)minibatch_segments);
1072+
10631073
int p3_blocks = (minibatch_segments + PRIO_BLOCK_SIZE - 1) / PRIO_BLOCK_SIZE;
10641074
compute_prio_imp_weights<<<p3_blocks, PRIO_BLOCK_SIZE, 0, stream>>>(
10651075
bufs.idx.data, bufs.prio_probs.data,
@@ -1368,7 +1378,8 @@ void train_impl(PuffeRL& pufferl) {
13681378
} else {
13691379
bool capturing = pufferl.train_warmup == hypers.cudagraphs;
13701380
if (capturing) {
1371-
cudaStreamBeginCapture(train_stream, cudaStreamCaptureModeGlobal);
1381+
assert(cudaStreamBeginCapture(train_stream, cudaStreamCaptureModeGlobal) == cudaSuccess
1382+
&& "cudaStreamBeginCapture failed");
13721383
}
13731384

13741385
cudaStream_t stream = train_stream;
@@ -1400,9 +1411,11 @@ void train_impl(PuffeRL& pufferl) {
14001411
}
14011412
if (capturing) {
14021413
cudaGraph_t _graph;
1403-
cudaStreamEndCapture(train_stream, &_graph);
1404-
cudaGraphInstantiate(&pufferl.train_cudagraph, _graph, 0);
1405-
cudaGraphDestroy(_graph);
1414+
assert(cudaStreamEndCapture(train_stream, &_graph) == cudaSuccess
1415+
&& "cudaStreamEndCapture failed");
1416+
assert(cudaGraphInstantiate(&pufferl.train_cudagraph, _graph, 0) == cudaSuccess
1417+
&& "cudaGraphInstantiate failed");
1418+
assert(cudaGraphDestroy(_graph) == cudaSuccess && "cudaGraphDestroy failed");
14061419
cudaDeviceSynchronize();
14071420
pufferl.train_captured = true;
14081421
}

0 commit comments

Comments
 (0)