@@ -17,7 +17,7 @@ enum LossIdx {
1717
1818struct RolloutBuf {
1919 PrecisionTensor observations; // (horizon, segments, input_size)
20- DoubleTensor actions; // (horizon, segments, num_atns)
20+ PrecisionTensor actions; // (horizon, segments, num_atns)
2121 PrecisionTensor values; // (horizon, segments)
2222 PrecisionTensor logprobs; // (horizon, segments)
2323 PrecisionTensor rewards; // (horizon, segments)
@@ -29,7 +29,7 @@ struct RolloutBuf {
2929struct TrainGraph {
3030 PrecisionTensor mb_obs; // (S, H, input_size)
3131 PrecisionTensor mb_state; // (L, S, 1, hidden)
32- DoubleTensor mb_actions; // (S, H, num_atns)
32+ PrecisionTensor mb_actions; // (S, H, num_atns)
3333 PrecisionTensor mb_logprobs; // (S, H)
3434 FloatTensor mb_advantages; // (S, H) f32
3535 PrecisionTensor mb_prio; // (S, 1)
@@ -44,7 +44,7 @@ struct TrainGraph {
4444struct PPOGraphArgs {
4545 precision_t * out_ratio;
4646 precision_t * out_newvalue;
47- const double * actions;
47+ const precision_t * actions;
4848 const precision_t * old_logprobs;
4949 const float * advantages;
5050 const precision_t * prio;
@@ -76,7 +76,7 @@ struct PPOKernelArgs {
7676// Pre-allocated buffers for PPO loss
7777struct PPOBuffersPuf {
7878 FloatTensor loss_output, grad_loss;
79- DoubleTensor saved_for_bwd;
79+ FloatTensor saved_for_bwd;
8080 FloatTensor grad_logits, grad_values, grad_logstd, adv_scratch;
8181};
8282
@@ -185,19 +185,10 @@ inline PrecisionTensor puf_slice(PrecisionTensor& p, int t, int start, int count
185185 return {.data = p.data + (t*S + start), .shape = {count}};
186186 }
187187}
188- inline DoubleTensor puf_slice (DoubleTensor& p, int t, int start, int count) {
189- if (ndim (p.shape ) == 3 ) {
190- long S = p.shape [1 ], F = p.shape [2 ];
191- return {.data = p.data + (t*S + start)*F, .shape = {count, F}};
192- } else {
193- long S = p.shape [1 ];
194- return {.data = p.data + (t*S + start), .shape = {count}};
195- }
196- }
197188
198189struct EnvBuf {
199190 OBS_TENSOR_T obs; // (total_agents, obs_size) — type defined per-env in binding.c
200- DoubleTensor actions; // (total_agents, num_atns) f64
191+ FloatTensor actions; // (total_agents, num_atns) f64
201192 FloatTensor rewards; // (total_agents,) f32
202193 FloatTensor terminals;// (total_agents,) f32
203194};
@@ -210,7 +201,7 @@ StaticVec* create_environments(int num_buffers, int total_agents,
210201 .shape = {total_agents, get_obs_size ()},
211202 };
212203 env.actions = {
213- .data = (double *)vec->gpu_actions ,
204+ .data = (float *)vec->gpu_actions ,
214205 .shape = {total_agents, get_num_atns ()},
215206 };
216207 env.rewards = {
@@ -392,7 +383,7 @@ __global__ void sample_logits_kernel(
392383 PrecisionTensor dec_out, // (B, fused_cols) fused logits+value from decoder
393384 PrecisionTensor logstd_puf, // (1, od) log std for continuous, or empty
394385 IntTensor act_sizes_puf, // (num_atns,) action head sizes
395- double * __restrict__ actions, // (B, num_atns) output
386+ precision_t * __restrict__ actions, // (B, num_atns) output
396387 precision_t * __restrict__ logprobs, // (B,) output
397388 precision_t * __restrict__ value_out, // (B,) output
398389 uint64_t seed,
@@ -443,7 +434,7 @@ __global__ void sample_logits_kernel(
443434 float normalized = (action - mean) / std;
444435 float log_prob = -0 .5f * normalized * normalized - 0 .5f * LOG_2PI - log_std;
445436
446- actions[idx * num_atns + h] = double (action);
437+ actions[idx * num_atns + h] = from_float (action);
447438 total_log_prob += log_prob;
448439 }
449440 } else {
@@ -514,7 +505,7 @@ __global__ void sample_logits_kernel(
514505 float log_prob = sampled_logit - logsumexp;
515506
516507 // Write action for this head
517- actions[idx * num_atns + h] = double (sampled_action);
508+ actions[idx * num_atns + h] = from_float (sampled_action);
518509 total_log_prob += log_prob;
519510
520511 // Advance to next action head
@@ -584,7 +575,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
584575 PrecisionTensor dec_puf = policy_forward (&pufferl->policy , pufferl->weights , pufferl->buffer_activations [buf], obs_dst, state_puf, stream);
585576
586577 // Sample actions, logprobs, values into rollout buffer
587- DoubleTensor act_slice = puf_slice (rollouts.actions , t, start, block_size);
578+ PrecisionTensor act_slice = puf_slice (rollouts.actions , t, start, block_size);
588579 PrecisionTensor lp_slice = puf_slice (rollouts.logprobs , t, start, block_size);
589580 PrecisionTensor val_slice = puf_slice (rollouts.values , t, start, block_size);
590581
@@ -604,9 +595,8 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
604595
605596 // Copy actions to env
606597 long act_cols = env.actions .shape [1 ];
607- cudaMemcpyAsync (
608- env.actions .data + start * act_cols,
609- act_slice.data , numel (act_slice.shape ) * sizeof (double ), cudaMemcpyDeviceToDevice, stream);
598+ cast_kernel<<<grid_size(numel(act_slice.shape)), BLOCK_SIZE, 0 , stream>>> (
599+ env.actions .data + start * act_cols, act_slice.data , numel (act_slice.shape ));
610600
611601 if (capturing) {
612602 cudagraph_capture_end (&pufferl->fused_rollout_cudagraphs [graph], cap_stream_raw);
@@ -1307,7 +1297,7 @@ __global__ void select_copy_kernel(
13071297
13081298 // Compute row byte counts from tensor shapes
13091299 int obs_row_bytes = (numel (rollouts.observations .shape ) / rollouts.observations .shape [0 ]) * sizeof (precision_t );
1310- int act_row_bytes = (numel (rollouts.actions .shape ) / rollouts.actions .shape [0 ]) * sizeof (double );
1300+ int act_row_bytes = (numel (rollouts.actions .shape ) / rollouts.actions .shape [0 ]) * sizeof (precision_t );
13111301 int lp_row_bytes = (numel (rollouts.logprobs .shape ) / rollouts.logprobs .shape [0 ]) * sizeof (precision_t );
13121302 int horizon = rollouts.values .shape [1 ];
13131303
0 commit comments