1- // Standalone C demo for drone environment
2- // Compile using: ./scripts/build_ocean.sh drone [local|fast]
3- // Run with: ./drone
4-
51#include "drone.h"
62#include "puffernet.h"
73#include "render.h"
117#include <emscripten.h>
128#endif
139
14- double randn (double mean , double std ) {
15- static int has_spare = 0 ;
16- static double spare ;
17-
18- if (has_spare ) {
19- has_spare = 0 ;
20- return mean + std * spare ;
21- }
22-
23- has_spare = 1 ;
24- double u , v , s ;
25- do {
26- u = 2.0 * rand () / RAND_MAX - 1.0 ;
27- v = 2.0 * rand () / RAND_MAX - 1.0 ;
28- s = u * u + v * v ;
29- } while (s >= 1.0 || s == 0.0 );
30-
31- s = sqrt (-2.0 * log (s ) / s );
32- spare = v * s ;
33- return mean + std * (u * s );
34- }
35-
36- #ifndef LINEAR_DIM
37- #define LINEAR_DIM 64
38- #endif
39-
40- #ifndef LSTM_DIM
41- #define LSTM_DIM 16
42- #endif
43-
44- typedef struct LinearContLSTM LinearContLSTM ;
45- struct LinearContLSTM {
46- int num_agents ;
47- float * obs ;
48- int num_actions ;
49- float * log_std ;
50- Linear * encoder1 ;
51- GELU * gelu1 ;
52- Linear * encoder2 ;
53- GELU * gelu2 ;
54- LSTM * lstm ;
55- Linear * actor ;
56- Linear * value_fn ;
57- };
58-
59- LinearContLSTM * make_linearcontlstm (Weights * weights , int num_agents , int input_dim ,
60- int logit_sizes [], int num_actions ) {
61- LinearContLSTM * net = calloc (1 , sizeof (LinearContLSTM ));
62- net -> num_agents = num_agents ;
63- net -> obs = calloc (num_agents * input_dim , sizeof (float ));
64- net -> num_actions = logit_sizes [0 ];
65-
66- // Must match export order exactly:
67- net -> log_std = get_weights (weights , net -> num_actions ); // 1. decoder_logstd
68- net -> encoder1 = make_linear (weights , num_agents , input_dim , LINEAR_DIM ); // 2-3. encoder.0
69- net -> gelu1 = make_gelu (num_agents , LINEAR_DIM );
70- net -> encoder2 = make_linear (weights , num_agents , LINEAR_DIM , LSTM_DIM ); // 4-5. encoder.2
71- net -> gelu2 = make_gelu (num_agents , LSTM_DIM );
72- net -> actor = make_linear (weights , num_agents , LSTM_DIM , net -> num_actions );// 6-7. decoder_mean
73- net -> value_fn = make_linear (weights , num_agents , LSTM_DIM , 1 ); // 8-9. value (FIX: was LSTM_DIM+4)
74- net -> lstm = make_lstm (weights , num_agents , LSTM_DIM , LSTM_DIM ); // 10-13. lstm
75-
76- return net ;
77- }
78-
79- void free_linearcontlstm (LinearContLSTM * net ) {
80- free (net -> obs );
81- free (net -> encoder1 );
82- free (net -> gelu1 );
83- free (net -> encoder2 );
84- free (net -> gelu2 );
85- free (net -> lstm );
86- free (net -> actor );
87- free (net -> value_fn );
88- free (net );
89- }
90-
91- void forward_linearcontlstm (LinearContLSTM * net , float * observations ) {
92- linear (net -> encoder1 , observations );
93- gelu (net -> gelu1 , net -> encoder1 -> output );
94- linear (net -> encoder2 , net -> gelu1 -> output );
95- gelu (net -> gelu2 , net -> encoder2 -> output );
96- lstm (net -> lstm , net -> gelu2 -> output );
97- linear (net -> actor , net -> lstm -> state_h );
98- }
99-
100- void sample_linearcontlstm (LinearContLSTM * net , float * actions , int deterministic ) {
101- for (int b = 0 ; b < net -> num_agents ; b ++ ) {
102- for (int i = 0 ; i < net -> num_actions ; i ++ ) {
103- int idx = b * net -> num_actions + i ;
104- float mean = net -> actor -> output [idx ];
105- if (deterministic ) {
106- actions [idx ] = mean ;
107- } else {
108- float std = expf (net -> log_std [i ]);
109- actions [idx ] = (float )randn (mean , std );
110- }
111- }
112- }
113- }
114-
115- void generate_dummy_actions (DroneEnv * env ) {
116- // Generate random floats in [-1, 1] range
117- env -> actions [0 ] = ((float )rand () / (float )RAND_MAX ) * 2.0f - 1.0f ;
118- env -> actions [1 ] = ((float )rand () / (float )RAND_MAX ) * 2.0f - 1.0f ;
119- env -> actions [2 ] = ((float )rand () / (float )RAND_MAX ) * 2.0f - 1.0f ;
120- env -> actions [3 ] = ((float )rand () / (float )RAND_MAX ) * 2.0f - 1.0f ;
121- }
122-
12310#ifdef __EMSCRIPTEN__
12411typedef struct {
12512 DroneEnv * env ;
126- LinearContLSTM * net ;
13+ PufferNet * net ;
12714 Weights * weights ;
12815} WebRenderArgs ;
12916
13017void emscriptenStep (void * e ) {
13118 WebRenderArgs * args = (WebRenderArgs * )e ;
13219 DroneEnv * env = args -> env ;
133- LinearContLSTM * net = args -> net ;
20+ PufferNet * net = args -> net ;
21+ size_t obs_size = 23 ;
13422
13523 for (int i = 0 ; i < env -> num_agents ; i ++ ) {
13624 int base = i * obs_size ;
@@ -140,22 +28,19 @@ void emscriptenStep(void* e) {
14028 env -> observations [base + 22 ] = 0.0f ;
14129 }
14230
143- forward_linearcontlstm (net , env -> observations );
144- sample_linearcontlstm (net , env -> actions , 0 );
31+ forward_puffernet (net , env -> observations , env -> actions );
14532 c_step (env );
14633 c_render (env );
147- return ;
14834}
14935
15036WebRenderArgs * web_args = NULL ;
15137#endif
15238
153- int main () {
154- srand (time (NULL )); // Seed random number generator
39+ void demo () {
40+ srand (time (NULL ));
15541
15642 DroneEnv * env = calloc (1 , sizeof (DroneEnv ));
15743 size_t obs_size = 23 ;
158- size_t act_size = 4 ;
15944
16045 env -> num_agents = 64 ;
16146 env -> max_rings = 10 ;
@@ -166,26 +51,13 @@ int main() {
16651 env -> hover_vel = 0.01 ;
16752 init (env );
16853
169- env -> observations = (float * )calloc (env -> num_agents * obs_size , sizeof (float ));
170- env -> actions = (float * )calloc (env -> num_agents * act_size , sizeof (float ));
171- env -> rewards = (float * )calloc (env -> num_agents , sizeof (float ));
172- env -> terminals = (float * )calloc (env -> num_agents , sizeof (float ));
54+ allocate (env );
17355
17456 Weights * weights = load_weights ("resources/drone/puffer_drone_weights.bin" , 4841 );
175- int logit_sizes [1 ] = {4 };
176- LinearContLSTM * net = make_linearcontlstm (weights , env -> num_agents , obs_size , logit_sizes , 1 );
177-
178- if (!env -> observations || !env -> actions || !env -> rewards ) {
179- fprintf (stderr , "ERROR: Failed to allocate memory for demo buffers.\n" );
180- free (env -> observations );
181- free (env -> actions );
182- free (env -> rewards );
183- free (env -> terminals );
184- free (env );
185- return 0 ;
186- }
57+ int logit_sizes [4 ] = {1 , 1 , 1 , 1 };
58+ // make_puffernet(weights, num_agents, obs_size, hidden_size, num_layers, logit_sizes, num_actions)
59+ PufferNet * net = make_puffernet (weights , env -> num_agents , obs_size , 64 , 2 , logit_sizes , 4 );
18760
188- init (env );
18961 c_reset (env );
19062
19163#ifdef __EMSCRIPTEN__
@@ -198,6 +70,7 @@ int main() {
19870 emscripten_set_main_loop_arg (emscriptenStep , args , 0 , true);
19971#else
20072 c_render (env );
73+ SetTargetFPS (60 );
20174
20275 while (!WindowShouldClose ()) {
20376 for (int i = 0 ; i < env -> num_agents ; i ++ ) {
@@ -207,20 +80,20 @@ int main() {
20780 env -> observations [base + 21 ] = 0.0f ;
20881 env -> observations [base + 22 ] = 0.0f ;
20982 }
210- forward_linearcontlstm (net , env -> observations );
211- sample_linearcontlstm (net , env -> actions , 0 );
83+ forward_puffernet (net , env -> observations , env -> actions );
21284 c_step (env );
21385 c_render (env );
21486 }
21587
21688 c_close (env );
217- free_linearcontlstm (net );
218- free (env -> observations );
219- free (env -> actions );
220- free (env -> rewards );
221- free (env -> terminals );
89+ free_puffernet (net );
90+ free (weights );
91+ free_allocated (env );
22292 free (env );
22393#endif
94+ }
22495
96+ int main () {
97+ demo ();
22598 return 0 ;
22699}
0 commit comments