Skip to content

Commit aed88e5

Browse files
committed
Lights out
1 parent 40c2ff4 commit aed88e5

4 files changed

Lines changed: 325 additions & 0 deletions

File tree

config/lightsout.ini

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[base]
2+
env_name = lightsout
3+
4+
[env]
5+
max_steps = 100
6+
7+
[train]
8+
total_timesteps = 200_000_000

ocean/lightsout/binding.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include "lightsout.h"
2+
3+
#define GRID_SIZE 5
4+
#define OBS_SIZE (GRID_SIZE * GRID_SIZE)
5+
#define NUM_ATNS 1
6+
#define ACT_SIZES {GRID_SIZE * GRID_SIZE}
7+
#define OBS_TENSOR_T ByteTensor
8+
9+
#define Env LightsOut
10+
#include "vecenv.h"
11+
12+
void my_init(Env* env, Dict* kwargs) {
13+
env->grid_size = GRID_SIZE;
14+
env->cell_size = 1280 / GRID_SIZE;
15+
if (1280 % GRID_SIZE != 0) env->cell_size++; // ceil
16+
env->max_steps = (int)dict_get(kwargs, "max_steps")->value;
17+
env->observation_size = OBS_SIZE;
18+
env->num_agents = 1;
19+
20+
env->ema = 0.5f;
21+
env->score_ema = 0.0f;
22+
env->scramble_prob = 0.15f;
23+
24+
init_lightsout(env);
25+
}
26+
27+
void my_log(Log* log, Dict* out) {
28+
dict_set(out, "perf", log->perf);
29+
dict_set(out, "score", log->score);
30+
dict_set(out, "episode_return", log->episode_return);
31+
dict_set(out, "episode_length", log->episode_length);
32+
dict_set(out, "scramble_p", log->scramble_p);
33+
}

ocean/lightsout/lightsout.c

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include <stdio.h>
2+
#include <time.h>
3+
#include "lightsout.h"
4+
5+
static LightsOut* g_env = NULL;
6+
7+
static void demo_cleanup(void) {
8+
if (g_env == NULL) {
9+
return;
10+
}
11+
free(g_env->observations);
12+
free(g_env->actions);
13+
free(g_env->rewards);
14+
free(g_env->terminals);
15+
c_close(g_env);
16+
g_env = NULL;
17+
}
18+
19+
int demo(){
20+
srand((unsigned)time(NULL));
21+
LightsOut env = {.grid_size = 5, .cell_size = 100, .client = NULL};
22+
g_env = &env;
23+
atexit(demo_cleanup);
24+
env.observations = (unsigned char*)calloc(env.grid_size * env.grid_size, sizeof(unsigned char));
25+
env.actions = (float*)calloc(1, sizeof(float));
26+
env.rewards = (float*)calloc(1, sizeof(float));
27+
env.terminals = (float*)calloc(1, sizeof(float));
28+
29+
c_reset(&env);
30+
env.client = make_client(env.cell_size, env.grid_size);
31+
32+
while (!WindowShouldClose()) {
33+
if (IsKeyPressed(KEY_UP) || IsKeyPressed(KEY_W)) env.client->cursor_row = (env.client->cursor_row - 1 + env.grid_size) % env.grid_size;
34+
if (IsKeyPressed(KEY_DOWN) || IsKeyPressed(KEY_S)) env.client->cursor_row = (env.client->cursor_row + 1) % env.grid_size;
35+
if (IsKeyPressed(KEY_LEFT) || IsKeyPressed(KEY_A)) env.client->cursor_col = (env.client->cursor_col - 1 + env.grid_size) % env.grid_size;
36+
if (IsKeyPressed(KEY_RIGHT) || IsKeyPressed(KEY_D)) env.client->cursor_col = (env.client->cursor_col + 1) % env.grid_size;
37+
if (IsKeyPressed(KEY_SPACE)) {
38+
int idx = env.client->cursor_row * env.grid_size + env.client->cursor_col;
39+
env.actions[0] = (float)idx;
40+
c_step(&env);
41+
} else if (IsKeyPressed(KEY_R)) {
42+
c_reset(&env);
43+
}
44+
c_render(&env);
45+
}
46+
47+
demo_cleanup();
48+
return 0;
49+
}
50+
int main(void) {
51+
demo();
52+
return 0;
53+
}

ocean/lightsout/lightsout.h

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
#include <stdlib.h>
2+
#include <math.h>
3+
#include <string.h>
4+
#include "raylib.h"
5+
6+
// Only use floats.
7+
typedef struct {
8+
float perf;
9+
float score;
10+
float episode_return;
11+
float episode_length;
12+
float scramble_p;
13+
float n; // Required as the last field.
14+
} Log;
15+
16+
typedef struct Client {
17+
int cell_size;
18+
int cursor_row;
19+
int cursor_col;
20+
} Client;
21+
22+
typedef struct {
23+
Log log; // Required field.
24+
unsigned char* observations; // Required field. Ensure type matches in .py and .c.
25+
float* actions; // Required field. Ensure type matches in .py and .c.
26+
float* rewards; // Required field.
27+
float* terminals; // Required field.
28+
int grid_size;
29+
int cell_size;
30+
int max_steps;
31+
int step_count;
32+
int lights_on;
33+
int prev_action;
34+
int last_action;
35+
float episode_return;
36+
float ema;
37+
float score_ema;
38+
float scramble_prob;
39+
unsigned char* grid;
40+
Client* client;
41+
int num_agents;
42+
int observation_size;
43+
unsigned int rng;
44+
} LightsOut;
45+
46+
void step_grid(LightsOut* env, int idx) {
47+
if (idx < 0 || idx >= env->grid_size * env->grid_size) return;
48+
int row = idx/env->grid_size;
49+
int col = idx%env->grid_size;
50+
51+
static const int dirs[5][2] = {{0,0}, {1,0}, {0,1}, {-1,0}, {0,-1}};
52+
for (int i = 0; i < 5; i++) {
53+
int dr = dirs[i][0];
54+
int dc = dirs[i][1];
55+
int r = row + dr;
56+
int c = col + dc;
57+
if (r >= 0 && r < env->grid_size && c >= 0 && c < env->grid_size) {
58+
int offset = r*env->grid_size + c;
59+
unsigned char old = env->grid[offset];
60+
env->grid[offset] = (unsigned char)!old;
61+
env->lights_on += old ? -1 : 1;
62+
}
63+
}
64+
}
65+
66+
void init_lightsout(LightsOut* env) {
67+
int n = env->grid_size * env->grid_size;
68+
if (env->grid == NULL) {
69+
env->grid = (unsigned char*)calloc(n, sizeof(unsigned char));
70+
} else {
71+
memset(env->grid, 0, n * sizeof(unsigned char));
72+
}
73+
74+
if (env->ema > 0.7f && env->score_ema > 0.0f) {
75+
env->scramble_prob = fminf(0.5f, env->scramble_prob + 0.01f); // Increase scramble prob if EMA is high
76+
} else if (env->ema < 0.3f) {
77+
env->scramble_prob = fmaxf(0.15f, env->scramble_prob - 0.01f); // Decrease scramble prob if EMA is low
78+
}
79+
80+
env->step_count = 0;
81+
env->lights_on = 0;
82+
env->prev_action = -1;
83+
env->last_action = -1;
84+
env->episode_return = 0.0f;
85+
86+
for (int i = 0; i < n; i++) {
87+
float u = (float)rand_r(&env->rng) / (float)RAND_MAX;
88+
if (u < env->scramble_prob) {
89+
step_grid(env, i);
90+
}
91+
}
92+
}
93+
94+
void c_close(LightsOut* env) {
95+
free(env->grid);
96+
env->grid = NULL;
97+
if (env->client != NULL) {
98+
if (IsWindowReady()) {
99+
CloseWindow();
100+
}
101+
free(env->client);
102+
env->client = NULL;
103+
}
104+
}
105+
106+
void compute_observations(LightsOut* env) {
107+
for (int i = 0; i < env->grid_size * env->grid_size; i++) {
108+
env->observations[i] = env->grid[i];
109+
}
110+
}
111+
112+
void c_reset(LightsOut* env) {
113+
env->rewards[0] = 0.0f;
114+
env->terminals[0] = 0.0f;
115+
init_lightsout(env);
116+
compute_observations(env);
117+
}
118+
119+
void c_step(LightsOut* env) {
120+
int num_cells = env->grid_size * env->grid_size;
121+
int atn = env->actions[0];
122+
env->terminals[0] = 0.0f;
123+
124+
float reward = -0.02 * (36.0 / (env->grid_size * env->grid_size)); // Base step penalty.
125+
int prev_on = env->lights_on;
126+
if (atn < 0 || atn >= num_cells) {
127+
reward -= 0.5f; // Invalid action penalty.
128+
} else {
129+
if (atn == env->last_action) {
130+
reward -= 0.03f; // Penalty for pressing the same cell twice in a row.
131+
} else if (atn == env->prev_action) {
132+
reward -= 0.02f; // Penalty for 2-step loop (A,B,A).
133+
}
134+
if (env->client != NULL) {
135+
env->client->cursor_row = atn / env->grid_size;
136+
env->client->cursor_col = atn % env->grid_size;
137+
}
138+
step_grid(env, atn);
139+
env->prev_action = env->last_action;
140+
env->last_action = atn;
141+
int next_on = env->lights_on;
142+
reward += 0.005f * (float)(prev_on - next_on); // Dense shaping: improve when lights decrease.
143+
}
144+
env->step_count += 1;
145+
146+
if (env->lights_on == 0) {
147+
reward = 2.0f; // Solved reward.
148+
env->ema = 0.85f * env->ema + 0.15f; // Update EMA of steps to solve.
149+
env->terminals[0] = 1.0f;
150+
} else if (env->client == NULL && env->step_count >= env->max_steps) {
151+
reward -= 0.5f; // Timeout penalty during training.
152+
env->ema = 0.85f * env->ema; // Decay EMA since we failed to solve.
153+
env->terminals[0] = 1.0f;
154+
}
155+
156+
env->rewards[0] = reward;
157+
env->episode_return += reward;
158+
159+
if (env->terminals[0] > 0.0f) {
160+
env->log.episode_return += env->episode_return;
161+
env->log.episode_length += (float)env->step_count;
162+
env->log.n += 1.0f;
163+
env->log.perf += (env->lights_on == 0) ? 1.0f : 0.0f;
164+
env->log.score += env->episode_return;
165+
env->log.scramble_p += env->scramble_prob;
166+
167+
env->score_ema = 0.9f * env->score_ema + 0.1f * env->episode_return;
168+
init_lightsout(env);
169+
}
170+
171+
compute_observations(env);
172+
}
173+
174+
// Raylib client
175+
static const Color COLORS[] = {
176+
(Color){6, 24, 24, 255},
177+
(Color){0, 0, 255, 255},
178+
(Color){255, 255, 255, 255}
179+
};
180+
181+
Client* make_client(int cell_size, int grid_size) {
182+
Client* client= (Client*)malloc(sizeof(Client));
183+
client->cell_size = cell_size;
184+
client->cursor_row = 0;
185+
client->cursor_col = 0;
186+
InitWindow(grid_size*cell_size, grid_size*cell_size, "PufferLib LightsOut");
187+
SetTargetFPS(5);
188+
return client;
189+
}
190+
191+
void c_render(LightsOut* env) {
192+
if (IsWindowReady() && (WindowShouldClose() || IsKeyPressed(KEY_ESCAPE))) {
193+
c_close(env);
194+
exit(0);
195+
}
196+
197+
if (env->client == NULL) {
198+
env->client = make_client(env->cell_size, env->grid_size);
199+
}
200+
201+
Client* client = env->client;
202+
203+
BeginDrawing();
204+
ClearBackground(COLORS[0]);
205+
int sz = client->cell_size;
206+
for (int y = 0; y < env->grid_size; y++) {
207+
for (int x = 0; x < env->grid_size; x++){
208+
int tile = env->grid[y*env->grid_size + x];
209+
if (tile != 0)
210+
DrawRectangle(x*sz, y*sz, sz, sz, COLORS[tile]);
211+
}
212+
}
213+
DrawRectangleLinesEx(
214+
(Rectangle){client->cursor_col * sz, client->cursor_row * sz, sz, sz},
215+
3.0f,
216+
COLORS[2]
217+
);
218+
219+
if (env->terminals[0] > 0.0f) {
220+
const char* msg = "Solved";
221+
int font_size = 48;
222+
int text_w = MeasureText(msg, font_size);
223+
int screen_w = env->grid_size * env->cell_size;
224+
int screen_h = env->grid_size * env->cell_size;
225+
226+
DrawRectangle(0, 0, screen_w, screen_h, (Color){0, 0, 0, 120}); // dim overlay
227+
DrawText(msg, (screen_w - text_w) / 2, (screen_h - font_size) / 2, font_size, RAYWHITE);
228+
}
229+
230+
EndDrawing();
231+
}

0 commit comments

Comments
 (0)