Skip to content

Commit 4126c9b

Browse files
committed
Initial env bind updates
1 parent 3a53e10 commit 4126c9b

15 files changed

Lines changed: 217 additions & 776 deletions

File tree

ocean/cartpole/cartpole.c

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@
1515

1616
const char* WEIGHTS_PATH = "resources/cartpole/cartpole_weights.bin";
1717

18-
float movement(int discrete_action, int userControlMode) {
18+
float movement(float action, int userControlMode) {
1919
if (userControlMode) {
2020
return (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) ? 1.0f : -1.0f;
2121
} else {
22-
return (discrete_action == 1) ? 1.0f : -1.0f;
22+
return (action > 0.5f) ? 1.0f : -1.0f;
2323
}
2424
}
2525

2626
void demo() {
2727
Weights* weights = load_weights(WEIGHTS_PATH, NUM_WEIGHTS);
28-
LinearLSTM* net;
2928

3029
int logit_sizes[1] = {ACTIONS_SIZE};
31-
net = make_linearlstm(weights, 1, OBSERVATIONS_SIZE, logit_sizes, 1);
30+
PufferNet* net = make_puffernet(weights, 1, OBSERVATIONS_SIZE, 64, 2, logit_sizes, 1);
31+
3232
Cartpole env = {0};
3333
env.continuous = CONTINUOUS;
3434
allocate(&env);
@@ -41,9 +41,8 @@ void demo() {
4141
int userControlMode = IsKeyDown(KEY_LEFT_SHIFT);
4242

4343
if (!userControlMode) {
44-
int action_value;
45-
forward_linearlstm(net, env.observations, &action_value);
46-
env.actions[0] = movement(action_value, 0);
44+
forward_puffernet(net, env.observations, env.actions);
45+
env.actions[0] = movement(env.actions[0], 0);
4746
} else {
4847
env.actions[0] = movement(env.actions[0], userControlMode);
4948
}
@@ -55,12 +54,12 @@ void demo() {
5554
c_render(&env);
5655
EndDrawing();
5756

58-
if (env.terminals[0]) {
57+
if (env.terminals[0] > 0.5f) {
5958
c_reset(&env);
6059
}
6160
}
6261

63-
free_linearlstm(net);
62+
free_puffernet(net);
6463
free(weights);
6564
free_allocated(&env);
6665
}

ocean/connect4/connect4.c

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,17 @@
44

55
const unsigned char NOOP = 8;
66

7-
void interactive() {
7+
void demo() {
88
Weights* weights = load_weights("resources/connect4/connect4_weights.bin", 138632);
99
int logit_sizes[] = {7};
10-
LinearLSTM* net = make_linearlstm(weights, 1, 42, logit_sizes, 1);
10+
PufferNet* net = make_puffernet(weights, 1, 42, 64, 2, logit_sizes, 1);
1111

1212
CConnect4 env = {
1313
};
1414
allocate_cconnect4(&env);
1515
c_reset(&env);
1616

1717
env.client = make_client();
18-
float observations[42] = {0};
19-
int actions[1] = {0};
2018

2119
int tick = 0;
2220
while (!WindowShouldClose()) {
@@ -31,11 +29,7 @@ void interactive() {
3129
if(IsKeyPressed(KEY_SIX)) env.actions[0] = 5;
3230
if(IsKeyPressed(KEY_SEVEN)) env.actions[0] = 6;
3331
} else if (tick % 30 == 0) {
34-
for (int i = 0; i < 42; i++) {
35-
observations[i] = env.observations[i];
36-
}
37-
forward_linearlstm(net, (float*)&observations, (int*)&actions);
38-
env.actions[0] = actions[0];
32+
forward_puffernet(net, env.observations, env.actions);
3933
}
4034

4135
tick = (tick + 1) % 60;
@@ -45,33 +39,13 @@ void interactive() {
4539

4640
c_render(&env);
4741
}
48-
free_linearlstm(net);
42+
free_puffernet(net);
4943
free(weights);
5044
close_client(env.client);
5145
free_allocated_cconnect4(&env);
5246
}
5347

54-
void performance_test() {
55-
long test_time = 10;
56-
CConnect4 env = {
57-
};
58-
allocate_cconnect4(&env);
59-
c_reset(&env);
60-
61-
long start = time(NULL);
62-
int i = 0;
63-
while (time(NULL) - start < test_time) {
64-
env.actions[0] = rand() % 7;
65-
c_step(&env);
66-
i++;
67-
}
68-
long end = time(NULL);
69-
printf("SPS: %ld\n", i / (end - start));
70-
free_allocated_cconnect4(&env);
71-
}
72-
7348
int main() {
74-
// performance_test();
75-
interactive();
49+
demo();
7650
return 0;
7751
}

ocean/drone/drone.c

Lines changed: 18 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
// Standalone C demo for drone environment
2-
// Compile using: ./scripts/build_ocean.sh drone [local|fast]
3-
// Run with: ./drone
4-
51
#include "drone.h"
62
#include "puffernet.h"
73
#include "render.h"
@@ -11,126 +7,18 @@
117
#include <emscripten.h>
128
#endif
139

14-
double randn(double mean, double std) {
15-
static int has_spare = 0;
16-
static double spare;
17-
18-
if (has_spare) {
19-
has_spare = 0;
20-
return mean + std * spare;
21-
}
22-
23-
has_spare = 1;
24-
double u, v, s;
25-
do {
26-
u = 2.0 * rand() / RAND_MAX - 1.0;
27-
v = 2.0 * rand() / RAND_MAX - 1.0;
28-
s = u * u + v * v;
29-
} while (s >= 1.0 || s == 0.0);
30-
31-
s = sqrt(-2.0 * log(s) / s);
32-
spare = v * s;
33-
return mean + std * (u * s);
34-
}
35-
36-
#ifndef LINEAR_DIM
37-
#define LINEAR_DIM 64
38-
#endif
39-
40-
#ifndef LSTM_DIM
41-
#define LSTM_DIM 16
42-
#endif
43-
44-
typedef struct LinearContLSTM LinearContLSTM;
45-
struct LinearContLSTM {
46-
int num_agents;
47-
float* obs;
48-
int num_actions;
49-
float* log_std;
50-
Linear* encoder1;
51-
GELU* gelu1;
52-
Linear* encoder2;
53-
GELU* gelu2;
54-
LSTM* lstm;
55-
Linear* actor;
56-
Linear* value_fn;
57-
};
58-
59-
LinearContLSTM* make_linearcontlstm(Weights* weights, int num_agents, int input_dim,
60-
int logit_sizes[], int num_actions) {
61-
LinearContLSTM* net = calloc(1, sizeof(LinearContLSTM));
62-
net->num_agents = num_agents;
63-
net->obs = calloc(num_agents * input_dim, sizeof(float));
64-
net->num_actions = logit_sizes[0];
65-
66-
// Must match export order exactly:
67-
net->log_std = get_weights(weights, net->num_actions); // 1. decoder_logstd
68-
net->encoder1 = make_linear(weights, num_agents, input_dim, LINEAR_DIM); // 2-3. encoder.0
69-
net->gelu1 = make_gelu(num_agents, LINEAR_DIM);
70-
net->encoder2 = make_linear(weights, num_agents, LINEAR_DIM, LSTM_DIM); // 4-5. encoder.2
71-
net->gelu2 = make_gelu(num_agents, LSTM_DIM);
72-
net->actor = make_linear(weights, num_agents, LSTM_DIM, net->num_actions);// 6-7. decoder_mean
73-
net->value_fn = make_linear(weights, num_agents, LSTM_DIM, 1); // 8-9. value (FIX: was LSTM_DIM+4)
74-
net->lstm = make_lstm(weights, num_agents, LSTM_DIM, LSTM_DIM); // 10-13. lstm
75-
76-
return net;
77-
}
78-
79-
void free_linearcontlstm(LinearContLSTM* net) {
80-
free(net->obs);
81-
free(net->encoder1);
82-
free(net->gelu1);
83-
free(net->encoder2);
84-
free(net->gelu2);
85-
free(net->lstm);
86-
free(net->actor);
87-
free(net->value_fn);
88-
free(net);
89-
}
90-
91-
void forward_linearcontlstm(LinearContLSTM* net, float* observations) {
92-
linear(net->encoder1, observations);
93-
gelu(net->gelu1, net->encoder1->output);
94-
linear(net->encoder2, net->gelu1->output);
95-
gelu(net->gelu2, net->encoder2->output);
96-
lstm(net->lstm, net->gelu2->output);
97-
linear(net->actor, net->lstm->state_h);
98-
}
99-
100-
void sample_linearcontlstm(LinearContLSTM* net, float* actions, int deterministic) {
101-
for (int b = 0; b < net->num_agents; b++) {
102-
for (int i = 0; i < net->num_actions; i++) {
103-
int idx = b * net->num_actions + i;
104-
float mean = net->actor->output[idx];
105-
if (deterministic) {
106-
actions[idx] = mean;
107-
} else {
108-
float std = expf(net->log_std[i]);
109-
actions[idx] = (float)randn(mean, std);
110-
}
111-
}
112-
}
113-
}
114-
115-
void generate_dummy_actions(DroneEnv* env) {
116-
// Generate random floats in [-1, 1] range
117-
env->actions[0] = ((float)rand() / (float)RAND_MAX) * 2.0f - 1.0f;
118-
env->actions[1] = ((float)rand() / (float)RAND_MAX) * 2.0f - 1.0f;
119-
env->actions[2] = ((float)rand() / (float)RAND_MAX) * 2.0f - 1.0f;
120-
env->actions[3] = ((float)rand() / (float)RAND_MAX) * 2.0f - 1.0f;
121-
}
122-
12310
#ifdef __EMSCRIPTEN__
12411
typedef struct {
12512
DroneEnv* env;
126-
LinearContLSTM* net;
13+
PufferNet* net;
12714
Weights* weights;
12815
} WebRenderArgs;
12916

13017
void emscriptenStep(void* e) {
13118
WebRenderArgs* args = (WebRenderArgs*)e;
13219
DroneEnv* env = args->env;
133-
LinearContLSTM* net = args->net;
20+
PufferNet* net = args->net;
21+
size_t obs_size = 23;
13422

13523
for (int i = 0; i < env->num_agents; i++) {
13624
int base = i * obs_size;
@@ -140,22 +28,19 @@ void emscriptenStep(void* e) {
14028
env->observations[base + 22] = 0.0f;
14129
}
14230

143-
forward_linearcontlstm(net, env->observations);
144-
sample_linearcontlstm(net, env->actions, 0);
31+
forward_puffernet(net, env->observations, env->actions);
14532
c_step(env);
14633
c_render(env);
147-
return;
14834
}
14935

15036
WebRenderArgs* web_args = NULL;
15137
#endif
15238

153-
int main() {
154-
srand(time(NULL)); // Seed random number generator
39+
void demo() {
40+
srand(time(NULL));
15541

15642
DroneEnv* env = calloc(1, sizeof(DroneEnv));
15743
size_t obs_size = 23;
158-
size_t act_size = 4;
15944

16045
env->num_agents = 64;
16146
env->max_rings = 10;
@@ -166,26 +51,13 @@ int main() {
16651
env->hover_vel = 0.01;
16752
init(env);
16853

169-
env->observations = (float*)calloc(env->num_agents * obs_size, sizeof(float));
170-
env->actions = (float*)calloc(env->num_agents * act_size, sizeof(float));
171-
env->rewards = (float*)calloc(env->num_agents, sizeof(float));
172-
env->terminals = (float*)calloc(env->num_agents, sizeof(float));
54+
allocate(env);
17355

17456
Weights* weights = load_weights("resources/drone/puffer_drone_weights.bin", 4841);
175-
int logit_sizes[1] = {4};
176-
LinearContLSTM* net = make_linearcontlstm(weights, env->num_agents, obs_size, logit_sizes, 1);
177-
178-
if (!env->observations || !env->actions || !env->rewards) {
179-
fprintf(stderr, "ERROR: Failed to allocate memory for demo buffers.\n");
180-
free(env->observations);
181-
free(env->actions);
182-
free(env->rewards);
183-
free(env->terminals);
184-
free(env);
185-
return 0;
186-
}
57+
int logit_sizes[4] = {1, 1, 1, 1};
58+
// make_puffernet(weights, num_agents, obs_size, hidden_size, num_layers, logit_sizes, num_actions)
59+
PufferNet* net = make_puffernet(weights, env->num_agents, obs_size, 64, 2, logit_sizes, 4);
18760

188-
init(env);
18961
c_reset(env);
19062

19163
#ifdef __EMSCRIPTEN__
@@ -198,6 +70,7 @@ int main() {
19870
emscripten_set_main_loop_arg(emscriptenStep, args, 0, true);
19971
#else
20072
c_render(env);
73+
SetTargetFPS(60);
20174

20275
while (!WindowShouldClose()) {
20376
for (int i = 0; i < env->num_agents; i++) {
@@ -207,20 +80,20 @@ int main() {
20780
env->observations[base + 21] = 0.0f;
20881
env->observations[base + 22] = 0.0f;
20982
}
210-
forward_linearcontlstm(net, env->observations);
211-
sample_linearcontlstm(net, env->actions, 0);
83+
forward_puffernet(net, env->observations, env->actions);
21284
c_step(env);
21385
c_render(env);
21486
}
21587

21688
c_close(env);
217-
free_linearcontlstm(net);
218-
free(env->observations);
219-
free(env->actions);
220-
free(env->rewards);
221-
free(env->terminals);
89+
free_puffernet(net);
90+
free(weights);
91+
free_allocated(env);
22292
free(env);
22393
#endif
94+
}
22495

96+
int main() {
97+
demo();
22598
return 0;
22699
}

0 commit comments

Comments
 (0)