Skip to content

Commit 3947f26

Browse files
authored
fix: Event consumer should stop on input_required (#566)
# Description Event consumer should stop on input_required task as fixed in the python version a2aproject/a2a-python#167
1 parent 00e6093 commit 3947f26

2 files changed

Lines changed: 70 additions & 62 deletions

File tree

server-common/src/main/java/io/a2a/server/events/EventConsumer.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io.a2a.spec.Event;
77
import io.a2a.spec.Message;
88
import io.a2a.spec.Task;
9+
import io.a2a.spec.TaskState;
910
import io.a2a.spec.TaskStatusUpdateEvent;
1011
import mutiny.zero.BackpressureStrategy;
1112
import mutiny.zero.TubeConfiguration;
@@ -77,7 +78,7 @@ public Flow.Publisher<EventQueueItem> consumeAll() {
7778
} else if (event instanceof Message) {
7879
isFinalEvent = true;
7980
} else if (event instanceof Task task) {
80-
isFinalEvent = task.status().state().isFinal();
81+
isFinalEvent = isStreamTerminatingTask(task);
8182
} else if (event instanceof QueueClosedEvent) {
8283
// Poison pill event - signals queue closure from remote node
8384
// Do NOT send to subscribers - just close the queue
@@ -94,7 +95,7 @@ public Flow.Publisher<EventQueueItem> consumeAll() {
9495
}
9596

9697
if (isFinalEvent) {
97-
LOGGER.debug("Final event detected, closing queue and breaking loop for queue {}", System.identityHashCode(queue));
98+
LOGGER.debug("Final or interrupted event detected, closing queue and breaking loop for queue {}", System.identityHashCode(queue));
9899
queue.close();
99100
LOGGER.debug("Queue closed, breaking loop for queue {}", System.identityHashCode(queue));
100101
break;
@@ -120,6 +121,21 @@ public Flow.Publisher<EventQueueItem> consumeAll() {
120121
});
121122
}
122123

124+
/**
125+
* Determines if a task is in a state for terminating the stream.
126+
* <p>A task is terminating if:</p>
127+
* <ul>
128+
* <li>Its state is final (e.g., completed, canceled, rejected, failed), OR</li>
129+
* <li>Its state is interrupted (e.g., input-required)</li>
130+
* </ul>
131+
* @param task the task to check
132+
* @return true if the task has a final state or an interrupted state, false otherwise
133+
*/
134+
private boolean isStreamTerminatingTask(Task task) {
135+
TaskState state = task.status().state();
136+
return state.isFinal() || state == TaskState.INPUT_REQUIRED;
137+
}
138+
123139
public EnhancedRunnable.DoneCallback createAgentRunnableDoneCallback() {
124140
return agentRunnable -> {
125141
if (agentRunnable.getError() != null) {

server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java

Lines changed: 52 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -114,32 +114,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException {
114114
final List<Event> receivedEvents = new ArrayList<>();
115115
final AtomicReference<Throwable> error = new AtomicReference<>();
116116

117-
publisher.subscribe(new Flow.Subscriber<>() {
118-
private Flow.Subscription subscription;
119-
120-
@Override
121-
public void onSubscribe(Flow.Subscription subscription) {
122-
this.subscription = subscription;
123-
subscription.request(1);
124-
}
125-
126-
@Override
127-
public void onNext(EventQueueItem item) {
128-
receivedEvents.add(item.getEvent());
129-
subscription.request(1);
130-
131-
}
132-
133-
@Override
134-
public void onError(Throwable throwable) {
135-
error.set(throwable);
136-
}
137-
138-
@Override
139-
public void onComplete() {
140-
subscription.cancel();
141-
}
142-
});
117+
publisher.subscribe(getSubscriber(receivedEvents, error));
143118

144119
assertNull(error.get());
145120
assertEquals(events.size(), receivedEvents.size());
@@ -175,32 +150,7 @@ public void testConsumeUntilMessage() throws Exception {
175150
final List<Event> receivedEvents = new ArrayList<>();
176151
final AtomicReference<Throwable> error = new AtomicReference<>();
177152

178-
publisher.subscribe(new Flow.Subscriber<>() {
179-
private Flow.Subscription subscription;
180-
181-
@Override
182-
public void onSubscribe(Flow.Subscription subscription) {
183-
this.subscription = subscription;
184-
subscription.request(1);
185-
}
186-
187-
@Override
188-
public void onNext(EventQueueItem item) {
189-
receivedEvents.add(item.getEvent());
190-
subscription.request(1);
191-
192-
}
193-
194-
@Override
195-
public void onError(Throwable throwable) {
196-
error.set(throwable);
197-
}
198-
199-
@Override
200-
public void onComplete() {
201-
subscription.cancel();
202-
}
203-
});
153+
publisher.subscribe(getSubscriber(receivedEvents, error));
204154

205155
assertNull(error.get());
206156
assertEquals(3, receivedEvents.size());
@@ -224,7 +174,55 @@ public void testConsumeMessageEvents() throws Exception {
224174
final List<Event> receivedEvents = new ArrayList<>();
225175
final AtomicReference<Throwable> error = new AtomicReference<>();
226176

227-
publisher.subscribe(new Flow.Subscriber<>() {
177+
publisher.subscribe(getSubscriber(receivedEvents, error));
178+
179+
assertNull(error.get());
180+
// The stream is closed after the first Message
181+
assertEquals(1, receivedEvents.size());
182+
assertSame(message, receivedEvents.get(0));
183+
}
184+
185+
@Test
186+
public void testConsumeTaskInputRequired() {
187+
Task task = Task.builder()
188+
.id("task-id")
189+
.contextId("task-context")
190+
.status(new TaskStatus(TaskState.INPUT_REQUIRED))
191+
.build();
192+
List<Event> events = List.of(
193+
task,
194+
TaskArtifactUpdateEvent.builder()
195+
.taskId("task-123")
196+
.contextId("session-xyz")
197+
.artifact(Artifact.builder()
198+
.artifactId("11")
199+
.parts(new TextPart("text"))
200+
.build())
201+
.build(),
202+
TaskStatusUpdateEvent.builder()
203+
.taskId("task-123")
204+
.contextId("session-xyz")
205+
.status(new TaskStatus(TaskState.WORKING))
206+
.isFinal(true)
207+
.build());
208+
for (Event event : events) {
209+
eventQueue.enqueueEvent(event);
210+
}
211+
212+
Flow.Publisher<EventQueueItem> publisher = eventConsumer.consumeAll();
213+
final List<Event> receivedEvents = new ArrayList<>();
214+
final AtomicReference<Throwable> error = new AtomicReference<>();
215+
216+
publisher.subscribe(getSubscriber(receivedEvents, error));
217+
218+
assertNull(error.get());
219+
// The stream is closed after the input_required task
220+
assertEquals(1, receivedEvents.size());
221+
assertSame(task, receivedEvents.get(0));
222+
}
223+
224+
private Flow.Subscriber<EventQueueItem> getSubscriber(List<Event> receivedEvents, AtomicReference<Throwable> error) {
225+
return new Flow.Subscriber<>() {
228226
private Flow.Subscription subscription;
229227

230228
@Override
@@ -237,7 +235,6 @@ public void onSubscribe(Flow.Subscription subscription) {
237235
public void onNext(EventQueueItem item) {
238236
receivedEvents.add(item.getEvent());
239237
subscription.request(1);
240-
241238
}
242239

243240
@Override
@@ -249,12 +246,7 @@ public void onError(Throwable throwable) {
249246
public void onComplete() {
250247
subscription.cancel();
251248
}
252-
});
253-
254-
assertNull(error.get());
255-
// The stream is closed after the first Message
256-
assertEquals(1, receivedEvents.size());
257-
assertSame(message, receivedEvents.get(0));
249+
};
258250
}
259251

260252
@Test

0 commit comments

Comments
 (0)