44#include "puffernet.h"
55#include "nmmo3.h"
66
7- // Only rens a few agents in the C
8- // version, and reduces for web.
7+ // Only run 1 agent in the C version
98// You can run the full 1024 on GPU
10- // with PyTorch.
11- #if defined(PLATFORM_WEB )
12- #define NUM_AGENTS 4
13- #else
14- #define NUM_AGENTS 16
15- #endif
16-
9+ #define NUM_AGENTS 1
1710
1811typedef struct MMONet MMONet ;
1912struct MMONet {
@@ -27,12 +20,10 @@ struct MMONet {
2720 Conv2D * map_conv2 ;
2821 Embedding * player_embed ;
2922 float * proj_buffer ;
30- Linear * proj ;
23+ Affine * proj ;
3124 ReLU * proj_relu ;
32- LayerNorm * layer_norm ;
33- LSTM * lstm ;
34- Linear * actor ;
35- Linear * value_fn ;
25+ Linear * decoder ;
26+ MinGRU * mingru ;
3627 Multidiscrete * multidiscrete ;
3728};
3829
@@ -49,12 +40,10 @@ MMONet* init_mmonet(Weights* weights, int num_agents) {
4940 net -> map_conv2 = make_conv2d (weights , num_agents , 4 , 3 , 128 , 128 , 3 , 1 );
5041 net -> player_embed = make_embedding (weights , num_agents * 47 , 128 , 32 );
5142 net -> proj_buffer = calloc (num_agents * 1817 , sizeof (float ));
52- net -> proj = make_linear (weights , num_agents , 1817 , hidden );
43+ net -> proj = make_affine (weights , num_agents , 1817 , hidden );
5344 net -> proj_relu = make_relu (num_agents , hidden );
54- net -> layer_norm = make_layernorm (weights , num_agents , hidden );
55- net -> actor = make_linear (weights , num_agents , hidden , 26 );
56- net -> value_fn = make_linear (weights , num_agents , hidden , 1 );
57- net -> lstm = make_lstm (weights , num_agents , hidden , hidden );
45+ net -> decoder = make_linear (weights , num_agents , hidden , 26 + 1 );
46+ net -> mingru = make_mingru (weights , num_agents , hidden , 4 );
5847 int logit_sizes [1 ] = {26 };
5948 net -> multidiscrete = make_multidiscrete (num_agents , logit_sizes , 1 );
6049 return net ;
@@ -72,15 +61,21 @@ void free_mmonet(MMONet* net) {
7261 free (net -> proj_buffer );
7362 free (net -> proj );
7463 free (net -> proj_relu );
75- free (net -> layer_norm );
76- free (net -> actor );
77- free (net -> value_fn );
78- free (net -> lstm );
64+ free (net -> decoder );
65+ free_mingru (net -> mingru );
7966 free (net -> multidiscrete );
8067 free (net );
8168}
8269
83- void forward (MMONet * net , unsigned char * observations , int * actions ) {
70+ void forward (MMONet * net , unsigned char * observations , float * terminals , float * actions ) {
71+ for (int b = 0 ; b < net -> num_agents ; b ++ ) {
72+ if (terminals [b ] > 0.5f ) {
73+ for (int l = 0 ; l < net -> mingru -> num_layers ; l ++ ) {
74+ memset (net -> mingru -> state + l * net -> mingru -> batch_size * net -> mingru -> hidden_size + b * net -> mingru -> hidden_size , 0 , net -> mingru -> hidden_size * sizeof (float ));
75+ }
76+ terminals [b ] = 0.0f ;
77+ }
78+ }
8479 memset (net -> ob_map , 0 , net -> num_agents * 11 * 15 * 59 * sizeof (float ));
8580
8681 // DUMMY INPUT FOR TESTING
@@ -147,53 +142,55 @@ void forward(MMONet* net, unsigned char* observations, int* actions) {
147142 }
148143 }
149144
150- linear (net -> proj , net -> proj_buffer );
145+ affine (net -> proj , net -> proj_buffer );
151146 relu (net -> proj_relu , net -> proj -> output );
152147
153- lstm (net -> lstm , net -> proj_relu -> output );
154- layernorm (net -> layer_norm , net -> lstm -> state_h );
155-
156- linear (net -> actor , net -> layer_norm -> output );
157- linear (net -> value_fn , net -> layer_norm -> output );
148+ mingru (net -> mingru , net -> proj_relu -> output );
149+ linear (net -> decoder , net -> mingru -> output );
158150
159- softmax_multidiscrete (net -> multidiscrete , net -> actor -> output , actions );
151+ softmax_multidiscrete (net -> multidiscrete , net -> decoder -> output , actions );
160152}
161153
162154void demo (int num_players ) {
163- Weights * weights = load_weights ("resources/nmmo3/nmmo3_weights.bin" , 3387547 );
155+ Weights * weights = load_weights ("resources/nmmo3/nmmo3_weights.bin" , 4430976 );
164156 MMONet * net = init_mmonet (weights , num_players );
165157
166158 MMO env = {
167159 .client = NULL ,
168160 .width = 512 ,
169161 .height = 512 ,
170- .num_players = num_players ,
162+ .num_agents = num_players ,
171163 .num_enemies = 2048 ,
172164 .num_resources = 2048 ,
173165 .num_weapons = 1024 ,
174166 .num_gems = 512 ,
175167 .tiers = 5 ,
176168 .levels = 40 ,
177- .teleportitis_prob = 0.0 ,
169+ .teleportitis_prob = 0.001 ,
178170 .enemy_respawn_ticks = 2 ,
179171 .item_respawn_ticks = 100 ,
180172 .x_window = 7 ,
181173 .y_window = 5 ,
174+ .reward_combat_level = 1.0 ,
175+ .reward_prof_level = 1.0 ,
176+ .reward_item_level = 1.0 ,
177+ .reward_market = 0.0 ,
178+ .reward_death = -1.0 ,
182179 };
183180 allocate_mmo (& env );
184181
185182 c_reset (& env );
186183 c_render (& env );
187184
188- int human_action = ATN_NOOP ;
185+ float human_action = ATN_NOOP ;
189186 bool human_mode = false;
190187 int i = 1 ;
191188 while (!WindowShouldClose ()) {
192189 if (IsKeyPressed (KEY_LEFT_CONTROL )) {
193190 human_mode = !human_mode ;
194191 }
195192 if (i % 36 == 0 ) {
196- forward (net , env .observations , env .actions );
193+ forward (net , env .observations , env .terminals , env . actions );
197194 if (human_mode ) {
198195 env .actions [0 ] = human_action ;
199196 }
@@ -221,7 +218,7 @@ void test_mmonet_performance(int num_players, int timeout) {
221218 MMO env = {
222219 .width = 512 ,
223220 .height = 512 ,
224- .num_players = num_players ,
221+ .num_agents = num_players ,
225222 .num_enemies = 128 ,
226223 .num_resources = 32 ,
227224 .num_weapons = 32 ,
@@ -240,7 +237,7 @@ void test_mmonet_performance(int num_players, int timeout) {
240237 int start = time (NULL );
241238 int num_steps = 0 ;
242239 while (time (NULL ) - start < timeout ) {
243- forward (net , env .observations , env .actions );
240+ forward (net , env .observations , env .terminals , env . actions );
244241 c_step (& env );
245242 num_steps ++ ;
246243 }
@@ -410,7 +407,7 @@ void test_cellular_automata(int width, int height, int colors, int max_fill) {
410407void test_generate_terrain (int width , int height , int x_border , int y_border ) {
411408 char terrain [width ][height ];
412409 unsigned char rendered [width ][height ][3 ];
413- generate_terrain ((char * )terrain , (unsigned char * )rendered , width , height , x_border , y_border );
410+ unsigned int rng = 42 ; generate_terrain ((char * )terrain , (unsigned char * )rendered , width , height , x_border , y_border , & rng );
414411
415412
416413 // Colorize
@@ -435,7 +432,7 @@ void test_performance(int num_players, int timeout) {
435432 MMO env = {
436433 .width = 512 ,
437434 .height = 512 ,
438- .num_players = num_players ,
435+ .num_agents = num_players ,
439436 .num_enemies = 128 ,
440437 .num_resources = 32 ,
441438 .num_weapons = 32 ,
@@ -467,6 +464,64 @@ void test_performance(int num_players, int timeout) {
467464 free_allocated_mmo (& env );
468465}
469466
467+ void test_no_render_log (int num_players , int target_episodes ) {
468+ Weights * weights = load_weights ("resources/nmmo3/nmmo3_weights.bin" , 4430976 );
469+ MMONet * net = init_mmonet (weights , num_players );
470+
471+ MMO env = {
472+ .client = NULL ,
473+ .width = 512 ,
474+ .height = 512 ,
475+ .num_agents = num_players ,
476+ .num_enemies = 2048 ,
477+ .num_resources = 2048 ,
478+ .num_weapons = 1024 ,
479+ .num_gems = 512 ,
480+ .tiers = 5 ,
481+ .levels = 40 ,
482+ .teleportitis_prob = 0.001 ,
483+ .enemy_respawn_ticks = 2 ,
484+ .item_respawn_ticks = 100 ,
485+ .x_window = 7 ,
486+ .y_window = 5 ,
487+ .reward_combat_level = 1.0 ,
488+ .reward_prof_level = 1.0 ,
489+ .reward_item_level = 1.0 ,
490+ .reward_market = 0.0 ,
491+ .reward_death = -1.0 ,
492+ };
493+ allocate_mmo (& env );
494+ c_reset (& env );
495+
496+ int num_steps = 0 ;
497+ int prev_n = 0 ;
498+ float prev_mcp = 0.0f ;
499+ while ((int )env .log .n < target_episodes ) {
500+ forward (net , env .observations , env .terminals , env .actions );
501+ c_step (& env );
502+ num_steps ++ ;
503+
504+ int curr_n = (int )env .log .n ;
505+ if (curr_n > prev_n ) {
506+ float ep_mcp = env .log .min_comb_prof - prev_mcp ;
507+ float running_mean = env .log .min_comb_prof / (float )curr_n ;
508+ printf ("Episode %d: min_comb_prof=%.3f running_mean=%.4f (step %d)\n" ,
509+ curr_n , ep_mcp , running_mean , num_steps );
510+ prev_n = curr_n ;
511+ prev_mcp = env .log .min_comb_prof ;
512+ }
513+ }
514+
515+ printf ("\n--- C eval summary (%d episodes, %d steps) ---\n" ,
516+ prev_n , num_steps );
517+ printf ("mean min_comb_prof = %.4f\n" ,
518+ env .log .min_comb_prof / (float )prev_n );
519+
520+ free_allocated_mmo (& env );
521+ free_mmonet (net );
522+ free (weights );
523+ }
524+
470525int main () {
471526
472527 /*
@@ -481,6 +536,6 @@ int main() {
481536 test_generate_terrain(width, height, 8, 8);
482537 */
483538 //test_performance(64, 10);
539+ //test_no_render_log(1, 100);
484540 demo (NUM_AGENTS );
485- //test_mmonet_performance(1024, 10);
486541}
0 commit comments