Skip to content

Commit c2ec25b

Browse files
committed
Refactored PlannerAgent to contrib/planners and started using a test jar
1 parent d7b8aa2 commit c2ec25b

10 files changed

Lines changed: 231 additions & 45 deletions

File tree

contrib/planners/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@
4141
</dependency>
4242

4343
<!-- Test dependencies -->
44+
<dependency>
45+
<groupId>com.google.adk</groupId>
46+
<artifactId>google-adk</artifactId>
47+
<version>${project.version}</version>
48+
<type>test-jar</type>
49+
<scope>test</scope>
50+
</dependency>
4451
<dependency>
4552
<groupId>org.junit.jupiter</groupId>
4653
<artifactId>junit-jupiter-api</artifactId>

core/src/main/java/com/google/adk/agents/Planner.java renamed to contrib/planners/src/main/java/com/google/adk/agents/Planner.java

File renamed without changes.

core/src/main/java/com/google/adk/agents/PlannerAction.java renamed to contrib/planners/src/main/java/com/google/adk/agents/PlannerAction.java

File renamed without changes.

core/src/main/java/com/google/adk/agents/PlannerAgent.java renamed to contrib/planners/src/main/java/com/google/adk/agents/PlannerAgent.java

File renamed without changes.

core/src/main/java/com/google/adk/agents/PlanningContext.java renamed to contrib/planners/src/main/java/com/google/adk/agents/PlanningContext.java

File renamed without changes.

contrib/planners/src/main/java/com/google/adk/planner/goap/DependencyGraphSearch.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818

1919
import com.google.common.collect.ImmutableList;
2020
import java.util.Collection;
21+
import java.util.HashMap;
2122
import java.util.HashSet;
23+
import java.util.LinkedHashMap;
2224
import java.util.LinkedHashSet;
25+
import java.util.List;
26+
import java.util.Map;
2327
import java.util.Set;
2428

2529
/**
@@ -56,6 +60,75 @@ public static ImmutableList<String> search(
5660
return ImmutableList.copyOf(executionOrder);
5761
}
5862

63+
/**
64+
* Groups agents into parallelizable execution levels.
65+
*
66+
* <p>Each group contains agents whose dependencies are all satisfied by agents in earlier groups
67+
* or by initial preconditions. Agents within the same group are independent and can run in
68+
* parallel.
69+
*
70+
* @param graph the dependency graph
71+
* @param metadata agent metadata used to compute dependency levels
72+
* @param preconditions state keys already available
73+
* @param goal the target output key
74+
* @return ordered list of agent groups; agents within each group can run in parallel
75+
* @throws IllegalStateException if a dependency cannot be resolved or a cycle is detected
76+
*/
77+
public static ImmutableList<ImmutableList<String>> searchGrouped(
78+
GoalOrientedSearchGraph graph,
79+
List<AgentMetadata> metadata,
80+
Collection<String> preconditions,
81+
String goal) {
82+
83+
ImmutableList<String> flatOrder = search(graph, preconditions, goal);
84+
85+
if (flatOrder.isEmpty()) {
86+
return ImmutableList.of();
87+
}
88+
89+
Map<String, AgentMetadata> agentToMeta = new HashMap<>();
90+
for (AgentMetadata m : metadata) {
91+
agentToMeta.put(m.agentName(), m);
92+
}
93+
94+
// Assign execution levels: level = 1 + max(level of dependency agents).
95+
// Agents at the same level have no mutual dependencies and can run in parallel.
96+
Set<String> preconSet = new HashSet<>(preconditions);
97+
Map<String, Integer> agentLevel = new LinkedHashMap<>();
98+
99+
for (String agentName : flatOrder) {
100+
AgentMetadata meta = agentToMeta.get(agentName);
101+
int maxDepLevel = -1;
102+
103+
for (String inputKey : meta.inputKeys()) {
104+
if (preconSet.contains(inputKey)) {
105+
continue;
106+
}
107+
String producerAgent = graph.getProducerAgent(inputKey);
108+
if (producerAgent != null && agentLevel.containsKey(producerAgent)) {
109+
maxDepLevel = Math.max(maxDepLevel, agentLevel.get(producerAgent));
110+
}
111+
}
112+
113+
agentLevel.put(agentName, maxDepLevel + 1);
114+
}
115+
116+
int maxLevel = agentLevel.values().stream().mapToInt(Integer::intValue).max().orElse(0);
117+
ImmutableList.Builder<ImmutableList<String>> groups = ImmutableList.builder();
118+
for (int level = 0; level <= maxLevel; level++) {
119+
final int l = level;
120+
ImmutableList<String> group =
121+
flatOrder.stream()
122+
.filter(name -> agentLevel.get(name) == l)
123+
.collect(ImmutableList.toImmutableList());
124+
if (!group.isEmpty()) {
125+
groups.add(group);
126+
}
127+
}
128+
129+
return groups.build();
130+
}
131+
59132
private static void resolve(
60133
GoalOrientedSearchGraph graph,
61134
String outputKey,

contrib/planners/src/main/java/com/google/adk/planner/goap/GoalOrientedPlanner.java

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import com.google.common.collect.ImmutableList;
2424
import io.reactivex.rxjava3.core.Single;
2525
import java.util.List;
26-
import java.util.concurrent.atomic.AtomicInteger;
2726
import org.slf4j.Logger;
2827
import org.slf4j.LoggerFactory;
2928

@@ -44,7 +43,8 @@
4443
* Agent D: inputs=["person", "horoscope"], output="writeup"
4544
* Goal: "writeup"
4645
*
47-
* Resolved path: A → B → C → D
46+
* Resolved groups: [A, B] → [C] → [D]
47+
* (A and B are independent and run in parallel)
4848
* </pre>
4949
*/
5050
public final class GoalOrientedPlanner implements Planner {
@@ -53,8 +53,9 @@ public final class GoalOrientedPlanner implements Planner {
5353

5454
private final String goal;
5555
private final List<AgentMetadata> metadata;
56-
private ImmutableList<BaseAgent> executionPath;
57-
private final AtomicInteger cursor = new AtomicInteger(0);
56+
// Mutable state — planners are used within a single reactive pipeline and are not thread-safe.
57+
private ImmutableList<ImmutableList<BaseAgent>> executionGroups;
58+
private int cursor;
5859

5960
public GoalOrientedPlanner(String goal, List<AgentMetadata> metadata) {
6061
this.goal = goal;
@@ -64,19 +65,23 @@ public GoalOrientedPlanner(String goal, List<AgentMetadata> metadata) {
6465
@Override
6566
public void init(PlanningContext context) {
6667
GoalOrientedSearchGraph graph = new GoalOrientedSearchGraph(metadata);
67-
ImmutableList<String> agentOrder =
68-
DependencyGraphSearch.search(graph, context.state().keySet(), goal);
68+
ImmutableList<ImmutableList<String>> agentGroups =
69+
DependencyGraphSearch.searchGrouped(graph, metadata, context.state().keySet(), goal);
6970

70-
logger.info("GoalOrientedPlanner resolved execution order: {}", agentOrder);
71+
logger.info("GoalOrientedPlanner resolved execution groups: {}", agentGroups);
7172

72-
executionPath =
73-
agentOrder.stream().map(context::findAgent).collect(ImmutableList.toImmutableList());
74-
cursor.set(0);
73+
executionGroups =
74+
agentGroups.stream()
75+
.map(
76+
group ->
77+
group.stream().map(context::findAgent).collect(ImmutableList.toImmutableList()))
78+
.collect(ImmutableList.toImmutableList());
79+
cursor = 0;
7580
}
7681

7782
@Override
7883
public Single<PlannerAction> firstAction(PlanningContext context) {
79-
cursor.set(0);
84+
cursor = 0;
8085
return selectNext();
8186
}
8287

@@ -86,10 +91,10 @@ public Single<PlannerAction> nextAction(PlanningContext context) {
8691
}
8792

8893
private Single<PlannerAction> selectNext() {
89-
int idx = cursor.getAndIncrement();
90-
if (executionPath == null || idx >= executionPath.size()) {
94+
if (executionGroups == null || cursor >= executionGroups.size()) {
9195
return Single.just(new PlannerAction.Done());
9296
}
93-
return Single.just(new PlannerAction.RunAgents(executionPath.get(idx)));
97+
ImmutableList<BaseAgent> group = executionGroups.get(cursor++);
98+
return Single.just(new PlannerAction.RunAgents(group));
9499
}
95100
}

contrib/planners/src/main/java/com/google/adk/planner/goap/GoalOrientedSearchGraph.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
* <li>Each output key maps to the input keys (dependencies) required to produce it
2929
* </ul>
3030
*
31-
* <p>Used by {@link DependencyGraphSearch} for A* search.
31+
* <p>Used by {@link DependencyGraphSearch} for backward-chaining dependency resolution.
3232
*/
3333
public final class GoalOrientedSearchGraph {
3434

core/src/test/java/com/google/adk/agents/PlannerAgentTest.java renamed to contrib/planners/src/test/java/com/google/adk/agents/PlannerAgentTest.java

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@
1616

1717
package com.google.adk.agents;
1818

19-
import static com.google.adk.testing.TestUtils.createEvent;
20-
import static com.google.adk.testing.TestUtils.createInvocationContext;
21-
import static com.google.adk.testing.TestUtils.createSubAgent;
2219
import static com.google.common.truth.Truth.assertThat;
2320

2421
import com.google.adk.events.Event;
2522
import com.google.adk.events.EventActions;
2623
import com.google.adk.testing.TestBaseAgent;
24+
import com.google.adk.testing.TestUtils;
2725
import com.google.common.collect.ImmutableList;
2826
import io.reactivex.rxjava3.core.Flowable;
2927
import io.reactivex.rxjava3.core.Single;
@@ -39,7 +37,7 @@ public final class PlannerAgentTest {
3937

4038
@Test
4139
public void runAsync_withDone_stopsImmediately() {
42-
TestBaseAgent subAgent = createSubAgent("sub", createEvent("e1"));
40+
TestBaseAgent subAgent = TestUtils.createSubAgent("sub", TestUtils.createEvent("e1"));
4341
Planner donePlanner =
4442
new Planner() {
4543
@Override
@@ -56,15 +54,15 @@ public Single<PlannerAction> nextAction(PlanningContext context) {
5654
PlannerAgent agent =
5755
PlannerAgent.builder().name("planner").subAgents(subAgent).planner(donePlanner).build();
5856

59-
InvocationContext ctx = createInvocationContext(agent);
57+
InvocationContext ctx = TestUtils.createInvocationContext(agent);
6058
List<Event> events = agent.runAsync(ctx).toList().blockingGet();
6159

6260
assertThat(events).isEmpty();
6361
}
6462

6563
@Test
6664
public void runAsync_withDoneWithResult_emitsResultEvent() {
67-
TestBaseAgent subAgent = createSubAgent("sub");
65+
TestBaseAgent subAgent = TestUtils.createSubAgent("sub");
6866
Planner resultPlanner =
6967
new Planner() {
7068
@Override
@@ -81,7 +79,7 @@ public Single<PlannerAction> nextAction(PlanningContext context) {
8179
PlannerAgent agent =
8280
PlannerAgent.builder().name("planner").subAgents(subAgent).planner(resultPlanner).build();
8381

84-
InvocationContext ctx = createInvocationContext(agent);
82+
InvocationContext ctx = TestUtils.createInvocationContext(agent);
8583
List<Event> events = agent.runAsync(ctx).toList().blockingGet();
8684

8785
assertThat(events).hasSize(1);
@@ -90,8 +88,8 @@ public Single<PlannerAction> nextAction(PlanningContext context) {
9088

9189
@Test
9290
public void runAsync_withNoOp_skipsAndContinues() {
93-
Event event1 = createEvent("e1");
94-
TestBaseAgent subAgent = createSubAgent("sub", event1);
91+
Event event1 = TestUtils.createEvent("e1");
92+
TestBaseAgent subAgent = TestUtils.createSubAgent("sub", event1);
9593

9694
AtomicInteger callCount = new AtomicInteger(0);
9795
Planner noOpThenRunPlanner =
@@ -118,15 +116,16 @@ public Single<PlannerAction> nextAction(PlanningContext context) {
118116
.planner(noOpThenRunPlanner)
119117
.build();
120118

121-
InvocationContext ctx = createInvocationContext(agent);
119+
InvocationContext ctx = TestUtils.createInvocationContext(agent);
122120
List<Event> events = agent.runAsync(ctx).toList().blockingGet();
123121

124122
assertThat(events).containsExactly(event1);
125123
}
126124

127125
@Test
128126
public void runAsync_withMaxIterations_stopsAtLimit() {
129-
TestBaseAgent subAgent = createSubAgent("sub", () -> Flowable.just(createEvent("e")));
127+
TestBaseAgent subAgent =
128+
TestUtils.createSubAgent("sub", () -> Flowable.just(TestUtils.createEvent("e")));
130129

131130
Planner alwaysRunPlanner =
132131
new Planner() {
@@ -149,7 +148,7 @@ public Single<PlannerAction> nextAction(PlanningContext context) {
149148
.maxIterations(3)
150149
.build();
151150

152-
InvocationContext ctx = createInvocationContext(agent);
151+
InvocationContext ctx = TestUtils.createInvocationContext(agent);
153152
List<Event> events = agent.runAsync(ctx).toList().blockingGet();
154153

155154
// 3 iterations: first + 2 next calls, each producing 1 event
@@ -158,12 +157,12 @@ public Single<PlannerAction> nextAction(PlanningContext context) {
158157

159158
@Test
160159
public void runAsync_sequentialPlannerPattern() {
161-
Event event1 = createEvent("e1");
162-
Event event2 = createEvent("e2");
163-
Event event3 = createEvent("e3");
164-
TestBaseAgent agentA = createSubAgent("agentA", event1);
165-
TestBaseAgent agentB = createSubAgent("agentB", event2);
166-
TestBaseAgent agentC = createSubAgent("agentC", event3);
160+
Event event1 = TestUtils.createEvent("e1");
161+
Event event2 = TestUtils.createEvent("e2");
162+
Event event3 = TestUtils.createEvent("e3");
163+
TestBaseAgent agentA = TestUtils.createSubAgent("agentA", event1);
164+
TestBaseAgent agentB = TestUtils.createSubAgent("agentB", event2);
165+
TestBaseAgent agentC = TestUtils.createSubAgent("agentC", event3);
167166

168167
AtomicInteger cursor = new AtomicInteger(0);
169168
ImmutableList<String> order = ImmutableList.of("agentA", "agentB", "agentC");
@@ -196,18 +195,18 @@ private Single<PlannerAction> selectNext(PlanningContext context) {
196195
.planner(seqPlanner)
197196
.build();
198197

199-
InvocationContext ctx = createInvocationContext(agent);
198+
InvocationContext ctx = TestUtils.createInvocationContext(agent);
200199
List<Event> events = agent.runAsync(ctx).toList().blockingGet();
201200

202201
assertThat(events).containsExactly(event1, event2, event3).inOrder();
203202
}
204203

205204
@Test
206205
public void runAsync_withParallelRunAgents_runsMultipleAgents() {
207-
Event event1 = createEvent("e1");
208-
Event event2 = createEvent("e2");
209-
TestBaseAgent agentA = createSubAgent("agentA", event1);
210-
TestBaseAgent agentB = createSubAgent("agentB", event2);
206+
Event event1 = TestUtils.createEvent("e1");
207+
Event event2 = TestUtils.createEvent("e2");
208+
TestBaseAgent agentA = TestUtils.createSubAgent("agentA", event1);
209+
TestBaseAgent agentB = TestUtils.createSubAgent("agentB", event2);
211210

212211
Planner parallelPlanner =
213212
new Planner() {
@@ -229,7 +228,7 @@ public Single<PlannerAction> nextAction(PlanningContext context) {
229228
.planner(parallelPlanner)
230229
.build();
231230

232-
InvocationContext ctx = createInvocationContext(agent);
231+
InvocationContext ctx = TestUtils.createInvocationContext(agent);
233232
List<Event> events = agent.runAsync(ctx).toList().blockingGet();
234233

235234
assertThat(events).containsExactly(event1, event2);
@@ -257,23 +256,23 @@ public Single<PlannerAction> nextAction(PlanningContext context) {
257256
.planner(planner)
258257
.build();
259258

260-
InvocationContext ctx = createInvocationContext(agent);
259+
InvocationContext ctx = TestUtils.createInvocationContext(agent);
261260
List<Event> events = agent.runAsync(ctx).toList().blockingGet();
262261

263262
assertThat(events).isEmpty();
264263
}
265264

266265
@Test(expected = IllegalStateException.class)
267266
public void builder_withoutPlanner_throwsIllegalState() {
268-
TestBaseAgent subAgent = createSubAgent("sub");
267+
TestBaseAgent subAgent = TestUtils.createSubAgent("sub");
269268
PlannerAgent.builder().name("planner").subAgents(subAgent).build();
270269
}
271270

272271
@Test
273272
public void runAsync_stateIsSharedAcrossAgents() {
274273
// Agent A writes to state, Agent B reads from state
275274
Event eventA =
276-
createEvent("eA").toBuilder()
275+
TestUtils.createEvent("eA").toBuilder()
277276
.actions(
278277
EventActions.builder()
279278
.stateDelta(
@@ -282,8 +281,8 @@ public void runAsync_stateIsSharedAcrossAgents() {
282281
.build())
283282
.build();
284283

285-
TestBaseAgent agentA = createSubAgent("agentA", eventA);
286-
TestBaseAgent agentB = createSubAgent("agentB", createEvent("eB"));
284+
TestBaseAgent agentA = TestUtils.createSubAgent("agentA", eventA);
285+
TestBaseAgent agentB = TestUtils.createSubAgent("agentB", TestUtils.createEvent("eB"));
287286

288287
AtomicInteger cursor = new AtomicInteger(0);
289288
Planner seqPlanner =
@@ -314,7 +313,7 @@ public Single<PlannerAction> nextAction(PlanningContext context) {
314313
.planner(seqPlanner)
315314
.build();
316315

317-
InvocationContext ctx = createInvocationContext(agent);
316+
InvocationContext ctx = TestUtils.createInvocationContext(agent);
318317
List<Event> events = agent.runAsync(ctx).toList().blockingGet();
319318

320319
// Both events should be emitted

0 commit comments

Comments
 (0)