Skip to content

Commit 0336926

Browse files
authored
Merge pull request #501 from PufferAI/bfloatatns
bfloat atns
2 parents b936162 + fccba07 commit 0336926

7 files changed

Lines changed: 23 additions & 62 deletions

File tree

pufferlib/ocean/breakout/binding.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#define NUM_ATNS 1
44
#define ACT_SIZES {3}
55
#define OBS_TENSOR_T FloatTensor
6-
#define ACT_TYPE DOUBLE
76

87
#define Env Breakout
98
#include "vecenv.h"

pufferlib/ocean/breakout/breakout.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ typedef struct Breakout {
4040
Client* client;
4141
Log log;
4242
float* observations;
43-
double* actions;
43+
float* actions;
4444
float* rewards;
4545
float* terminals;
4646
int num_agents;
@@ -121,7 +121,7 @@ void init(Breakout* env) {
121121
void allocate(Breakout* env) {
122122
init(env);
123123
env->observations = (float*)calloc(11 + env->num_bricks, sizeof(float));
124-
env->actions = (double*)calloc(1, sizeof(double));
124+
env->actions = (float*)calloc(1, sizeof(float));
125125
env->rewards = (float*)calloc(1, sizeof(float));
126126
env->terminals = (float*)calloc(1, sizeof(float));
127127
}

pufferlib/src/bindings.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,6 @@ PYBIND11_MODULE(_C, m) {
453453
.def("__repr__", [](const PrecisionTensor& t) { return std::string(puf_repr(&t)); })
454454
.def("ndim", [](const PrecisionTensor& t) { return ndim(t.shape); })
455455
.def("numel", [](const PrecisionTensor& t) { return numel(t.shape); });
456-
py::class_<DoubleTensor>(m, "DoubleTensor")
457-
.def("__repr__", [](const DoubleTensor& t) { return std::string(puf_repr(&t)); })
458-
.def("ndim", [](const DoubleTensor& t) { return ndim(t.shape); })
459-
.def("numel", [](const DoubleTensor& t) { return numel(t.shape); });
460456
py::class_<FloatTensor>(m, "FloatTensor")
461457
.def("__repr__", [](const FloatTensor& t) { return std::string(puf_repr(&t)); })
462458
.def("ndim", [](const FloatTensor& t) { return ndim(t.shape); })

pufferlib/src/kernels.cu

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,6 @@ __global__ void transpose_102(precision_t* __restrict__ dst,
138138
dst[b * A * C + a * C + c] = src[idx];
139139
}
140140

141-
// This exists for actions (currently fp64)
142-
__global__ void transpose_102(double* __restrict__ dst,
143-
const double* __restrict__ src, int A, int B, int C) {
144-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
145-
int total = A * B * C;
146-
if (idx >= total) {
147-
return;
148-
}
149-
int a = idx / (B * C), rem = idx % (B * C), b = rem / C, c = rem % C;
150-
dst[b * A * C + a * C + c] = src[idx];
151-
}
152-
153141
__global__ void fill_precision_kernel(precision_t* __restrict__ dst, precision_t val, int n) {
154142
int idx = blockIdx.x * blockDim.x + threadIdx.x;
155143
if (idx < n) {
@@ -247,10 +235,6 @@ inline const char* puf_repr(const PrecisionTensor* t) {
247235
return _puf_repr_impl("PrecisionTensor", USE_BF16 ? "bf16" : "f32",
248236
t->shape, ndim(t->shape), numel(t->shape), !t->data);
249237
}
250-
inline const char* puf_repr(const DoubleTensor* t) {
251-
return _puf_repr_impl("DoubleTensor", "f64",
252-
t->shape, ndim(t->shape), numel(t->shape), !t->data);
253-
}
254238
inline const char* puf_repr(const FloatTensor* t) {
255239
return _puf_repr_impl("FloatTensor", "f32",
256240
t->shape, ndim(t->shape), numel(t->shape), !t->data);
@@ -431,9 +415,6 @@ void alloc_register(Allocator* a, PrecisionTensor* t) {
431415
void alloc_register(Allocator* a, FloatTensor* t) {
432416
alloc_register_impl(a, (void**)&t->data, t->shape, sizeof(float));
433417
}
434-
void alloc_register(Allocator* a, DoubleTensor* t) {
435-
alloc_register_impl(a, (void**)&t->data, t->shape, sizeof(double));
436-
}
437418
void alloc_register(Allocator* a, LongTensor* t) {
438419
alloc_register_impl(a, (void**)&t->data, t->shape, sizeof(long));
439420
}

pufferlib/src/pufferlib.cu

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ enum LossIdx {
1616

1717
struct RolloutBuf {
1818
PrecisionTensor observations; // (horizon, segments, input_size)
19-
DoubleTensor actions; // (horizon, segments, num_atns)
19+
PrecisionTensor actions; // (horizon, segments, num_atns)
2020
PrecisionTensor values; // (horizon, segments)
2121
PrecisionTensor logprobs; // (horizon, segments)
2222
PrecisionTensor rewards; // (horizon, segments)
@@ -49,7 +49,7 @@ void register_rollout_buffers(RolloutBuf& bufs, Allocator* alloc, int H, int S,
4949
struct TrainGraph {
5050
PrecisionTensor mb_obs; // (S, H, input_size)
5151
PrecisionTensor mb_state; // (L, S, 1, hidden)
52-
DoubleTensor mb_actions; // (S, H, num_atns)
52+
PrecisionTensor mb_actions; // (S, H, num_atns)
5353
PrecisionTensor mb_logprobs; // (S, H)
5454
FloatTensor mb_advantages; // (S, H) f32
5555
PrecisionTensor mb_prio; // (S, 1)
@@ -62,7 +62,7 @@ struct TrainGraph {
6262
struct PPOGraphArgs {
6363
precision_t* out_ratio;
6464
precision_t* out_newvalue;
65-
const double* actions;
65+
const precision_t* actions;
6666
const precision_t* old_logprobs;
6767
const float* advantages;
6868
const precision_t* prio;
@@ -90,7 +90,7 @@ struct PPOKernelArgs {
9090

9191
struct PPOBuffersPuf {
9292
FloatTensor loss_output, grad_loss;
93-
DoubleTensor saved_for_bwd;
93+
FloatTensor saved_for_bwd;
9494
FloatTensor grad_logits, grad_values, grad_logstd, adv_scratch;
9595
};
9696

@@ -179,19 +179,10 @@ inline PrecisionTensor puf_slice(PrecisionTensor& p, int t, int start, int count
179179
return {.data = p.data + (t*S + start), .shape = {count}};
180180
}
181181
}
182-
inline DoubleTensor puf_slice(DoubleTensor& p, int t, int start, int count) {
183-
if (ndim(p.shape) == 3) {
184-
long S = p.shape[1], F = p.shape[2];
185-
return {.data = p.data + (t*S + start)*F, .shape = {count, F}};
186-
} else {
187-
long S = p.shape[1];
188-
return {.data = p.data + (t*S + start), .shape = {count}};
189-
}
190-
}
191182

192183
struct EnvBuf {
193184
OBS_TENSOR_T obs; // (total_agents, obs_size) — type defined per-env in binding.c
194-
DoubleTensor actions; // (total_agents, num_atns) f64
185+
FloatTensor actions; // (total_agents, num_atns) f64
195186
FloatTensor rewards; // (total_agents,) f32
196187
FloatTensor terminals;// (total_agents,) f32
197188
};
@@ -204,7 +195,7 @@ StaticVec* create_environments(int num_buffers, int total_agents,
204195
.shape = {total_agents, get_obs_size()},
205196
};
206197
env.actions = {
207-
.data = (double*)vec->gpu_actions,
198+
.data = (float*)vec->gpu_actions,
208199
.shape = {total_agents, get_num_atns()},
209200
};
210201
env.rewards = {
@@ -386,7 +377,7 @@ __global__ void sample_logits_kernel(
386377
PrecisionTensor dec_out, // (B, fused_cols) fused logits+value from decoder
387378
PrecisionTensor logstd_puf, // (1, od) log std for continuous, or empty
388379
IntTensor act_sizes_puf, // (num_atns,) action head sizes
389-
double* __restrict__ actions, // (B, num_atns) output
380+
precision_t* __restrict__ actions, // (B, num_atns) output
390381
precision_t* __restrict__ logprobs, // (B,) output
391382
precision_t* __restrict__ value_out, // (B,) output
392383
uint64_t seed,
@@ -437,7 +428,7 @@ __global__ void sample_logits_kernel(
437428
float normalized = (action - mean) / std;
438429
float log_prob = -0.5f * normalized * normalized - 0.5f * LOG_2PI - log_std;
439430

440-
actions[idx * num_atns + h] = double(action);
431+
actions[idx * num_atns + h] = from_float(action);
441432
total_log_prob += log_prob;
442433
}
443434
} else {
@@ -508,7 +499,7 @@ __global__ void sample_logits_kernel(
508499
float log_prob = sampled_logit - logsumexp;
509500

510501
// Write action for this head
511-
actions[idx * num_atns + h] = double(sampled_action);
502+
actions[idx * num_atns + h] = from_float(sampled_action);
512503
total_log_prob += log_prob;
513504

514505
// Advance to next action head
@@ -578,7 +569,7 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
578569
PrecisionTensor dec_puf = policy_forward(&pufferl->policy, pufferl->weights, pufferl->buffer_activations[buf], obs_dst, state_puf, stream);
579570

580571
// Sample actions, logprobs, values into rollout buffer
581-
DoubleTensor act_slice = puf_slice(rollouts.actions, t, start, block_size);
572+
PrecisionTensor act_slice = puf_slice(rollouts.actions, t, start, block_size);
582573
PrecisionTensor lp_slice = puf_slice(rollouts.logprobs, t, start, block_size);
583574
PrecisionTensor val_slice = puf_slice(rollouts.values, t, start, block_size);
584575

@@ -598,9 +589,8 @@ extern "C" void net_callback_wrapper(void* ctx, int buf, int t) {
598589

599590
// Copy actions to env
600591
long act_cols = env.actions.shape[1];
601-
cudaMemcpyAsync(
602-
env.actions.data + start * act_cols,
603-
act_slice.data, numel(act_slice.shape) * sizeof(double), cudaMemcpyDeviceToDevice, stream);
592+
cast_kernel<<<grid_size(numel(act_slice.shape)), BLOCK_SIZE, 0, stream>>>(
593+
env.actions.data + start * act_cols, act_slice.data, numel(act_slice.shape));
604594

605595
if (capturing) {
606596
cudagraph_capture_end(&pufferl->fused_rollout_cudagraphs[graph], cap_stream_raw);
@@ -1301,7 +1291,7 @@ __global__ void select_copy_kernel(
13011291

13021292
// Compute row byte counts from tensor shapes
13031293
int obs_row_bytes = (numel(rollouts.observations.shape) / rollouts.observations.shape[0]) * sizeof(precision_t);
1304-
int act_row_bytes = (numel(rollouts.actions.shape) / rollouts.actions.shape[0]) * sizeof(double);
1294+
int act_row_bytes = (numel(rollouts.actions.shape) / rollouts.actions.shape[0]) * sizeof(precision_t);
13051295
int lp_row_bytes = (numel(rollouts.logprobs.shape) / rollouts.logprobs.shape[0]) * sizeof(precision_t);
13061296
int horizon = rollouts.values.shape[1];
13071297

pufferlib/src/tensor.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@ typedef struct {
1010
int64_t shape[PUF_MAX_DIMS];
1111
} FloatTensor;
1212

13-
typedef struct {
14-
double* data;
15-
int64_t shape[PUF_MAX_DIMS];
16-
} DoubleTensor;
17-
1813
typedef struct {
1914
unsigned char* data;
2015
int64_t shape[PUF_MAX_DIMS];

pufferlib/src/vecenv.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,11 @@ typedef struct StaticVec {
7979
int* buffer_env_starts;
8080
int* buffer_env_counts;
8181
void* observations;
82-
double* actions;
82+
float* actions;
8383
float* rewards;
8484
float* terminals;
8585
void* gpu_observations;
86-
double* gpu_actions;
86+
float* gpu_actions;
8787
float* gpu_rewards;
8888
float* gpu_terminals;
8989
cudaStream_t* streams;
@@ -252,7 +252,7 @@ static void* static_omp_threadmanager(void* arg) {
252252
cudaMemcpyAsync(
253253
&vec->actions[agent_start * NUM_ATNS],
254254
&vec->gpu_actions[agent_start * NUM_ATNS],
255-
agents_per_buffer * NUM_ATNS * sizeof(double),
255+
agents_per_buffer * NUM_ATNS * sizeof(float),
256256
cudaMemcpyDeviceToHost, stream);
257257
cudaStreamSynchronize(stream);
258258
clock_gettime(CLOCK_MONOTONIC, &t1);
@@ -384,17 +384,17 @@ StaticVec* create_static_vec(int total_agents, int num_buffers, Dict* vec_kwargs
384384

385385
size_t obs_elem_size = obs_element_size();
386386
cudaHostAlloc((void**)&vec->observations, total_agents * OBS_SIZE * obs_elem_size, cudaHostAllocPortable);
387-
cudaHostAlloc((void**)&vec->actions, total_agents * NUM_ATNS * sizeof(double), cudaHostAllocPortable);
387+
cudaHostAlloc((void**)&vec->actions, total_agents * NUM_ATNS * sizeof(float), cudaHostAllocPortable);
388388
cudaHostAlloc((void**)&vec->rewards, total_agents * sizeof(float), cudaHostAllocPortable);
389389
cudaHostAlloc((void**)&vec->terminals, total_agents * sizeof(float), cudaHostAllocPortable);
390390

391391
cudaMalloc((void**)&vec->gpu_observations, total_agents * OBS_SIZE * obs_elem_size);
392-
cudaMalloc((void**)&vec->gpu_actions, total_agents * NUM_ATNS * sizeof(double));
392+
cudaMalloc((void**)&vec->gpu_actions, total_agents * NUM_ATNS * sizeof(float));
393393
cudaMalloc((void**)&vec->gpu_rewards, total_agents * sizeof(float));
394394
cudaMalloc((void**)&vec->gpu_terminals, total_agents * sizeof(float));
395395

396396
cudaMemset(vec->gpu_observations, 0, total_agents * OBS_SIZE * obs_elem_size);
397-
cudaMemset(vec->gpu_actions, 0, total_agents * NUM_ATNS * sizeof(double));
397+
cudaMemset(vec->gpu_actions, 0, total_agents * NUM_ATNS * sizeof(float));
398398
cudaMemset(vec->gpu_rewards, 0, total_agents * sizeof(float));
399399
cudaMemset(vec->gpu_terminals, 0, total_agents * sizeof(float));
400400

@@ -483,7 +483,7 @@ void static_vec_close(StaticVec* vec) {
483483

484484
cudaDeviceSynchronize();
485485
size_t obs_bytes = vec->total_agents * OBS_SIZE * obs_element_size();
486-
size_t act_bytes = vec->total_agents * NUM_ATNS * sizeof(double);
486+
size_t act_bytes = vec->total_agents * NUM_ATNS * sizeof(float);
487487
size_t rew_bytes = vec->total_agents * sizeof(float);
488488
size_t term_bytes = vec->total_agents * sizeof(float);
489489
cudaFree(vec->gpu_observations);
@@ -578,7 +578,7 @@ size_t get_obs_elem_size(void) { return obs_element_size(); }
578578
void static_vec_step(StaticVec* vec) {
579579
// D2H: copy GPU actions to CPU pinned memory so envs can read them
580580
cudaMemcpy(vec->actions, vec->gpu_actions,
581-
(size_t)vec->total_agents * NUM_ATNS * sizeof(double),
581+
(size_t)vec->total_agents * NUM_ATNS * sizeof(float),
582582
cudaMemcpyDeviceToHost);
583583

584584
memset(vec->rewards, 0, vec->total_agents * sizeof(float));

0 commit comments

Comments
 (0)