|
7 | 7 |
|
8 | 8 | from guardpost.abc import BaseStrategy |
9 | 9 | from guardpost.authentication import Identity |
| 10 | +from guardpost.common import RolesRequirement |
10 | 11 |
|
11 | 12 |
|
12 | 13 | class AuthorizationError(Exception): |
@@ -208,46 +209,78 @@ def with_default_policy(self, policy: Policy) -> "AuthorizationStrategy": |
208 | 209 | return self |
209 | 210 |
|
210 | 211 | async def authorize( |
211 | | - self, policy_name: Optional[str], identity: Identity, scope: Any = None |
| 212 | + self, |
| 213 | + policy_name: Optional[str], |
| 214 | + identity: Identity, |
| 215 | + scope: Any = None, |
| 216 | + roles: Optional[Sequence[str]] = None, |
212 | 217 | ): |
213 | 218 | if policy_name: |
214 | 219 | policy = self.get_policy(policy_name) |
215 | 220 |
|
216 | 221 | if not policy: |
217 | 222 | raise PolicyNotFoundError(policy_name) |
218 | 223 |
|
219 | | - await self._handle_with_policy(policy, identity, scope) |
| 224 | + await self._handle_with_policy(policy, identity, scope, roles) |
220 | 225 | else: |
221 | 226 | if self.default_policy: |
222 | | - await self._handle_with_policy(self.default_policy, identity, scope) |
| 227 | + await self._handle_with_policy( |
| 228 | + self.default_policy, identity, scope, roles |
| 229 | + ) |
| 230 | + return |
| 231 | + |
| 232 | + if roles: |
| 233 | + # This code is only executed if the user specified roles without |
| 234 | + # specifying an authorization policy. |
| 235 | + await self._handle_with_roles(identity, roles) |
223 | 236 | return |
224 | 237 |
|
225 | 238 | if not identity: |
226 | 239 | raise UnauthorizedError("Missing identity", []) |
227 | 240 | if not identity.is_authenticated(): |
228 | 241 | raise UnauthorizedError("The resource requires authentication", []) |
229 | 242 |
|
230 | | - def _get_requirements(self, policy: Policy, scope: Any) -> Iterable[Requirement]: |
| 243 | + def _get_requirements( |
| 244 | + self, policy: Policy, scope: Any, roles: Optional[Sequence[str]] = None |
| 245 | + ) -> Iterable[Requirement]: |
| 246 | + if roles: |
| 247 | + yield RolesRequirement(roles=roles) |
231 | 248 | yield from self._get_instances(policy.requirements, scope) |
232 | 249 |
|
233 | | - async def _handle_with_policy(self, policy: Policy, identity: Identity, scope: Any): |
| 250 | + async def _handle_with_policy( |
| 251 | + self, |
| 252 | + policy: Policy, |
| 253 | + identity: Identity, |
| 254 | + scope: Any, |
| 255 | + roles: Optional[Sequence[str]] = None, |
| 256 | + ): |
234 | 257 | with AuthorizationContext( |
235 | | - identity, list(self._get_requirements(policy, scope)) |
| 258 | + identity, list(self._get_requirements(policy, scope, roles)) |
236 | 259 | ) as context: |
237 | | - for requirement in context.requirements: |
238 | | - if _is_async_handler(type(requirement)): # type: ignore |
239 | | - await requirement.handle(context) |
240 | | - else: |
241 | | - requirement.handle(context) # type: ignore |
242 | | - |
243 | | - if not context.has_succeeded: |
244 | | - if identity and identity.is_authenticated(): |
245 | | - raise ForbiddenError( |
246 | | - context.forced_failure, context.pending_requirements |
247 | | - ) |
248 | | - raise UnauthorizedError( |
| 260 | + await self._handle_context(identity, context) |
| 261 | + |
| 262 | + async def _handle_with_roles( |
| 263 | + self, identity: Identity, roles: Optional[Sequence[str]] = None |
| 264 | + ): |
| 265 | + # This method is to be used only when the user specified roles without a policy |
| 266 | + with AuthorizationContext(identity, [RolesRequirement(roles=roles)]) as context: |
| 267 | + await self._handle_context(identity, context) |
| 268 | + |
| 269 | + async def _handle_context(self, identity: Identity, context: AuthorizationContext): |
| 270 | + for requirement in context.requirements: |
| 271 | + if _is_async_handler(type(requirement)): # type: ignore |
| 272 | + await requirement.handle(context) |
| 273 | + else: |
| 274 | + requirement.handle(context) # type: ignore |
| 275 | + |
| 276 | + if not context.has_succeeded: |
| 277 | + if identity and identity.is_authenticated(): |
| 278 | + raise ForbiddenError( |
249 | 279 | context.forced_failure, context.pending_requirements |
250 | 280 | ) |
| 281 | + raise UnauthorizedError( |
| 282 | + context.forced_failure, context.pending_requirements |
| 283 | + ) |
251 | 284 |
|
252 | 285 | async def _handle_with_identity_getter( |
253 | 286 | self, policy_name: Optional[str], *args, **kwargs |
|
0 commit comments