Skip to content

Commit b590bdc

Browse files
committed
feat: Fix ClientSideCredentialAccessBoundary race condition when multiple concurrent calls are made to generateToken.
1 parent 0f92593 commit b590bdc

2 files changed

Lines changed: 54 additions & 21 deletions

File tree

cab-token-generator/java/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactory.java

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ void refreshCredentialsIfRequired() throws IOException {
248248
}
249249
try {
250250
// Wait for the refresh task to complete.
251-
currentRefreshTask.task.get();
251+
currentRefreshTask.get();
252252
} catch (InterruptedException e) {
253253
// Restore the interrupted status and throw an exception.
254254
Thread.currentThread().interrupt();
@@ -495,31 +495,17 @@ class RefreshTask extends AbstractFuture<IntermediateCredentials> implements Run
495495
this.task = task;
496496
this.isNew = isNew;
497497

498-
// Add listener to update factory's credentials when the task completes.
498+
// Single listener to guarantee that finishRefreshTask updates the internal state BEFORE
499+
// the outer future completes and unblocks waiters.
499500
task.addListener(
500501
() -> {
501502
try {
502503
finishRefreshTask(task);
504+
RefreshTask.this.set(Futures.getDone(task));
503505
} catch (ExecutionException e) {
504-
Throwable cause = e.getCause();
505-
RefreshTask.this.setException(cause);
506-
}
507-
},
508-
MoreExecutors.directExecutor());
509-
510-
// Add callback to set the result or exception based on the outcome.
511-
Futures.addCallback(
512-
task,
513-
new FutureCallback<IntermediateCredentials>() {
514-
@Override
515-
public void onSuccess(IntermediateCredentials result) {
516-
RefreshTask.this.set(result);
517-
}
518-
519-
@Override
520-
public void onFailure(@Nullable Throwable t) {
521-
RefreshTask.this.setException(
522-
t != null ? t : new IOException("Refresh failed with null Throwable."));
506+
RefreshTask.this.setException(e.getCause());
507+
} catch (Exception e) {
508+
RefreshTask.this.setException(e);
523509
}
524510
},
525511
MoreExecutors.directExecutor());

cab-token-generator/javatests/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactoryTest.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,4 +988,51 @@ void generateToken_withMalformSessionKey_failure() throws Exception {
988988

989989
assertThrows(GeneralSecurityException.class, () -> factory.generateToken(accessBoundary));
990990
}
991+
992+
@org.junit.jupiter.api.Test
993+
void generateToken_freshInstance_concurrent_noNpe() throws Exception {
994+
for (int run = 0; run < 10; run++) { // Run 10 times in a single test instance to save time
995+
GoogleCredentials sourceCredentials = getServiceAccountSourceCredentials(mockTokenServerTransportFactory);
996+
ClientSideCredentialAccessBoundaryFactory factory = ClientSideCredentialAccessBoundaryFactory.newBuilder()
997+
.setSourceCredential(sourceCredentials)
998+
.setHttpTransportFactory(mockStsTransportFactory)
999+
.build();
1000+
1001+
CredentialAccessBoundary.Builder cabBuilder = CredentialAccessBoundary.newBuilder();
1002+
CredentialAccessBoundary accessBoundary = cabBuilder
1003+
.addRule(
1004+
CredentialAccessBoundary.AccessBoundaryRule.newBuilder()
1005+
.setAvailableResource("resource")
1006+
.setAvailablePermissions(ImmutableList.of("role"))
1007+
.build())
1008+
.build();
1009+
1010+
int numThreads = 5;
1011+
Thread[] threads = new Thread[numThreads];
1012+
CountDownLatch latch = new CountDownLatch(numThreads);
1013+
java.util.concurrent.atomic.AtomicInteger npeCount = new java.util.concurrent.atomic.AtomicInteger();
1014+
1015+
for (int i = 0; i < numThreads; i++) {
1016+
threads[i] = new Thread(() -> {
1017+
try {
1018+
latch.countDown();
1019+
latch.await();
1020+
factory.generateToken(accessBoundary);
1021+
} catch (NullPointerException e) {
1022+
npeCount.incrementAndGet();
1023+
} catch (Exception e) {
1024+
// Ignore other exceptions for the sake of the race reproduction
1025+
}
1026+
});
1027+
threads[i].start();
1028+
}
1029+
1030+
for (Thread thread : threads) {
1031+
thread.join();
1032+
}
1033+
1034+
org.junit.jupiter.api.Assertions.assertEquals(0, npeCount.get(),
1035+
"Expected zero NullPointerExceptions due to the race condition, but some were thrown.");
1036+
}
1037+
}
9911038
}

0 commit comments

Comments
 (0)