Skip to content

Commit 3a53e10

Browse files
committed
Nmmo3 model port
1 parent 1ae8df6 commit 3a53e10

4 files changed

Lines changed: 185 additions & 201 deletions

File tree

ocean/g2048/g2048_net.h

Lines changed: 0 additions & 131 deletions
This file was deleted.

ocean/nmmo3/nmmo3.c

Lines changed: 96 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,9 @@
44
#include "puffernet.h"
55
#include "nmmo3.h"
66

7-
// Only rens a few agents in the C
8-
// version, and reduces for web.
7+
// Only run 1 agent in the C version
98
// You can run the full 1024 on GPU
10-
// with PyTorch.
11-
#if defined(PLATFORM_WEB)
12-
#define NUM_AGENTS 4
13-
#else
14-
#define NUM_AGENTS 16
15-
#endif
16-
9+
#define NUM_AGENTS 1
1710

1811
typedef struct MMONet MMONet;
1912
struct MMONet {
@@ -27,12 +20,10 @@ struct MMONet {
2720
Conv2D* map_conv2;
2821
Embedding* player_embed;
2922
float* proj_buffer;
30-
Linear* proj;
23+
Affine* proj;
3124
ReLU* proj_relu;
32-
LayerNorm* layer_norm;
33-
LSTM* lstm;
34-
Linear* actor;
35-
Linear* value_fn;
25+
Linear* decoder;
26+
MinGRU* mingru;
3627
Multidiscrete* multidiscrete;
3728
};
3829

@@ -49,12 +40,10 @@ MMONet* init_mmonet(Weights* weights, int num_agents) {
4940
net->map_conv2 = make_conv2d(weights, num_agents, 4, 3, 128, 128, 3, 1);
5041
net->player_embed = make_embedding(weights, num_agents*47, 128, 32);
5142
net->proj_buffer = calloc(num_agents*1817, sizeof(float));
52-
net->proj = make_linear(weights, num_agents, 1817, hidden);
43+
net->proj = make_affine(weights, num_agents, 1817, hidden);
5344
net->proj_relu = make_relu(num_agents, hidden);
54-
net->layer_norm = make_layernorm(weights, num_agents, hidden);
55-
net->actor = make_linear(weights, num_agents, hidden, 26);
56-
net->value_fn = make_linear(weights, num_agents, hidden, 1);
57-
net->lstm = make_lstm(weights, num_agents, hidden, hidden);
45+
net->decoder = make_linear(weights, num_agents, hidden, 26 + 1);
46+
net->mingru = make_mingru(weights, num_agents, hidden, 4);
5847
int logit_sizes[1] = {26};
5948
net->multidiscrete = make_multidiscrete(num_agents, logit_sizes, 1);
6049
return net;
@@ -72,15 +61,21 @@ void free_mmonet(MMONet* net) {
7261
free(net->proj_buffer);
7362
free(net->proj);
7463
free(net->proj_relu);
75-
free(net->layer_norm);
76-
free(net->actor);
77-
free(net->value_fn);
78-
free(net->lstm);
64+
free(net->decoder);
65+
free_mingru(net->mingru);
7966
free(net->multidiscrete);
8067
free(net);
8168
}
8269

83-
void forward(MMONet* net, unsigned char* observations, int* actions) {
70+
void forward(MMONet* net, unsigned char* observations, float* terminals, float* actions) {
71+
for (int b = 0; b < net->num_agents; b++) {
72+
if (terminals[b] > 0.5f) {
73+
for (int l = 0; l < net->mingru->num_layers; l++) {
74+
memset(net->mingru->state + l * net->mingru->batch_size * net->mingru->hidden_size + b * net->mingru->hidden_size, 0, net->mingru->hidden_size * sizeof(float));
75+
}
76+
terminals[b] = 0.0f;
77+
}
78+
}
8479
memset(net->ob_map, 0, net->num_agents*11*15*59*sizeof(float));
8580

8681
// DUMMY INPUT FOR TESTING
@@ -147,53 +142,55 @@ void forward(MMONet* net, unsigned char* observations, int* actions) {
147142
}
148143
}
149144

150-
linear(net->proj, net->proj_buffer);
145+
affine(net->proj, net->proj_buffer);
151146
relu(net->proj_relu, net->proj->output);
152147

153-
lstm(net->lstm, net->proj_relu->output);
154-
layernorm(net->layer_norm, net->lstm->state_h);
155-
156-
linear(net->actor, net->layer_norm->output);
157-
linear(net->value_fn, net->layer_norm->output);
148+
mingru(net->mingru, net->proj_relu->output);
149+
linear(net->decoder, net->mingru->output);
158150

159-
softmax_multidiscrete(net->multidiscrete, net->actor->output, actions);
151+
softmax_multidiscrete(net->multidiscrete, net->decoder->output, actions);
160152
}
161153

162154
void demo(int num_players) {
163-
Weights* weights = load_weights("resources/nmmo3/nmmo3_weights.bin", 3387547);
155+
Weights* weights = load_weights("resources/nmmo3/nmmo3_weights.bin", 4430976);
164156
MMONet* net = init_mmonet(weights, num_players);
165157

166158
MMO env = {
167159
.client = NULL,
168160
.width = 512,
169161
.height = 512,
170-
.num_players = num_players,
162+
.num_agents = num_players,
171163
.num_enemies = 2048,
172164
.num_resources = 2048,
173165
.num_weapons = 1024,
174166
.num_gems = 512,
175167
.tiers = 5,
176168
.levels = 40,
177-
.teleportitis_prob = 0.0,
169+
.teleportitis_prob = 0.001,
178170
.enemy_respawn_ticks = 2,
179171
.item_respawn_ticks = 100,
180172
.x_window = 7,
181173
.y_window = 5,
174+
.reward_combat_level = 1.0,
175+
.reward_prof_level = 1.0,
176+
.reward_item_level = 1.0,
177+
.reward_market = 0.0,
178+
.reward_death = -1.0,
182179
};
183180
allocate_mmo(&env);
184181

185182
c_reset(&env);
186183
c_render(&env);
187184

188-
int human_action = ATN_NOOP;
185+
float human_action = ATN_NOOP;
189186
bool human_mode = false;
190187
int i = 1;
191188
while (!WindowShouldClose()) {
192189
if (IsKeyPressed(KEY_LEFT_CONTROL)) {
193190
human_mode = !human_mode;
194191
}
195192
if (i % 36 == 0) {
196-
forward(net, env.observations, env.actions);
193+
forward(net, env.observations, env.terminals, env.actions);
197194
if (human_mode) {
198195
env.actions[0] = human_action;
199196
}
@@ -221,7 +218,7 @@ void test_mmonet_performance(int num_players, int timeout) {
221218
MMO env = {
222219
.width = 512,
223220
.height = 512,
224-
.num_players = num_players,
221+
.num_agents = num_players,
225222
.num_enemies = 128,
226223
.num_resources = 32,
227224
.num_weapons = 32,
@@ -240,7 +237,7 @@ void test_mmonet_performance(int num_players, int timeout) {
240237
int start = time(NULL);
241238
int num_steps = 0;
242239
while (time(NULL) - start < timeout) {
243-
forward(net, env.observations, env.actions);
240+
forward(net, env.observations, env.terminals, env.actions);
244241
c_step(&env);
245242
num_steps++;
246243
}
@@ -410,7 +407,7 @@ void test_cellular_automata(int width, int height, int colors, int max_fill) {
410407
void test_generate_terrain(int width, int height, int x_border, int y_border) {
411408
char terrain[width][height];
412409
unsigned char rendered[width][height][3];
413-
generate_terrain((char*)terrain, (unsigned char*)rendered, width, height, x_border, y_border);
410+
unsigned int rng = 42; generate_terrain((char*)terrain, (unsigned char*)rendered, width, height, x_border, y_border, &rng);
414411

415412

416413
// Colorize
@@ -435,7 +432,7 @@ void test_performance(int num_players, int timeout) {
435432
MMO env = {
436433
.width = 512,
437434
.height = 512,
438-
.num_players = num_players,
435+
.num_agents = num_players,
439436
.num_enemies = 128,
440437
.num_resources = 32,
441438
.num_weapons = 32,
@@ -467,6 +464,64 @@ void test_performance(int num_players, int timeout) {
467464
free_allocated_mmo(&env);
468465
}
469466

467+
void test_no_render_log(int num_players, int target_episodes) {
468+
Weights* weights = load_weights("resources/nmmo3/nmmo3_weights.bin", 4430976);
469+
MMONet* net = init_mmonet(weights, num_players);
470+
471+
MMO env = {
472+
.client = NULL,
473+
.width = 512,
474+
.height = 512,
475+
.num_agents = num_players,
476+
.num_enemies = 2048,
477+
.num_resources = 2048,
478+
.num_weapons = 1024,
479+
.num_gems = 512,
480+
.tiers = 5,
481+
.levels = 40,
482+
.teleportitis_prob = 0.001,
483+
.enemy_respawn_ticks = 2,
484+
.item_respawn_ticks = 100,
485+
.x_window = 7,
486+
.y_window = 5,
487+
.reward_combat_level = 1.0,
488+
.reward_prof_level = 1.0,
489+
.reward_item_level = 1.0,
490+
.reward_market = 0.0,
491+
.reward_death = -1.0,
492+
};
493+
allocate_mmo(&env);
494+
c_reset(&env);
495+
496+
int num_steps = 0;
497+
int prev_n = 0;
498+
float prev_mcp = 0.0f;
499+
while ((int)env.log.n < target_episodes) {
500+
forward(net, env.observations, env.terminals, env.actions);
501+
c_step(&env);
502+
num_steps++;
503+
504+
int curr_n = (int)env.log.n;
505+
if (curr_n > prev_n) {
506+
float ep_mcp = env.log.min_comb_prof - prev_mcp;
507+
float running_mean = env.log.min_comb_prof / (float)curr_n;
508+
printf("Episode %d: min_comb_prof=%.3f running_mean=%.4f (step %d)\n",
509+
curr_n, ep_mcp, running_mean, num_steps);
510+
prev_n = curr_n;
511+
prev_mcp = env.log.min_comb_prof;
512+
}
513+
}
514+
515+
printf("\n--- C eval summary (%d episodes, %d steps) ---\n",
516+
prev_n, num_steps);
517+
printf("mean min_comb_prof = %.4f\n",
518+
env.log.min_comb_prof / (float)prev_n);
519+
520+
free_allocated_mmo(&env);
521+
free_mmonet(net);
522+
free(weights);
523+
}
524+
470525
int main() {
471526

472527
/*
@@ -481,6 +536,6 @@ int main() {
481536
test_generate_terrain(width, height, 8, 8);
482537
*/
483538
//test_performance(64, 10);
539+
//test_no_render_log(1, 100);
484540
demo(NUM_AGENTS);
485-
//test_mmonet_performance(1024, 10);
486541
}

ocean/nmmo3/nmmo3.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,9 @@ void add_player_log(MMO* env, int pid) {
759759
log->perf = log->min_comb_prof / (float)env->levels;
760760
log->n++;
761761
*ret = (Reward){0};
762+
if (pid < env->num_agents) {
763+
env->terminals[pid] = 1.0f;
764+
}
762765
}
763766

764767
void init(MMO* env) {

0 commit comments

Comments
 (0)