Skip to content

Commit d452013

Browse files
authored
Merge pull request #16 from NESTLab/feat/gsp-b-broadcast
feat(agent): add GSP-B (full-broadcast) variant
2 parents ceaa2f6 + 2b1f2f2 commit d452013

3 files changed

Lines changed: 203 additions & 11 deletions

File tree

rl_code/Main.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
'recurrent': config['RECURRENT'],
112112
'attention': config['ATTENTION'],
113113
'neighbors': config['NEIGHBORS'],
114+
'broadcast': config.get('BROADCAST', False),
114115
'gsp_input_size':config['GSP_INPUT_SIZE'],
115116
'gsp_output_size':config['GSP_OUTPUT_SIZE'],
116117
'gsp_look_back':config['GSP_LOOK_BACK'],
@@ -362,6 +363,11 @@
362363
if model.gsp_neighbors:
363364
agent_gsp_states = model.make_gsp_states(agent_prox_flags, old_heading_gsp)
364365
ctde_gsp = model.choose_agent_gsp(agent_gsp_states, test_mode)
366+
elif model.gsp_broadcast:
367+
# GSP-B: per-agent self-centric view with full-broadcast
368+
# [self_prox, self_prev_gsp, other_i_prox, other_i_prev_gsp, ...]
369+
agent_gsp_states = model.make_gsp_states_broadcast(agent_prox_flags, old_heading_gsp)
370+
ctde_gsp = model.choose_agent_gsp(agent_gsp_states, test_mode)
365371
else:
366372
ctde_gsp = model.choose_agent_gsp(agent_prox_flags, test_mode)
367373
for i in range(Utility.params['num_robots']):
@@ -377,20 +383,27 @@
377383
states, state_prox_flags = model.make_gsp_states(old_agent_prox_flags, neighbors_old_heading_gsp, True)
378384
new_states = model.make_gsp_states(agent_prox_flags, old_heading_gsp)
379385
for i in range(Utility.params['num_robots']):
380-
# print(f'[AGENT] {i} PROX FLAGS:', state_prox_flags[i])
381-
# only store if state has value
382386
if np.sum(state_prox_flags[i]) > 0:
383-
# print(f'[AGENT] {i} Has Value, Storing GSP State: {states[i]}')
384387
if model.gsp_networks['learning_scheme'] == 'attention':
385388
model.store_gsp_transition(states[i], label, 0, 0, 0)
386389
else:
387-
# Under the direct-MSE GSP training path, the 2nd arg
388-
# (action field) carries the supervised target label.
389-
# See GSP-RL fix/gsp-direct-mse-training PR #24 and
390-
# Stelaris docs/research/2026-04-13-gsp-information-collapse-analysis.md.
390+
# 2nd arg = label (supervised target for direct-MSE GSP training)
391391
state = states[i]
392392
new_state = new_states[i]
393393
model.store_gsp_transition(state, label, 0, new_state, 0)
394+
elif model.gsp_broadcast:
395+
# GSP-B per-agent storage with broadcast inputs.
396+
# state_t : broadcast view at previous step (uses neighbors_old_heading_gsp so
397+
# the prev_gsp slot reflects the prediction from the previous tick)
398+
# state_{t+1}: broadcast view at current step
399+
states = model.make_gsp_states_broadcast(old_agent_prox_flags, neighbors_old_heading_gsp)
400+
new_states = model.make_gsp_states_broadcast(agent_prox_flags, old_heading_gsp)
401+
for i in range(Utility.params['num_robots']):
402+
# Gate on self-prox being non-zero so we only store informative transitions,
403+
# matching the GSP and GSP-N branches. Self-prox lives at index 0 under the
404+
# self-first layout.
405+
if states[i][0] != 0:
406+
model.store_gsp_transition(states[i], label, 0, new_states[i], 0)
394407
else:
395408
for i in range(Utility.params['num_robots']):
396409
if model.gsp_networks['learning_scheme'] == 'attention':

rl_code/src/agent.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,25 @@ def __init__(
3232
gsp_min_max_action: float,
3333
gsp_look_back: int,
3434
gsp_sequence_length: int,
35+
broadcast: bool = False,
3536
prox_filter_angle_deg: float = 45.0,
3637
n_hop_neighbors: int = 1,
3738
):
39+
if neighbors and broadcast:
40+
raise ValueError(
41+
"GSP variants neighbors=True and broadcast=True are mutually exclusive — "
42+
"they overload gsp_input_size differently. Pick one."
43+
)
3844
if neighbors:
3945
# 2 inputs from ownship (prev_gsp, avg_prox)
4046
# 2 inputs from each neighbor (prev_gsp, avg_prox)
4147
# 2*n_hop_neighbors for symmetry in both CW and CCW
42-
gsp_input_size = 2+2*(n_hop_neighbors*2)
48+
gsp_input_size = 2+2*(n_hop_neighbors*2)
49+
if broadcast:
50+
# GSP-B: each agent's view is (self_prox, self_prev_gsp) + (other_prox, other_prev_gsp)
51+
# for all (n_agents - 1) other agents. Total 2*n_agents. Known limitation:
52+
# coupled to team size, not transferable across num_robots.
53+
gsp_input_size = 2 * n_agents
4354

4455
output_size = n_actions
4556
if network in ['DQN', 'DDQN']:
@@ -68,13 +79,16 @@ def __init__(
6879
self._network = network
6980
self._n_actions = n_actions
7081
self._neighbors = neighbors
82+
self._broadcast = broadcast
7183
self._n_hop_neighbors = n_hop_neighbors
7284
self.neighbors_dict = {}
7385
self._options_per_action = options_per_action
7486
self._prox_filter_angle_deg = prox_filter_angle_deg
7587

7688

77-
if self._neighbors:
89+
if self._neighbors or self._broadcast:
90+
# Per-agent observation ring buffers: GSP-N and GSP-B both produce
91+
# per-agent self-centric views, so each agent has its own history.
7892
self.gsp_observation = []
7993
for _ in range(self._n_agents):
8094
self.gsp_observation.append([[0 for _ in range(self.gsp_network_input)] for _ in range(self.gsp_sequence_length)])
@@ -98,6 +112,10 @@ def __init__(
98112
def gsp_neighbors(self):
99113
return self._neighbors
100114

115+
@property
116+
def gsp_broadcast(self):
117+
return self._broadcast
118+
101119
@property
102120
def n_agents(self):
103121
return self._n_agents
@@ -155,6 +173,40 @@ def make_agent_state(self, env_obs, heading_gsp=None, global_knowledge=None):
155173
env_obs = np.concatenate((env_obs, global_knowledge))
156174
return env_obs
157175

176+
def make_gsp_states_broadcast(self, agent_prox_values, agent_prev_gsp):
177+
"""Build per-agent GSP inputs for GSP-B (full-broadcast variant).
178+
179+
Each agent's view is self-first: [self_prox, self_prev_gsp, other_0_prox,
180+
other_0_prev_gsp, other_1_prox, other_1_prev_gsp, ..., other_{n-1}_prox,
181+
other_{n-1}_prev_gsp]. "other" iterates all agents in ascending id order,
182+
skipping self. Total length = 2 * n_agents.
183+
184+
Known limitation: the network input size is coupled to n_agents, so a
185+
trained GSP-B policy does not transfer to teams of different size. This
186+
is the tradeoff vs GSP-N, which uses fixed (self + n_hop_neighbors * 2)
187+
inputs and transfers across team sizes.
188+
"""
189+
states = []
190+
for agent in range(self._n_agents):
191+
agent_state = np.zeros(self.gsp_network_input)
192+
# Self first
193+
agent_state[0] = agent_prox_values[agent]
194+
agent_state[1] = agent_prev_gsp[agent]
195+
i = 2
196+
# Then every other agent in ascending id order, skipping self
197+
for other in range(self._n_agents):
198+
if other == agent:
199+
continue
200+
agent_state[i] = agent_prox_values[other]
201+
agent_state[i + 1] = agent_prev_gsp[other]
202+
i += 2
203+
# Maintain gsp_observation ring buffer the same way make_gsp_states does,
204+
# so recurrent/attention variants can still see sequences if added later.
205+
self.gsp_observation[agent].pop(0)
206+
self.gsp_observation[agent].append(agent_state)
207+
states.append(agent_state)
208+
return states
209+
158210
def make_gsp_states(self, agent_prox_values, agent_prev_gsp, return_prox_flags = False):
159211
states = []
160212
prox_flags = []
@@ -242,7 +294,11 @@ def choose_agent_action(self, observation, failures, test=False):
242294
return actions, action_num
243295

244296
def choose_agent_gsp(self, agent_gsp_states, test = False):
245-
if self._neighbors:
297+
if self._neighbors or self._broadcast:
298+
# Per-agent predictions with self-centric inputs. GSP-N (neighbors)
299+
# and GSP-B (broadcast) share the same per-agent forward-pass shape;
300+
# only the input vector differs. Non-recurrent broadcast uses the
301+
# same stateless path as non-recurrent neighbors.
246302
actions = []
247303
for i in range(self._n_agents):
248304
if self.recurrent_gsp:
@@ -257,7 +313,7 @@ def choose_agent_gsp(self, agent_gsp_states, test = False):
257313
)
258314
# Take the last timestep's action
259315
actions.append(action_tensor[-1].cpu().detach().numpy())
260-
else:
316+
else:
261317
actions.append(self.choose_action(agent_gsp_states[i], self.gsp_networks, test))
262318
return actions
263319
else:
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""Tests for GSP-B (full-broadcast variant) state construction.
2+
3+
GSP-B: each agent's input is [self_prox, self_prev_gsp, other_0_prox,
4+
other_0_prev_gsp, other_1_prox, other_1_prev_gsp, ..., other_{n-1}_prox,
5+
other_{n-1}_prev_gsp], length 2*n_agents. Self-first ordering.
6+
7+
Known limitation (inherited from plain GSP): the network input size is
8+
coupled to n_agents, so a trained GSP-B policy only transfers to the same
9+
team size. This is the tradeoff vs GSP-N's fixed (self + n_hop_neighbors)
10+
input which transfers across team sizes.
11+
"""
12+
13+
import numpy as np
14+
import pytest
15+
16+
from src.agent import Agent
17+
18+
19+
BASE_CONFIG = {
20+
"GAMMA": 0.99, "TAU": 0.005, "ALPHA": 0.001, "BETA": 0.002, "LR": 0.0001,
21+
"EPSILON": 0.0, "EPS_MIN": 0.0, "EPS_DEC": 0.0,
22+
"BATCH_SIZE": 16, "MEM_SIZE": 1000, "REPLACE_TARGET_COUNTER": 10,
23+
"NOISE": 0.0, "UPDATE_ACTOR_ITER": 1, "WARMUP": 0,
24+
"GSP_LEARNING_FREQUENCY": 1, "GSP_BATCH_SIZE": 16,
25+
}
26+
27+
28+
def make_agent(n_agents=4, network="DDQN", broadcast=True):
29+
return Agent(
30+
config=BASE_CONFIG,
31+
network=network,
32+
n_agents=n_agents,
33+
n_obs=8,
34+
n_actions=4,
35+
options_per_action=3,
36+
id=0,
37+
min_max_action=1.0,
38+
meta_param_size=1,
39+
gsp=True,
40+
recurrent=False,
41+
attention=False,
42+
neighbors=False,
43+
broadcast=broadcast,
44+
gsp_input_size=4, # overridden when broadcast=True
45+
gsp_output_size=1,
46+
gsp_min_max_action=1.0,
47+
gsp_look_back=2,
48+
gsp_sequence_length=5,
49+
)
50+
51+
52+
def test_broadcast_agent_has_gsp_broadcast_property_true():
53+
agent = make_agent()
54+
assert agent.gsp_broadcast is True
55+
56+
57+
def test_broadcast_agent_gsp_input_size_is_two_times_n_agents():
58+
"""For 4 agents, the broadcast input is [self_prox, self_prev_gsp, +3×(prox, prev_gsp)] = 8."""
59+
agent = make_agent(n_agents=4)
60+
assert agent.gsp_network_input == 8
61+
62+
63+
def test_broadcast_agent_gsp_input_size_scales_with_n_agents():
64+
"""For 8 agents, input is 16. Known limitation: coupled to team size."""
65+
agent = make_agent(n_agents=8)
66+
assert agent.gsp_network_input == 16
67+
68+
69+
def test_make_gsp_states_broadcast_returns_one_state_per_agent():
70+
agent = make_agent(n_agents=4)
71+
prox = [0.1, 0.2, 0.3, 0.4]
72+
prev_gsp = [-0.5, 0.0, 0.25, 0.75]
73+
states = agent.make_gsp_states_broadcast(prox, prev_gsp)
74+
assert len(states) == 4
75+
for s in states:
76+
assert len(s) == 8
77+
78+
79+
def test_make_gsp_states_broadcast_self_first_ordering():
80+
"""For each agent i, the first two entries must be (prox[i], prev_gsp[i])."""
81+
agent = make_agent(n_agents=4)
82+
prox = [0.11, 0.22, 0.33, 0.44]
83+
prev_gsp = [-0.1, -0.2, -0.3, -0.4]
84+
states = agent.make_gsp_states_broadcast(prox, prev_gsp)
85+
for i in range(4):
86+
assert states[i][0] == pytest.approx(prox[i]), f"agent {i} self_prox"
87+
assert states[i][1] == pytest.approx(prev_gsp[i]), f"agent {i} self_prev_gsp"
88+
89+
90+
def test_make_gsp_states_broadcast_others_in_order():
91+
"""After the self-pair, the remaining entries are other agents in ascending id order (skipping self)."""
92+
agent = make_agent(n_agents=4)
93+
prox = [0.10, 0.20, 0.30, 0.40]
94+
prev_gsp = [0.01, 0.02, 0.03, 0.04]
95+
states = agent.make_gsp_states_broadcast(prox, prev_gsp)
96+
# Agent 0: self=0, others=[1, 2, 3]
97+
assert list(states[0]) == pytest.approx([0.10, 0.01, 0.20, 0.02, 0.30, 0.03, 0.40, 0.04])
98+
# Agent 2: self=2, others=[0, 1, 3]
99+
assert list(states[2]) == pytest.approx([0.30, 0.03, 0.10, 0.01, 0.20, 0.02, 0.40, 0.04])
100+
# Agent 3: self=3, others=[0, 1, 2]
101+
assert list(states[3]) == pytest.approx([0.40, 0.04, 0.10, 0.01, 0.20, 0.02, 0.30, 0.03])
102+
103+
104+
def test_broadcast_is_mutually_exclusive_with_neighbors():
105+
"""Can't have both neighbors=True and broadcast=True; they overload gsp_input_size."""
106+
with pytest.raises((ValueError, AssertionError)):
107+
Agent(
108+
config=BASE_CONFIG,
109+
network="DDQN", n_agents=4, n_obs=8, n_actions=4,
110+
options_per_action=3, id=0, min_max_action=1.0, meta_param_size=1,
111+
gsp=True, recurrent=False, attention=False,
112+
neighbors=True, broadcast=True,
113+
gsp_input_size=4, gsp_output_size=1,
114+
gsp_min_max_action=1.0, gsp_look_back=2, gsp_sequence_length=5,
115+
)
116+
117+
118+
def test_plain_gsp_without_broadcast_unchanged():
119+
"""Plain GSP (neighbors=False, broadcast=False) keeps the legacy input size."""
120+
agent = make_agent(broadcast=False)
121+
# Should fall through to the config-provided gsp_input_size=4
122+
assert agent.gsp_network_input == 4
123+
assert agent.gsp_broadcast is False

0 commit comments

Comments
 (0)