Skip to content

Commit f9f7155

Browse files
authored
Merge pull request #537 from Infatoshi/craftax-full-pr
Craftax Full: native C port + optimizations + renderer
2 parents abeb03c + e122c9a commit f9f7155

25 files changed

Lines changed: 10532 additions & 916 deletions

build.sh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,13 @@ if [ -z "$NCCL_LFLAG" ]; then
208208
NCCL_LFLAG=$(python -c "import nvidia.nccl, os; print('-L' + os.path.join(nvidia.nccl.__path__[0], 'lib'))" 2>/dev/null || echo "")
209209
fi
210210

211+
WHEEL_RPATH_FLAGS=()
212+
for lib_flag in "$CUDNN_LFLAG" "$NCCL_LFLAG"; do
213+
if [[ "$lib_flag" == -L* ]]; then
214+
WHEEL_RPATH_FLAGS+=("-Wl,-rpath,${lib_flag#-L}")
215+
fi
216+
done
217+
211218
export CCACHE_DIR="${CCACHE_DIR:-$HOME/.ccache}"
212219
export CCACHE_BASEDIR="$(pwd)"
213220
export CCACHE_COMPILERCHECK=content
@@ -232,7 +239,7 @@ if [ ! -f "$BINDING_SRC" ]; then
232239
fi
233240

234241
echo "Compiling static library for $ENV..."
235-
${CC:-clang} -c "${CLANG_OPT[@]}" \
242+
${CC:-clang} -c "${CLANG_OPT[@]}" $EXTRA_CFLAGS \
236243
-I. -Isrc -I$SRC_DIR -Ivendor \
237244
-I./$RAYLIB_NAME/include -I$CUDA_HOME/include \
238245
-DPLATFORM_DESKTOP \
@@ -268,6 +275,7 @@ if [ -z "$MODE" ]; then
268275
${CXX:-g++} -shared -fPIC -fopenmp
269276
build/bindings.o "$STATIC_LIB" "$RAYLIB_A"
270277
-L$CUDA_HOME/lib64 $CUDNN_LFLAG $NCCL_LFLAG
278+
"${WHEEL_RPATH_FLAGS[@]}"
271279
-lcudart -lnccl -lnvidia-ml -lcublas -lcusolver -lcurand -lcudnn
272280
$OMP_LIB $LINK_OPT
273281
"${SHARED_LDFLAGS[@]}"

config/craftax.ini

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ num_buffers = 4
77
num_threads = 16
88

99
[env]
10+
seed_offset = 0
11+
# Pre-generated world pool. Each reset memcpys from a pool entry
12+
# instead of re-running generate_world (~ms -> ~us per reset).
13+
# Bounds world diversity: at most reset_pool_size unique maps are
14+
# ever seen per process. Set to 0 to disable (required for the
15+
# parity harness to maintain exact per-seed determinism).
16+
reset_pool_size = 1024
1017

1118
[train]
1219
total_timesteps = 200_000_000

config/craftax_classic.ini

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[base]
2+
env_name = craftax_classic
3+
4+
[vec]
5+
total_agents = 8192
6+
num_buffers = 4
7+
num_threads = 16
8+
9+
[env]
10+
11+
[train]
12+
total_timesteps = 200_000_000

ocean/craftax/PORT_NOTES.md

Lines changed: 543 additions & 0 deletions
Large diffs are not rendered by default.

ocean/craftax/binding.c

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,59 @@
1+
#define CRAFTAX_ENABLE_ENV_IMPL
12
#include "craftax.h"
3+
#include "step_crafting.h"
4+
#include "step_update_mobs.h"
5+
#include "step_spawn_mobs.h"
26

3-
#define OBS_SIZE 1345
7+
#define OBS_SIZE CRAFTAX_OBS_SIZE
48
#define NUM_ATNS 1
5-
#define ACT_SIZES {17}
9+
#define ACT_SIZES {CRAFTAX_NUM_ACTIONS}
610
#define OBS_TENSOR_T FloatTensor
711

812
#define Env Craftax
913
#include "vecenv.h"
1014

1115
void my_init(Env* env, Dict* kwargs) {
12-
// No per-env kwargs for Craftax-Classic: the 64x64 map, inventory sizes,
13-
// mob caps, etc. are all compile-time constants.
16+
env->num_agents = 1;
17+
18+
uint64_t seed_offset = 0;
19+
DictItem* item = dict_get_unsafe(kwargs, "seed_offset");
20+
if (item != NULL) {
21+
seed_offset = (uint64_t)item->value;
22+
}
23+
env->seed = seed_offset + (uint64_t)env->rng;
24+
25+
// Process-wide reset pool (first caller wins, rest block until ready).
26+
// 0 disables caching -- regenerate every reset (exact parity mode).
27+
int reset_pool_size = 0;
28+
DictItem* pool_item = dict_get_unsafe(kwargs, "reset_pool_size");
29+
if (pool_item != NULL) reset_pool_size = (int)pool_item->value;
30+
craftax_set_reset_pool_size(reset_pool_size);
31+
1432
c_init(env);
1533
}
1634

1735
void my_log(Log* log, Dict* out) {
18-
dict_set(out, "perf", log->perf);
19-
dict_set(out, "score", log->score);
36+
dict_set(out, "perf", log->perf);
37+
dict_set(out, "score", log->score);
2038
dict_set(out, "episode_return", log->episode_return);
2139
dict_set(out, "episode_length", log->episode_length);
2240

23-
static const char* ACH_NAMES[NUM_ACHIEVEMENTS] = {
24-
"collect_wood", "place_table", "eat_cow", "collect_sapling",
25-
"collect_drink", "make_wood_pick", "make_wood_sword","place_plant",
26-
"defeat_zombie", "collect_stone", "place_stone", "eat_plant",
27-
"defeat_skeleton","make_stone_pick","make_stone_sword","wake_up",
28-
"place_furnace", "collect_coal", "collect_iron", "collect_diamond",
29-
"make_iron_pick", "make_iron_sword",
41+
// Log 8 checkpoint achievements that form the tech / exploration curve.
42+
// perf (above) already aggregates all 67 into a normalized score; the
43+
// individual lines here are the milestones worth watching on a dashboard.
44+
// The env still tracks all 67 internally for reward and perf; we just
45+
// don't send every one through the log Dict.
46+
struct { const char* name; int idx; } checkpoints[] = {
47+
{"collect_wood", 0},
48+
{"make_wood_pickaxe", 5},
49+
{"make_stone_pickaxe", 13},
50+
{"collect_iron", 18},
51+
{"make_iron_pickaxe", 20},
52+
{"collect_diamond", 19},
53+
{"enter_gnomish_mines", 28},
54+
{"defeat_necromancer", 48},
3055
};
31-
for (int i = 0; i < NUM_ACHIEVEMENTS; i++) {
32-
dict_set(out, ACH_NAMES[i], log->achievements[i]);
56+
for (int i = 0; i < (int)(sizeof(checkpoints) / sizeof(checkpoints[0])); i++) {
57+
dict_set(out, checkpoints[i].name, log->achievements[checkpoints[i].idx]);
3358
}
3459
}

ocean/craftax/craftax.c

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Standalone viewer for Craftax (random-action policy).
2+
//
3+
// Build:
4+
// ./build.sh craftax --fast # optimized
5+
// ./build.sh craftax --local # debug with sanitizers
6+
// Run:
7+
// ./craftax
8+
9+
#define CRAFTAX_ENABLE_ENV_IMPL
10+
#include "craftax.h"
11+
#include "step_crafting.h"
12+
#include "step_update_mobs.h"
13+
#include "step_spawn_mobs.h"
14+
15+
#include <stdio.h>
16+
#include <stdlib.h>
17+
#include <time.h>
18+
19+
static uint32_t xorshift32(uint32_t* s) {
20+
uint32_t x = *s;
21+
x ^= x << 13; x ^= x >> 17; x ^= x << 5;
22+
*s = x ? x : 0xdeadbeef;
23+
return x;
24+
}
25+
26+
int main(int argc, char** argv) {
27+
uint64_t seed = (argc > 1) ? strtoull(argv[1], NULL, 10) : (uint64_t)time(NULL);
28+
29+
Craftax env;
30+
memset(&env, 0, sizeof(env));
31+
env.num_agents = 1;
32+
env.seed = seed;
33+
env.rng = (uint32_t)seed;
34+
35+
// Minimal buffers for a single agent
36+
env.observations = calloc(CRAFTAX_OBS_SIZE, sizeof(float));
37+
env.actions = calloc(1, sizeof(float));
38+
env.rewards = calloc(1, sizeof(float));
39+
env.terminals = calloc(1, sizeof(float));
40+
41+
c_init(&env);
42+
c_reset(&env);
43+
44+
uint32_t action_rng = (uint32_t)(seed ^ 0x9E3779B9u);
45+
bool human_control = false;
46+
int human_action = CRAFTAX_ACTION_NOOP;
47+
48+
while (!WindowShouldClose()) {
49+
// Toggle human control
50+
if (IsKeyPressed(KEY_H)) human_control = !human_control;
51+
52+
if (human_control) {
53+
human_action = CRAFTAX_ACTION_NOOP;
54+
if (IsKeyPressed(KEY_A) || IsKeyPressed(KEY_LEFT)) human_action = CRAFTAX_ACTION_LEFT;
55+
if (IsKeyPressed(KEY_D) || IsKeyPressed(KEY_RIGHT)) human_action = CRAFTAX_ACTION_RIGHT;
56+
if (IsKeyPressed(KEY_W) || IsKeyPressed(KEY_UP)) human_action = CRAFTAX_ACTION_UP;
57+
if (IsKeyPressed(KEY_S) || IsKeyPressed(KEY_DOWN)) human_action = CRAFTAX_ACTION_DOWN;
58+
if (IsKeyPressed(KEY_SPACE)) human_action = CRAFTAX_ACTION_DO;
59+
if (IsKeyPressed(KEY_Z)) human_action = CRAFTAX_ACTION_SLEEP;
60+
env.actions[0] = (float)human_action;
61+
if (human_action != CRAFTAX_ACTION_NOOP || IsKeyPressed(KEY_PERIOD)) c_step(&env);
62+
} else {
63+
env.actions[0] = (float)(xorshift32(&action_rng) % CRAFTAX_NUM_ACTIONS);
64+
c_step(&env);
65+
}
66+
67+
c_render(&env);
68+
}
69+
70+
c_close(&env);
71+
free(env.observations);
72+
free(env.actions);
73+
free(env.rewards);
74+
free(env.terminals);
75+
return 0;
76+
}

0 commit comments

Comments
 (0)