Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
Expand Down Expand Up @@ -166,6 +168,67 @@ public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() {
assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration);
}

@Test
public void requestWhenNoServletRequestThenAuthorizeAndSendRequest() {
RequestContextHolder.resetRequestAttributes();
InMemoryOAuth2AuthorizedClientService delegate = new InMemoryOAuth2AuthorizedClientService(
this.clientRegistrationRepository);
OAuth2AuthorizedClientService clientService = spy(new OAuth2AuthorizedClientService() {
@Override
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId,
String principal) {
return delegate.loadAuthorizedClient(clientRegistrationId, principal);
}

@Override
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
delegate.saveAuthorizedClient(authorizedClient, principal);
}

@Override
public void removeAuthorizedClient(String clientRegistrationId, String principal) {
delegate.removeAuthorizedClient(clientRegistrationId, principal);
}
});
this.authorizedClientFilter = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
new AuthorizedClientServiceOAuth2AuthorizedClientManager(this.clientRegistrationRepository,
clientService));
this.webClient = WebClient.builder().apply(this.authorizedClientFilter.oauth2Configuration()).build();

// @formatter:off
String accessTokenResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
String clientResponse = "{\n"
+ " \"attribute1\": \"value1\",\n"
+ " \"attribute2\": \"value2\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse));
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
.tokenUri(this.serverUrl)
.build();
given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId())))
.willReturn(clientRegistration);

this.webClient.get()
.uri(this.serverUrl)
.attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction
.clientRegistrationId(clientRegistration.getRegistrationId()))
.retrieve()
.bodyToMono(String.class)
.block();
assertThat(this.server.getRequestCount()).isEqualTo(2);
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor
.forClass(OAuth2AuthorizedClient.class);
verify(clientService).saveAuthorizedClient(authorizedClientCaptor.capture(), eq(this.authentication));
assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration);
}

@Test
public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() {
// @formatter:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;

Expand Down Expand Up @@ -134,6 +136,9 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private OAuth2AuthorizedClientRepository authorizedClientRepository;

@Mock
private OAuth2AuthorizedClientService oAuth2AuthorizedClientService;

@Mock
private ClientRegistrationRepository clientRegistrationRepository;

Expand Down Expand Up @@ -661,6 +666,41 @@ public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalRe
authentication, servletRequest);
}

@Test
public void filterWhenServletRequestNullClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() {
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
new AuthorizedClientServiceOAuth2AuthorizedClientManager(this.clientRegistrationRepository,
oAuth2AuthorizedClientService));
this.function.setDefaultOAuth2AuthorizedClient(true);
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, authorities,
"initial-registration-id");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, authorities,
this.registration.getRegistrationId());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken);
given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(this.registration);
given(this.oAuth2AuthorizedClientService.loadAuthorizedClient(this.registration.getRegistrationId(),
initialAuthentication.getName()))
.willReturn(authorizedClient);
final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.build();
this.function.setPrincipalResolver((request) -> authentication);
this.function.filter(clientRequest, this.exchange)
.contextWrite(context(null, null, initialAuthentication))
.block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request = requests.get(0);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
verify(this.oAuth2AuthorizedClientService).loadAuthorizedClient(this.registration.getRegistrationId(),
authentication.getName());
}

@Test
public void filterWhenUnauthorizedThenInvokeFailureHandler() {
assertHttpStatusInvokesFailureHandler(HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN);
Expand Down