Skip to content

Commit 41edaef

Browse files
Gayathri Srividya RajavarapuGayathri Srividya Rajavarapu
authored andcommitted
test: cover typed flat rest auth env vars
1 parent 7403f51 commit 41edaef

2 files changed

Lines changed: 138 additions & 1 deletion

File tree

pyiceberg/catalog/rest/__init__.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,18 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import importlib
1920
import json
2021
from collections import deque
2122
from enum import Enum
23+
from types import UnionType
2224
from typing import (
2325
TYPE_CHECKING,
2426
Any,
27+
Union,
28+
get_args,
29+
get_origin,
30+
get_type_hints,
2531
)
2632
from urllib.parse import quote, unquote
2733

@@ -397,6 +403,64 @@ class ListViewsResponse(IcebergBaseModel):
397403
_PLANNING_RESPONSE_ADAPTER = TypeAdapter(PlanningResponse)
398404

399405

406+
def _get_auth_manager_class(class_or_name: str) -> type[AuthManager]:
407+
if class_or_name in AuthManagerFactory._registry:
408+
return AuthManagerFactory._registry[class_or_name]
409+
410+
try:
411+
module_path, class_name = class_or_name.rsplit(".", 1)
412+
module = importlib.import_module(module_path)
413+
return getattr(module, class_name)
414+
except Exception as err:
415+
raise ValueError(f"Could not load AuthManager class for '{class_or_name}'") from err
416+
417+
418+
def _coerce_auth_option_value(key: str, value: Any, annotation: Any) -> Any:
419+
if not isinstance(value, str):
420+
return value
421+
422+
origin = get_origin(annotation)
423+
if origin is list:
424+
try:
425+
parsed = json.loads(value)
426+
except json.JSONDecodeError as err:
427+
raise ValueError(f"Failed to parse auth configuration value '{key}' as JSON array") from err
428+
429+
if not isinstance(parsed, list) or not all(isinstance(item, str) for item in parsed):
430+
raise ValueError(f"auth configuration value '{key}' must be a JSON array of strings")
431+
return parsed
432+
433+
if origin in (Union, UnionType):
434+
non_none_args = [arg for arg in get_args(annotation) if arg is not type(None)]
435+
if len(non_none_args) == 1:
436+
return _coerce_auth_option_value(key, value, non_none_args[0])
437+
438+
if origin is not None:
439+
if origin is list:
440+
try:
441+
parsed = json.loads(value)
442+
except json.JSONDecodeError as err:
443+
raise ValueError(f"Failed to parse auth configuration value '{key}' as JSON array") from err
444+
445+
if not isinstance(parsed, list) or not all(isinstance(item, str) for item in parsed):
446+
raise ValueError(f"auth configuration value '{key}' must be a JSON array of strings")
447+
return parsed
448+
449+
if annotation is int:
450+
try:
451+
return int(value)
452+
except ValueError as err:
453+
raise ValueError(f"Failed to parse auth configuration value '{key}' as integer") from err
454+
455+
return value
456+
457+
458+
def _coerce_auth_config_values(class_or_name: str, config: dict[str, Any]) -> dict[str, Any]:
459+
manager_class = _get_auth_manager_class(class_or_name)
460+
hints = get_type_hints(manager_class.__init__)
461+
return {key: _coerce_auth_option_value(key, value, hints.get(key, Any)) for key, value in config.items()}
462+
463+
400464
class RestCatalog(Catalog):
401465
uri: str
402466
_session: Session
@@ -467,14 +531,16 @@ def _create_session(self) -> Session:
467531
raise ValueError("auth.type must be defined")
468532
auth_type_config = auth_config.get(auth_type, {})
469533
auth_impl = auth_config.get("impl")
534+
auth_manager_name = auth_impl or auth_type
470535

471536
if auth_type == CUSTOM and not auth_impl:
472537
raise ValueError("auth.impl must be specified when using custom auth.type")
473538

474539
if auth_type != CUSTOM and auth_impl:
475540
raise ValueError("auth.impl can only be specified when using custom auth.type")
476541

477-
self._auth_manager = AuthManagerFactory.create(auth_impl or auth_type, auth_type_config)
542+
typed_auth_type_config = _coerce_auth_config_values(auth_manager_name, auth_type_config)
543+
self._auth_manager = AuthManagerFactory.create(auth_manager_name, typed_auth_type_config)
478544
session.auth = AuthManagerAdapter(self._auth_manager)
479545
else:
480546
self._auth_manager = self._create_legacy_oauth2_auth_manager(session)

tests/catalog/test_rest.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2580,6 +2580,77 @@ def test_rest_catalog_with_oauth2_auth_flat_environment_variables(requests_mock:
25802580
assert catalog.uri == TEST_URI
25812581

25822582

2583+
@pytest.mark.parametrize(
2584+
"auth_type, env_overrides, expected_config",
2585+
[
2586+
pytest.param(
2587+
"oauth2",
2588+
{
2589+
"PYICEBERG_CATALOG__REST__AUTH__OAUTH2__CLIENT_ID": "some_client_id",
2590+
"PYICEBERG_CATALOG__REST__AUTH__OAUTH2__CLIENT_SECRET": "some_client_secret",
2591+
"PYICEBERG_CATALOG__REST__AUTH__OAUTH2__TOKEN_URL": f"{TEST_URI}oauth2/token",
2592+
"PYICEBERG_CATALOG__REST__AUTH__OAUTH2__REFRESH_MARGIN": "90",
2593+
"PYICEBERG_CATALOG__REST__AUTH__OAUTH2__EXPIRES_IN": "3600",
2594+
},
2595+
{
2596+
"client_id": "some_client_id",
2597+
"client_secret": "some_client_secret",
2598+
"token_url": f"{TEST_URI}oauth2/token",
2599+
"refresh_margin": 90,
2600+
"expires_in": 3600,
2601+
},
2602+
id="oauth2-numeric-fields",
2603+
),
2604+
pytest.param(
2605+
"google",
2606+
{
2607+
"PYICEBERG_CATALOG__REST__AUTH__GOOGLE__CREDENTIALS_PATH": "/fake/path.json",
2608+
"PYICEBERG_CATALOG__REST__AUTH__GOOGLE__SCOPES": '["scope-a", "scope-b"]',
2609+
},
2610+
{
2611+
"credentials_path": "/fake/path.json",
2612+
"scopes": ["scope-a", "scope-b"],
2613+
},
2614+
id="google-scopes",
2615+
),
2616+
pytest.param(
2617+
"entra",
2618+
{
2619+
"PYICEBERG_CATALOG__REST__AUTH__ENTRA__SCOPES": '["scope-a", "scope-b"]',
2620+
},
2621+
{
2622+
"scopes": ["scope-a", "scope-b"],
2623+
},
2624+
id="entra-scopes",
2625+
),
2626+
],
2627+
)
2628+
def test_rest_catalog_with_typed_auth_flat_environment_variables(
2629+
rest_mock: Mocker,
2630+
auth_type: str,
2631+
env_overrides: dict[str, str],
2632+
expected_config: dict[str, Any],
2633+
) -> None:
2634+
rest_mock.get(f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200)
2635+
2636+
fake_auth_manager = mock.Mock()
2637+
fake_auth_manager.auth_header.return_value = ""
2638+
env = {
2639+
"PYICEBERG_CATALOG__REST__URI": TEST_URI,
2640+
"PYICEBERG_CATALOG__REST__AUTH__TYPE": auth_type,
2641+
**env_overrides,
2642+
}
2643+
2644+
with (
2645+
mock.patch.dict(os.environ, env, clear=True),
2646+
mock.patch("pyiceberg.catalog.rest.AuthManagerFactory.create", return_value=fake_auth_manager) as create_auth_manager,
2647+
):
2648+
catalog = RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore
2649+
2650+
assert catalog.uri == TEST_URI
2651+
assert create_auth_manager.call_args_list == [mock.call(auth_type, expected_config), mock.call(auth_type, expected_config)]
2652+
2653+
25832654
EXAMPLE_ENV = {"PYICEBERG_CATALOG__PRODUCTION__URI": TEST_URI}
25842655

25852656

0 commit comments

Comments
 (0)