Skip to content

Commit b936162

Browse files
committed
minor cleanup
1 parent 8adcc9e commit b936162

1 file changed

Lines changed: 41 additions & 47 deletions

File tree

pufferlib/src/pufferlib.cu

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "muon.cu"
99
#include "vecenv.h"
1010

11-
// Loss component indices
1211
enum LossIdx {
1312
LOSS_PG = 0, LOSS_VF = 1, LOSS_ENT = 2, LOSS_TOTAL = 3,
1413
LOSS_OLD_APPROX_KL = 4, LOSS_APPROX_KL = 5, LOSS_CLIPFRAC = 6,
@@ -26,6 +25,27 @@ struct RolloutBuf {
2625
PrecisionTensor importance; // (horizon, segments)
2726
};
2827

28+
void register_rollout_buffers(RolloutBuf& bufs, Allocator* alloc, int H, int S, int input_size, int num_atns) {
29+
bufs = (RolloutBuf){
30+
.observations = {.shape = {H, S, input_size}},
31+
.actions = {.shape = {H, S, num_atns}},
32+
.values = {.shape = {H, S}},
33+
.logprobs = {.shape = {H, S}},
34+
.rewards = {.shape = {H, S}},
35+
.terminals = {.shape = {H, S}},
36+
.ratio = {.shape = {H, S}},
37+
.importance = {.shape = {H, S}},
38+
};
39+
alloc_register(alloc, &bufs.observations);
40+
alloc_register(alloc, &bufs.actions);
41+
alloc_register(alloc, &bufs.values);
42+
alloc_register(alloc, &bufs.logprobs);
43+
alloc_register(alloc, &bufs.rewards);
44+
alloc_register(alloc, &bufs.terminals);
45+
alloc_register(alloc, &bufs.ratio);
46+
alloc_register(alloc, &bufs.importance);
47+
}
48+
2949
struct TrainGraph {
3050
PrecisionTensor mb_obs; // (S, H, input_size)
3151
PrecisionTensor mb_state; // (L, S, 1, hidden)
@@ -39,8 +59,6 @@ struct TrainGraph {
3959
PrecisionTensor mb_newvalue; // (S, H, 1)
4060
};
4161

42-
// Fused PPO forward + backward kernel: computes loss partials AND gradients in one pass.
43-
// Avoids redundant recomputation of logits, logsumexp, ratio, advantage normalization.
4462
struct PPOGraphArgs {
4563
precision_t* out_ratio;
4664
precision_t* out_newvalue;
@@ -53,18 +71,15 @@ struct PPOGraphArgs {
5371
};
5472

5573
struct PPOKernelArgs {
56-
// Gradient outputs
57-
float* grad_logits; // For continuous: grad_mean
58-
float* grad_logstd; // For continuous: grad_logstd (nullptr for discrete)
74+
float* grad_logits;
75+
float* grad_logstd; // For continuous actions
5976
float* grad_values_pred;
60-
// Inputs (from dec_out)
6177
const precision_t* logits;
62-
const precision_t* logstd; // nullptr for discrete
78+
const precision_t* logstd; // Continuous only
6379
const precision_t* values_pred;
6480
const float* adv_mean;
6581
const float* adv_var;
6682
const int* act_sizes;
67-
// Scalars
6883
int num_atns;
6984
float clip_coef, vf_clip_coef, vf_coef, ent_coef;
7085
int T_seq, A_total, N;
@@ -73,32 +88,12 @@ struct PPOKernelArgs {
7388
bool is_continuous;
7489
};
7590

76-
// Pre-allocated buffers for PPO loss
7791
struct PPOBuffersPuf {
7892
FloatTensor loss_output, grad_loss;
7993
DoubleTensor saved_for_bwd;
8094
FloatTensor grad_logits, grad_values, grad_logstd, adv_scratch;
8195
};
8296

83-
// Pre-allocated buffers for prio_replay
84-
struct PrioBuffers {
85-
FloatTensor prio_probs, cdf, mb_prio;
86-
LongTensor idx;
87-
};
88-
89-
void register_prio_buffers(PrioBuffers& bufs, Allocator* alloc, int S, int minibatch_segments) {
90-
bufs = (PrioBuffers){
91-
.prio_probs = {.shape = {S}},
92-
.cdf = {.shape = {S}},
93-
.mb_prio = {.shape = {minibatch_segments, 1}},
94-
.idx = {.shape = {minibatch_segments}},
95-
};
96-
alloc_register(alloc, &bufs.prio_probs);
97-
alloc_register(alloc, &bufs.cdf);
98-
alloc_register(alloc, &bufs.idx);
99-
alloc_register(alloc, &bufs.mb_prio);
100-
}
101-
10297
void register_ppo_buffers(PPOBuffersPuf& bufs, Allocator* alloc, int N, int T, int A_total, bool is_continuous) {
10398
long total = (long)N * T;
10499
bufs = (PPOBuffersPuf){
@@ -119,25 +114,24 @@ void register_ppo_buffers(PPOBuffersPuf& bufs, Allocator* alloc, int N, int T, i
119114
alloc_register(alloc, &bufs.adv_scratch);
120115
}
121116

122-
void register_rollout_buffers(RolloutBuf& bufs, Allocator* alloc, int H, int S, int input_size, int num_atns) {
123-
bufs = (RolloutBuf){
124-
.observations = {.shape = {H, S, input_size}},
125-
.actions = {.shape = {H, S, num_atns}},
126-
.values = {.shape = {H, S}},
127-
.logprobs = {.shape = {H, S}},
128-
.rewards = {.shape = {H, S}},
129-
.terminals = {.shape = {H, S}},
130-
.ratio = {.shape = {H, S}},
131-
.importance = {.shape = {H, S}},
117+
// Prioritized replay over single-epoch data. These kernels are
118+
// the least cleaned because we will likely have a better method in 5.0
119+
struct PrioBuffers {
120+
FloatTensor prio_probs, cdf, mb_prio;
121+
LongTensor idx;
122+
};
123+
124+
void register_prio_buffers(PrioBuffers& bufs, Allocator* alloc, int S, int minibatch_segments) {
125+
bufs = (PrioBuffers){
126+
.prio_probs = {.shape = {S}},
127+
.cdf = {.shape = {S}},
128+
.mb_prio = {.shape = {minibatch_segments, 1}},
129+
.idx = {.shape = {minibatch_segments}},
132130
};
133-
alloc_register(alloc, &bufs.observations);
134-
alloc_register(alloc, &bufs.actions);
135-
alloc_register(alloc, &bufs.values);
136-
alloc_register(alloc, &bufs.logprobs);
137-
alloc_register(alloc, &bufs.rewards);
138-
alloc_register(alloc, &bufs.terminals);
139-
alloc_register(alloc, &bufs.ratio);
140-
alloc_register(alloc, &bufs.importance);
131+
alloc_register(alloc, &bufs.prio_probs);
132+
alloc_register(alloc, &bufs.cdf);
133+
alloc_register(alloc, &bufs.idx);
134+
alloc_register(alloc, &bufs.mb_prio);
141135
}
142136

143137
void register_train_buffers(TrainGraph& bufs, Allocator* alloc, int S, int H, int input_size,

0 commit comments

Comments
 (0)