-
Notifications
You must be signed in to change notification settings - Fork 442
Expand file tree
/
Copy pathpufferlib.cu
More file actions
145 lines (122 loc) · 5.42 KB
/
pufferlib.cu
File metadata and controls
145 lines (122 loc) · 5.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace pufferlib {
__host__ __device__ void puff_advantage_row_cuda_fallback(float* values, float* rewards, float* dones,
float* importance, float* advantages, float gamma, float lambda,
float rho_clip, float c_clip, int horizon) {
float lastpufferlam = 0;
for (int t = horizon-2; t >= 0; t--) {
int t_next = t + 1;
float nextnonterminal = 1.0 - dones[t_next];
float rho_t = fminf(importance[t], rho_clip);
float c_t = fminf(importance[t], c_clip);
float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]);
lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal;
advantages[t] = lastpufferlam;
}
}
__device__ __forceinline__ void puff_advantage_row_cuda(float* values, float* rewards, float* dones,
float* importance, float* advantages, float gamma, float lambda,
float rho_clip, float c_clip, int horizon) {
// Fall back to original if horizon not divisible by 4
if (horizon % 4 != 0) {
puff_advantage_row_cuda_fallback(values, rewards, dones,
importance, advantages, gamma, lambda, rho_clip, c_clip, horizon);
return;
}
float lastpufferlam = 0.0f;
int num_chunks = horizon / 4;
// need to track values across chunks
float next_value = values[horizon - 1];
float next_done = dones[horizon - 1];
float next_reward = rewards[horizon - 1];
// Process chunks from end to beginning
for (int chunk = num_chunks - 1; chunk >= 0; chunk--) {
int base = chunk * 4;
// Load 4 elements at once
float4 v4 = *reinterpret_cast<float4*>(values + base);
float4 r4 = *reinterpret_cast<float4*>(rewards + base);
float4 d4 = *reinterpret_cast<float4*>(dones + base);
float4 i4 = *reinterpret_cast<float4*>(importance + base);
float v[4] = {v4.x, v4.y, v4.z, v4.w};
float r[4] = {r4.x, r4.y, r4.z, r4.w};
float d[4] = {d4.x, d4.y, d4.z, d4.w};
float imp[4] = {i4.x, i4.y, i4.z, i4.w};
float adv[4] = {0.0f, 0.0f, 0.0f, 0.0f};
int start_idx = (chunk == num_chunks - 1) ? 2 : 3;
#pragma unroll
for (int i = start_idx; i >= 0; i--) {
float nextnonterminal = 1.0f - next_done;
float rho_t = fminf(imp[i], rho_clip);
float c_t = fminf(imp[i], c_clip);
float delta = rho_t * (next_reward + gamma * next_value * nextnonterminal - v[i]);
lastpufferlam = delta + gamma * lambda * c_t * lastpufferlam * nextnonterminal;
adv[i] = lastpufferlam;
next_value = v[i];
next_done = d[i];
next_reward = r[i];
}
float4 adv4 = make_float4(adv[0], adv[1], adv[2], adv[3]);
*reinterpret_cast<float4*>(advantages + base) = adv4;
}
}
void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards,
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
int num_steps, int horizon) {
// Validate input tensors
torch::Device device = values.device();
for (const torch::Tensor& t : {values, rewards, dones, importance, advantages}) {
TORCH_CHECK(t.dim() == 2, "Tensor must be 2D");
TORCH_CHECK(t.device() == device, "All tensors must be on same device");
TORCH_CHECK(t.size(0) == num_steps, "First dimension must match num_steps");
TORCH_CHECK(t.size(1) == horizon, "Second dimension must match horizon");
TORCH_CHECK(t.dtype() == torch::kFloat32, "All tensors must be float32");
if (!t.is_contiguous()) {
t.contiguous();
}
}
}
// [num_steps, horizon]
__global__ void puff_advantage_kernel(float* values, float* rewards,
float* dones, float* importance, float* advantages, float gamma,
float lambda, float rho_clip, float c_clip, int num_steps, int horizon) {
int row = blockIdx.x*blockDim.x + threadIdx.x;
if (row >= num_steps) {
return;
}
int offset = row*horizon;
puff_advantage_row_cuda(values + offset, rewards + offset, dones + offset,
importance + offset, advantages + offset, gamma, lambda, rho_clip, c_clip, horizon);
}
void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
double gamma, double lambda, double rho_clip, double c_clip) {
int num_steps = values.size(0);
int horizon = values.size(1);
vtrace_check_cuda(values, rewards, dones, importance, advantages, num_steps, horizon);
TORCH_CHECK(values.is_cuda(), "All tensors must be on GPU");
int threads_per_block = 32;
int blocks = (num_steps + threads_per_block - 1) / threads_per_block;
puff_advantage_kernel<<<blocks, threads_per_block>>>(
values.data_ptr<float>(),
rewards.data_ptr<float>(),
dones.data_ptr<float>(),
importance.data_ptr<float>(),
advantages.data_ptr<float>(),
gamma,
lambda,
rho_clip,
c_clip,
num_steps,
horizon
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(err));
}
}
TORCH_LIBRARY_IMPL(pufferlib, CUDA, m) {
m.impl("compute_puff_advantage", &compute_puff_advantage_cuda);
}
}