Skip to content

Commit c50990c

Browse files
authored
fix: ListTasks validation and serialization across all transports (#460)
These fixes are needed for the TCK tests in a2aproject/a2a-tck#93 - Fix pageSize=0 validation using hasPageSize() instead of zeroToNull - Add includingDefaultValueFields() for complete JSON responses - Validate enum UNRECOGNIZED, negative timestamps across transports - Support multiple timestamp/status formats in REST - Ensure consistent InvalidParamsError behavior across JSON-RPC, gRPC, REST
1 parent 882f064 commit c50990c

11 files changed

Lines changed: 335 additions & 98 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pom.xml.releaseBackup
55
pom.xml.versionsBackup
66
release.properties
77
.flattened-pom.xml
8+
*.args
89

910
# Eclipse
1011
.project

extras/task-store-database-jpa/src/main/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,25 @@ public ListTasksResult list(ListTasksParams params) {
226226
LOGGER.debug("Listing tasks with params: contextId={}, status={}, pageSize={}, pageToken={}",
227227
params.contextId(), params.status(), params.pageSize(), params.pageToken());
228228

229+
// Parse pageToken once at the beginning
230+
Instant tokenTimestamp = null;
231+
String tokenId = null;
232+
if (params.pageToken() != null && !params.pageToken().isEmpty()) {
233+
String[] tokenParts = params.pageToken().split(":", 2);
234+
if (tokenParts.length == 2) {
235+
try {
236+
long timestampMillis = Long.parseLong(tokenParts[0]);
237+
tokenId = tokenParts[1];
238+
tokenTimestamp = Instant.ofEpochMilli(timestampMillis);
239+
} catch (NumberFormatException e) {
240+
throw new io.a2a.spec.InvalidParamsError(null,
241+
"Invalid pageToken format: timestamp must be numeric milliseconds", null);
242+
}
243+
} else {
244+
throw new io.a2a.spec.InvalidParamsError(null, "Invalid pageToken format: expected 'timestamp:id'", null);
245+
}
246+
}
247+
229248
// Build dynamic JPQL query with WHERE clauses for filtering
230249
StringBuilder queryBuilder = new StringBuilder("SELECT t FROM JpaTask t WHERE 1=1");
231250
StringBuilder countQueryBuilder = new StringBuilder("SELECT COUNT(t) FROM JpaTask t WHERE 1=1");
@@ -249,18 +268,9 @@ public ListTasksResult list(ListTasksParams params) {
249268
}
250269

251270
// Apply pagination cursor using keyset pagination for composite sort (timestamp DESC, id ASC)
252-
// PageToken format: "timestamp_millis:taskId" (e.g., "1699999999000:task-123")
253-
if (params.pageToken() != null && !params.pageToken().isEmpty()) {
254-
String[] tokenParts = params.pageToken().split(":", 2);
255-
if (tokenParts.length == 2) {
256-
// Keyset pagination: get tasks where timestamp < tokenTimestamp OR (timestamp = tokenTimestamp AND id > tokenId)
257-
// All tasks have timestamps (TaskStatus canonical constructor ensures this)
258-
queryBuilder.append(" AND (t.statusTimestamp < :tokenTimestamp OR (t.statusTimestamp = :tokenTimestamp AND t.id > :tokenId))");
259-
} else {
260-
// Legacy ID-only pageToken format is not supported with timestamp-based sorting
261-
// Throw error to prevent incorrect pagination results
262-
throw new io.a2a.spec.InvalidParamsError(null, "Invalid pageToken format: expected 'timestamp:id'", null);
263-
}
271+
if (tokenTimestamp != null) {
272+
// Keyset pagination: get tasks where timestamp < tokenTimestamp OR (timestamp = tokenTimestamp AND id > tokenId)
273+
queryBuilder.append(" AND (t.statusTimestamp < :tokenTimestamp OR (t.statusTimestamp = :tokenTimestamp AND t.id > :tokenId))");
264274
}
265275

266276
// Sort by status timestamp descending (most recent first), then by ID for stable ordering
@@ -279,25 +289,9 @@ public ListTasksResult list(ListTasksParams params) {
279289
if (params.lastUpdatedAfter() != null) {
280290
query.setParameter("lastUpdatedAfter", params.lastUpdatedAfter());
281291
}
282-
if (params.pageToken() != null && !params.pageToken().isEmpty()) {
283-
String[] tokenParts = params.pageToken().split(":", 2);
284-
if (tokenParts.length == 2) {
285-
// Parse keyset pagination parameters
286-
try {
287-
long timestampMillis = Long.parseLong(tokenParts[0]);
288-
String tokenId = tokenParts[1];
289-
290-
// All tasks have timestamps (TaskStatus canonical constructor ensures this)
291-
Instant tokenTimestamp = Instant.ofEpochMilli(timestampMillis);
292-
query.setParameter("tokenTimestamp", tokenTimestamp);
293-
query.setParameter("tokenId", tokenId);
294-
} catch (NumberFormatException e) {
295-
// Malformed timestamp in pageToken
296-
throw new io.a2a.spec.InvalidParamsError(null,
297-
"Invalid pageToken format: timestamp must be numeric milliseconds", null);
298-
}
299-
}
300-
// Note: Legacy ID-only format already rejected in query building phase
292+
if (tokenTimestamp != null) {
293+
query.setParameter("tokenTimestamp", tokenTimestamp);
294+
query.setParameter("tokenId", tokenId);
301295
}
302296

303297
// Apply page size limit (+1 to check for next page)
@@ -362,7 +356,10 @@ public ListTasksResult list(ListTasksParams params) {
362356
private Task transformTask(Task task, int historyLength, boolean includeArtifacts) {
363357
// Limit history if needed (keep most recent N messages)
364358
List<Message> history = task.history();
365-
if (historyLength > 0 && history != null && history.size() > historyLength) {
359+
if (historyLength == 0) {
360+
// When historyLength is 0, return empty history
361+
history = List.of();
362+
} else if (historyLength > 0 && history != null && history.size() > historyLength) {
366363
history = history.subList(history.size() - historyLength, history.size());
367364
}
368365

server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,24 @@ public ListTasksResult list(ListTasksParams params) {
9090
int mid = left + (right - left) / 2;
9191
Task task = allFilteredTasks.get(mid);
9292

93-
// All tasks have timestamps (TaskStatus canonical constructor ensures this)
94-
// Truncate to milliseconds for consistency with pageToken precision
95-
java.time.Instant taskTimestamp = task.status().timestamp().toInstant()
96-
.truncatedTo(java.time.temporal.ChronoUnit.MILLIS);
97-
int timestampCompare = taskTimestamp.compareTo(tokenTimestamp);
98-
99-
if (timestampCompare < 0 || (timestampCompare == 0 && task.id().compareTo(tokenId) > 0)) {
100-
// This task is after the token, search left half
101-
right = mid;
102-
} else {
103-
// This task is before or equal to token, search right half
93+
java.time.Instant taskTimestamp = (task.status() != null && task.status().timestamp() != null)
94+
? task.status().timestamp().toInstant().truncatedTo(java.time.temporal.ChronoUnit.MILLIS)
95+
: null;
96+
97+
if (taskTimestamp == null) {
98+
// Task with null timestamp is always "before" a token with a timestamp, as they are sorted last.
99+
// So, we search in the right half.
104100
left = mid + 1;
101+
} else {
102+
int timestampCompare = taskTimestamp.compareTo(tokenTimestamp);
103+
104+
if (timestampCompare < 0 || (timestampCompare == 0 && task.id().compareTo(tokenId) > 0)) {
105+
// This task is after the token, search left half
106+
right = mid;
107+
} else {
108+
// This task is before or equal to token, search right half
109+
left = mid + 1;
110+
}
105111
}
106112
}
107113
startIndex = left;
@@ -144,7 +150,10 @@ public ListTasksResult list(ListTasksParams params) {
144150
private Task transformTask(Task task, int historyLength, boolean includeArtifacts) {
145151
// Limit history if needed (keep most recent N messages)
146152
List<Message> history = task.history();
147-
if (historyLength > 0 && history != null && history.size() > historyLength) {
153+
if (historyLength == 0) {
154+
// When historyLength is 0, return empty history
155+
history = List.of();
156+
} else if (historyLength > 0 && history != null && history.size() > historyLength) {
148157
history = history.subList(history.size() - historyLength, history.size());
149158
}
150159

spec-grpc/src/main/java/io/a2a/grpc/mapper/A2ACommonFieldMapper.java

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,23 @@ default Map<String, Object> metadataFromProto(Struct struct) {
285285
*/
286286
@Named("zeroToNull")
287287
default Integer zeroToNull(int value) {
288-
return value > 0 ? value : null;
288+
return value != 0 ? value : null;
289+
}
290+
291+
/**
292+
* Converts protobuf int to Integer, preserving all values including 0.
293+
* <p>
294+
* Unlike zeroToNull, this method preserves 0 values, allowing compact constructor
295+
* validation to catch invalid values (e.g., pageSize=0 must fail validation).
296+
* For truly optional fields where 0 means "unset", use zeroToNull instead.
297+
* Use this with {@code @Mapping(qualifiedByName = "intToIntegerOrNull")}.
298+
*
299+
* @param value the protobuf int value
300+
* @return Integer (never null for primitive int input)
301+
*/
302+
@Named("intToIntegerOrNull")
303+
default Integer intToIntegerOrNull(int value) {
304+
return value;
289305
}
290306

291307
/**
@@ -337,35 +353,53 @@ default long instantToMillis(Instant instant) {
337353
* Converts protobuf milliseconds-since-epoch (int64) to domain Instant.
338354
* <p>
339355
* Returns null if input is 0 (protobuf default for unset field).
356+
* Throws InvalidParamsError for negative values (invalid timestamps).
340357
* Use this with {@code @Mapping(qualifiedByName = "millisToInstant")}.
341358
*
342359
* @param millis milliseconds since epoch
343360
* @return domain Instant, or null if millis is 0
361+
* @throws InvalidParamsError if millis is negative
344362
*/
345363
@Named("millisToInstant")
346364
default Instant millisToInstant(long millis) {
365+
if (millis < 0L) {
366+
throw new InvalidParamsError(null,
367+
"Timestamp must be a non-negative number of milliseconds since epoch, but got: " + millis,
368+
null);
369+
}
347370
return millis > 0L ? Instant.ofEpochMilli(millis) : null;
348371
}
349372

350373
// ========================================================================
351374
// Enum Conversions (handling UNSPECIFIED/UNKNOWN)
352375
// ========================================================================
353376
/**
354-
* Converts protobuf TaskState to domain TaskState, treating UNSPECIFIED/UNKNOWN as null.
377+
* Converts protobuf TaskState to domain TaskState, treating UNSPECIFIED as null.
355378
* <p>
356-
* Protobuf enums default to UNSPECIFIED (0 value) when unset. The domain may also have
357-
* UNKNOWN for unparseable values. Both should map to null for optional fields.
379+
* Protobuf enums default to UNSPECIFIED (0 value) when unset, which maps to null for optional fields.
380+
* However, UNRECOGNIZED (invalid enum values from JSON) throws InvalidParamsError for proper validation.
358381
* Use this with {@code @Mapping(qualifiedByName = "taskStateOrNull")}.
359382
*
360383
* @param state the protobuf TaskState
361-
* @return domain TaskState or null if UNSPECIFIED/UNKNOWN
384+
* @return domain TaskState or null if UNSPECIFIED
385+
* @throws InvalidParamsError if state is UNRECOGNIZED (invalid enum value)
362386
*/
363387
@Named("taskStateOrNull")
364388
default io.a2a.spec.TaskState taskStateOrNull(io.a2a.grpc.TaskState state) {
365389
if (state == null || state == io.a2a.grpc.TaskState.TASK_STATE_UNSPECIFIED) {
366390
return null;
367391
}
392+
// Reject invalid enum values (e.g., "INVALID_STATUS" from JSON)
393+
if (state == io.a2a.grpc.TaskState.UNRECOGNIZED) {
394+
String validStates = java.util.Arrays.stream(io.a2a.spec.TaskState.values())
395+
.filter(s -> s != io.a2a.spec.TaskState.UNKNOWN)
396+
.map(Enum::name)
397+
.collect(java.util.stream.Collectors.joining(", "));
398+
throw new InvalidParamsError(null,
399+
"Invalid task state value. Must be one of: " + validStates,
400+
null);
401+
}
368402
io.a2a.spec.TaskState result = TaskStateMapper.INSTANCE.fromProto(state);
369-
return result == io.a2a.spec.TaskState.UNKNOWN ? null : result;
403+
return result;
370404
}
371405
}

spec-grpc/src/main/java/io/a2a/grpc/mapper/ListTasksParamsMapper.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ public interface ListTasksParamsMapper {
3838
*/
3939
@Mapping(target = "contextId", source = "contextId", qualifiedByName = "emptyToNull")
4040
@Mapping(target = "status", source = "status", qualifiedByName = "taskStateOrNull")
41-
@Mapping(target = "pageSize", source = "pageSize", qualifiedByName = "zeroToNull")
41+
// pageSize: Check if field is set using hasPageSize() to distinguish unset (null) from explicit 0 (validation error)
42+
@Mapping(target = "pageSize", expression = "java(request.hasPageSize() ? request.getPageSize() : null)")
4243
@Mapping(target = "pageToken", source = "pageToken", qualifiedByName = "emptyToNull")
4344
@Mapping(target = "historyLength", source = "historyLength", qualifiedByName = "zeroToNull")
4445
@Mapping(target = "lastUpdatedAfter", source = "lastUpdatedAfter", qualifiedByName = "millisToInstant")

spec-grpc/src/main/java/io/a2a/grpc/utils/JSONRPCUtils.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -458,14 +458,12 @@ private static JsonProcessingException convertProtoBufExceptionToJsonProcessingE
458458
return new JsonProcessingException(ERROR_MESSAGE.formatted("unknown parsing error"));
459459
}
460460

461-
// Extract field name if present in error message
462-
String prefix = "Cannot find field: ";
463-
if (message.contains(prefix)) {
464-
return new InvalidParamsJsonMappingException(ERROR_MESSAGE.formatted(message.substring(message.indexOf(prefix) + prefix.length())), id);
465-
}
466-
prefix = "Invalid value for";
467-
if (message.contains(prefix)) {
468-
return new InvalidParamsJsonMappingException(ERROR_MESSAGE.formatted(message.substring(message.indexOf(prefix) + prefix.length())), id);
461+
// Extract field name if present in error message - check common prefixes
462+
String[] prefixes = {"Cannot find field: ", "Invalid value for", "Invalid enum value:"};
463+
for (String prefix : prefixes) {
464+
if (message.contains(prefix)) {
465+
return new InvalidParamsJsonMappingException(ERROR_MESSAGE.formatted(message.substring(message.indexOf(prefix) + prefix.length())), id);
466+
}
469467
}
470468

471469
// Try to extract specific error details using regex patterns
@@ -548,7 +546,7 @@ public static String toJsonRPCRequest(@Nullable String requestId, String method,
548546
output.name("method").value(method);
549547
}
550548
if (payload != null) {
551-
String resultValue = JsonFormat.printer().omittingInsignificantWhitespace().print(payload);
549+
String resultValue = JsonFormat.printer().includingDefaultValueFields().omittingInsignificantWhitespace().print(payload);
552550
output.name("params").jsonValue(resultValue);
553551
}
554552
output.endObject();
@@ -571,7 +569,7 @@ public static String toJsonRPCResultResponse(Object requestId, com.google.protob
571569
output.name("id").value(number.longValue());
572570
}
573571
}
574-
String resultValue = JsonFormat.printer().omittingInsignificantWhitespace().print(builder);
572+
String resultValue = JsonFormat.printer().includingDefaultValueFields().omittingInsignificantWhitespace().print(builder);
575573
output.name("result").jsonValue(resultValue);
576574
output.endObject();
577575
return result.toString();

spec/src/main/java/io/a2a/spec/ListTasksParams.java

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import java.time.Instant;
44

5-
import io.a2a.util.Assert;
65
import org.jspecify.annotations.Nullable;
76

87
/**
@@ -27,9 +26,12 @@ public record ListTasksParams(
2726
@Nullable Boolean includeArtifacts,
2827
String tenant
2928
) {
29+
private static final int MIN_PAGE_SIZE = 1;
30+
private static final int MAX_PAGE_SIZE = 100;
31+
private static final int DEFAULT_PAGE_SIZE = 50;
3032
/**
3133
* Compact constructor for validation.
32-
* Validates that the tenant parameter is not null.
34+
* Validates that the tenant parameter is not null and parameters are within valid ranges.
3335
*
3436
* @param contextId filter by context ID
3537
* @param status filter by task status
@@ -39,9 +41,24 @@ public record ListTasksParams(
3941
* @param lastUpdatedAfter filter by last update timestamp
4042
* @param includeArtifacts whether to include artifacts
4143
* @param tenant the tenant identifier
44+
* @throws InvalidParamsError if tenant is null or if pageSize or historyLength are out of valid range
4245
*/
4346
public ListTasksParams {
44-
Assert.checkNotNullParam("tenant", tenant);
47+
if (tenant == null) {
48+
throw new InvalidParamsError(null, "Parameter 'tenant' may not be null", null);
49+
}
50+
51+
// Validate pageSize (1-100)
52+
if (pageSize != null && (pageSize < MIN_PAGE_SIZE || pageSize > MAX_PAGE_SIZE)) {
53+
throw new InvalidParamsError(null,
54+
"pageSize must be between " + MIN_PAGE_SIZE + " and " + MAX_PAGE_SIZE + ", got: " + pageSize, null);
55+
}
56+
57+
// Validate historyLength (>= 0)
58+
if (historyLength != null && historyLength < 0) {
59+
throw new InvalidParamsError(null,
60+
"historyLength must be non-negative, got: " + historyLength, null);
61+
}
4562
}
4663
/**
4764
* Default constructor for listing all tasks.
@@ -61,33 +78,23 @@ public ListTasksParams(Integer pageSize, String pageToken) {
6178
}
6279

6380
/**
64-
* Validates and returns the effective page size (between 1 and 100, defaults to 50).
81+
* Returns the effective page size (defaults to 50 if not specified).
82+
* Values are validated in the constructor to be within the range [1, 100].
6583
*
6684
* @return the effective page size
6785
*/
6886
public int getEffectivePageSize() {
69-
if (pageSize == null) {
70-
return 50;
71-
}
72-
if (pageSize < 1) {
73-
return 1;
74-
}
75-
if (pageSize > 100) {
76-
return 100;
77-
}
78-
return pageSize;
87+
return pageSize != null ? pageSize : DEFAULT_PAGE_SIZE;
7988
}
8089

8190
/**
82-
* Returns the effective history length (non-negative, defaults to 0).
91+
* Returns the effective history length (defaults to 0 if not specified).
92+
* Values are validated in the constructor to be non-negative.
8393
*
8494
* @return the effective history length
8595
*/
8696
public int getEffectiveHistoryLength() {
87-
if (historyLength == null || historyLength < 0) {
88-
return 0;
89-
}
90-
return historyLength;
97+
return historyLength != null ? historyLength : 0;
9198
}
9299

93100
/**

0 commit comments

Comments
 (0)