File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -571,3 +571,17 @@ pub fn csrf_cookie_name() -> &'static str {
571571pub fn csrf_header_name ( ) -> & ' static str {
572572 CSRF_HEADER_NAME
573573}
574+
575+ /// Middleware to enforce CSRF protection.
576+ ///
577+ /// This middleware extracts the `CsrfGuard` which performs the validation.
578+ /// It is designed to be used with `axum::middleware::from_fn_with_state`
579+ /// to ensure the database pool state is available for extraction.
580+ pub async fn enforce_csrf (
581+ axum:: extract:: State ( _pool) : axum:: extract:: State < crate :: db:: DbPool > ,
582+ _guard : CsrfGuard ,
583+ req : axum:: extract:: Request ,
584+ next : axum:: middleware:: Next ,
585+ ) -> axum:: response:: Response {
586+ next. run ( req) . await
587+ }
Original file line number Diff line number Diff line change @@ -247,9 +247,8 @@ async fn main() {
247247 "/api/upload" ,
248248 post ( handlers:: upload:: upload_image) ,
249249 )
250- . with_state ( pool. clone ( ) )
251- . route_layer ( from_extractor :: < csrf:: CsrfGuard > ( ) )
252- . route_layer ( from_fn ( middleware:: auth:: auth_middleware) )
250+ . route_layer ( axum:: middleware:: from_fn_with_state ( pool. clone ( ) , csrf:: enforce_csrf) )
251+ . route_layer ( axum:: middleware:: from_fn_with_state ( pool. clone ( ) , middleware:: auth:: auth_middleware) )
253252 . layer ( RequestBodyLimitLayer :: new ( ADMIN_BODY_LIMIT ) )
254253 . layer ( GovernorLayer :: new ( admin_rate_limit_config. clone ( ) ) ) ;
255254
Original file line number Diff line number Diff line change @@ -33,6 +33,7 @@ use axum::{
3333/// On success, inserts Claims into request extensions for easy access
3434/// by downstream handlers.
3535pub async fn auth_middleware (
36+ axum:: extract:: State ( pool) : axum:: extract:: State < crate :: db:: DbPool > ,
3637 mut request : axum:: extract:: Request ,
3738 next : axum:: middleware:: Next ,
3839) -> Result < axum:: response:: Response , ( StatusCode , Json < crate :: models:: ErrorResponse > ) > {
@@ -57,12 +58,7 @@ pub async fn auth_middleware(
5758 } ) ?;
5859
5960 // Check if token is blacklisted
60- let pool = request
61- . extensions ( )
62- . get :: < crate :: db:: DbPool > ( )
63- . expect ( "Database pool not found in request extensions" ) ;
64-
65- if let Ok ( true ) = repositories:: token_blacklist:: is_token_blacklisted ( pool, & token) . await {
61+ if let Ok ( true ) = repositories:: token_blacklist:: is_token_blacklisted ( & pool, & token) . await {
6662 return Err ( (
6763 StatusCode :: UNAUTHORIZED ,
6864 Json ( crate :: models:: ErrorResponse {
You can’t perform that action at this time.
0 commit comments