@@ -14,27 +14,31 @@ enum LossIdx {
1414 LOSS_N = 7 , NUM_LOSSES = 8 ,
1515};
1616
17+ // Data collected by parallel environment workers. Each worker handles
18+ // a constant subset of agents
1719struct RolloutBuf {
18- PrecisionTensor observations; // (horizon, segments , input_size)
19- PrecisionTensor actions; // (horizon, segments , num_atns)
20- PrecisionTensor values; // (horizon, segments)
21- PrecisionTensor logprobs; // (horizon, segments)
22- PrecisionTensor rewards; // (horizon, segments)
23- PrecisionTensor terminals; // (horizon, segments)
24- PrecisionTensor ratio; // (horizon, segments)
25- PrecisionTensor importance; // (horizon, segments)
20+ PrecisionTensor observations; // (horizon, agents , input_size)
21+ PrecisionTensor actions; // (horizon, agents , num_atns)
22+ PrecisionTensor values; // (horizon, agents) - all other tensors
23+ PrecisionTensor logprobs;
24+ PrecisionTensor rewards;
25+ PrecisionTensor terminals;
26+ PrecisionTensor ratio;
27+ PrecisionTensor importance;
2628};
2729
28- void register_rollout_buffers (RolloutBuf& bufs, Allocator* alloc, int H, int S, int input_size, int num_atns) {
30+ // Buffers are initialized as raw structs with only shape information. alloc_register
31+ // stores the shape and data pointer. Memory is only allocated after all buffers are registered.
32+ void register_rollout_buffers (RolloutBuf& bufs, Allocator* alloc,int T, int B, int input_size, int num_atns) {
2933 bufs = (RolloutBuf){
30- .observations = {.shape = {H, S , input_size}},
31- .actions = {.shape = {H, S , num_atns}},
32- .values = {.shape = {H, S }},
33- .logprobs = {.shape = {H, S }},
34- .rewards = {.shape = {H, S }},
35- .terminals = {.shape = {H, S }},
36- .ratio = {.shape = {H, S }},
37- .importance = {.shape = {H, S }},
34+ .observations = {.shape = {T, B , input_size}},
35+ .actions = {.shape = {T, B , num_atns}},
36+ .values = {.shape = {T, B }},
37+ .logprobs = {.shape = {T, B }},
38+ .rewards = {.shape = {T, B }},
39+ .terminals = {.shape = {T, B }},
40+ .ratio = {.shape = {T, B }},
41+ .importance = {.shape = {T, B }},
3842 };
3943 alloc_register (alloc, &bufs.observations );
4044 alloc_register (alloc, &bufs.actions );
@@ -46,25 +50,28 @@ void register_rollout_buffers(RolloutBuf& bufs, Allocator* alloc, int H, int S,
4650 alloc_register (alloc, &bufs.importance );
4751}
4852
53+ // Train data layout is transposed to (B, T) from rollouts layout (T, B)
54+ // This allows env workers to collect data with contiguous writes and
55+ // training to perform several (though not all) ops in contiguous memory
4956struct TrainGraph {
50- PrecisionTensor mb_obs; // (S, H , input_size)
51- PrecisionTensor mb_state; // (L, S , 1, hidden)
52- PrecisionTensor mb_actions; // (S, H , num_atns)
53- PrecisionTensor mb_logprobs; // (S, H )
54- FloatTensor mb_advantages; // (S, H) f32
55- PrecisionTensor mb_prio; // (S, 1 )
56- PrecisionTensor mb_values; // (S, H )
57- PrecisionTensor mb_returns; // (S, H )
58- PrecisionTensor mb_ratio; // (S, H )
59- PrecisionTensor mb_newvalue; // (S, H , 1)
57+ PrecisionTensor mb_obs; // (B, T , input_size)
58+ PrecisionTensor mb_state; // (layers, B , 1, hidden)
59+ PrecisionTensor mb_actions; // (B, T , num_atns)
60+ PrecisionTensor mb_logprobs; // (B, T )
61+ PrecisionTensor mb_advantages; // (B, T)
62+ PrecisionTensor mb_prio; // (B, T )
63+ PrecisionTensor mb_values; // (B, T )
64+ PrecisionTensor mb_returns; // (B, T )
65+ PrecisionTensor mb_ratio; // (B, T )
66+ PrecisionTensor mb_newvalue; // (B, T , 1)
6067};
6168
6269struct PPOGraphArgs {
6370 precision_t * out_ratio;
6471 precision_t * out_newvalue;
6572 const precision_t * actions;
6673 const precision_t * old_logprobs;
67- const float * advantages;
74+ const precision_t * advantages;
6875 const precision_t * prio;
6976 const precision_t * values;
7077 const precision_t * returns;
@@ -300,7 +307,7 @@ typedef struct {
300307 RolloutBuf train_rollouts; // Pre-allocated transposed copy for train_impl
301308 EnvBuf env;
302309 TrainGraph train_buf;
303- FloatTensor advantages_puf; // Pre-allocated for train_impl (S, H) f32
310+ PrecisionTensor advantages_puf; // Pre-allocated for train_impl (S, H)
304311 cudaGraphExec_t* fused_rollout_cudagraphs; // [horizon][num_buffers]
305312 cudaGraphExec_t train_cudagraph;
306313 cudaStream_t* streams; // per-buffer raw CUDA streams
@@ -696,7 +703,7 @@ __global__ void ppo_loss_fwd_bwd_kernel(
696703 // --- Shared computation (used by both forward and backward) ---
697704
698705 float old_logp = to_float (g.old_logprobs [nt]);
699- float adv = float (g.advantages [nt]);
706+ float adv = to_float (g.advantages [nt]);
700707 float w = to_float (g.prio [n]);
701708 float val = to_float (g.values [nt]);
702709 float ret = to_float (g.returns [nt]);
@@ -880,13 +887,13 @@ __global__ void ppo_loss_reduce_kernel(
880887 }
881888}
882889
883- __global__ void var_mean_kernel (const float * __restrict__ src, float * __restrict__ var_out,
890+ __global__ void var_mean_kernel (const precision_t * __restrict__ src, float * __restrict__ var_out,
884891 float * __restrict__ mean_out, int n) {
885892 __shared__ float sdata[256 ];
886893 int tid = threadIdx .x ;
887894 float sum = 0 .0f ;
888895 for (int i = tid; i < n; i += blockDim .x ) {
889- sum += src[i];
896+ sum += to_float ( src[i]) ;
890897 }
891898 sdata[tid] = sum;
892899 __syncthreads ();
@@ -903,7 +910,7 @@ __global__ void var_mean_kernel(const float* __restrict__ src, float* __restrict
903910 __syncthreads ();
904911 float ss = 0 .0f ;
905912 for (int i = tid; i < n; i += blockDim .x ) {
906- float d = src[i] - mean;
913+ float d = to_float ( src[i]) - mean;
907914 ss += d * d;
908915 }
909916 sdata[tid] = ss;
@@ -994,7 +1001,7 @@ void ppo_loss_fwd_bwd(
9941001#define PRIO_BLOCK_SIZE 256
9951002#define PRIO_NUM_WARPS (PRIO_BLOCK_SIZE / PRIO_WARP_SIZE)
9961003__global__ void compute_prio_adv_reduction (
997- const float * __restrict__ advantages,
1004+ const precision_t * __restrict__ advantages,
9981005 float * prio_weights,
9991006 float prio_alpha,
10001007 int stride
@@ -1005,7 +1012,7 @@ __global__ void compute_prio_adv_reduction(
10051012
10061013 float local_sum = 0 .0f ;
10071014 for (int t = tx; t < stride; t += blockDim .x ) {
1008- local_sum += fabsf (advantages[offset + t]);
1015+ local_sum += fabsf (to_float ( advantages[offset + t]) );
10091016 }
10101017
10111018 for (int s = PRIO_WARP_SIZE / 2 ; s >= 1 ; s /= 2 ) {
@@ -1112,7 +1119,7 @@ __global__ void multinomial_with_replacement_kernel(
11121119 }
11131120}
11141121
1115- void prio_replay_cuda (FloatTensor & advantages, float prio_alpha,
1122+ void prio_replay_cuda (PrecisionTensor & advantages, float prio_alpha,
11161123 int minibatch_segments, int total_agents, float anneal_beta,
11171124 PrioBuffers& bufs, ulong seed, long * offset_ptr, cudaStream_t stream) {
11181125 int S = advantages.shape [0 ], T = advantages.shape [1 ];
@@ -1132,7 +1139,7 @@ void prio_replay_cuda(FloatTensor& advantages, float prio_alpha,
11321139
11331140__device__ void puff_advantage_row_scalar (
11341141 const precision_t * values, const precision_t * rewards, const precision_t * dones,
1135- const precision_t * importance, float * advantages, float gamma, float lambda,
1142+ const precision_t * importance, precision_t * advantages, float gamma, float lambda,
11361143 float rho_clip, float c_clip, int horizon
11371144) {
11381145 float lastpufferlam = 0 ;
@@ -1147,7 +1154,7 @@ __device__ void puff_advantage_row_scalar(
11471154 float v_nxt = to_float (values[t_next]);
11481155 float delta = rho_t *r_nxt + gamma*v_nxt*nextnonterminal - v;
11491156 lastpufferlam = delta + gamma*lambda*c_t *lastpufferlam*nextnonterminal;
1150- advantages[t] = lastpufferlam;
1157+ advantages[t] = from_float ( lastpufferlam) ;
11511158 }
11521159}
11531160
@@ -1165,9 +1172,22 @@ __device__ __forceinline__ void adv_vec_load(const __nv_bfloat16* ptr, float* ou
11651172 }
11661173}
11671174
1175+ // Store N floats as precision_t via 128-bit writes (float4 for f32, uint4 for bf16)
1176+ __device__ __forceinline__ void adv_vec_store (float * ptr, const float * vals) {
1177+ *reinterpret_cast <float4 *>(ptr) = make_float4 (vals[0 ], vals[1 ], vals[2 ], vals[3 ]);
1178+ }
1179+
1180+ __device__ __forceinline__ void adv_vec_store (__nv_bfloat16* ptr, const float * vals) {
1181+ // N=8 for bf16: all 8 elements fit in one uint4 (128 bits)
1182+ __nv_bfloat16 tmp[8 ];
1183+ #pragma unroll
1184+ for (int i = 0 ; i < 8 ; i++) tmp[i] = __float2bfloat16 (vals[i]);
1185+ *reinterpret_cast <uint4 *>(ptr) = *reinterpret_cast <const uint4 *>(tmp);
1186+ }
1187+
11681188__device__ __forceinline__ void puff_advantage_row_vec (
11691189 const precision_t * values, const precision_t * rewards, const precision_t * dones,
1170- const precision_t * importance, float * advantages, float gamma, float lambda,
1190+ const precision_t * importance, precision_t * advantages, float gamma, float lambda,
11711191 float rho_clip, float c_clip, int horizon
11721192) {
11731193 constexpr int N = 16 / sizeof (precision_t );
@@ -1204,17 +1224,12 @@ __device__ __forceinline__ void puff_advantage_row_vec(
12041224 next_reward = r[i];
12051225 }
12061226
1207- *reinterpret_cast <float4 *>(advantages + base) =
1208- make_float4 (adv[0 ], adv[1 ], adv[2 ], adv[3 ]);
1209- if (N > 4 ) {
1210- *reinterpret_cast <float4 *>(advantages + base + 4 ) =
1211- make_float4 (adv[4 ], adv[5 ], adv[6 ], adv[7 ]);
1212- }
1227+ adv_vec_store (advantages + base, adv);
12131228 }
12141229}
12151230
12161231__global__ void puff_advantage_kernel (const precision_t * values, const precision_t * rewards,
1217- const precision_t * dones, const precision_t * importance, float * advantages, float gamma,
1232+ const precision_t * dones, const precision_t * importance, precision_t * advantages, float gamma,
12181233 float lambda, float rho_clip, float c_clip, int num_steps, int horizon) {
12191234 int row = blockIdx .x *blockDim .x + threadIdx .x ;
12201235 if (row >= num_steps) {
@@ -1226,7 +1241,7 @@ __global__ void puff_advantage_kernel(const precision_t* values, const precision
12261241}
12271242
12281243__global__ void puff_advantage_kernel_scalar (const precision_t * values, const precision_t * rewards,
1229- const precision_t * dones, const precision_t * importance, float * advantages, float gamma,
1244+ const precision_t * dones, const precision_t * importance, precision_t * advantages, float gamma,
12301245 float lambda, float rho_clip, float c_clip, int num_steps, int horizon) {
12311246 int row = blockIdx .x *blockDim .x + threadIdx .x ;
12321247 if (row >= num_steps) {
@@ -1238,7 +1253,7 @@ __global__ void puff_advantage_kernel_scalar(const precision_t* values, const pr
12381253}
12391254
12401255void puff_advantage_cuda (PrecisionTensor& values, PrecisionTensor& rewards,
1241- PrecisionTensor& dones, PrecisionTensor& importance, FloatTensor & advantages,
1256+ PrecisionTensor& dones, PrecisionTensor& importance, PrecisionTensor & advantages,
12421257 float gamma, float lambda, float rho_clip, float c_clip, cudaStream_t stream) {
12431258 int num_steps = values.shape [0 ], horizon = values.shape [1 ];
12441259 int blocks = grid_size (num_steps);
@@ -1260,30 +1275,30 @@ __global__ void index_copy_kernel(char* __restrict__ dst, const int64_t* __restr
12601275
12611276__device__ __forceinline__ void copy_values_adv_returns (
12621277 const precision_t * __restrict__ src_values, precision_t * __restrict__ dst_values,
1263- const float * __restrict__ src_advantages, float * __restrict__ dst_advantages,
1278+ const precision_t * __restrict__ src_advantages, precision_t * __restrict__ dst_advantages,
12641279 precision_t * __restrict__ dst_returns,
12651280 int src_row, int dst_row, int horizon
12661281) {
12671282 int srh = (int64_t )src_row * horizon;
12681283 int drh = (int64_t )dst_row * horizon;
12691284 const precision_t * s_values = src_values + srh;
1270- const float * s_adv = src_advantages + srh;
1285+ const precision_t * s_adv = src_advantages + srh;
12711286 precision_t * d_values = dst_values + drh;
1272- float * d_adv = dst_advantages + drh;
1287+ precision_t * d_adv = dst_advantages + drh;
12731288 precision_t * d_returns = dst_returns + drh;
12741289 for (int i = threadIdx .x ; i < horizon; i += blockDim .x ) {
12751290 precision_t val = s_values[i];
1276- float adv = s_adv[i];
1291+ precision_t adv = s_adv[i];
12771292 d_values[i] = val;
12781293 d_adv[i] = adv;
1279- d_returns[i] = from_float (to_float (val) + adv);
1294+ d_returns[i] = from_float (to_float (val) + to_float ( adv) );
12801295 }
12811296}
12821297
12831298__global__ void select_copy_kernel (
12841299 RolloutBuf rollouts, TrainGraph graph,
12851300 const int64_t * __restrict__ idx,
1286- const float * __restrict__ advantages, const float * __restrict__ mb_prio
1301+ const precision_t * __restrict__ advantages, const float * __restrict__ mb_prio
12871302) {
12881303 int mb = blockIdx .x ;
12891304 int ch = blockIdx .y ;
@@ -1364,7 +1379,7 @@ void train_impl(PuffeRL& pufferl) {
13641379 rollouts.ratio .data , from_float (1 .0f ), numel (rollouts.ratio .shape ));
13651380
13661381 // Zero pre-allocated advantages buffer
1367- FloatTensor & advantages_puf = pufferl.advantages_puf ;
1382+ PrecisionTensor & advantages_puf = pufferl.advantages_puf ;
13681383
13691384 // Inline any of these only used once
13701385 int minibatch_size = hypers.minibatch_size ;
0 commit comments