@@ -18,6 +18,7 @@ use crate::{
1818#[ serde( rename_all = "camelCase" ) ]
1919pub struct SignInData {
2020 pub redirect_origin : Option < Url > ,
21+ pub redirect_url : Option < Url > ,
2122}
2223
2324pub struct OauthSignInAction {
@@ -74,15 +75,26 @@ impl Action<OauthProvider, OauthSession> for OauthSignInAction {
7475 let data = serde_json:: from_value :: < SignInData > ( request. form_data )
7576 . map_err ( |err| ShieldError :: Validation ( err. to_string ( ) ) ) ?;
7677
77- let redirect_origin = if let Some ( redirect_origins) = & self . options . redirect_origins
78- && let Some ( redirect_origin) = data. redirect_origin
79- // TODO: Consider returning an error when redirect origin is not allowed.
80- && redirect_origins. contains ( & redirect_origin)
78+ let redirect_url = data. redirect_url . or_else ( || {
79+ data. redirect_origin . and_then ( |redirect_origin| {
80+ redirect_origin. join ( & self . options . sign_in_redirect ) . ok ( )
81+ } )
82+ } ) ;
83+
84+ if let Some ( redirect_url) = & redirect_url
85+ && let Some ( redirect_origins) = & self . options . redirect_origins
8186 {
82- Some ( redirect_origin)
83- } else {
84- None
85- } ;
87+ let redirect_origin = Url :: parse ( & redirect_url. origin ( ) . ascii_serialization ( ) )
88+ . map_err ( |err| {
89+ ShieldError :: Validation ( format ! ( "redirect origin parse error: {err}" ) )
90+ } ) ?;
91+
92+ if !redirect_origins. contains ( & redirect_origin) {
93+ return Err ( ShieldError :: Validation ( format ! (
94+ "redirect origin `{redirect_origin}` not allowed"
95+ ) ) ) ;
96+ }
97+ }
8698
8799 let client = provider. oauth_client ( ) . await ?;
88100
@@ -120,7 +132,7 @@ impl Action<OauthProvider, OauthSession> for OauthSignInAction {
120132 Ok ( Response :: new ( ResponseType :: Redirect ( auth_url. to_string ( ) ) )
121133 . session_action ( SessionAction :: Unauthenticate )
122134 . session_action ( SessionAction :: data ( OauthSession {
123- redirect_origin ,
135+ redirect_url ,
124136 csrf : Some ( csrf_token. secret ( ) . clone ( ) ) ,
125137 pkce_verifier : pkce_code_challenge
126138 . map ( |( _, pkce_code_verifier) | pkce_code_verifier. secret ( ) . clone ( ) ) ,
0 commit comments