88#include " muon.cu"
99#include " vecenv.h"
1010
11- // Loss component indices
1211enum LossIdx {
1312 LOSS_PG = 0 , LOSS_VF = 1 , LOSS_ENT = 2 , LOSS_TOTAL = 3 ,
1413 LOSS_OLD_APPROX_KL = 4 , LOSS_APPROX_KL = 5 , LOSS_CLIPFRAC = 6 ,
@@ -26,6 +25,27 @@ struct RolloutBuf {
2625 PrecisionTensor importance; // (horizon, segments)
2726};
2827
28+ void register_rollout_buffers (RolloutBuf& bufs, Allocator* alloc, int H, int S, int input_size, int num_atns) {
29+ bufs = (RolloutBuf){
30+ .observations = {.shape = {H, S, input_size}},
31+ .actions = {.shape = {H, S, num_atns}},
32+ .values = {.shape = {H, S}},
33+ .logprobs = {.shape = {H, S}},
34+ .rewards = {.shape = {H, S}},
35+ .terminals = {.shape = {H, S}},
36+ .ratio = {.shape = {H, S}},
37+ .importance = {.shape = {H, S}},
38+ };
39+ alloc_register (alloc, &bufs.observations );
40+ alloc_register (alloc, &bufs.actions );
41+ alloc_register (alloc, &bufs.values );
42+ alloc_register (alloc, &bufs.logprobs );
43+ alloc_register (alloc, &bufs.rewards );
44+ alloc_register (alloc, &bufs.terminals );
45+ alloc_register (alloc, &bufs.ratio );
46+ alloc_register (alloc, &bufs.importance );
47+ }
48+
2949struct TrainGraph {
3050 PrecisionTensor mb_obs; // (S, H, input_size)
3151 PrecisionTensor mb_state; // (L, S, 1, hidden)
@@ -39,8 +59,6 @@ struct TrainGraph {
3959 PrecisionTensor mb_newvalue; // (S, H, 1)
4060};
4161
42- // Fused PPO forward + backward kernel: computes loss partials AND gradients in one pass.
43- // Avoids redundant recomputation of logits, logsumexp, ratio, advantage normalization.
4462struct PPOGraphArgs {
4563 precision_t * out_ratio;
4664 precision_t * out_newvalue;
@@ -53,18 +71,15 @@ struct PPOGraphArgs {
5371};
5472
5573struct PPOKernelArgs {
56- // Gradient outputs
57- float * grad_logits; // For continuous: grad_mean
58- float * grad_logstd; // For continuous: grad_logstd (nullptr for discrete)
74+ float * grad_logits;
75+ float * grad_logstd; // For continuous actions
5976 float * grad_values_pred;
60- // Inputs (from dec_out)
6177 const precision_t * logits;
62- const precision_t * logstd; // nullptr for discrete
78+ const precision_t * logstd; // Continuous only
6379 const precision_t * values_pred;
6480 const float * adv_mean;
6581 const float * adv_var;
6682 const int * act_sizes;
67- // Scalars
6883 int num_atns;
6984 float clip_coef, vf_clip_coef, vf_coef, ent_coef;
7085 int T_seq, A_total, N;
@@ -73,32 +88,12 @@ struct PPOKernelArgs {
7388 bool is_continuous;
7489};
7590
76- // Pre-allocated buffers for PPO loss
7791struct PPOBuffersPuf {
7892 FloatTensor loss_output, grad_loss;
7993 DoubleTensor saved_for_bwd;
8094 FloatTensor grad_logits, grad_values, grad_logstd, adv_scratch;
8195};
8296
83- // Pre-allocated buffers for prio_replay
84- struct PrioBuffers {
85- FloatTensor prio_probs, cdf, mb_prio;
86- LongTensor idx;
87- };
88-
89- void register_prio_buffers (PrioBuffers& bufs, Allocator* alloc, int S, int minibatch_segments) {
90- bufs = (PrioBuffers){
91- .prio_probs = {.shape = {S}},
92- .cdf = {.shape = {S}},
93- .mb_prio = {.shape = {minibatch_segments, 1 }},
94- .idx = {.shape = {minibatch_segments}},
95- };
96- alloc_register (alloc, &bufs.prio_probs );
97- alloc_register (alloc, &bufs.cdf );
98- alloc_register (alloc, &bufs.idx );
99- alloc_register (alloc, &bufs.mb_prio );
100- }
101-
10297void register_ppo_buffers (PPOBuffersPuf& bufs, Allocator* alloc, int N, int T, int A_total, bool is_continuous) {
10398 long total = (long )N * T;
10499 bufs = (PPOBuffersPuf){
@@ -119,25 +114,24 @@ void register_ppo_buffers(PPOBuffersPuf& bufs, Allocator* alloc, int N, int T, i
119114 alloc_register (alloc, &bufs.adv_scratch );
120115}
121116
122- void register_rollout_buffers (RolloutBuf& bufs, Allocator* alloc, int H, int S, int input_size, int num_atns) {
123- bufs = (RolloutBuf){
124- .observations = {.shape = {H, S, input_size}},
125- .actions = {.shape = {H, S, num_atns}},
126- .values = {.shape = {H, S}},
127- .logprobs = {.shape = {H, S}},
128- .rewards = {.shape = {H, S}},
129- .terminals = {.shape = {H, S}},
130- .ratio = {.shape = {H, S}},
131- .importance = {.shape = {H, S}},
117+ // Prioritized replay over single-epoch data. These kernels are
118+ // the least cleaned because we will likely have a better method in 5.0
119+ struct PrioBuffers {
120+ FloatTensor prio_probs, cdf, mb_prio;
121+ LongTensor idx;
122+ };
123+
124+ void register_prio_buffers (PrioBuffers& bufs, Allocator* alloc, int S, int minibatch_segments) {
125+ bufs = (PrioBuffers){
126+ .prio_probs = {.shape = {S}},
127+ .cdf = {.shape = {S}},
128+ .mb_prio = {.shape = {minibatch_segments, 1 }},
129+ .idx = {.shape = {minibatch_segments}},
132130 };
133- alloc_register (alloc, &bufs.observations );
134- alloc_register (alloc, &bufs.actions );
135- alloc_register (alloc, &bufs.values );
136- alloc_register (alloc, &bufs.logprobs );
137- alloc_register (alloc, &bufs.rewards );
138- alloc_register (alloc, &bufs.terminals );
139- alloc_register (alloc, &bufs.ratio );
140- alloc_register (alloc, &bufs.importance );
131+ alloc_register (alloc, &bufs.prio_probs );
132+ alloc_register (alloc, &bufs.cdf );
133+ alloc_register (alloc, &bufs.idx );
134+ alloc_register (alloc, &bufs.mb_prio );
141135}
142136
143137void register_train_buffers (TrainGraph& bufs, Allocator* alloc, int S, int H, int input_size,
0 commit comments