|
1 | 1 | #include "go.h" |
2 | | -#define OBS_SIZE 100 |
| 2 | +// 9x9 - obs 326, act 82 |
| 3 | +// 13x13 - obs 678, act 170 |
| 4 | +// 19x19 - obs 1446, act 362 |
| 5 | +#define OBS_SIZE 326 |
3 | 6 | #define NUM_ATNS 1 |
4 | | -#define ACT_SIZES {50} |
5 | | -#define OBS_TYPE FLOAT |
6 | | -#define ACT_TYPE DOUBLE |
| 7 | +#define ACT_SIZES {82} |
| 8 | +#define OBS_TENSOR_T FloatTensor |
7 | 9 |
|
8 | 10 | #define Env CGo |
9 | 11 | #include "vecenv.h" |
10 | 12 |
|
11 | 13 | void my_init(Env* env, Dict* kwargs) { |
12 | 14 | env->num_agents = 1; |
| 15 | + env->side = (rand_r(&env->rng) % 2) + 1; |
| 16 | + env->selfplay = dict_get(kwargs, "selfplay")->value; |
13 | 17 | env->width = dict_get(kwargs, "width")->value; |
14 | 18 | env->height = dict_get(kwargs, "height")->value; |
15 | 19 | env->grid_size = dict_get(kwargs, "grid_size")->value; |
16 | 20 | env->board_width = dict_get(kwargs, "board_width")->value; |
17 | 21 | env->board_height = dict_get(kwargs, "board_height")->value; |
18 | 22 | env->grid_square_size = dict_get(kwargs, "grid_square_size")->value; |
19 | | - env->moves_made = dict_get(kwargs, "moves_made")->value; |
20 | 23 | env->komi = dict_get(kwargs, "komi")->value; |
21 | | - env->score = dict_get(kwargs, "score")->value; |
22 | | - env->last_capture_position = dict_get(kwargs, "last_capture_position")->value; |
23 | 24 | env->reward_move_pass = dict_get(kwargs, "reward_move_pass")->value; |
24 | 25 | env->reward_move_invalid = dict_get(kwargs, "reward_move_invalid")->value; |
25 | 26 | env->reward_move_valid = dict_get(kwargs, "reward_move_valid")->value; |
|
0 commit comments