@@ -16,7 +16,7 @@ enum LossIdx {
1616
1717struct RolloutBuf {
1818 PrecisionTensor observations; // (horizon, segments, input_size)
19- DoubleTensor actions; // (horizon, segments, num_atns)
19+ PrecisionTensor actions; // (horizon, segments, num_atns)
2020 PrecisionTensor values; // (horizon, segments)
2121 PrecisionTensor logprobs; // (horizon, segments)
2222 PrecisionTensor rewards; // (horizon, segments)
@@ -49,7 +49,7 @@ void register_rollout_buffers(RolloutBuf& bufs, Allocator* alloc, int H, int S,
4949struct TrainGraph {
5050 PrecisionTensor mb_obs; // (S, H, input_size)
5151 PrecisionTensor mb_state; // (L, S, 1, hidden)
52- DoubleTensor mb_actions; // (S, H, num_atns)
52+ PrecisionTensor mb_actions; // (S, H, num_atns)
5353 PrecisionTensor mb_logprobs; // (S, H)
5454 FloatTensor mb_advantages; // (S, H) f32
5555 PrecisionTensor mb_prio; // (S, 1)
@@ -62,7 +62,7 @@ struct TrainGraph {
6262struct PPOGraphArgs {
6363 precision_t * out_ratio;
6464 precision_t * out_newvalue;
65- const double * actions;
65+ const precision_t * actions;
6666 const precision_t * old_logprobs;
6767 const float * advantages;
6868 const precision_t * prio;
@@ -90,7 +90,7 @@ struct PPOKernelArgs {
9090
9191struct PPOBuffersPuf {
9292 FloatTensor loss_output, grad_loss;
93- DoubleTensor saved_for_bwd;
93+ FloatTensor saved_for_bwd;
9494 FloatTensor grad_logits, grad_values, grad_logstd, adv_scratch;
9595};
9696
@@ -179,19 +179,10 @@ inline PrecisionTensor puf_slice(PrecisionTensor& p, int t, int start, int count
179179 return {.data = p.data + (t*S + start), .shape = {count}};
180180 }
181181}
182- inline DoubleTensor puf_slice (DoubleTensor& p, int t, int start, int count) {
183- if (ndim (p.shape ) == 3 ) {
184- long S = p.shape [1 ], F = p.shape [2 ];
185- return {.data = p.data + (t*S + start)*F, .shape = {count, F}};
186- } else {
187- long S = p.shape [1 ];
188- return {.data = p.data + (t*S + start), .shape = {count}};
189- }
190- }
191182
192183struct EnvBuf {
193184 OBS_TENSOR_T obs; // (total_agents, obs_size) — type defined per-env in binding.c
194- DoubleTensor actions; // (total_agents, num_atns) f64
185+ FloatTensor actions; // (total_agents, num_atns) f64
195186 FloatTensor rewards; // (total_agents,) f32
196187 FloatTensor terminals;// (total_agents,) f32
197188};
@@ -204,7 +195,7 @@ StaticVec* create_environments(int num_buffers, int total_agents,
204195 .shape = {total_agents, get_obs_size ()},
205196 };
206197 env.actions = {
207- .data = (double *)vec->gpu_actions ,
198+ .data = (float *)vec->gpu_actions ,
208199 .shape = {total_agents, get_num_atns ()},
209200 };
210201 env.rewards = {
@@ -386,7 +377,7 @@ __global__ void sample_logits_kernel(
386377 PrecisionTensor dec_out, // (B, fused_cols) fused logits+value from decoder
387378 PrecisionTensor logstd_puf, // (1, od) log std for continuous, or empty
388379 IntTensor act_sizes_puf, // (num_atns,) action head sizes
389- double * __restrict__ actions, // (B, num_atns) output
380+ precision_t * __restrict__ actions, // (B, num_atns) output
390381 precision_t * __restrict__ logprobs, // (B,) output
391382 precision_t * __restrict__ value_out, // (B,) output
392383 uint64_t seed,
@@ -437,7 +428,7 @@ __global__ void sample_logits_kernel(
437428 float normalized = (action - mean) / std;
438429 float log_prob = -0 .5f * normalized * normalized - 0 .5f * LOG_2PI - log_std;
439430
440- actions[idx * num_atns + h] = double (action);
431+ actions[idx * num_atns + h] = from_float (action);
441432 total_log_prob += log_prob;
442433 }
443434 } else {
@@ -508,7 +499,7 @@ __global__ void sample_logits_kernel(
508499 float log_prob = sampled_logit - logsumexp;
509500
510501 // Write action for this head
511- actions[idx * num_atns + h] = double (sampled_action);
502+ actions[idx * num_atns + h] = from_float (sampled_action);
512503 total_log_prob += log_prob;
513504
514505 // Advance to next action head
@@ -578,7 +569,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
578569 PrecisionTensor dec_puf = policy_forward (&pufferl->policy , pufferl->weights , pufferl->buffer_activations [buf], obs_dst, state_puf, stream);
579570
580571 // Sample actions, logprobs, values into rollout buffer
581- DoubleTensor act_slice = puf_slice (rollouts.actions , t, start, block_size);
572+ PrecisionTensor act_slice = puf_slice (rollouts.actions , t, start, block_size);
582573 PrecisionTensor lp_slice = puf_slice (rollouts.logprobs , t, start, block_size);
583574 PrecisionTensor val_slice = puf_slice (rollouts.values , t, start, block_size);
584575
@@ -598,9 +589,8 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
598589
599590 // Copy actions to env
600591 long act_cols = env.actions .shape [1 ];
601- cudaMemcpyAsync (
602- env.actions .data + start * act_cols,
603- act_slice.data , numel (act_slice.shape ) * sizeof (double ), cudaMemcpyDeviceToDevice, stream);
592+ cast_kernel<<<grid_size(numel(act_slice.shape)), BLOCK_SIZE, 0 , stream>>> (
593+ env.actions .data + start * act_cols, act_slice.data , numel (act_slice.shape ));
604594
605595 if (capturing) {
606596 cudagraph_capture_end (&pufferl->fused_rollout_cudagraphs [graph], cap_stream_raw);
@@ -1301,7 +1291,7 @@ __global__ void select_copy_kernel(
13011291
13021292 // Compute row byte counts from tensor shapes
13031293 int obs_row_bytes = (numel (rollouts.observations .shape ) / rollouts.observations .shape [0 ]) * sizeof (precision_t );
1304- int act_row_bytes = (numel (rollouts.actions .shape ) / rollouts.actions .shape [0 ]) * sizeof (double );
1294+ int act_row_bytes = (numel (rollouts.actions .shape ) / rollouts.actions .shape [0 ]) * sizeof (precision_t );
13051295 int lp_row_bytes = (numel (rollouts.logprobs .shape ) / rollouts.logprobs .shape [0 ]) * sizeof (precision_t );
13061296 int horizon = rollouts.values .shape [1 ];
13071297
0 commit comments