|
16 | 16 | # under the License. |
17 | 17 | from __future__ import annotations |
18 | 18 |
|
| 19 | +import importlib |
19 | 20 | import json |
20 | 21 | from collections import deque |
21 | 22 | from enum import Enum |
| 23 | +from types import UnionType |
22 | 24 | from typing import ( |
23 | 25 | TYPE_CHECKING, |
24 | 26 | Any, |
| 27 | + Union, |
| 28 | + get_args, |
| 29 | + get_origin, |
| 30 | + get_type_hints, |
25 | 31 | ) |
26 | 32 | from urllib.parse import quote, unquote |
27 | 33 |
|
@@ -397,6 +403,64 @@ class ListViewsResponse(IcebergBaseModel): |
397 | 403 | _PLANNING_RESPONSE_ADAPTER = TypeAdapter(PlanningResponse) |
398 | 404 |
|
399 | 405 |
|
| 406 | +def _get_auth_manager_class(class_or_name: str) -> type[AuthManager]: |
| 407 | + if class_or_name in AuthManagerFactory._registry: |
| 408 | + return AuthManagerFactory._registry[class_or_name] |
| 409 | + |
| 410 | + try: |
| 411 | + module_path, class_name = class_or_name.rsplit(".", 1) |
| 412 | + module = importlib.import_module(module_path) |
| 413 | + return getattr(module, class_name) |
| 414 | + except Exception as err: |
| 415 | + raise ValueError(f"Could not load AuthManager class for '{class_or_name}'") from err |
| 416 | + |
| 417 | + |
| 418 | +def _coerce_auth_option_value(key: str, value: Any, annotation: Any) -> Any: |
| 419 | + if not isinstance(value, str): |
| 420 | + return value |
| 421 | + |
| 422 | + origin = get_origin(annotation) |
| 423 | + if origin is list: |
| 424 | + try: |
| 425 | + parsed = json.loads(value) |
| 426 | + except json.JSONDecodeError as err: |
| 427 | + raise ValueError(f"Failed to parse auth configuration value '{key}' as JSON array") from err |
| 428 | + |
| 429 | + if not isinstance(parsed, list) or not all(isinstance(item, str) for item in parsed): |
| 430 | + raise ValueError(f"auth configuration value '{key}' must be a JSON array of strings") |
| 431 | + return parsed |
| 432 | + |
| 433 | + if origin in (Union, UnionType): |
| 434 | + non_none_args = [arg for arg in get_args(annotation) if arg is not type(None)] |
| 435 | + if len(non_none_args) == 1: |
| 436 | + return _coerce_auth_option_value(key, value, non_none_args[0]) |
| 437 | + |
| 438 | + if origin is not None: |
| 439 | + if origin is list: |
| 440 | + try: |
| 441 | + parsed = json.loads(value) |
| 442 | + except json.JSONDecodeError as err: |
| 443 | + raise ValueError(f"Failed to parse auth configuration value '{key}' as JSON array") from err |
| 444 | + |
| 445 | + if not isinstance(parsed, list) or not all(isinstance(item, str) for item in parsed): |
| 446 | + raise ValueError(f"auth configuration value '{key}' must be a JSON array of strings") |
| 447 | + return parsed |
| 448 | + |
| 449 | + if annotation is int: |
| 450 | + try: |
| 451 | + return int(value) |
| 452 | + except ValueError as err: |
| 453 | + raise ValueError(f"Failed to parse auth configuration value '{key}' as integer") from err |
| 454 | + |
| 455 | + return value |
| 456 | + |
| 457 | + |
| 458 | +def _coerce_auth_config_values(class_or_name: str, config: dict[str, Any]) -> dict[str, Any]: |
| 459 | + manager_class = _get_auth_manager_class(class_or_name) |
| 460 | + hints = get_type_hints(manager_class.__init__) |
| 461 | + return {key: _coerce_auth_option_value(key, value, hints.get(key, Any)) for key, value in config.items()} |
| 462 | + |
| 463 | + |
400 | 464 | class RestCatalog(Catalog): |
401 | 465 | uri: str |
402 | 466 | _session: Session |
@@ -467,14 +531,16 @@ def _create_session(self) -> Session: |
467 | 531 | raise ValueError("auth.type must be defined") |
468 | 532 | auth_type_config = auth_config.get(auth_type, {}) |
469 | 533 | auth_impl = auth_config.get("impl") |
| 534 | + auth_manager_name = auth_impl or auth_type |
470 | 535 |
|
471 | 536 | if auth_type == CUSTOM and not auth_impl: |
472 | 537 | raise ValueError("auth.impl must be specified when using custom auth.type") |
473 | 538 |
|
474 | 539 | if auth_type != CUSTOM and auth_impl: |
475 | 540 | raise ValueError("auth.impl can only be specified when using custom auth.type") |
476 | 541 |
|
477 | | - self._auth_manager = AuthManagerFactory.create(auth_impl or auth_type, auth_type_config) |
| 542 | + typed_auth_type_config = _coerce_auth_config_values(auth_manager_name, auth_type_config) |
| 543 | + self._auth_manager = AuthManagerFactory.create(auth_manager_name, typed_auth_type_config) |
478 | 544 | session.auth = AuthManagerAdapter(self._auth_manager) |
479 | 545 | else: |
480 | 546 | self._auth_manager = self._create_legacy_oauth2_auth_manager(session) |
|
0 commit comments