Skip to content

Commit 634eb94

Browse files
feat(oauth,oidc): add redirect URL (#278)
1 parent 386528c commit 634eb94

6 files changed

Lines changed: 54 additions & 31 deletions

File tree

packages/methods/shield-oauth/src/actions/sign_in.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::{
1818
#[serde(rename_all = "camelCase")]
1919
pub struct SignInData {
2020
pub redirect_origin: Option<Url>,
21+
pub redirect_url: Option<Url>,
2122
}
2223

2324
pub struct OauthSignInAction {
@@ -74,15 +75,26 @@ impl Action<OauthProvider, OauthSession> for OauthSignInAction {
7475
let data = serde_json::from_value::<SignInData>(request.form_data)
7576
.map_err(|err| ShieldError::Validation(err.to_string()))?;
7677

77-
let redirect_origin = if let Some(redirect_origins) = &self.options.redirect_origins
78-
&& let Some(redirect_origin) = data.redirect_origin
79-
// TODO: Consider returning an error when redirect origin is not allowed.
80-
&& redirect_origins.contains(&redirect_origin)
78+
let redirect_url = data.redirect_url.or_else(|| {
79+
data.redirect_origin.and_then(|redirect_origin| {
80+
redirect_origin.join(&self.options.sign_in_redirect).ok()
81+
})
82+
});
83+
84+
if let Some(redirect_url) = &redirect_url
85+
&& let Some(redirect_origins) = &self.options.redirect_origins
8186
{
82-
Some(redirect_origin)
83-
} else {
84-
None
85-
};
87+
let redirect_origin = Url::parse(&redirect_url.origin().ascii_serialization())
88+
.map_err(|err| {
89+
ShieldError::Validation(format!("redirect origin parse error: {err}"))
90+
})?;
91+
92+
if !redirect_origins.contains(&redirect_origin) {
93+
return Err(ShieldError::Validation(format!(
94+
"redirect origin `{redirect_origin}` not allowed"
95+
)));
96+
}
97+
}
8698

8799
let client = provider.oauth_client().await?;
88100

@@ -120,7 +132,7 @@ impl Action<OauthProvider, OauthSession> for OauthSignInAction {
120132
Ok(Response::new(ResponseType::Redirect(auth_url.to_string()))
121133
.session_action(SessionAction::Unauthenticate)
122134
.session_action(SessionAction::data(OauthSession {
123-
redirect_origin,
135+
redirect_url,
124136
csrf: Some(csrf_token.secret().clone()),
125137
pkce_verifier: pkce_code_challenge
126138
.map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()),

packages/methods/shield-oauth/src/actions/sign_in_callback.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,16 @@ impl<U: User + 'static> Action<OauthProvider, OauthSession> for OauthSignInCallb
257257
};
258258

259259
Ok(Response::new(ResponseType::Redirect(
260-
self.options.sign_in_redirect.clone(),
260+
session
261+
.method
262+
.redirect_url
263+
.as_ref()
264+
.map(ToString::to_string)
265+
.unwrap_or_else(|| self.options.sign_in_redirect.clone()),
261266
))
262267
.session_action(SessionAction::authenticate(user))
263268
.session_action(SessionAction::data(OauthSession {
264-
redirect_origin: None,
269+
redirect_url: None,
265270
csrf: None,
266271
pkce_verifier: None,
267272
oauth_connection_id: Some(connection.id),

packages/methods/shield-oauth/src/session.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use url::Url;
33

44
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
55
pub struct OauthSession {
6-
pub redirect_origin: Option<Url>,
6+
pub redirect_url: Option<Url>,
77
pub csrf: Option<String>,
88
pub pkce_verifier: Option<String>,
99
pub oauth_connection_id: Option<String>,

packages/methods/shield-oidc/src/actions/sign_in.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::{
2121
#[serde(rename_all = "camelCase")]
2222
pub struct SignInData {
2323
pub redirect_origin: Option<Url>,
24+
pub redirect_url: Option<Url>,
2425
}
2526

2627
pub struct OidcSignInAction {
@@ -85,15 +86,26 @@ impl Action<OidcProvider, OidcSession> for OidcSignInAction {
8586
let data = serde_json::from_value::<SignInData>(request.form_data)
8687
.map_err(|err| ShieldError::Validation(err.to_string()))?;
8788

88-
let redirect_origin = if let Some(redirect_origins) = &self.options.redirect_origins
89-
&& let Some(redirect_origin) = data.redirect_origin
90-
// TODO: Consider returning an error when redirect origin is not allowed.
91-
&& redirect_origins.contains(&redirect_origin)
89+
let redirect_url = data.redirect_url.or_else(|| {
90+
data.redirect_origin.and_then(|redirect_origin| {
91+
redirect_origin.join(&self.options.sign_in_redirect).ok()
92+
})
93+
});
94+
95+
if let Some(redirect_url) = &redirect_url
96+
&& let Some(redirect_origins) = &self.options.redirect_origins
9297
{
93-
Some(redirect_origin)
94-
} else {
95-
None
96-
};
98+
let redirect_origin = Url::parse(&redirect_url.origin().ascii_serialization())
99+
.map_err(|err| {
100+
ShieldError::Validation(format!("redirect origin parse error: {err}"))
101+
})?;
102+
103+
if !redirect_origins.contains(&redirect_origin) {
104+
return Err(ShieldError::Validation(format!(
105+
"redirect origin `{redirect_origin}` not allowed"
106+
)));
107+
}
108+
}
97109

98110
let client = provider.oidc_client().await?;
99111

@@ -133,7 +145,7 @@ impl Action<OidcProvider, OidcSession> for OidcSignInAction {
133145
Ok(Response::new(ResponseType::Redirect(auth_url.to_string()))
134146
.session_action(SessionAction::unauthenticate())
135147
.session_action(SessionAction::data(OidcSession {
136-
redirect_origin,
148+
redirect_url,
137149
csrf: Some(csrf_token.secret().clone()),
138150
nonce: Some(nonce.secret().clone()),
139151
pkce_verifier: pkce_code_challenge

packages/methods/shield-oidc/src/actions/sign_in_callback.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -293,20 +293,14 @@ impl<U: User + 'static> Action<OidcProvider, OidcSession> for OidcSignInCallback
293293
Ok(Response::new(ResponseType::Redirect(
294294
session
295295
.method
296-
.redirect_origin
296+
.redirect_url
297297
.as_ref()
298-
.and_then(|redirect_origin| {
299-
redirect_origin
300-
.join(&self.options.sign_in_redirect)
301-
.as_ref()
302-
.map(ToString::to_string)
303-
.ok()
304-
})
298+
.map(ToString::to_string)
305299
.unwrap_or_else(|| self.options.sign_in_redirect.clone()),
306300
))
307301
.session_action(SessionAction::authenticate(user))
308302
.session_action(SessionAction::data(OidcSession {
309-
redirect_origin: None,
303+
redirect_url: None,
310304
csrf: None,
311305
nonce: None,
312306
pkce_verifier: None,

packages/methods/shield-oidc/src/session.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use url::Url;
33

44
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
55
pub struct OidcSession {
6-
pub redirect_origin: Option<Url>,
6+
pub redirect_url: Option<Url>,
77
pub csrf: Option<String>,
88
pub nonce: Option<String>,
99
pub pkce_verifier: Option<String>,

0 commit comments

Comments
 (0)