Skip to content

Commit cae7c85

Browse files
committed
Update typing
1 parent cea512e commit cae7c85

5 files changed

Lines changed: 33 additions & 25 deletions

File tree

src/rain_api_core/auth.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import contextlib
22
import dataclasses
33
import logging
4+
from collections.abc import Mapping
45
from http.cookies import CookieError, SimpleCookie
56
from time import time
6-
from typing import List, Mapping, Optional
7+
from typing import Optional
78
from wsgiref.handlers import format_date_time as format_7231_date
89

910
import jwt
@@ -15,12 +16,12 @@
1516
class UserProfile:
1617
user_id: str
1718
token: str
18-
groups: List[str]
19+
groups: list[str]
1920
first_name: str
2021
last_name: str
2122
email: str
22-
iat: int = None
23-
exp: int = None
23+
iat: Optional[int] = None
24+
exp: Optional[int] = None
2425

2526
@classmethod
2627
def from_jwt_payload(cls, payload):
@@ -115,7 +116,9 @@ def _in_blacklist(self, user_profile: UserProfile):
115116
return True
116117
return False
117118

118-
def get_profile_from_headers(self, headers) -> Optional[UserProfile]:
119+
def get_profile_from_headers(
120+
self, headers: Mapping[str, str],
121+
) -> Optional[UserProfile]:
119122
"""Inspects headers for auth cookie and return user_profile if authenticated, None otherwise"""
120123
auth_cookie = self._get_auth_cookie(headers)
121124
if not auth_cookie:
@@ -130,7 +133,7 @@ def get_profile_from_headers(self, headers) -> Optional[UserProfile]:
130133
return None
131134
return user_profile
132135

133-
def get_header_to_set_auth_cookie(self, user_profile: Optional[UserProfile], cookie_domain=''):
136+
def get_header_to_set_auth_cookie(self, user_profile: Optional[UserProfile], cookie_domain: str = ''):
134137
""" Gets a header to set auth-cookie
135138
136139
Parameters:

src/rain_api_core/aws_util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def get_yaml_file(bucket: str, key: str) -> dict:
155155
return get_yaml(bucket, key)
156156

157157

158-
def get_role_creds(user_id: str = None, in_region: bool = False):
158+
def get_role_creds(user_id: Optional[str] = None, in_region: bool = False):
159159
"""
160160
:param user_id: string with URS username
161161
:param in_region: boolean If True a download role that works only in region will be returned
@@ -205,7 +205,10 @@ def get_role_creds(user_id: str = None, in_region: bool = False):
205205
return role_creds_cache[download_role_arn][user_id]["session"], session_offset
206206

207207

208-
def get_role_session(creds: dict = None, user_id: str = None) -> boto_Session:
208+
def get_role_session(
209+
creds: Optional[dict] = None,
210+
user_id: Optional[str] = None,
211+
) -> boto_Session:
209212
global session_cache # pylint: disable=global-statement
210213
sts_resp = creds if creds else get_role_creds(user_id)[0]
211214
log.debug('sts_resp: {0}'.format(sts_resp))

src/rain_api_core/bucket_map.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from collections import defaultdict
2+
from collections.abc import Generator, Iterable, Sequence
23
from dataclasses import dataclass, field
3-
from typing import Generator, Iterable, Optional, Sequence, Tuple
4+
from typing import Optional
45

56
# By default, buckets are accessible to any logged in users. This is
67
# represented by an empty set.
78
_DEFAULT_PERMISSION_FACTORY = set
89

910

1011
def _is_accessible(
11-
required_groups: Optional[set],
12-
groups: Optional[Iterable[str]]
12+
required_groups: Optional[set[str]],
13+
groups: Optional[Iterable[str]],
1314
) -> bool:
1415
# Check for public access
1516
if required_groups is None:
@@ -30,7 +31,7 @@ class BucketMapEntry():
3031
headers: dict = field(default_factory=dict)
3132
_access_control: Optional[dict] = None
3233

33-
def is_accessible(self, groups: Iterable[str] = None) -> bool:
34+
def is_accessible(self, groups: Optional[Iterable[str]] = None) -> bool:
3435
"""Check if the object is accessible with the given permissions.
3536
3637
Setting `groups` to an iterable implies that the user has logged in,
@@ -41,7 +42,7 @@ def is_accessible(self, groups: Iterable[str] = None) -> bool:
4142
required_groups = self.get_required_groups()
4243
return _is_accessible(required_groups, groups)
4344

44-
def get_required_groups(self) -> Optional[set]:
45+
def get_required_groups(self) -> Optional[set[str]]:
4546
"""Get a set of permissions protecting this object.
4647
4748
It is sufficient to have one of the permissions in the set in order to
@@ -126,7 +127,7 @@ def get_path(self, path: Sequence[str]) -> Optional[BucketMapEntry]:
126127

127128
return None
128129

129-
def entries(self):
130+
def entries(self) -> Generator[BucketMapEntry]:
130131
for bucket, path_parts, headers in _walk_entries(self._get_map()):
131132
yield self._make_entry(
132133
bucket=bucket,
@@ -135,7 +136,7 @@ def entries(self):
135136
headers=headers
136137
)
137138

138-
def to_iam_policy(self, groups: Iterable[str] = None) -> dict:
139+
def to_iam_policy(self, groups: Optional[Iterable[str]] = None) -> Optional[dict]:
139140
if not self._iam_compatible:
140141
_check_iam_compatible(self.access_control)
141142
generator = IamPolicyGenerator(groups)
@@ -150,8 +151,8 @@ def _make_entry(
150151
bucket: str,
151152
bucket_path: str,
152153
object_key: str,
153-
headers: Optional[dict] = None
154-
):
154+
headers: Optional[dict] = None,
155+
) -> BucketMapEntry:
155156
return BucketMapEntry(
156157
bucket=self.bucket_name_prefix + bucket,
157158
bucket_path=bucket_path,
@@ -164,7 +165,7 @@ def _make_entry(
164165
)
165166

166167

167-
def _walk_entries(node: dict, path=()) -> Generator[Tuple[str, tuple, Optional[dict]], None, None]:
168+
def _walk_entries(node: dict, path=()) -> Generator[tuple[str, tuple, Optional[dict]]]:
168169
"""A generator to recursively yield all leaves of a bucket map"""
169170

170171
for key, val in node.items():
@@ -294,7 +295,7 @@ def _access_text(access) -> str:
294295

295296

296297
class IamPolicyGenerator:
297-
def __init__(self, groups: Iterable[str]):
298+
def __init__(self, groups: Optional[Iterable[str]]):
298299
self.groups = groups
299300

300301
def _is_accessible(self, required_groups: Optional[set]) -> bool:

src/rain_api_core/timer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
22
import time
3+
from collections.abc import Callable
34
from dataclasses import dataclass
4-
from typing import Callable, Optional
5+
from typing import Optional
56

67

78
@dataclass(eq=False)
@@ -28,7 +29,7 @@ def __init__(self, timer: Callable[[], float] = time.time):
2829
self.last_name: Optional[str] = None
2930
self.total = Interval()
3031

31-
def mark(self, name: str = None) -> float:
32+
def mark(self, name: Optional[str] = None) -> float:
3233
"""Record a new event.
3334
3435
If called without `name`, any previously started event will be marked

src/rain_api_core/urs_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
log = logging.getLogger(__name__)
1111

1212

13-
def get_base_url(ctxt: dict = None) -> str:
13+
def get_base_url(ctxt: Optional[dict] = None) -> str:
1414
# Make a redirect url using optional custom domain_name, otherwise use raw domain/stage provided by API Gateway.
1515
try:
1616
domain = os.getenv('DOMAIN_NAME') or f"{ctxt['domainName']}/{ctxt['stage']}"
@@ -20,7 +20,7 @@ def get_base_url(ctxt: dict = None) -> str:
2020
raise
2121

2222

23-
def get_redirect_url(ctxt: dict = None) -> str:
23+
def get_redirect_url(ctxt: Optional[dict] = None) -> str:
2424
return f'{get_base_url(ctxt)}login'
2525

2626

@@ -49,7 +49,7 @@ def do_auth(code: str, redirect_url: str, aux_headers: dict = {}) -> dict:
4949
return {}
5050

5151

52-
def get_urs_url(ctxt: dict, to: str = None) -> str:
52+
def get_urs_url(ctxt: dict, to: Optional[str] = None) -> str:
5353
base_url = os.getenv('AUTH_BASE_URL', 'https://urs.earthdata.nasa.gov') + '/oauth/authorize'
5454

5555
# From URS Application
@@ -93,7 +93,7 @@ def get_user_profile(urs_user_payload: dict, access_token) -> UserProfile:
9393
def get_profile(
9494
user_id: str,
9595
token: str,
96-
temptoken: str = None,
96+
temptoken: Optional[str] = None,
9797
aux_headers: dict = {},
9898
) -> Optional[UserProfile]:
9999
if not user_id or not token:

0 commit comments

Comments
 (0)