From 40fc193636ae4c58916fb9ffb0340ae085890822 Mon Sep 17 00:00:00 2001 From: Spencer Date: Sun, 5 Apr 2026 00:30:40 +0000 Subject: [PATCH] go --- ocean/go/binding.c | 15 +- ocean/go/go.h | 501 +++++++++++++++++++++++++++++++-------------- 2 files changed, 359 insertions(+), 157 deletions(-) diff --git a/ocean/go/binding.c b/ocean/go/binding.c index 6566abdd57..c63516794f 100644 --- a/ocean/go/binding.c +++ b/ocean/go/binding.c @@ -1,25 +1,26 @@ #include "go.h" -#define OBS_SIZE 100 +// 9x9 - obs 326, act 82 +// 13x13 - obs 678, act 170 +// 19x19 - obs 1446, act 362 +#define OBS_SIZE 326 #define NUM_ATNS 1 -#define ACT_SIZES {50} -#define OBS_TYPE FLOAT -#define ACT_TYPE DOUBLE +#define ACT_SIZES {82} +#define OBS_TENSOR_T FloatTensor #define Env CGo #include "vecenv.h" void my_init(Env* env, Dict* kwargs) { env->num_agents = 1; + env->side = (rand_r(&env->rng) % 2) + 1; + env->selfplay = dict_get(kwargs, "selfplay")->value; env->width = dict_get(kwargs, "width")->value; env->height = dict_get(kwargs, "height")->value; env->grid_size = dict_get(kwargs, "grid_size")->value; env->board_width = dict_get(kwargs, "board_width")->value; env->board_height = dict_get(kwargs, "board_height")->value; env->grid_square_size = dict_get(kwargs, "grid_square_size")->value; - env->moves_made = dict_get(kwargs, "moves_made")->value; env->komi = dict_get(kwargs, "komi")->value; - env->score = dict_get(kwargs, "score")->value; - env->last_capture_position = dict_get(kwargs, "last_capture_position")->value; env->reward_move_pass = dict_get(kwargs, "reward_move_pass")->value; env->reward_move_invalid = dict_get(kwargs, "reward_move_invalid")->value; env->reward_move_valid = dict_get(kwargs, "reward_move_valid")->value; diff --git a/ocean/go/go.h b/ocean/go/go.h index 7a02ccf74b..2088d54bc6 100644 --- a/ocean/go/go.h +++ b/ocean/go/go.h @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -11,6 +12,7 @@ #define NUM_DIRECTIONS 4 #define ENV_WIN -1 #define PLAYER_WIN 1 +#define MAX_CHANGED_PER_MOVE 362 static const int DIRECTIONS[NUM_DIRECTIONS][2] = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; // LD_LIBRARY_PATH=raylib/lib ./go @@ -21,6 +23,11 @@ struct Log { float episode_return; float episode_length; float n; + float illegal_move_count; + float legal_move_count; + float pass_move_count; + float white_wins; + float black_wins; }; typedef struct Group Group; @@ -64,12 +71,12 @@ typedef struct CGo CGo; struct CGo { Client* client; float* observations; - double* actions; + float* actions; float* rewards; float* terminals; - int num_agents; Log log; float score; + int num_agents; int width; int height; int* board_x; @@ -78,14 +85,14 @@ struct CGo { int board_height; int grid_square_size; int grid_size; - int* board_states; - int* previous_board_state; + uint8_t* board_states; + uint8_t* previous_board_state; int last_capture_position; - int* temp_board_states; int moves_made; int* capture_count; float komi; - int* visited; + uint8_t* visited; + uint8_t current_version; Group* groups; Group* temp_groups; float reward_move_pass; @@ -94,25 +101,56 @@ struct CGo { float reward_player_capture; float reward_opponent_capture; float tick; + int selfplay; + int turn; + int side; + int legal_move_count; + int illegal_move_count; + int pass_move_count; + int previous_move; + int human_play; + // undo stack + int changed_pos[MAX_CHANGED_PER_MOVE]; + uint8_t old_board_values[MAX_CHANGED_PER_MOVE]; + int changed_count; + int old_capture_count[2]; + float old_reward; + float old_episode_return; + unsigned int rng; }; void add_log(CGo* env) { env->log.episode_length += env->tick; // Calculate perf as a win rate (1.0 if win, 0.0 if loss) - float win_value = 0.0; - if (env->score > 0) { - win_value = 1.0; // Win - } - else if (env->score < 0) { - win_value = 0.0; // Loss - } - else { - win_value = 0.0; // Tie + float win_value = (env->score > 0) ? 1.0f : (env->score < 0) ? 0.0f : 0.5f; + float black_win = 0.0; + float white_win = 0.0; + if(env->score > 0){ + if(env->side == 1){ + black_win = 1.0; + } + else{ + white_win = 1.0; + } } - - env->log.perf = (env->log.perf * env->log.n + win_value) / (env->log.n + 1.0); - + else if (env->score < 0){ + if(env->side == 1){ + white_win = 1.0; + } + else{ + black_win = 1.0; + } + } else { + black_win = 0.5; + white_win = 0.5; + } + env->log.illegal_move_count += env->illegal_move_count; + env->log.legal_move_count += env->legal_move_count; + env->log.pass_move_count += env->pass_move_count; + env->log.perf += win_value; + env->log.black_wins += black_win; + env->log.white_wins += white_win; env->log.score += env->score; env->log.episode_return += env->rewards[0]; env->log.n += 1.0; @@ -141,10 +179,10 @@ void init(CGo* env) { int grid_size = env->grid_size*env->grid_size; env->board_x = (int*)calloc(board_render_size, sizeof(int)); env->board_y = (int*)calloc(board_render_size, sizeof(int)); - env->board_states = (int*)calloc(grid_size, sizeof(int)); - env->visited = (int*)calloc(grid_size, sizeof(int)); - env->previous_board_state = (int*)calloc(grid_size, sizeof(int)); - env->temp_board_states = (int*)calloc(grid_size, sizeof(int)); + env->board_states = (uint8_t*)calloc(grid_size, sizeof(uint8_t)); + env->visited = (uint8_t*)calloc(grid_size, sizeof(uint8_t)); + env->current_version = 1; + env->previous_board_state = (uint8_t*)calloc(grid_size, sizeof(uint8_t)); env->capture_count = (int*)calloc(2, sizeof(int)); env->groups = (Group*)calloc(grid_size, sizeof(Group)); env->temp_groups = (Group*)calloc(grid_size, sizeof(Group)); @@ -154,8 +192,13 @@ void init(CGo* env) { void allocate(CGo* env) { init(env); - env->observations = (float*)calloc((env->grid_size)*(env->grid_size)*2 + 2, sizeof(float)); - env->actions = (double*)calloc(1, sizeof(double)); + if(env->selfplay){ + env->observations = (float*)calloc(2*((env->grid_size)*(env->grid_size)*4 +2), sizeof(float)); + env->actions = (float*)calloc(2, sizeof(float)); + } else{ + env->observations = (float*)calloc((env->grid_size)*(env->grid_size)*4 +1, sizeof(float)); + env->actions = (float*)calloc(1, sizeof(float)); + } env->rewards = (float*)calloc(1, sizeof(float)); env->terminals = (float*)calloc(1, sizeof(float)); } @@ -166,7 +209,6 @@ void c_close(CGo* env) { free(env->board_states); free(env->visited); free(env->previous_board_state); - free(env->temp_board_states); free(env->capture_count); free(env->temp_groups); free(env->groups); @@ -180,38 +222,60 @@ void free_allocated(CGo* env) { c_close(env); } -void compute_observations(CGo* env) { - int observation_indx=0; - for (int i = 0; i < (env->grid_size)*(env->grid_size); i++) { - if(env->board_states[i] ==1 ){ - env->observations[observation_indx] = 1.0; - } - else { - env->observations[observation_indx] = 0.0; - } - observation_indx++; +static inline void increment_version(CGo* env) { + env->current_version++; + if (env->current_version == 0) { + memset(env->visited, 0, (env->grid_size) * (env->grid_size)); + env->current_version = 1; } - for (int i = 0; i < (env->grid_size)*(env->grid_size); i++) { - if(env->board_states[i] ==2 ){ - env->observations[observation_indx] = 1.0; - } - else { - env->observations[observation_indx] = 0.0; +} + +void compute_observations(CGo* env) { + int obs_len = env->grid_size * env->grid_size * 4 + 2; + int N = env->grid_size * env->grid_size; + int iterations = env->selfplay ? 2 : 1; + + for(int i = 0; i < iterations; i++){ + float* current_obs = env->observations + (i * obs_len); + + int self, opp; + if (i == 0) { + self = env->side; + opp = 3 - self; + } else { + // Flip perspective for selfplay + self = 3 - env->side; + opp = env->side; } - observation_indx++; - } - env->observations[observation_indx] = env->capture_count[0]; - env->observations[observation_indx+1] = env->capture_count[1]; + int turn = env->turn + 1 == self ? 1 : 0; + + // Memory Layout: [Current Self][Current Opp][Prev Self][Prev Opp] + float* plane_self = current_obs; + float* plane_opp = current_obs + N; + float* plane_prev_self = current_obs + (2 * N); + float* plane_prev_opp = current_obs + (3 * N); + + for (int idx = 0; idx < N; idx++) { + int val = env->board_states[idx]; + int prev_val = env->previous_board_state[idx]; + plane_self[idx] = (float)(val == self); + plane_opp[idx] = (float)(val == opp); + + plane_prev_self[idx] = (float)(prev_val == self); + plane_prev_opp[idx] = (float)(prev_val == opp); + } + + // Set the color bit at the very end + current_obs[4 * N] = (float)(self - 1); + current_obs[4 * N + 1] = (float)(turn); + } } int is_valid_position(CGo* env, int x, int y) { return (x >= 0 && x < env->grid_size && y >= 0 && y < env->grid_size); } -void reset_visited(CGo* env) { - memset(env->visited, 0, sizeof(int) * (env->grid_size) * (env->grid_size)); -} void flood_fill(CGo* env, int x, int y, int* territory, int player) { if (!is_valid_position(env, x, y)) { @@ -219,10 +283,10 @@ void flood_fill(CGo* env, int x, int y, int* territory, int player) { } int pos = y * (env->grid_size) + x; - if (env->visited[pos] || env->board_states[pos] != 0) { + if (env->visited[pos] == env->current_version || env->board_states[pos] != 0) { return; } - env->visited[pos] = 1; + env->visited[pos] = env->current_version; territory[player]++; // Check adjacent positions for (int i = 0; i < 4; i++) { @@ -233,17 +297,18 @@ void flood_fill(CGo* env, int x, int y, int* territory, int player) { void compute_score_tromp_taylor(CGo* env) { int player_score = 0; int opponent_score = 0; - reset_visited(env); - + int player = env->side; + int opponent = 3 - player; + increment_version(env); // Queue for BFS int queue_size = (env->grid_size) * (env->grid_size); int queue[queue_size]; // First count stones for (int i = 0; i < queue_size; i++) { - if (env->board_states[i] == 1) { + if (env->board_states[i] == player) { player_score++; - } else if (env->board_states[i] == 2) { + } else if (env->board_states[i] == opponent) { opponent_score++; } } @@ -251,7 +316,7 @@ void compute_score_tromp_taylor(CGo* env) { // Then process empty territories for (int start_pos = 0; start_pos < queue_size; start_pos++) { // Skip if not empty or already visited - if (env->board_states[start_pos] != 0 || env->visited[start_pos]) { + if (env->board_states[start_pos] != 0 || env->visited[start_pos] == env->current_version) { continue; } @@ -261,7 +326,7 @@ void compute_score_tromp_taylor(CGo* env) { int bordering_player = 0; // 0=neutral, 1=player1, 2=player2, 3=mixed queue[rear++] = start_pos; - env->visited[start_pos] = 1; + env->visited[start_pos] = env->current_version; // Process connected empty points while (front < rear) { @@ -280,29 +345,32 @@ void compute_score_tromp_taylor(CGo* env) { } int npos = ny * env->grid_size + nx; - - if (env->board_states[npos] == 0 && !env->visited[npos]) { + int neighbor_color = env->board_states[npos]; + if (neighbor_color ==0) { // Add unvisited empty points to queue - queue[rear++] = npos; - env->visited[npos] = 1; + if(env->visited[npos] != env->current_version) { + queue[rear++] = npos; + env->visited[npos] = env->current_version; + } } else if (bordering_player == 0) { - bordering_player = env->board_states[npos]; - } else if (bordering_player != env->board_states[npos]) { + bordering_player = neighbor_color; + } else if (bordering_player != neighbor_color) { bordering_player = 3; // Mixed territory } } } // Assign territory points - if (bordering_player == 1) { + if (bordering_player == player) { player_score += territory_size; - } else if (bordering_player == 2) { + } else if (bordering_player == opponent) { opponent_score += territory_size; } // Mixed territories (bordering_player == 3) are neutral and not counted } - - env->score = (float)player_score - (float)opponent_score - env->komi; + float komi = (env->side == 2) ? env->komi : -env->komi; + env->score = (float)player_score - (float)opponent_score + komi; + //printf("Score: %f\n", env->score); } int find_in_group(int* group, int group_size, int value) { @@ -315,26 +383,28 @@ int find_in_group(int* group, int group_size, int value) { } -void capture_group(CGo* env, int* board, int root, int* affected_groups, int* affected_count) { - // Reset visited array - reset_visited(env); - +void capture_group(CGo* env, uint8_t* board, int root, int* affected_groups, int* affected_count) { + increment_version(env); // Use a queue for BFS int queue_size = (env->grid_size) * (env->grid_size); int queue[queue_size]; int front = 0, rear = 0; int captured_player = board[root]; // Player whose stones are being captured + if (captured_player != 1 && captured_player !=2) return; int capturing_player = 3 - captured_player; // Player who captures queue[rear++] = root; - env->visited[root] = 1; + env->visited[root] = env->current_version; while (front != rear) { int pos = queue[front++]; + env->old_board_values[env->changed_count] = board[pos]; // captured_player + env->changed_pos[env->changed_count] = pos; + env->changed_count++; board[pos] = 0; // Remove stone env->capture_count[capturing_player - 1]++; // Update capturing player's count - if(capturing_player-1 == 0){ + if(capturing_player == env->side){ env->rewards[0] += env->reward_player_capture; env->log.episode_return += env->reward_player_capture; } else{ @@ -353,8 +423,8 @@ void capture_group(CGo* env, int* board, int root, int* affected_groups, int* af continue; } - if (board[npos] == captured_player && !env->visited[npos]) { - env->visited[npos] = 1; + if (board[npos] == captured_player && env->visited[npos]!=env->current_version) { + env->visited[npos] = env->current_version; queue[rear++] = npos; } else if (board[npos] == capturing_player) { @@ -370,14 +440,14 @@ void capture_group(CGo* env, int* board, int root, int* affected_groups, int* af } -int count_liberties(CGo* env, int root, int* queue) { - reset_visited(env); +int count_liberties(CGo* env, int root, int* queue, uint8_t* board) { + increment_version(env); int liberties = 0; int front = 0; int rear = 0; queue[rear++] = root; - env->visited[root] = 1; + env->visited[root] = env->current_version; while (front < rear) { int pos = queue[front++]; int x = pos % (env->grid_size); @@ -391,44 +461,46 @@ int count_liberties(CGo* env, int root, int* queue) { } int npos = ny * (env->grid_size) + nx; - if (env->visited[npos]) { + if (env->visited[npos]== env->current_version) { continue; } - int temp_npos = env->temp_board_states[npos]; + int temp_npos = board[npos]; if (temp_npos == 0) { liberties++; - env->visited[npos] = 1; - } else if (temp_npos == env->temp_board_states[root]) { + env->visited[npos] = env->current_version; + } else if (temp_npos == board[root]) { queue[rear++] = npos; - env->visited[npos] = 1; + env->visited[npos] = env->current_version; } } } return liberties; } -int is_ko(CGo* env) { - for (int i = 0; i < (env->grid_size) * (env->grid_size); i++) { - if (env->temp_board_states[i] != env->previous_board_state[i]) { - return 0; // Not a ko - } - } - return 1; // Is a ko -} - int make_move(CGo* env, int pos, int player){ int x = pos % (env->grid_size); int y = pos / (env->grid_size); // cannot place stone on occupied tile if (env->board_states[pos] != 0) { + if(player == env->side){ + env->illegal_move_count+=1; + } return 0 ; } + env->old_capture_count[0] = env->capture_count[0]; + env->old_capture_count[1] = env->capture_count[1]; + env->old_reward = env->rewards[0]; + env->old_episode_return = env->log.episode_return; + + env->changed_count = 0; + + env->old_board_values[env->changed_count] = env->board_states[pos]; + env->changed_pos[env->changed_count++] = pos; + env->board_states[pos] = player; // temp structures - memcpy(env->temp_board_states, env->board_states, sizeof(int) * (env->grid_size) * (env->grid_size)); memcpy(env->temp_groups, env->groups, sizeof(Group) * (env->grid_size) * (env->grid_size)); // create new group - env->temp_board_states[pos] = player; env->temp_groups[pos].parent = pos; env->temp_groups[pos].rank = 0; env->temp_groups[pos].size = 1; @@ -449,10 +521,10 @@ int make_move(CGo* env, int pos, int player){ if (!is_valid_position(env, nx, ny)) { continue; } - if (env->temp_board_states[npos] == player) { + if (env->board_states[npos] == player) { union_groups(env->temp_groups, pos, npos); affected_groups[affected_count++] = npos; - } else if (env->temp_board_states[npos] == 3 - player) { + } else if (env->board_states[npos] == 3 - player) { affected_groups[affected_count++] = npos; } } @@ -460,15 +532,15 @@ int make_move(CGo* env, int pos, int player){ // Recalculate liberties only for affected groups for (int i = 0; i < affected_count; i++) { int root = find(env->temp_groups, affected_groups[i]); - env->temp_groups[root].liberties = count_liberties(env, root, queue); + env->temp_groups[root].liberties = count_liberties(env, root, queue, env->board_states); } // Check for captures bool captured = false; for (int i = 0; i < affected_count; i++) { int root = find(env->temp_groups, affected_groups[i]); - if (env->temp_board_states[root] == 3 - player && env->temp_groups[root].liberties == 0) { - capture_group(env, env->temp_board_states, root, affected_groups, &affected_count); + if (env->board_states[root] == 3 - player && env->temp_groups[root].liberties == 0) { + capture_group(env, env->board_states, root, affected_groups, &affected_count); captured = true; } } @@ -476,26 +548,43 @@ int make_move(CGo* env, int pos, int player){ if (captured) { for (int i = 0; i < affected_count; i++) { int root = find(env->temp_groups, affected_groups[i]); - env->temp_groups[root].liberties = count_liberties(env, root, queue); + env->temp_groups[root].liberties = count_liberties(env, root, queue, env->board_states); } // Check for ko rule violation - if(is_ko(env)) { - return 0; - } } + // self capture int root = find(env->temp_groups, pos); if (env->temp_groups[root].liberties == 0) { - return 0; + goto rollback; + } + + if(captured && memcmp(env->board_states, env->previous_board_state, env->grid_size*env->grid_size*sizeof(uint8_t)) == 0){ + goto rollback; } - memcpy(env->board_states, env->temp_board_states, sizeof(int) * (env->grid_size) * (env->grid_size)); + memcpy(env->previous_board_state, env->board_states, sizeof(uint8_t) * (env->grid_size) * (env->grid_size)); memcpy(env->groups, env->temp_groups, sizeof(Group) * (env->grid_size) * (env->grid_size)); + for(int i = 0; i < env->changed_count; i++){ + env->previous_board_state[env->changed_pos[i]] = env->old_board_values[i]; + } return 1; +rollback: + for (int i = 0; i < env->changed_count; i++) { + env->board_states[env->changed_pos[i]] = env->old_board_values[i]; + } + env->capture_count[0] = env->old_capture_count[0]; + env->capture_count[1] = env->old_capture_count[1]; + env->rewards[0] = env->old_reward; + env->log.episode_return = env->old_episode_return; + + if (player == env->side) env->illegal_move_count++; + + return 0; } -void enemy_random_move(CGo* env){ +void enemy_random_move(CGo* env, int side){ int num_positions = (env->grid_size)*(env->grid_size); int positions[num_positions]; int count = 0; @@ -508,27 +597,29 @@ void enemy_random_move(CGo* env){ } // Shuffle the positions for(int i = count - 1; i > 0; i--){ - int j = rand() % (i + 1); + int j = rand_r(&env->rng) % (i + 1); int temp = positions[i]; positions[i] = positions[j]; positions[j] = temp; } // Try to make a move in a random empty position for(int i = 0; i < count; i++){ - if(make_move(env, positions[i], 2)){ + if(make_move(env, positions[i], side)){ + env->previous_move = positions[i] + 1; return; } } // If no move is possible, pass or end the game - env->terminals[0] = 1.0f; + env->previous_move = 0; + env->terminals[0] = 1; } int find_group_liberty(CGo* env, int root){ - reset_visited(env); + increment_version(env); int queue[(env->grid_size)*(env->grid_size)]; int front = 0, rear = 0; queue[rear++] = root; - env->visited[root] = 1; + env->visited[root] = env->current_version; while(front < rear){ int pos = queue[front++]; @@ -544,8 +635,8 @@ int find_group_liberty(CGo* env, int root){ } if(env->board_states[npos] == 0){ return npos; // Found a liberty - } else if(env->board_states[npos] == env->board_states[root] && !env->visited[npos]){ - env->visited[npos] = 1; + } else if(env->board_states[npos] == env->board_states[root] && env->visited[npos] != env->current_version){ + env->visited[npos] = env->current_version; queue[rear++] = npos; } } @@ -553,7 +644,9 @@ int find_group_liberty(CGo* env, int root){ return -1; // Should not happen if liberties > 0 } -void enemy_greedy_hard(CGo* env){ +void enemy_greedy_hard(CGo* env, int side){ + + int opp = 3 - side; // Attempt to capture opponent stones in atari int liberties[4][(env->grid_size) * (env->grid_size)]; int liberty_counts[4] = {0}; @@ -561,14 +654,14 @@ void enemy_greedy_hard(CGo* env){ if(env->board_states[i]==0){ continue; } - if (env->board_states[i]==1){ + if (env->board_states[i]==opp){ int root = find(env->groups, i); int group_liberties = env->groups[root].liberties; if (group_liberties >= 1 && group_liberties <= 4) { int liberty = find_group_liberty(env, root); liberties[group_liberties - 1][liberty_counts[group_liberties - 1]++] = liberty; } - } else if (env->board_states[i]==2){ + } else if (env->board_states[i]==side){ int root = find(env->groups, i); int group_liberties = env->groups[root].liberties; if (group_liberties==1) { @@ -580,16 +673,18 @@ void enemy_greedy_hard(CGo* env){ // make move to attack or defend for (int priority = 0; priority < 4; priority++) { for (int i = 0; i < liberty_counts[priority]; i++) { - if (make_move(env, liberties[priority][i], 2)) { + if (make_move(env, liberties[priority][i], side)) { + env->previous_move = liberties[priority][i]+1; return; } } } + // random move - enemy_random_move(env); + enemy_random_move(env, side); } -void enemy_greedy_easy(CGo* env){ +void enemy_greedy_easy(CGo* env, int side){ // Attempt to capture opponent stones in atari for(int i = 0; i < (env->grid_size)*(env->grid_size); i++){ if(env->board_states[i] != 1){ @@ -618,17 +713,20 @@ void enemy_greedy_easy(CGo* env){ } } // Play a random legal move - enemy_random_move(env); + enemy_random_move(env, side); } void c_reset(CGo* env) { env->tick = 0; + env->illegal_move_count = 0; + env->legal_move_count = 0; + env->pass_move_count = 0; + env->turn = 0; + env->previous_move = -1; // We don't reset the log struct - leave it accumulating like in Pong - env->terminals[0] = 0.0f; env->score = 0; for (int i = 0; i < (env->grid_size)*(env->grid_size); i++) { env->board_states[i] = 0; - env->temp_board_states[i] = 0; env->visited[i] = 0; env->previous_board_state[i] = 0; env->groups[i].parent = i; @@ -643,6 +741,15 @@ void c_reset(CGo* env) { compute_observations(env); } +void clip_rewards(CGo* env){ + if(env->rewards[0] > 1){ + env->rewards[0] = 1; + } + if(env->rewards[0] < -1){ + env->rewards[0] = -1; + } +} + void end_game(CGo* env){ compute_score_tromp_taylor(env); if (env->score > 0) { @@ -654,60 +761,139 @@ void end_game(CGo* env){ else { env->rewards[0] = 0.0; } + //env->rewards[0] = env->score / 10.0f; + clip_rewards(env); + env->terminals[0] = 1; add_log(env); c_reset(env); } +void human_play(CGo* env){ + int indx=1; + if(!env->selfplay || !env->human_play){ + return; + } + if(env->selfplay && env->turn + 1 != env->side){ + env->actions[indx] = -1; + } + if (IsMouseButtonPressed(MOUSE_LEFT_BUTTON)) { + Vector2 mousePos = GetMousePosition(); + + // Calculate the offset for the board + int boardOffsetX = env->grid_square_size; + int boardOffsetY = env->grid_square_size; + + // Adjust mouse position relative to the board + int relativeX = mousePos.x - boardOffsetX; + int relativeY = mousePos.y - boardOffsetY; + + // Calculate cell indices for the corners + int cellX = (relativeX + env->grid_square_size / 2) / env->grid_square_size; + int cellY = (relativeY + env->grid_square_size / 2) / env->grid_square_size; + + // Ensure the click is within the game board + if (cellX >= 0 && cellX <= env->grid_size && cellY >= 0 && cellY <= env->grid_size) { + // Calculate the point index (1-19) based on the click position + int pointIndex = cellY * (env->grid_size) + cellX + 1; + env->actions[indx] = (unsigned short)pointIndex; + } + // Check if pass button is clicked + int left = (env->grid_size + 1)*env->grid_square_size; + int top = env->grid_square_size; + int passButtonX = left; + int passButtonY = top + 90; + int passButtonWidth = 100; + int passButtonHeight = 50; + + if (mousePos.x >= passButtonX && mousePos.x <= passButtonX + passButtonWidth && + mousePos.y >= passButtonY && mousePos.y <= passButtonY + passButtonHeight) { + env->actions[indx] = 0; // Send action 0 for pass + } + } + +} + void c_step(CGo* env) { env->tick += 1; env->rewards[0] = 0.0; - int action = (int)env->actions[0]; + env->terminals[0] = 0; + int action = 0; + int bot_side = 3 - env->side; + int is_legal = 0; + if(env->human_play){ + human_play(env); + } + if(env->selfplay){ + action = (env->turn +1 == env->side) ? (int)env->actions[0] : (int)env->actions[1]; + } else { + action = (int)env->actions[0]; + } + if(action == -1){ + compute_observations(env); + return; + } // useful for training , can prob be a hyper param. Recommend to increase with larger board size float max_moves = 3 * env->grid_size * env->grid_size; - if (env->tick > max_moves) { - env->terminals[0] = 1.0f; + if (env->tick > max_moves && !env->human_play) { + env->terminals[0] = 1; end_game(env); compute_observations(env); return; } - if(action == NOOP){ - env->rewards[0] = env->reward_move_pass; - env->log.episode_return += env->reward_move_pass; - enemy_greedy_hard(env); + // play against bots + if(!env->selfplay && env->turn == (bot_side - 1)){ + enemy_greedy_hard(env, bot_side); if (env->terminals[0] == 1) { end_game(env); + } + compute_observations(env); + clip_rewards(env); + env->turn = (env->turn + 1) % 2; + return; + } + // process action + if(action == NOOP){ + if(env->turn + 1 == env->side){ + //printf("Pass\n"); + env->legal_move_count +=1; + env->rewards[0] = env->reward_move_pass; + env->log.episode_return += env->reward_move_pass; + env->pass_move_count += 1; + } + if (env->terminals[0] == 1 || env->previous_move == NOOP) { + end_game(env); return; } + env->previous_move = NOOP; + env->turn = (env->turn+1)%2; compute_observations(env); return; } if (action >= MOVE_MIN && action <= (env->grid_size)*(env->grid_size)) { - memcpy(env->previous_board_state, env->board_states, sizeof(int) * (env->grid_size) * (env->grid_size)); - if(make_move(env, action-1, 1)) { + is_legal = make_move(env, action - 1, env->turn + 1); + if(is_legal) { env->moves_made++; - env->rewards[0] = env->reward_move_valid; - env->log.episode_return += env->reward_move_valid; - enemy_greedy_hard(env); - + if(env->turn + 1 == env->side){ + env->legal_move_count +=1; + env->rewards[0] += env->reward_move_valid; + env->log.episode_return += env->reward_move_valid; + } } else { - env->rewards[0] = env->reward_move_invalid; - env->log.episode_return += env->reward_move_invalid; + if(env->turn + 1 == env->side){ + env->rewards[0] = env->reward_move_invalid; + env->log.episode_return += env->reward_move_invalid; + } } - compute_observations(env); - } - - if(env->rewards[0] > 1){ - env->rewards[0] = 1; - } - if(env->rewards[0] < -1){ - env->rewards[0] = -1; } + env->previous_move = action; if (env->terminals[0] == 1) { end_game(env); return; } - + if(is_legal){ + env->turn = (env->turn + 1) % 2; + } compute_observations(env); } @@ -728,10 +914,11 @@ Client* make_client(int width, int height) { client->width = width; client->height = height; InitWindow(width, height, "PufferLib Ray Go"); - SetTargetFPS(60); + SetTargetFPS(10); return client; } + void c_render(CGo* env) { if (env->client == NULL) { env->client = make_client(env->width, env->height); @@ -793,14 +980,28 @@ void c_render(CGo* env) { int top = env->grid_square_size; DrawRectangle(left, top + 90, 100, 50, GRAY); DrawText("Pass", left + 25, top + 105, 20, PUFF_WHITE); - - // show capture count for both players DrawText( - TextFormat("Player 1 Capture Count: %d", env->capture_count[0]), + TextFormat("Tick: %d", (int)env->tick), + left, top + 150, 20, PUFF_WHITE + ); + if(env->side == 1){ + DrawText( + TextFormat("Agent: black"), + left, top + 170, 20, PUFF_WHITE + ); + } + else { + DrawText( + TextFormat("Agent: white"), + left, top + 170, 20, PUFF_WHITE + ); + } + DrawText( + TextFormat("Black Capture Count: %d", env->capture_count[0]), left, top, 20, PUFF_WHITE ); DrawText( - TextFormat("Player 2 Capture Count: %d", env->capture_count[1]), + TextFormat("White Capture Count: %d", env->capture_count[1]), left, top + 40, 20, PUFF_WHITE ); EndDrawing();