Skip to content

Commit b4dd7f6

Browse files
committed
Fixed -> 1. getRequestMetadata calls refreshTrustBoundaryIfExpired without a try catch block. 2. Lock acquiral for refreshFuture.compareAndSet(null, future) now fixed. 3. Oauth2Credentials isn't caching RAB which was earlier leading to serialization issues.
1 parent 8249823 commit b4dd7f6

4 files changed

Lines changed: 185 additions & 20 deletions

File tree

oauth2_http/java/com/google/auth/oauth2/GoogleCredentials.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,12 @@ void refreshRegionalAccessBoundaryWithSelfSignedJwtIfExpired(
440440
public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
441441
Map<String, List<String>> metadata = super.getRequestMetadata(uri);
442442
metadata = addRegionalAccessBoundaryToRequestMetadata(metadata);
443-
refreshRegionalAccessBoundaryIfExpired(uri, getAccessToken(), null);
443+
try {
444+
// Sets off an async refresh for request-metadata.
445+
refreshRegionalAccessBoundaryIfExpired(uri, getAccessToken(), null);
446+
} catch (IOException e) {
447+
// Ignore failure in async refresh trigger.
448+
}
444449
return metadata;
445450
}
446451

@@ -535,8 +540,9 @@ static Map<String, List<String>> addQuotaProjectIdToRequestMetadata(
535540

536541
/**
537542
* Adds Regional Access Boundary header to requestMetadata if available. Overwrites if present.
543+
* If the current RAB is null, it removes any stale header that might have survived serialization.
538544
*
539-
* @return a new map with Regional Access Boundary header added or updated
545+
* @return a new map with Regional Access Boundary header added, updated, or removed
540546
*/
541547
Map<String, List<String>> addRegionalAccessBoundaryToRequestMetadata(
542548
Map<String, List<String>> requestMetadata) {
@@ -550,6 +556,12 @@ Map<String, List<String>> addRegionalAccessBoundaryToRequestMetadata(
550556
RegionalAccessBoundary.X_ALLOWED_LOCATIONS_HEADER_KEY,
551557
Collections.singletonList(rab.getEncodedLocations()));
552558
return ImmutableMap.copyOf(newMetadata);
559+
} else if (requestMetadata.containsKey(RegionalAccessBoundary.X_ALLOWED_LOCATIONS_HEADER_KEY)) {
560+
// If RAB is null but the header exists (e.g., from a serialized cache), we must strip it
561+
// to prevent sending stale data to the server.
562+
Map<String, List<String>> newMetadata = new HashMap<>(requestMetadata);
563+
newMetadata.remove(RegionalAccessBoundary.X_ALLOWED_LOCATIONS_HEADER_KEY);
564+
return ImmutableMap.copyOf(newMetadata);
553565
}
554566
return requestMetadata;
555567
}
@@ -558,13 +570,6 @@ Map<String, List<String>> addRegionalAccessBoundaryToRequestMetadata(
558570
protected Map<String, List<String>> getAdditionalHeaders() {
559571
Map<String, List<String>> headers = new HashMap<>(super.getAdditionalHeaders());
560572

561-
RegionalAccessBoundary rab = regionalAccessBoundaryManager.getCachedRAB();
562-
if (rab != null) {
563-
headers.put(
564-
RegionalAccessBoundary.X_ALLOWED_LOCATIONS_HEADER_KEY,
565-
Collections.singletonList(rab.getEncodedLocations()));
566-
}
567-
568573
String quotaProjectId = this.getQuotaProjectId();
569574
return addQuotaProjectIdToRequestMetadata(quotaProjectId, headers);
570575
}

oauth2_http/java/com/google/auth/oauth2/RegionalAccessBoundaryManager.java

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,26 @@ void triggerAsyncRefresh(
139139
}
140140
};
141141

142-
if (executor != null) {
143-
executor.execute(refreshTask);
144-
} else {
145-
// We use new Thread() here instead of
146-
// CompletableFuture.runAsync() (which uses ForkJoinPool.commonPool()).
147-
// This avoids consuming CPU resources since
148-
// The common pool has a small, fixed number of threads designed for
149-
// CPU-bound tasks.
150-
Thread refreshThread = new Thread(refreshTask, "RAB-refresh-thread");
151-
refreshThread.setDaemon(true);
152-
refreshThread.start();
142+
try {
143+
if (executor != null) {
144+
executor.execute(refreshTask);
145+
} else {
146+
// We use new Thread() here instead of
147+
// CompletableFuture.runAsync() (which uses ForkJoinPool.commonPool()).
148+
// This avoids consuming CPU resources since
149+
// The common pool has a small, fixed number of threads designed for
150+
// CPU-bound tasks.
151+
Thread refreshThread = new Thread(refreshTask, "RAB-refresh-thread");
152+
refreshThread.setDaemon(true);
153+
refreshThread.start();
154+
}
155+
} catch (Exception | Error e) {
156+
// If scheduling fails (e.g., RejectedExecutionException, OutOfMemoryError for threads),
157+
// the task's finally block will never execute. We must release the lock here.
158+
refreshFuture.set(null);
159+
future.completeExceptionally(e);
160+
handleRefreshFailure(
161+
new Exception("Regional Access Boundary background refresh failed to schedule", e));
153162
}
154163
}
155164
}

oauth2_http/javatests/com/google/auth/oauth2/GoogleCredentialsTest.java

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import com.google.api.client.json.GenericJson;
4545
import com.google.api.client.util.Clock;
4646
import com.google.auth.Credentials;
47+
import com.google.auth.RequestMetadataCallback;
4748
import com.google.auth.TestUtils;
4849
import com.google.auth.http.HttpTransportFactory;
4950
import com.google.auth.oauth2.ExternalAccountAuthorizedUserCredentialsTest.MockExternalAccountAuthorizedUserCredentialsTransportFactory;
@@ -56,6 +57,7 @@
5657
import java.util.*;
5758
import java.util.concurrent.atomic.AtomicLong;
5859
import java.util.concurrent.atomic.AtomicReference;
60+
import javax.annotation.Nullable;
5961
import org.junit.Test;
6062
import org.junit.runner.RunWith;
6163
import org.junit.runners.JUnit4;
@@ -801,6 +803,51 @@ public void serialize() throws IOException, ClassNotFoundException {
801803
assertNotNull(deserializedCredentials.regionalAccessBoundaryManager);
802804
}
803805

806+
@Test
807+
public void serialize_removesStaleRabHeaders() throws Exception {
808+
GoogleCredentials.disableRabRefreshForTest = false;
809+
810+
MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
811+
RegionalAccessBoundary rab =
812+
new RegionalAccessBoundary(
813+
"test-encoded", Collections.singletonList("test-loc"), System.currentTimeMillis());
814+
transportFactory.transport.setRegionalAccessBoundary(rab);
815+
transportFactory.transport.addServiceAccount(SA_CLIENT_EMAIL, ACCESS_TOKEN);
816+
817+
GoogleCredentials credentials =
818+
new ServiceAccountCredentials.Builder()
819+
.setClientEmail(SA_CLIENT_EMAIL)
820+
.setPrivateKey(OAuth2Utils.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8))
821+
.setPrivateKeyId(SA_PRIVATE_KEY_ID)
822+
.setHttpTransportFactory(transportFactory)
823+
.setScopes(SCOPES)
824+
.build();
825+
826+
// 1. Trigger request metadata to start async RAB refresh
827+
credentials.getRequestMetadata(URI.create("https://foo.com"));
828+
829+
// Wait for the RAB to be fetched and cached
830+
waitForRegionalAccessBoundary(credentials);
831+
832+
// 2. Verify the live credential has the RAB header
833+
Map<String, List<String>> metadata = credentials.getRequestMetadata();
834+
assertEquals(
835+
Collections.singletonList("test-encoded"),
836+
metadata.get(RegionalAccessBoundary.X_ALLOWED_LOCATIONS_HEADER_KEY));
837+
838+
// 3. Serialize and deserialize.
839+
GoogleCredentials deserialized = serializeAndDeserialize(credentials);
840+
841+
// 4. Verify.
842+
// The manager is transient, so it should be empty.
843+
assertNull(deserialized.getRegionalAccessBoundary());
844+
845+
// The metadata should NOT contain the RAB header anymore, preventing stale headers.
846+
Map<String, List<String>> deserializedMetadata = deserialized.getRequestMetadata();
847+
assertNull(
848+
deserializedMetadata.get(RegionalAccessBoundary.X_ALLOWED_LOCATIONS_HEADER_KEY));
849+
}
850+
804851
@Test
805852
public void toString_containsFields() throws IOException {
806853
String expectedToString =
@@ -1173,6 +1220,70 @@ public void regionalAccessBoundary_shouldSkipRefreshForRegionalEndpoints() throw
11731220
assertEquals(0, transport.getRegionalAccessBoundaryRequestCount());
11741221
}
11751222

1223+
@Test
1224+
public void getRequestMetadata_ignoresRabRefreshException() throws IOException {
1225+
GoogleCredentials credentials =
1226+
new GoogleCredentials() {
1227+
@Override
1228+
public AccessToken refreshAccessToken() throws IOException {
1229+
return new AccessToken("token", null);
1230+
}
1231+
1232+
@Override
1233+
void refreshRegionalAccessBoundaryIfExpired(
1234+
@Nullable URI uri,
1235+
@Nullable AccessToken token,
1236+
@Nullable java.util.concurrent.Executor executor)
1237+
throws IOException {
1238+
throw new IOException("Simulated RAB failure");
1239+
}
1240+
};
1241+
1242+
// This should not throw the IOException from refreshRegionalAccessBoundaryIfExpired
1243+
Map<String, List<String>> metadata =
1244+
credentials.getRequestMetadata(URI.create("https://foo.com"));
1245+
assertTrue(metadata.containsKey("Authorization"));
1246+
}
1247+
1248+
@Test
1249+
public void getRequestMetadataAsync_ignoresRabRefreshException() throws IOException {
1250+
GoogleCredentials credentials =
1251+
new GoogleCredentials() {
1252+
@Override
1253+
public AccessToken refreshAccessToken() throws IOException {
1254+
return new AccessToken("token", null);
1255+
}
1256+
1257+
@Override
1258+
void refreshRegionalAccessBoundaryIfExpired(
1259+
@Nullable URI uri,
1260+
@Nullable AccessToken token,
1261+
@Nullable java.util.concurrent.Executor executor)
1262+
throws IOException {
1263+
throw new IOException("Simulated RAB failure");
1264+
}
1265+
};
1266+
1267+
java.util.concurrent.atomic.AtomicBoolean success =
1268+
new java.util.concurrent.atomic.AtomicBoolean(false);
1269+
credentials.getRequestMetadata(
1270+
URI.create("https://foo.com"),
1271+
Runnable::run,
1272+
new RequestMetadataCallback() {
1273+
@Override
1274+
public void onSuccess(Map<String, List<String>> metadata) {
1275+
success.set(true);
1276+
}
1277+
1278+
@Override
1279+
public void onFailure(Throwable exception) {
1280+
fail("Should not have failed");
1281+
}
1282+
});
1283+
1284+
assertTrue(success.get());
1285+
}
1286+
11761287
private GoogleCredentials createTestCredentials(MockTokenServerTransport transport)
11771288
throws IOException {
11781289
transport.addServiceAccount(SA_CLIENT_EMAIL, ACCESS_TOKEN);

oauth2_http/javatests/com/google/auth/oauth2/RegionalAccessBoundaryTest.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,46 @@ public void testManagerTriggersRefreshInGracePeriod() throws InterruptedExceptio
182182
assertEquals(newerEncoded, resultRab.getEncodedLocations());
183183
}
184184

185+
@Test
186+
public void testManagerReleasesLockOnSchedulingFailure() {
187+
RegionalAccessBoundaryManager manager = new RegionalAccessBoundaryManager();
188+
HttpTransportFactory transportFactory = () -> new MockHttpTransport();
189+
RegionalAccessBoundaryProvider provider = () -> "https://dummy";
190+
AccessToken token =
191+
new AccessToken("token", new java.util.Date(System.currentTimeMillis() + 10 * 3600000L));
192+
193+
java.util.concurrent.Executor rejectingExecutor =
194+
new java.util.concurrent.Executor() {
195+
@Override
196+
public void execute(Runnable command) {
197+
throw new java.util.concurrent.RejectedExecutionException("Simulated rejection");
198+
}
199+
};
200+
201+
manager.triggerAsyncRefresh(transportFactory, provider, token, rejectingExecutor);
202+
203+
// After rejection, the lock should be released, but it should be in cooldown.
204+
assertTrue(manager.isCooldownActive());
205+
206+
// Advance the clock to bypass cooldown
207+
testClock.set(
208+
testClock.currentTimeMillis() + RegionalAccessBoundaryManager.MAX_COOLDOWN_MILLIS + 1000);
209+
assertFalse(manager.isCooldownActive());
210+
211+
// Schedule again with a valid executor to prove the lock was released.
212+
java.util.concurrent.atomic.AtomicBoolean taskRan =
213+
new java.util.concurrent.atomic.AtomicBoolean(false);
214+
java.util.concurrent.Executor workingExecutor =
215+
new java.util.concurrent.Executor() {
216+
@Override
217+
public void execute(Runnable command) {
218+
taskRan.set(true);
219+
}
220+
};
221+
manager.triggerAsyncRefresh(transportFactory, provider, token, workingExecutor);
222+
assertTrue(taskRan.get());
223+
}
224+
185225
private static class TestClock implements Clock {
186226
private final AtomicLong currentTime = new AtomicLong(System.currentTimeMillis());
187227

0 commit comments

Comments
 (0)