|
| 1 | +#define CRAFTAX_ENABLE_ENV_IMPL |
1 | 2 | #include "craftax.h" |
| 3 | +#include "step_crafting.h" |
| 4 | +#include "step_update_mobs.h" |
| 5 | +#include "step_spawn_mobs.h" |
2 | 6 |
|
3 | | -#define OBS_SIZE 1345 |
| 7 | +#define OBS_SIZE CRAFTAX_OBS_SIZE |
4 | 8 | #define NUM_ATNS 1 |
5 | | -#define ACT_SIZES {17} |
| 9 | +#define ACT_SIZES {CRAFTAX_NUM_ACTIONS} |
6 | 10 | #define OBS_TENSOR_T FloatTensor |
7 | 11 |
|
8 | 12 | #define Env Craftax |
9 | 13 | #include "vecenv.h" |
10 | 14 |
|
11 | 15 | void my_init(Env* env, Dict* kwargs) { |
12 | | - // No per-env kwargs for Craftax-Classic: the 64x64 map, inventory sizes, |
13 | | - // mob caps, etc. are all compile-time constants. |
| 16 | + env->num_agents = 1; |
| 17 | + |
| 18 | + uint64_t seed_offset = 0; |
| 19 | + DictItem* item = dict_get_unsafe(kwargs, "seed_offset"); |
| 20 | + if (item != NULL) { |
| 21 | + seed_offset = (uint64_t)item->value; |
| 22 | + } |
| 23 | + env->seed = seed_offset + (uint64_t)env->rng; |
| 24 | + |
| 25 | + // Process-wide reset pool (first caller wins, rest block until ready). |
| 26 | + // 0 disables caching -- regenerate every reset (exact parity mode). |
| 27 | + int reset_pool_size = 0; |
| 28 | + DictItem* pool_item = dict_get_unsafe(kwargs, "reset_pool_size"); |
| 29 | + if (pool_item != NULL) reset_pool_size = (int)pool_item->value; |
| 30 | + craftax_set_reset_pool_size(reset_pool_size); |
| 31 | + |
14 | 32 | c_init(env); |
15 | 33 | } |
16 | 34 |
|
17 | 35 | void my_log(Log* log, Dict* out) { |
18 | | - dict_set(out, "perf", log->perf); |
19 | | - dict_set(out, "score", log->score); |
| 36 | + dict_set(out, "perf", log->perf); |
| 37 | + dict_set(out, "score", log->score); |
20 | 38 | dict_set(out, "episode_return", log->episode_return); |
21 | 39 | dict_set(out, "episode_length", log->episode_length); |
22 | 40 |
|
23 | | - static const char* ACH_NAMES[NUM_ACHIEVEMENTS] = { |
24 | | - "collect_wood", "place_table", "eat_cow", "collect_sapling", |
25 | | - "collect_drink", "make_wood_pick", "make_wood_sword","place_plant", |
26 | | - "defeat_zombie", "collect_stone", "place_stone", "eat_plant", |
27 | | - "defeat_skeleton","make_stone_pick","make_stone_sword","wake_up", |
28 | | - "place_furnace", "collect_coal", "collect_iron", "collect_diamond", |
29 | | - "make_iron_pick", "make_iron_sword", |
| 41 | + // Log 8 checkpoint achievements that form the tech / exploration curve. |
| 42 | + // perf (above) already aggregates all 67 into a normalized score; the |
| 43 | + // individual lines here are the milestones worth watching on a dashboard. |
| 44 | + // The env still tracks all 67 internally for reward and perf; we just |
| 45 | + // don't send every one through the log Dict. |
| 46 | + struct { const char* name; int idx; } checkpoints[] = { |
| 47 | + {"collect_wood", 0}, |
| 48 | + {"make_wood_pickaxe", 5}, |
| 49 | + {"make_stone_pickaxe", 13}, |
| 50 | + {"collect_iron", 18}, |
| 51 | + {"make_iron_pickaxe", 20}, |
| 52 | + {"collect_diamond", 19}, |
| 53 | + {"enter_gnomish_mines", 28}, |
| 54 | + {"defeat_necromancer", 48}, |
30 | 55 | }; |
31 | | - for (int i = 0; i < NUM_ACHIEVEMENTS; i++) { |
32 | | - dict_set(out, ACH_NAMES[i], log->achievements[i]); |
| 56 | + for (int i = 0; i < (int)(sizeof(checkpoints) / sizeof(checkpoints[0])); i++) { |
| 57 | + dict_set(out, checkpoints[i].name, log->achievements[checkpoints[i].idx]); |
33 | 58 | } |
34 | 59 | } |
0 commit comments