|
34 | 34 | from pydantic import ConfigDict, Field, TypeAdapter, field_validator |
35 | 35 | from requests import HTTPError, Session |
36 | 36 | from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt |
37 | | -from typing_extensions import override |
| 37 | +from typing_extensions import NotRequired, TypedDict, override |
38 | 38 |
|
39 | 39 | from pyiceberg import __version__ |
40 | 40 | from pyiceberg.catalog import BOTOCORE_SESSION, TOKEN, URI, WAREHOUSE_LOCATION, Catalog, PropertiesUpdateSummary |
@@ -403,6 +403,17 @@ class ListViewsResponse(IcebergBaseModel): |
403 | 403 | _PLANNING_RESPONSE_ADAPTER = TypeAdapter(PlanningResponse) |
404 | 404 |
|
405 | 405 |
|
| 406 | +class ParsedAuthConfig(TypedDict): |
| 407 | + auth_type: str |
| 408 | + auth_manager_name: str |
| 409 | + auth_type_config: dict[str, Any] |
| 410 | + |
| 411 | + |
| 412 | +class AuthConfigEnvelope(TypedDict): |
| 413 | + type: str |
| 414 | + impl: NotRequired[str] |
| 415 | + |
| 416 | + |
406 | 417 | def _get_auth_manager_class(class_or_name: str) -> type[AuthManager]: |
407 | 418 | if class_or_name in AuthManagerFactory._registry: |
408 | 419 | return AuthManagerFactory._registry[class_or_name] |
@@ -461,6 +472,65 @@ def _coerce_auth_config_values(class_or_name: str, config: dict[str, Any]) -> di |
461 | 472 | return {key: _coerce_auth_option_value(key, value, hints.get(key, Any)) for key, value in config.items()} |
462 | 473 |
|
463 | 474 |
|
| 475 | +def _load_auth_config_from_properties(properties: Properties) -> AuthConfigEnvelope | dict[str, Any] | None: |
| 476 | + raw_auth = properties.get(AUTH) |
| 477 | + if isinstance(raw_auth, str): |
| 478 | + try: |
| 479 | + decoded_auth = json.loads(raw_auth) |
| 480 | + except json.JSONDecodeError as e: |
| 481 | + raise ValueError("Failed to parse auth configuration as JSON") from e |
| 482 | + if decoded_auth is not None and not isinstance(decoded_auth, dict): |
| 483 | + raise ValueError("auth configuration must be a dictionary") |
| 484 | + return decoded_auth |
| 485 | + |
| 486 | + if raw_auth is not None: |
| 487 | + if not isinstance(raw_auth, dict): |
| 488 | + raise ValueError("auth configuration must be a dictionary") |
| 489 | + return raw_auth |
| 490 | + |
| 491 | + if auth_type := properties.get(f"{AUTH}.type"): |
| 492 | + type_prefix = f"{AUTH}.{auth_type}." |
| 493 | + return { |
| 494 | + "type": auth_type, |
| 495 | + "impl": properties.get(f"{AUTH}.impl"), |
| 496 | + auth_type: { |
| 497 | + key[len(type_prefix) :].replace("-", "_"): value |
| 498 | + for key, value in properties.items() |
| 499 | + if key.startswith(type_prefix) |
| 500 | + }, |
| 501 | + } |
| 502 | + |
| 503 | + return None |
| 504 | + |
| 505 | + |
| 506 | +def _resolve_auth_config(auth_config: AuthConfigEnvelope | dict[str, Any]) -> ParsedAuthConfig: |
| 507 | + auth_type = auth_config.get("type") |
| 508 | + if not isinstance(auth_type, str): |
| 509 | + raise ValueError("auth.type must be defined") |
| 510 | + |
| 511 | + auth_type_config = auth_config.get(auth_type, {}) |
| 512 | + if not isinstance(auth_type_config, dict): |
| 513 | + raise ValueError(f"auth.{auth_type} must be a dictionary") |
| 514 | + |
| 515 | + auth_impl = auth_config.get("impl") |
| 516 | + if auth_impl is not None and not isinstance(auth_impl, str): |
| 517 | + raise ValueError("auth.impl must be a string") |
| 518 | + |
| 519 | + auth_manager_name = auth_impl or auth_type |
| 520 | + |
| 521 | + if auth_type == CUSTOM and not auth_impl: |
| 522 | + raise ValueError("auth.impl must be specified when using custom auth.type") |
| 523 | + |
| 524 | + if auth_type != CUSTOM and auth_impl: |
| 525 | + raise ValueError("auth.impl can only be specified when using custom auth.type") |
| 526 | + |
| 527 | + return { |
| 528 | + "auth_type": auth_type, |
| 529 | + "auth_manager_name": auth_manager_name, |
| 530 | + "auth_type_config": auth_type_config, |
| 531 | + } |
| 532 | + |
| 533 | + |
464 | 534 | class RestCatalog(Catalog): |
465 | 535 | uri: str |
466 | 536 | _session: Session |
@@ -500,47 +570,13 @@ def _create_session(self) -> Session: |
500 | 570 | elif ssl_client_cert := ssl_client.get(CERT): |
501 | 571 | session.cert = ssl_client_cert |
502 | 572 |
|
503 | | - raw_auth = self.properties.get(AUTH) |
504 | | - if isinstance(raw_auth, str): |
505 | | - try: |
506 | | - auth_config: dict[str, Any] | None = json.loads(raw_auth) |
507 | | - except json.JSONDecodeError as e: |
508 | | - raise ValueError("Failed to parse auth configuration as JSON") from e |
509 | | - elif raw_auth is not None: |
510 | | - auth_config = raw_auth |
511 | | - elif auth_type := self.properties.get(f"{AUTH}.type"): |
512 | | - type_prefix = f"{AUTH}.{auth_type}." |
513 | | - auth_config = { |
514 | | - "type": auth_type, |
515 | | - "impl": self.properties.get(f"{AUTH}.impl"), |
516 | | - auth_type: { |
517 | | - key[len(type_prefix) :].replace("-", "_"): value |
518 | | - for key, value in self.properties.items() |
519 | | - if key.startswith(type_prefix) |
520 | | - }, |
521 | | - } |
522 | | - else: |
523 | | - auth_config = None |
524 | | - |
525 | | - if auth_config is not None and not isinstance(auth_config, dict): |
526 | | - raise ValueError("auth configuration must be a dictionary") |
527 | | - |
| 573 | + auth_config = _load_auth_config_from_properties(self.properties) |
528 | 574 | if auth_config: |
529 | | - auth_type = auth_config.get("type") |
530 | | - if auth_type is None: |
531 | | - raise ValueError("auth.type must be defined") |
532 | | - auth_type_config = auth_config.get(auth_type, {}) |
533 | | - auth_impl = auth_config.get("impl") |
534 | | - auth_manager_name = auth_impl or auth_type |
535 | | - |
536 | | - if auth_type == CUSTOM and not auth_impl: |
537 | | - raise ValueError("auth.impl must be specified when using custom auth.type") |
538 | | - |
539 | | - if auth_type != CUSTOM and auth_impl: |
540 | | - raise ValueError("auth.impl can only be specified when using custom auth.type") |
541 | | - |
542 | | - typed_auth_type_config = _coerce_auth_config_values(auth_manager_name, auth_type_config) |
543 | | - self._auth_manager = AuthManagerFactory.create(auth_manager_name, typed_auth_type_config) |
| 575 | + resolved_auth = _resolve_auth_config(auth_config) |
| 576 | + typed_auth_type_config = _coerce_auth_config_values( |
| 577 | + resolved_auth["auth_manager_name"], resolved_auth["auth_type_config"] |
| 578 | + ) |
| 579 | + self._auth_manager = AuthManagerFactory.create(resolved_auth["auth_manager_name"], typed_auth_type_config) |
544 | 580 | session.auth = AuthManagerAdapter(self._auth_manager) |
545 | 581 | else: |
546 | 582 | self._auth_manager = self._create_legacy_oauth2_auth_manager(session) |
|
0 commit comments