11import json
22import logging
33
4+ from urllib .parse import quote
5+
46from mauth_client .authenticator import LocalAuthenticator
57from mauth_client .config import Config
68from mauth_client .consts import (
@@ -22,17 +24,17 @@ def __init__(self, app, exempt=None):
2224 self .exempt = exempt .copy () if exempt else set ()
2325
2426 def __call__ (self , environ , start_response ):
25- req = environ [ "werkzeug.request" ]
27+ path = environ . get ( "PATH_INFO" , "" )
2628
27- if req . path in self .exempt :
29+ if path in self .exempt :
2830 return self .app (environ , start_response )
2931
3032 signable = RequestSignable (
31- method = req . method ,
32- url = req . url ,
33+ method = environ [ "REQUEST_METHOD" ] ,
34+ url = self . _extract_url ( environ ) ,
3335 body = self ._read_body (environ ),
3436 )
35- signed = Signed .from_headers (dict ( req . headers ))
37+ signed = Signed .from_headers (self . _extract_headers ( environ ))
3638 authenticator = LocalAuthenticator (signable , signed , logger )
3739 is_authentic , status , message = authenticator .is_authentic ()
3840
@@ -60,3 +62,56 @@ def _read_body(self, environ):
6062 body = input .read ()
6163 input .seek (0 )
6264 return body
65+
66+ def _extract_headers (self , environ ):
67+ """
68+ Adapted from werkzeug package: https://github.com/pallets/werkzeug
69+ """
70+ headers = {}
71+
72+ # don't care to titleize the header keys since
73+ # the Signed class is just going to lowercase them
74+ for k , v in environ .items ():
75+ if k .startswith ("HTTP_" ) and k not in {
76+ "HTTP_CONTENT_TYPE" ,
77+ "HTTP_CONTENT_LENGTH" ,
78+ }:
79+ key = k [5 :].replace ("_" , "-" )
80+ headers [key ] = v
81+ elif k in {"CONTENT_TYPE" , "CONTENT_LENGTH" }:
82+ key = k .replace ("_" , "-" )
83+ headers [key ] = v
84+
85+ return headers
86+
87+ SAFE_CHARS = "!$&'()*+,/:;=@%"
88+
89+ def _extract_url (self , environ ):
90+ """
91+ Adapted from https://peps.python.org/pep-0333/#url-reconstruction
92+ """
93+ scheme = environ ["wsgi.url_scheme" ]
94+ url_parts = [scheme , "://" ]
95+ http_host = environ .get ("HTTP_HOST" )
96+
97+ if http_host :
98+ url_parts .append (http_host )
99+ else :
100+ url_parts .append (environ ["SERVER_NAME" ])
101+ port = environ ["SERVER_PORT" ]
102+
103+ if (scheme == "https" and port != 443 ) or (scheme != "https" and port != 80 ):
104+ url_parts .append (f":{ port } " )
105+
106+ url_parts .append (
107+ quote (environ .get ("SCRIPT_NAME" , "" ), safe = self .SAFE_CHARS )
108+ )
109+ url_parts .append (
110+ quote (environ .get ("PATH_INFO" , "" ), safe = self .SAFE_CHARS )
111+ )
112+
113+ qs = environ .get ("QUERY_STRING" )
114+ if qs :
115+ url_parts .append (f"?{ quote (qs , safe = self .SAFE_CHARS )} " )
116+
117+ return "" .join (url_parts )
0 commit comments