Skip to content

Commit 8973613

Browse files
authored
Merge pull request #518 from l1onh3art88/4.0
4.0 -rware
2 parents c00352e + bae4041 commit 8973613

4 files changed

Lines changed: 59 additions & 38 deletions

File tree

config/go.ini

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
[base]
22
env_name = go
3+
policy_name = MinGRU
4+
rnn_name = Recurrent
5+
6+
[vec]
7+
total_agents = 4096
8+
9+
[policy]
10+
num_layers = 2
11+
hidden_size = 64
312

413
[env]
514
width = 950
@@ -17,9 +26,10 @@ reward_move_valid = 0
1726
reward_move_invalid = -0.5393516480382454
1827
reward_opponent_capture = -0.3152783593705354
1928
reward_player_capture = 0.42122681325442923
29+
selfplay = 0
2030

2131
[train]
22-
total_timesteps = 100_000_000
32+
total_timesteps = 500_000_000
2333
adam_beta1 = 0.5686370767889766
2434
adam_beta2 = 0.9999454817221638
2535
adam_eps = 2.007252656207671e-12
@@ -41,9 +51,9 @@ vtrace_rho_clip = 4.060318960532289
4151
[sweep.train.total_timesteps]
4252
distribution = log_normal
4353
min = 1e8
44-
max = 5e8
45-
mean = 2e8
46-
scale = 0.25
54+
max = 1e9
55+
mean = 3e8
56+
scale = auto
4757

4858
[sweep.env.reward_move_invalid]
4959
distribution = uniform

config/rware.ini

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
[base]
22
env_name = rware
3+
policy_name = MinGRU
4+
rnn_name = Recurrent
35

46
[vec]
5-
num_envs = 4
7+
total_agents = 4096
68

79
[env]
810
num_envs = 256

ocean/rware/binding.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
#define OBS_SIZE 27
33
#define NUM_ATNS 1
44
#define ACT_SIZES {5}
5-
#define OBS_TYPE FLOAT
6-
#define ACT_TYPE DOUBLE
5+
#define OBS_TENSOR_T FloatTensor
76

87
#define Env CRware
98
#include "vecenv.h"

ocean/rware/rware.h

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ struct MovementGraph {
151151
struct CRware {
152152
Client* client;
153153
float* observations;
154-
double* actions;
154+
float* actions;
155155
float* rewards;
156156
float* terminals;
157157
Log* agent_logs;
@@ -172,6 +172,7 @@ struct CRware {
172172
int grid_square_size;
173173
int* original_shelve_locations;
174174
MovementGraph* movement_graph;
175+
unsigned int rng;
175176
};
176177

177178
void add_log(CRware* env, Log* agent_log) {
@@ -197,7 +198,7 @@ void place_agent(CRware* env, int agent_idx) {
197198

198199
int found_valid_position = 0;
199200
while (!found_valid_position) {
200-
int random_pos = rand() % map_size;
201+
int random_pos = rand_r(&env->rng) % map_size;
201202

202203
// Skip if position is not empty
203204
if (env->warehouse_states[random_pos] != EMPTY) {
@@ -213,7 +214,7 @@ void place_agent(CRware* env, int agent_idx) {
213214
// Position is valid, place the agent
214215
env->old_agent_locations[agent_idx] = random_pos;
215216
env->agent_locations[agent_idx] = random_pos;
216-
env->agent_directions[agent_idx] = rand() % 4;
217+
env->agent_directions[agent_idx] = rand_r(&env->rng) % 4;
217218
env->agent_states[agent_idx] = 0;
218219
found_valid_position = 1;
219220
}
@@ -232,7 +233,7 @@ int request_new_shelf(CRware* env) {
232233
total_shelves = 144;
233234
shelf_locations = medium_shelf_locations;
234235
}
235-
int random_index = rand() % total_shelves;
236+
int random_index = rand_r(&env->rng) % total_shelves;
236237
int shelf_location = shelf_locations[random_index];
237238
if (env->warehouse_states[shelf_location] == SHELF ) {
238239
env->warehouse_states[shelf_location] = REQUESTED_SHELF;
@@ -296,7 +297,7 @@ void init(CRware* env) {
296297
void allocate(CRware* env) {
297298
init(env);
298299
env->observations = (float*)calloc(env->num_agents*(SELF_OBS+VISION_OBS), sizeof(float));
299-
env->actions = (double*)calloc(env->num_agents, sizeof(double));
300+
env->actions = (float*)calloc(env->num_agents, sizeof(float));
300301
env->rewards = (float*)calloc(env->num_agents, sizeof(float));
301302
env->terminals = (float*)calloc(env->num_agents, sizeof(float));
302303
}
@@ -529,26 +530,14 @@ void calculate_weights(CRware* env) {
529530
}
530531
}
531532

532-
void update_movement_graph(CRware* env, int agent_idx) {
533+
void reset_movement_graph(CRware* env) {
533534
MovementGraph* graph = env->movement_graph;
534-
int new_position = get_new_position(env, agent_idx);
535-
if (new_position == -1) {
536-
return;
537-
}
538-
graph->target_positions[agent_idx] = new_position;
539-
540-
// reset cycle and weights
541535
for (int i = 0; i < env->num_agents; i++) {
536+
graph->target_positions[i] = -1;
542537
graph->cycle_ids[i] = -1;
543538
graph->weights[i] = 0;
544539
}
545540
graph->num_cycles = 0;
546-
547-
// detect cycles with Floyd algorithm
548-
detect_cycles(env);
549-
550-
// calculate weights for tree
551-
calculate_weights(env);
552541
}
553542

554543
void move_agent(CRware* env, int agent_idx) {
@@ -585,7 +574,6 @@ void move_agent(CRware* env, int agent_idx) {
585574
env->warehouse_states[new_position] = SHELF;
586575
}
587576
env->agent_locations[agent_idx] = new_position;
588-
env->movement_graph->target_positions[agent_idx] = -1;
589577
}
590578

591579
void pickup_shelf(CRware* env, int agent_idx) {
@@ -620,10 +608,24 @@ void pickup_shelf(CRware* env, int agent_idx) {
620608
env->rewards[agent_idx] = 0.5;
621609
env->agent_logs[agent_idx].episode_return += 0.5;
622610
env->agent_logs[agent_idx].score = 1.0;
611+
// Try random selection first, then fall back to linear scan to avoid infinite loop
612+
// when all shelves are currently being carried (warehouse_states == EMPTY).
613+
int total_shelves;
614+
const int* shelf_locations;
615+
if (env->map_choice == 1) { total_shelves = 32; shelf_locations = tiny_shelf_locations; }
616+
else if (env->map_choice == 2) { total_shelves = 80; shelf_locations = small_shelf_locations; }
617+
else { total_shelves = 144; shelf_locations = medium_shelf_locations; }
623618
int shelf_count = 0;
624-
while (shelf_count < 1) {
619+
for (int attempt = 0; attempt < total_shelves && !shelf_count; attempt++) {
625620
shelf_count += request_new_shelf(env);
626621
}
622+
if (shelf_count) return;
623+
for (int i = 0; i < total_shelves && !shelf_count; i++) {
624+
if (env->warehouse_states[shelf_locations[i]] == SHELF) {
625+
env->warehouse_states[shelf_locations[i]] = REQUESTED_SHELF;
626+
shelf_count = 1;
627+
}
628+
}
627629
}
628630
}
629631

@@ -684,30 +686,38 @@ void process_tree_movements(CRware* env, MovementGraph* graph) {
684686
void c_step(CRware* env) {
685687
memset(env->rewards, 0, env->num_agents * sizeof(float));
686688
MovementGraph* graph = env->movement_graph;
689+
690+
// Reset movement graph so stale targets from previous steps don't
691+
// create phantom cycles in Floyd's detection or the weight propagation loop.
692+
reset_movement_graph(env);
693+
694+
int is_movement = 0;
687695
for (int i = 0; i < env->num_agents; i++) {
688696
env->old_agent_locations[i] = env->agent_locations[i];
689697
env->agent_logs[i].episode_length += 1;
690698
int action = (int)env->actions[i];
691-
692-
// Handle direction changes and non-movement actions
699+
693700
if (action != NOOP && action != TOGGLE_LOAD) {
694701
env->agent_directions[i] = get_direction(env, action, i);
695702
}
696703
if (action == TOGGLE_LOAD) {
697704
pickup_shelf(env, i);
698705
}
699706
if (action == FORWARD) {
700-
update_movement_graph(env, i);
707+
int new_pos = get_new_position(env, i);
708+
if (new_pos != -1) {
709+
graph->target_positions[i] = new_pos;
710+
is_movement++;
711+
}
701712
}
702713
}
703-
int is_movement=0;
704-
for(int i=0; i<env->num_agents; i++) {
705-
if ((int)env->actions[i] == FORWARD) is_movement++;
706-
}
707-
if (is_movement>=1) {
708-
// Process movements in cycles first
714+
715+
if (is_movement >= 1) {
716+
// Run cycle detection and weight calculation once with the complete graph.
717+
// Per-agent incremental updates caused stale intermediate state.
718+
detect_cycles(env);
719+
calculate_weights(env);
709720
process_cycle_movements(env, graph);
710-
// process tree movements
711721
process_tree_movements(env, graph);
712722
}
713723

0 commit comments

Comments
 (0)