Skip to content

Commit fccba07

Browse files
committed
bfloat atns
1 parent 550ff94 commit fccba07

7 files changed

Lines changed: 23 additions & 62 deletions

File tree

pufferlib/ocean/breakout/binding.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#define NUM_ATNS 1
44
#define ACT_SIZES {3}
55
#define OBS_TENSOR_T FloatTensor
6-
#define ACT_TYPE DOUBLE
76

87
#define Env Breakout
98
#include "vecenv.h"

pufferlib/ocean/breakout/breakout.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ typedef struct Breakout {
4040
Client* client;
4141
Log log;
4242
float* observations;
43-
double* actions;
43+
float* actions;
4444
float* rewards;
4545
float* terminals;
4646
int num_agents;
@@ -121,7 +121,7 @@ void init(Breakout* env) {
121121
void allocate(Breakout* env) {
122122
init(env);
123123
env->observations = (float*)calloc(11 + env->num_bricks, sizeof(float));
124-
env->actions = (double*)calloc(1, sizeof(double));
124+
env->actions = (float*)calloc(1, sizeof(float));
125125
env->rewards = (float*)calloc(1, sizeof(float));
126126
env->terminals = (float*)calloc(1, sizeof(float));
127127
}

pufferlib/src/bindings.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,6 @@ PYBIND11_MODULE(_C, m) {
453453
.def("__repr__", [](const PrecisionTensor& t) { return std::string(puf_repr(&t)); })
454454
.def("ndim", [](const PrecisionTensor& t) { return ndim(t.shape); })
455455
.def("numel", [](const PrecisionTensor& t) { return numel(t.shape); });
456-
py::class_<DoubleTensor>(m, "DoubleTensor")
457-
.def("__repr__", [](const DoubleTensor& t) { return std::string(puf_repr(&t)); })
458-
.def("ndim", [](const DoubleTensor& t) { return ndim(t.shape); })
459-
.def("numel", [](const DoubleTensor& t) { return numel(t.shape); });
460456
py::class_<FloatTensor>(m, "FloatTensor")
461457
.def("__repr__", [](const FloatTensor& t) { return std::string(puf_repr(&t)); })
462458
.def("ndim", [](const FloatTensor& t) { return ndim(t.shape); })

pufferlib/src/kernels.cu

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,6 @@ __global__ void transpose_102(precision_t* __restrict__ dst,
138138
dst[b * A * C + a * C + c] = src[idx];
139139
}
140140

141-
// This exists for actions (currently fp64)
142-
__global__ void transpose_102(double* __restrict__ dst,
143-
const double* __restrict__ src, int A, int B, int C) {
144-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
145-
int total = A * B * C;
146-
if (idx >= total) {
147-
return;
148-
}
149-
int a = idx / (B * C), rem = idx % (B * C), b = rem / C, c = rem % C;
150-
dst[b * A * C + a * C + c] = src[idx];
151-
}
152-
153141
__global__ void fill_precision_kernel(precision_t* __restrict__ dst, precision_t val, int n) {
154142
int idx = blockIdx.x * blockDim.x + threadIdx.x;
155143
if (idx < n) {
@@ -247,10 +235,6 @@ inline const char* puf_repr(const PrecisionTensor* t) {
247235
return _puf_repr_impl("PrecisionTensor", USE_BF16 ? "bf16" : "f32",
248236
t->shape, ndim(t->shape), numel(t->shape), !t->data);
249237
}
250-
inline const char* puf_repr(const DoubleTensor* t) {
251-
return _puf_repr_impl("DoubleTensor", "f64",
252-
t->shape, ndim(t->shape), numel(t->shape), !t->data);
253-
}
254238
inline const char* puf_repr(const FloatTensor* t) {
255239
return _puf_repr_impl("FloatTensor", "f32",
256240
t->shape, ndim(t->shape), numel(t->shape), !t->data);
@@ -431,9 +415,6 @@ void alloc_register(Allocator* a, PrecisionTensor* t) {
431415
void alloc_register(Allocator* a, FloatTensor* t) {
432416
alloc_register_impl(a, (void**)&t->data, t->shape, sizeof(float));
433417
}
434-
void alloc_register(Allocator* a, DoubleTensor* t) {
435-
alloc_register_impl(a, (void**)&t->data, t->shape, sizeof(double));
436-
}
437418
void alloc_register(Allocator* a, LongTensor* t) {
438419
alloc_register_impl(a, (void**)&t->data, t->shape, sizeof(long));
439420
}

pufferlib/src/pufferlib.cu

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ enum LossIdx {
1717

1818
struct RolloutBuf {
1919
PrecisionTensor observations; // (horizon, segments, input_size)
20-
DoubleTensor actions; // (horizon, segments, num_atns)
20+
PrecisionTensor actions; // (horizon, segments, num_atns)
2121
PrecisionTensor values; // (horizon, segments)
2222
PrecisionTensor logprobs; // (horizon, segments)
2323
PrecisionTensor rewards; // (horizon, segments)
@@ -29,7 +29,7 @@ struct RolloutBuf {
2929
struct TrainGraph {
3030
PrecisionTensor mb_obs; // (S, H, input_size)
3131
PrecisionTensor mb_state; // (L, S, 1, hidden)
32-
DoubleTensor mb_actions; // (S, H, num_atns)
32+
PrecisionTensor mb_actions; // (S, H, num_atns)
3333
PrecisionTensor mb_logprobs; // (S, H)
3434
FloatTensor mb_advantages; // (S, H) f32
3535
PrecisionTensor mb_prio; // (S, 1)
@@ -44,7 +44,7 @@ struct TrainGraph {
4444
struct PPOGraphArgs {
4545
precision_t* out_ratio;
4646
precision_t* out_newvalue;
47-
const double* actions;
47+
const precision_t* actions;
4848
const precision_t* old_logprobs;
4949
const float* advantages;
5050
const precision_t* prio;
@@ -76,7 +76,7 @@ struct PPOKernelArgs {
7676
// Pre-allocated buffers for PPO loss
7777
struct PPOBuffersPuf {
7878
FloatTensor loss_output, grad_loss;
79-
DoubleTensor saved_for_bwd;
79+
FloatTensor saved_for_bwd;
8080
FloatTensor grad_logits, grad_values, grad_logstd, adv_scratch;
8181
};
8282

@@ -185,19 +185,10 @@ inline PrecisionTensor puf_slice(PrecisionTensor& p, int t, int start, int count
185185
return {.data = p.data + (t*S + start), .shape = {count}};
186186
}
187187
}
188-
inline DoubleTensor puf_slice(DoubleTensor& p, int t, int start, int count) {
189-
if (ndim(p.shape) == 3) {
190-
long S = p.shape[1], F = p.shape[2];
191-
return {.data = p.data + (t*S + start)*F, .shape = {count, F}};
192-
} else {
193-
long S = p.shape[1];
194-
return {.data = p.data + (t*S + start), .shape = {count}};
195-
}
196-
}
197188

198189
struct EnvBuf {
199190
OBS_TENSOR_T obs; // (total_agents, obs_size) — type defined per-env in binding.c
200-
DoubleTensor actions; // (total_agents, num_atns) f64
191+
FloatTensor actions; // (total_agents, num_atns) f64
201192
FloatTensor rewards; // (total_agents,) f32
202193
FloatTensor terminals;// (total_agents,) f32
203194
};
@@ -210,7 +201,7 @@ StaticVec* create_environments(int num_buffers, int total_agents,
210201
.shape = {total_agents, get_obs_size()},
211202
};
212203
env.actions = {
213-
.data = (double*)vec->gpu_actions,
204+
.data = (float*)vec->gpu_actions,
214205
.shape = {total_agents, get_num_atns()},
215206
};
216207
env.rewards = {
@@ -392,7 +383,7 @@ __global__ void sample_logits_kernel(
392383
PrecisionTensor dec_out, // (B, fused_cols) fused logits+value from decoder
393384
PrecisionTensor logstd_puf, // (1, od) log std for continuous, or empty
394385
IntTensor act_sizes_puf, // (num_atns,) action head sizes
395-
double* __restrict__ actions, // (B, num_atns) output
386+
precision_t* __restrict__ actions, // (B, num_atns) output
396387
precision_t* __restrict__ logprobs, // (B,) output
397388
precision_t* __restrict__ value_out, // (B,) output
398389
uint64_t seed,
@@ -443,7 +434,7 @@ __global__ void sample_logits_kernel(
443434
float normalized = (action - mean) / std;
444435
float log_prob = -0.5f * normalized * normalized - 0.5f * LOG_2PI - log_std;
445436

446-
actions[idx * num_atns + h] = double(action);
437+
actions[idx * num_atns + h] = from_float(action);
447438
total_log_prob += log_prob;
448439
}
449440
} else {
@@ -514,7 +505,7 @@ __global__ void sample_logits_kernel(
514505
float log_prob = sampled_logit - logsumexp;
515506

516507
// Write action for this head
517-
actions[idx * num_atns + h] = double(sampled_action);
508+
actions[idx * num_atns + h] = from_float(sampled_action);
518509
total_log_prob += log_prob;
519510

520511
// Advance to next action head
@@ -584,7 +575,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
584575
PrecisionTensor dec_puf = policy_forward(&pufferl->policy, pufferl->weights, pufferl->buffer_activations[buf], obs_dst, state_puf, stream);
585576

586577
// Sample actions, logprobs, values into rollout buffer
587-
DoubleTensor act_slice = puf_slice(rollouts.actions, t, start, block_size);
578+
PrecisionTensor act_slice = puf_slice(rollouts.actions, t, start, block_size);
588579
PrecisionTensor lp_slice = puf_slice(rollouts.logprobs, t, start, block_size);
589580
PrecisionTensor val_slice = puf_slice(rollouts.values, t, start, block_size);
590581

@@ -604,9 +595,8 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
604595

605596
// Copy actions to env
606597
long act_cols = env.actions.shape[1];
607-
cudaMemcpyAsync(
608-
env.actions.data + start * act_cols,
609-
act_slice.data, numel(act_slice.shape) * sizeof(double), cudaMemcpyDeviceToDevice, stream);
598+
cast_kernel<<<grid_size(numel(act_slice.shape)), BLOCK_SIZE, 0, stream>>>(
599+
env.actions.data + start * act_cols, act_slice.data, numel(act_slice.shape));
610600

611601
if (capturing) {
612602
cudagraph_capture_end(&pufferl->fused_rollout_cudagraphs[graph], cap_stream_raw);
@@ -1307,7 +1297,7 @@ __global__ void select_copy_kernel(
13071297

13081298
// Compute row byte counts from tensor shapes
13091299
int obs_row_bytes = (numel(rollouts.observations.shape) / rollouts.observations.shape[0]) * sizeof(precision_t);
1310-
int act_row_bytes = (numel(rollouts.actions.shape) / rollouts.actions.shape[0]) * sizeof(double);
1300+
int act_row_bytes = (numel(rollouts.actions.shape) / rollouts.actions.shape[0]) * sizeof(precision_t);
13111301
int lp_row_bytes = (numel(rollouts.logprobs.shape) / rollouts.logprobs.shape[0]) * sizeof(precision_t);
13121302
int horizon = rollouts.values.shape[1];
13131303

pufferlib/src/tensor.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@ typedef struct {
1010
int64_t shape[PUF_MAX_DIMS];
1111
} FloatTensor;
1212

13-
typedef struct {
14-
double* data;
15-
int64_t shape[PUF_MAX_DIMS];
16-
} DoubleTensor;
17-
1813
typedef struct {
1914
unsigned char* data;
2015
int64_t shape[PUF_MAX_DIMS];

pufferlib/src/vecenv.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,11 @@ typedef struct StaticVec {
7979
int* buffer_env_starts;
8080
int* buffer_env_counts;
8181
void* observations;
82-
double* actions;
82+
float* actions;
8383
float* rewards;
8484
float* terminals;
8585
void* gpu_observations;
86-
double* gpu_actions;
86+
float* gpu_actions;
8787
float* gpu_rewards;
8888
float* gpu_terminals;
8989
cudaStream_t* streams;
@@ -252,7 +252,7 @@ static void* static_omp_threadmanager(void* arg) {
252252
cudaMemcpyAsync(
253253
&vec->actions[agent_start * NUM_ATNS],
254254
&vec->gpu_actions[agent_start * NUM_ATNS],
255-
agents_per_buffer * NUM_ATNS * sizeof(double),
255+
agents_per_buffer * NUM_ATNS * sizeof(float),
256256
cudaMemcpyDeviceToHost, stream);
257257
cudaStreamSynchronize(stream);
258258
clock_gettime(CLOCK_MONOTONIC, &t1);
@@ -384,17 +384,17 @@ StaticVec* create_static_vec(int total_agents, int num_buffers, Dict* vec_kwargs
384384

385385
size_t obs_elem_size = obs_element_size();
386386
cudaHostAlloc((void**)&vec->observations, total_agents * OBS_SIZE * obs_elem_size, cudaHostAllocPortable);
387-
cudaHostAlloc((void**)&vec->actions, total_agents * NUM_ATNS * sizeof(double), cudaHostAllocPortable);
387+
cudaHostAlloc((void**)&vec->actions, total_agents * NUM_ATNS * sizeof(float), cudaHostAllocPortable);
388388
cudaHostAlloc((void**)&vec->rewards, total_agents * sizeof(float), cudaHostAllocPortable);
389389
cudaHostAlloc((void**)&vec->terminals, total_agents * sizeof(float), cudaHostAllocPortable);
390390

391391
cudaMalloc((void**)&vec->gpu_observations, total_agents * OBS_SIZE * obs_elem_size);
392-
cudaMalloc((void**)&vec->gpu_actions, total_agents * NUM_ATNS * sizeof(double));
392+
cudaMalloc((void**)&vec->gpu_actions, total_agents * NUM_ATNS * sizeof(float));
393393
cudaMalloc((void**)&vec->gpu_rewards, total_agents * sizeof(float));
394394
cudaMalloc((void**)&vec->gpu_terminals, total_agents * sizeof(float));
395395

396396
cudaMemset(vec->gpu_observations, 0, total_agents * OBS_SIZE * obs_elem_size);
397-
cudaMemset(vec->gpu_actions, 0, total_agents * NUM_ATNS * sizeof(double));
397+
cudaMemset(vec->gpu_actions, 0, total_agents * NUM_ATNS * sizeof(float));
398398
cudaMemset(vec->gpu_rewards, 0, total_agents * sizeof(float));
399399
cudaMemset(vec->gpu_terminals, 0, total_agents * sizeof(float));
400400

@@ -483,7 +483,7 @@ void static_vec_close(StaticVec* vec) {
483483

484484
cudaDeviceSynchronize();
485485
size_t obs_bytes = vec->total_agents * OBS_SIZE * obs_element_size();
486-
size_t act_bytes = vec->total_agents * NUM_ATNS * sizeof(double);
486+
size_t act_bytes = vec->total_agents * NUM_ATNS * sizeof(float);
487487
size_t rew_bytes = vec->total_agents * sizeof(float);
488488
size_t term_bytes = vec->total_agents * sizeof(float);
489489
cudaFree(vec->gpu_observations);
@@ -578,7 +578,7 @@ size_t get_obs_elem_size(void) { return obs_element_size(); }
578578
void static_vec_step(StaticVec* vec) {
579579
// D2H: copy GPU actions to CPU pinned memory so envs can read them
580580
cudaMemcpy(vec->actions, vec->gpu_actions,
581-
(size_t)vec->total_agents * NUM_ATNS * sizeof(double),
581+
(size_t)vec->total_agents * NUM_ATNS * sizeof(float),
582582
cudaMemcpyDeviceToHost);
583583

584584
memset(vec->rewards, 0, vec->total_agents * sizeof(float));

0 commit comments

Comments
 (0)