Skip to content

Commit 91e4ce9

Browse files
committed
minor refactor
1 parent 0336926 commit 91e4ce9

2 files changed

Lines changed: 71 additions & 56 deletions

File tree

pufferlib/src/bindings.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ void py_puff_advantage(
218218
kernel<<<blocks, 256>>>(
219219
(const precision_t*)values_ptr, (const precision_t*)rewards_ptr,
220220
(const precision_t*)dones_ptr, (const precision_t*)importance_ptr,
221-
(float*)advantages_ptr,
221+
(precision_t*)advantages_ptr,
222222
gamma, lambda, rho_clip, c_clip, num_steps, horizon);
223223
}
224224

pufferlib/src/pufferlib.cu

Lines changed: 70 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,31 @@ enum LossIdx {
1414
LOSS_N = 7, NUM_LOSSES = 8,
1515
};
1616

17+
// Data collected by parallel environment workers. Each worker handles
18+
// a constant subset of agents
1719
struct RolloutBuf {
18-
PrecisionTensor observations; // (horizon, segments, input_size)
19-
PrecisionTensor actions; // (horizon, segments, num_atns)
20-
PrecisionTensor values; // (horizon, segments)
21-
PrecisionTensor logprobs; // (horizon, segments)
22-
PrecisionTensor rewards; // (horizon, segments)
23-
PrecisionTensor terminals; // (horizon, segments)
24-
PrecisionTensor ratio; // (horizon, segments)
25-
PrecisionTensor importance; // (horizon, segments)
20+
PrecisionTensor observations; // (horizon, agents, input_size)
21+
PrecisionTensor actions; // (horizon, agents, num_atns)
22+
PrecisionTensor values; // (horizon, agents) - all other tensors
23+
PrecisionTensor logprobs;
24+
PrecisionTensor rewards;
25+
PrecisionTensor terminals;
26+
PrecisionTensor ratio;
27+
PrecisionTensor importance;
2628
};
2729

28-
void register_rollout_buffers(RolloutBuf& bufs, Allocator* alloc, int H, int S, int input_size, int num_atns) {
30+
// Buffers are initialized as raw structs with only shape information. alloc_register
31+
// stores the shape and data pointer. Memory is only allocated after all buffers are registered.
32+
void register_rollout_buffers(RolloutBuf& bufs, Allocator* alloc,int T, int B, int input_size, int num_atns) {
2933
bufs = (RolloutBuf){
30-
.observations = {.shape = {H, S, input_size}},
31-
.actions = {.shape = {H, S, num_atns}},
32-
.values = {.shape = {H, S}},
33-
.logprobs = {.shape = {H, S}},
34-
.rewards = {.shape = {H, S}},
35-
.terminals = {.shape = {H, S}},
36-
.ratio = {.shape = {H, S}},
37-
.importance = {.shape = {H, S}},
34+
.observations = {.shape = {T, B, input_size}},
35+
.actions = {.shape = {T, B, num_atns}},
36+
.values = {.shape = {T, B}},
37+
.logprobs = {.shape = {T, B}},
38+
.rewards = {.shape = {T, B}},
39+
.terminals = {.shape = {T, B}},
40+
.ratio = {.shape = {T, B}},
41+
.importance = {.shape = {T, B}},
3842
};
3943
alloc_register(alloc, &bufs.observations);
4044
alloc_register(alloc, &bufs.actions);
@@ -46,25 +50,28 @@ void register_rollout_buffers(RolloutBuf& bufs, Allocator* alloc, int H, int S,
4650
alloc_register(alloc, &bufs.importance);
4751
}
4852

53+
// Train data layout is transposed to (B, T) from rollouts layout (T, B)
54+
// This allows env workers to collect data with contiguous writes and
55+
// training to perform several (though not all) ops in contiguous memory
4956
struct TrainGraph {
50-
PrecisionTensor mb_obs; // (S, H, input_size)
51-
PrecisionTensor mb_state; // (L, S, 1, hidden)
52-
PrecisionTensor mb_actions; // (S, H, num_atns)
53-
PrecisionTensor mb_logprobs; // (S, H)
54-
FloatTensor mb_advantages; // (S, H) f32
55-
PrecisionTensor mb_prio; // (S, 1)
56-
PrecisionTensor mb_values; // (S, H)
57-
PrecisionTensor mb_returns; // (S, H)
58-
PrecisionTensor mb_ratio; // (S, H)
59-
PrecisionTensor mb_newvalue; // (S, H, 1)
57+
PrecisionTensor mb_obs; // (B, T, input_size)
58+
PrecisionTensor mb_state; // (layers, B, 1, hidden)
59+
PrecisionTensor mb_actions; // (B, T, num_atns)
60+
PrecisionTensor mb_logprobs; // (B, T)
61+
PrecisionTensor mb_advantages; // (B, T)
62+
PrecisionTensor mb_prio; // (B, T)
63+
PrecisionTensor mb_values; // (B, T)
64+
PrecisionTensor mb_returns; // (B, T)
65+
PrecisionTensor mb_ratio; // (B, T)
66+
PrecisionTensor mb_newvalue; // (B, T, 1)
6067
};
6168

6269
struct PPOGraphArgs {
6370
precision_t* out_ratio;
6471
precision_t* out_newvalue;
6572
const precision_t* actions;
6673
const precision_t* old_logprobs;
67-
const float* advantages;
74+
const precision_t* advantages;
6875
const precision_t* prio;
6976
const precision_t* values;
7077
const precision_t* returns;
@@ -300,7 +307,7 @@ typedef struct {
300307
RolloutBuf train_rollouts; // Pre-allocated transposed copy for train_impl
301308
EnvBuf env;
302309
TrainGraph train_buf;
303-
FloatTensor advantages_puf; // Pre-allocated for train_impl (S, H) f32
310+
PrecisionTensor advantages_puf; // Pre-allocated for train_impl (S, H)
304311
cudaGraphExec_t* fused_rollout_cudagraphs; // [horizon][num_buffers]
305312
cudaGraphExec_t train_cudagraph;
306313
cudaStream_t* streams; // per-buffer raw CUDA streams
@@ -696,7 +703,7 @@ __global__ void ppo_loss_fwd_bwd_kernel(
696703
// --- Shared computation (used by both forward and backward) ---
697704

698705
float old_logp = to_float(g.old_logprobs[nt]);
699-
float adv = float(g.advantages[nt]);
706+
float adv = to_float(g.advantages[nt]);
700707
float w = to_float(g.prio[n]);
701708
float val = to_float(g.values[nt]);
702709
float ret = to_float(g.returns[nt]);
@@ -880,13 +887,13 @@ __global__ void ppo_loss_reduce_kernel(
880887
}
881888
}
882889

883-
__global__ void var_mean_kernel(const float* __restrict__ src, float* __restrict__ var_out,
890+
__global__ void var_mean_kernel(const precision_t* __restrict__ src, float* __restrict__ var_out,
884891
float* __restrict__ mean_out, int n) {
885892
__shared__ float sdata[256];
886893
int tid = threadIdx.x;
887894
float sum = 0.0f;
888895
for (int i = tid; i < n; i += blockDim.x) {
889-
sum += src[i];
896+
sum += to_float(src[i]);
890897
}
891898
sdata[tid] = sum;
892899
__syncthreads();
@@ -903,7 +910,7 @@ __global__ void var_mean_kernel(const float* __restrict__ src, float* __restrict
903910
__syncthreads();
904911
float ss = 0.0f;
905912
for (int i = tid; i < n; i += blockDim.x) {
906-
float d = src[i] - mean;
913+
float d = to_float(src[i]) - mean;
907914
ss += d * d;
908915
}
909916
sdata[tid] = ss;
@@ -994,7 +1001,7 @@ void ppo_loss_fwd_bwd(
9941001
#define PRIO_BLOCK_SIZE 256
9951002
#define PRIO_NUM_WARPS (PRIO_BLOCK_SIZE / PRIO_WARP_SIZE)
9961003
__global__ void compute_prio_adv_reduction(
997-
const float* __restrict__ advantages,
1004+
const precision_t* __restrict__ advantages,
9981005
float* prio_weights,
9991006
float prio_alpha,
10001007
int stride
@@ -1005,7 +1012,7 @@ __global__ void compute_prio_adv_reduction(
10051012

10061013
float local_sum = 0.0f;
10071014
for (int t = tx; t < stride; t += blockDim.x) {
1008-
local_sum += fabsf(advantages[offset + t]);
1015+
local_sum += fabsf(to_float(advantages[offset + t]));
10091016
}
10101017

10111018
for (int s = PRIO_WARP_SIZE / 2; s >= 1; s /= 2) {
@@ -1112,7 +1119,7 @@ __global__ void multinomial_with_replacement_kernel(
11121119
}
11131120
}
11141121

1115-
void prio_replay_cuda(FloatTensor& advantages, float prio_alpha,
1122+
void prio_replay_cuda(PrecisionTensor& advantages, float prio_alpha,
11161123
int minibatch_segments, int total_agents, float anneal_beta,
11171124
PrioBuffers& bufs, ulong seed, long* offset_ptr, cudaStream_t stream) {
11181125
int S = advantages.shape[0], T = advantages.shape[1];
@@ -1132,7 +1139,7 @@ void prio_replay_cuda(FloatTensor& advantages, float prio_alpha,
11321139

11331140
__device__ void puff_advantage_row_scalar(
11341141
const precision_t* values, const precision_t* rewards, const precision_t* dones,
1135-
const precision_t* importance, float* advantages, float gamma, float lambda,
1142+
const precision_t* importance, precision_t* advantages, float gamma, float lambda,
11361143
float rho_clip, float c_clip, int horizon
11371144
) {
11381145
float lastpufferlam = 0;
@@ -1147,7 +1154,7 @@ __device__ void puff_advantage_row_scalar(
11471154
float v_nxt = to_float(values[t_next]);
11481155
float delta = rho_t*r_nxt + gamma*v_nxt*nextnonterminal - v;
11491156
lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal;
1150-
advantages[t] = lastpufferlam;
1157+
advantages[t] = from_float(lastpufferlam);
11511158
}
11521159
}
11531160

@@ -1165,9 +1172,22 @@ __device__ __forceinline__ void adv_vec_load(const __nv_bfloat16* ptr, float* ou
11651172
}
11661173
}
11671174

1175+
// Store N floats as precision_t via 128-bit writes (float4 for f32, uint4 for bf16)
1176+
__device__ __forceinline__ void adv_vec_store(float* ptr, const float* vals) {
1177+
*reinterpret_cast<float4*>(ptr) = make_float4(vals[0], vals[1], vals[2], vals[3]);
1178+
}
1179+
1180+
__device__ __forceinline__ void adv_vec_store(__nv_bfloat16* ptr, const float* vals) {
1181+
// N=8 for bf16: all 8 elements fit in one uint4 (128 bits)
1182+
__nv_bfloat16 tmp[8];
1183+
#pragma unroll
1184+
for (int i = 0; i < 8; i++) tmp[i] = __float2bfloat16(vals[i]);
1185+
*reinterpret_cast<uint4*>(ptr) = *reinterpret_cast<const uint4*>(tmp);
1186+
}
1187+
11681188
__device__ __forceinline__ void puff_advantage_row_vec(
11691189
const precision_t* values, const precision_t* rewards, const precision_t* dones,
1170-
const precision_t* importance, float* advantages, float gamma, float lambda,
1190+
const precision_t* importance, precision_t* advantages, float gamma, float lambda,
11711191
float rho_clip, float c_clip, int horizon
11721192
) {
11731193
constexpr int N = 16 / sizeof(precision_t);
@@ -1204,17 +1224,12 @@ __device__ __forceinline__ void puff_advantage_row_vec(
12041224
next_reward = r[i];
12051225
}
12061226

1207-
*reinterpret_cast<float4*>(advantages + base) =
1208-
make_float4(adv[0], adv[1], adv[2], adv[3]);
1209-
if (N > 4) {
1210-
*reinterpret_cast<float4*>(advantages + base + 4) =
1211-
make_float4(adv[4], adv[5], adv[6], adv[7]);
1212-
}
1227+
adv_vec_store(advantages + base, adv);
12131228
}
12141229
}
12151230

12161231
__global__ void puff_advantage_kernel(const precision_t* values, const precision_t* rewards,
1217-
const precision_t* dones, const precision_t* importance, float* advantages, float gamma,
1232+
const precision_t* dones, const precision_t* importance, precision_t* advantages, float gamma,
12181233
float lambda, float rho_clip, float c_clip, int num_steps, int horizon) {
12191234
int row = blockIdx.x*blockDim.x + threadIdx.x;
12201235
if (row >= num_steps) {
@@ -1226,7 +1241,7 @@ __global__ void puff_advantage_kernel(const precision_t* values, const precision
12261241
}
12271242

12281243
__global__ void puff_advantage_kernel_scalar(const precision_t* values, const precision_t* rewards,
1229-
const precision_t* dones, const precision_t* importance, float* advantages, float gamma,
1244+
const precision_t* dones, const precision_t* importance, precision_t* advantages, float gamma,
12301245
float lambda, float rho_clip, float c_clip, int num_steps, int horizon) {
12311246
int row = blockIdx.x*blockDim.x + threadIdx.x;
12321247
if (row >= num_steps) {
@@ -1238,7 +1253,7 @@ __global__ void puff_advantage_kernel_scalar(const precision_t* values, const pr
12381253
}
12391254

12401255
void puff_advantage_cuda(PrecisionTensor& values, PrecisionTensor& rewards,
1241-
PrecisionTensor& dones, PrecisionTensor& importance, FloatTensor& advantages,
1256+
PrecisionTensor& dones, PrecisionTensor& importance, PrecisionTensor& advantages,
12421257
float gamma, float lambda, float rho_clip, float c_clip, cudaStream_t stream) {
12431258
int num_steps = values.shape[0], horizon = values.shape[1];
12441259
int blocks = grid_size(num_steps);
@@ -1260,30 +1275,30 @@ __global__ void index_copy_kernel(char* __restrict__ dst, const int64_t* __restr
12601275

12611276
__device__ __forceinline__ void copy_values_adv_returns(
12621277
const precision_t* __restrict__ src_values, precision_t* __restrict__ dst_values,
1263-
const float* __restrict__ src_advantages, float* __restrict__ dst_advantages,
1278+
const precision_t* __restrict__ src_advantages, precision_t* __restrict__ dst_advantages,
12641279
precision_t* __restrict__ dst_returns,
12651280
int src_row, int dst_row, int horizon
12661281
) {
12671282
int srh = (int64_t)src_row * horizon;
12681283
int drh = (int64_t)dst_row * horizon;
12691284
const precision_t* s_values = src_values + srh;
1270-
const float* s_adv = src_advantages + srh;
1285+
const precision_t* s_adv = src_advantages + srh;
12711286
precision_t* d_values = dst_values + drh;
1272-
float* d_adv = dst_advantages + drh;
1287+
precision_t* d_adv = dst_advantages + drh;
12731288
precision_t* d_returns = dst_returns + drh;
12741289
for (int i = threadIdx.x; i < horizon; i += blockDim.x) {
12751290
precision_t val = s_values[i];
1276-
float adv = s_adv[i];
1291+
precision_t adv = s_adv[i];
12771292
d_values[i] = val;
12781293
d_adv[i] = adv;
1279-
d_returns[i] = from_float(to_float(val) + adv);
1294+
d_returns[i] = from_float(to_float(val) + to_float(adv));
12801295
}
12811296
}
12821297

12831298
__global__ void select_copy_kernel(
12841299
RolloutBuf rollouts, TrainGraph graph,
12851300
const int64_t* __restrict__ idx,
1286-
const float* __restrict__ advantages, const float* __restrict__ mb_prio
1301+
const precision_t* __restrict__ advantages, const float* __restrict__ mb_prio
12871302
) {
12881303
int mb = blockIdx.x;
12891304
int ch = blockIdx.y;
@@ -1364,7 +1379,7 @@ void train_impl(PuffeRL& pufferl) {
13641379
rollouts.ratio.data, from_float(1.0f), numel(rollouts.ratio.shape));
13651380

13661381
// Zero pre-allocated advantages buffer
1367-
FloatTensor& advantages_puf = pufferl.advantages_puf;
1382+
PrecisionTensor& advantages_puf = pufferl.advantages_puf;
13681383

13691384
// Inline any of these only used once
13701385
int minibatch_size = hypers.minibatch_size;

0 commit comments

Comments
 (0)