diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 1b2b5206..ad2c6929 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -35,7 +35,11 @@ def attach( def detach(self) -> None: ... - def initialize(self, evaluation_context: EvaluationContext) -> None: ... + def initialize( + self, + evaluation_context: EvaluationContext, + domain: str | None = None, + ) -> None: ... def shutdown(self) -> None: ... @@ -140,9 +144,17 @@ def detach(self) -> None: if hasattr(self, "_on_emit"): del self._on_emit - def initialize(self, evaluation_context: EvaluationContext) -> None: + def initialize( + self, + evaluation_context: EvaluationContext, + domain: str | None = None, + ) -> None: pass + @property + def domain_scoped(self) -> bool: + return False + def shutdown(self) -> None: pass diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index e46caadd..ad55314a 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -1,4 +1,6 @@ +import inspect import threading +from collections.abc import Callable from openfeature._event_support import run_handlers_for_provider from openfeature.evaluation_context import EvaluationContext, get_evaluation_context @@ -11,6 +13,36 @@ from openfeature.provider.no_op_provider import NoOpProvider +def _is_domain_scoped(provider: FeatureProvider) -> bool: + return getattr(provider, "domain_scoped", False) is True + + +def _callable_accepts_domain(callable_obj: Callable[..., object]) -> bool: + try: + signature = inspect.signature(callable_obj) + except (TypeError, ValueError): + return False + return "domain" in signature.parameters or any( + param.kind == inspect.Parameter.VAR_KEYWORD + for param in signature.parameters.values() + ) + + +def _initialize_accepts_domain(provider: FeatureProvider) -> bool: + return _callable_accepts_domain(provider.initialize) + + +def _call_initialize( + provider: FeatureProvider, + evaluation_context: EvaluationContext, + domain: str | None, +) -> None: + if _initialize_accepts_domain(provider): + provider.initialize(evaluation_context, domain=domain) + else: + provider.initialize(evaluation_context) + + class ProviderRegistry: _default_provider: FeatureProvider _providers: dict[str, FeatureProvider] @@ -36,17 +68,20 @@ def set_provider( old_provider: FeatureProvider | None = None needs_init = False with self._lock: + self._reject_domain_scoped_rebind(provider, domain) old_provider = self._providers.get(domain) - self._providers[domain] = provider already_bound = provider is self._default_provider or any( p is provider for d, p in self._providers.items() if d != domain ) + self._providers[domain] = provider if not already_bound: needs_init = True self._provider_status[provider] = ProviderStatus.NOT_READY if needs_init: - self._initialize_provider(provider, wait_for_init=wait_for_init) + self._initialize_provider( + provider, domain=domain, wait_for_init=wait_for_init + ) # old-provider shutdown is always async so a hanging shutdown() cannot # block set_provider. @@ -67,6 +102,7 @@ def set_default_provider( old_provider: FeatureProvider | None = None needs_init = False with self._lock: + self._reject_domain_scoped_rebind(provider, None) old_provider = self._default_provider self._default_provider = provider if ( @@ -77,7 +113,9 @@ def set_default_provider( self._provider_status[provider] = ProviderStatus.NOT_READY if needs_init: - self._initialize_provider(provider, wait_for_init=wait_for_init) + self._initialize_provider( + provider, domain=None, wait_for_init=wait_for_init + ) if old_provider is not None and old_provider is not provider: self._shutdown_if_unused(old_provider) @@ -104,8 +142,32 @@ def shutdown(self) -> None: def _get_evaluation_context(self) -> EvaluationContext: return get_evaluation_context() + def _provider_bindings(self, provider: FeatureProvider) -> list[str | None]: + bindings: list[str | None] = [] + if provider is self._default_provider: + bindings.append(None) + bindings.extend(d for d, p in self._providers.items() if p is provider) + return bindings + + def _reject_domain_scoped_rebind( + self, provider: FeatureProvider, domain: str | None + ) -> None: + if not _is_domain_scoped(provider): + return + bindings = self._provider_bindings(provider) + if bindings and domain not in bindings: + raise GeneralError( + error_message=( + "Cannot bind domain-scoped provider to more than one domain" + ) + ) + def _initialize_provider( - self, provider: FeatureProvider, wait_for_init: bool + self, + provider: FeatureProvider, + *, + domain: str | None, + wait_for_init: bool, ) -> None: provider.attach(self.dispatch_event) if not hasattr(provider, "initialize"): @@ -115,22 +177,26 @@ def _initialize_provider( ) return if wait_for_init: - self._run_initialize(provider, raise_on_error=True) + self._run_initialize(provider, domain=domain, raise_on_error=True) return thread = threading.Thread( target=self._run_initialize, args=(provider,), - kwargs={"raise_on_error": False}, + kwargs={"domain": domain, "raise_on_error": False}, daemon=True, ) thread.start() def _run_initialize( - self, provider: FeatureProvider, raise_on_error: bool = False + self, + provider: FeatureProvider, + *, + domain: str | None, + raise_on_error: bool = False, ) -> None: try: - provider.initialize(self._get_evaluation_context()) + _call_initialize(provider, self._get_evaluation_context(), domain) # stale init: provider was replaced/shut down during initialize(); drop event. # Check active registration, not _provider_status, since replaced providers # remain in _provider_status until async shutdown pops them. diff --git a/tests/legacy_init_provider.py b/tests/legacy_init_provider.py new file mode 100644 index 00000000..c65337b4 --- /dev/null +++ b/tests/legacy_init_provider.py @@ -0,0 +1,15 @@ +from openfeature.evaluation_context import EvaluationContext +from openfeature.provider.no_op_provider import NoOpProvider + + +class LegacyInitProvider(NoOpProvider): + """Provider mirroring contrib overrides: initialize(context) without domain.""" + + def __init__(self) -> None: + super().__init__() + self.initialize_calls = 0 + self.last_evaluation_context: EvaluationContext | None = None + + def initialize(self, evaluation_context: EvaluationContext) -> None: + self.initialize_calls += 1 + self.last_evaluation_context = evaluation_context diff --git a/tests/provider/test_registry.py b/tests/provider/test_registry.py index f7c55712..228b558f 100644 --- a/tests/provider/test_registry.py +++ b/tests/provider/test_registry.py @@ -4,10 +4,17 @@ import pytest +from openfeature.evaluation_context import EvaluationContext, set_evaluation_context from openfeature.exception import GeneralError, ProviderFatalError from openfeature.provider import ProviderStatus -from openfeature.provider._registry import ProviderRegistry +from openfeature.provider._registry import ( + ProviderRegistry, + _callable_accepts_domain, + _is_domain_scoped, +) +from openfeature.provider.metadata import Metadata from openfeature.provider.no_op_provider import NoOpProvider +from tests.legacy_init_provider import LegacyInitProvider def test_registry_serves_noop_as_default(): @@ -240,7 +247,7 @@ def test_set_provider_returns_before_initialization_completes(): init_may_proceed = threading.Event() provider = Mock() - def slow_initialize(ctx): + def slow_initialize(ctx, domain=None): init_started.set() init_may_proceed.wait() @@ -261,7 +268,7 @@ def test_set_provider_and_wait_blocks_until_ready(): initialized = threading.Event() provider = Mock() - def tracking_initialize(ctx): + def tracking_initialize(ctx, domain=None): initialized.set() provider.initialize.side_effect = tracking_initialize @@ -290,7 +297,7 @@ def test_concurrent_set_provider_for_same_provider_initializes_once(): init_count = 0 start_gate = threading.Event() - def slow_initialize(ctx): + def slow_initialize(ctx, domain=None): nonlocal init_count # widen the window in which two threads can both observe "not bound" start_gate.wait(timeout=2) @@ -323,7 +330,7 @@ def test_provider_replaced_during_async_init_does_not_set_ready_status(): slow_provider = Mock() - def slow_initialize(ctx): + def slow_initialize(ctx, domain=None): init_started.set() init_may_proceed.wait(timeout=2) @@ -424,3 +431,236 @@ def slow_shutdown(): "stale shutdown of A clobbered the fresh registration's status" ) provider_a.detach.assert_not_called() + + +def test_initialize_receives_bound_domain(): + registry = ProviderRegistry() + provider = Mock() + + registry.set_provider("my-domain", provider, wait_for_init=True) + + provider.initialize.assert_called_once() + _, kwargs = provider.initialize.call_args + assert kwargs.get("domain") == "my-domain" + + +def test_initialize_receives_none_domain_for_default_provider(): + registry = ProviderRegistry() + provider = Mock() + + registry.set_default_provider(provider, wait_for_init=True) + + provider.initialize.assert_called_once() + _, kwargs = provider.initialize.call_args + assert kwargs.get("domain") is None + + +def test_domain_scoped_provider_rejects_second_domain(): + registry = ProviderRegistry() + provider = Mock() + provider.domain_scoped = True + + registry.set_provider("domain1", provider, wait_for_init=True) + + with pytest.raises(GeneralError) as exc_info: + registry.set_provider("domain2", provider) + + assert ( + exc_info.value.error_message + == "Cannot bind domain-scoped provider to more than one domain" + ) + assert registry.get_provider("domain1") is provider + provider.initialize.assert_called_once() + + +def test_domain_scoped_provider_rejects_default_after_domain_binding(): + registry = ProviderRegistry() + provider = Mock() + provider.domain_scoped = True + + registry.set_provider("domain", provider, wait_for_init=True) + + with pytest.raises(GeneralError): + registry.set_default_provider(provider) + + assert registry.get_default_provider() is not provider + + +def test_domain_scoped_provider_rejects_domain_after_default_binding(): + registry = ProviderRegistry() + provider = Mock() + provider.domain_scoped = True + + registry.set_default_provider(provider, wait_for_init=True) + + with pytest.raises(GeneralError): + registry.set_provider("domain", provider) + + assert registry.get_provider("domain") is registry.get_default_provider() + + +def test_initialize_skips_domain_for_legacy_signature(): + registry = ProviderRegistry() + provider = LegacyInitProvider() + + registry.set_provider("domain", provider, wait_for_init=True) + + assert provider.initialize_calls == 1 + + +def test_legacy_abstract_provider_initialize_without_domain(): + registry = ProviderRegistry() + evaluation_context = EvaluationContext("targeting_key", {"attr": "val"}) + set_evaluation_context(evaluation_context) + provider = LegacyInitProvider() + + registry.set_provider("domain", provider, wait_for_init=True) + + assert provider.initialize_calls == 1 + assert provider.last_evaluation_context == evaluation_context + assert registry.get_provider_status(provider) == ProviderStatus.READY + + +def test_initialize_does_not_retry_when_domain_aware_provider_raises_type_error(): + registry = ProviderRegistry() + + class BrokenProvider: + def __init__(self): + self.call_count = 0 + + def attach(self, on_emit): + pass + + def detach(self): + pass + + def get_metadata(self): + return Metadata(name="broken") + + def initialize(self, evaluation_context, domain=None): + self.call_count += 1 + raise TypeError("configuration error") + + provider = BrokenProvider() + with pytest.raises(TypeError, match="configuration error"): + registry.set_provider("domain", provider, wait_for_init=True) # type: ignore[arg-type] + + assert provider.call_count == 1 + provider.detach() + + +def test_is_domain_scoped_uses_class_level_bool_attribute(): + class ClassScopedProvider: + domain_scoped = True + + assert _is_domain_scoped(ClassScopedProvider()) is True # type: ignore[arg-type] + + +def test_is_domain_scoped_uses_property(): + class PropertyScopedProvider: + @property + def domain_scoped(self) -> bool: + return True + + assert _is_domain_scoped(PropertyScopedProvider()) is True # type: ignore[arg-type] + + +def test_is_domain_scoped_rejects_truthy_non_bool_values(): + class StrScopedProvider: + def __init__(self) -> None: + self.domain_scoped = "us-east" + + assert _is_domain_scoped(StrScopedProvider()) is False # type: ignore[arg-type] + + +def test_domain_scoped_property_provider_rejects_second_domain(): + registry = ProviderRegistry() + + class PropertyScopedProvider(LegacyInitProvider): + @property + def domain_scoped(self) -> bool: + return True + + provider = PropertyScopedProvider() + registry.set_provider("domain1", provider, wait_for_init=True) + + with pytest.raises(GeneralError) as exc_info: + registry.set_provider("domain2", provider) + + assert ( + exc_info.value.error_message + == "Cannot bind domain-scoped provider to more than one domain" + ) + + +def test_reregistering_same_provider_on_same_domain_reinitializes(): + registry = ProviderRegistry() + provider = Mock() + init_count = 0 + + def counting_initialize(evaluation_context, domain=None): + nonlocal init_count + init_count += 1 + + provider.initialize.side_effect = counting_initialize + + registry.set_provider("domain", provider, wait_for_init=True) + registry.set_provider("domain", provider, wait_for_init=True) + + assert init_count == 2 + + +def test_reregistering_same_provider_after_failed_init_retries(): + registry = ProviderRegistry() + provider = Mock() + attempts = 0 + + def flaky_initialize(evaluation_context, domain=None): + nonlocal attempts + attempts += 1 + if attempts == 1: + raise ProviderFatalError() + + provider.initialize.side_effect = flaky_initialize + + with pytest.raises(ProviderFatalError): + registry.set_provider("domain", provider, wait_for_init=True) + + assert registry.get_provider_status(provider) == ProviderStatus.FATAL + + registry.set_provider("domain", provider, wait_for_init=True) + + assert attempts == 2 + assert registry.get_provider_status(provider) == ProviderStatus.READY + + +def test_callable_accepts_domain_returns_false_for_uninspectable_callable(): + assert _callable_accepts_domain(object()) is False # type: ignore[arg-type] + + +def test_callable_accepts_domain_returns_true_for_kwargs_signature(): + def initialize(evaluation_context, **kwargs): + kwargs["domain"] = "ignored" + + assert _callable_accepts_domain(initialize) is True + initialize(EvaluationContext()) + + +def test_provider_without_initialize_is_ready_immediately(): + registry = ProviderRegistry() + + class NoInitProvider: + def attach(self, on_emit): + pass + + def detach(self): + pass + + def get_metadata(self): + return Metadata(name="no-init") + + provider = NoInitProvider() + registry.set_provider("domain", provider, wait_for_init=True) # type: ignore[arg-type] + + assert registry.get_provider_status(provider) == ProviderStatus.READY + provider.detach() diff --git a/tests/test_api.py b/tests/test_api.py index cdb077fe..483989df 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -31,6 +31,7 @@ get_transaction_context, set_transaction_context_propagator, ) +from tests.legacy_init_provider import LegacyInitProvider def wait_for_mock_call(mock: MagicMock, timeout: float = 1.0) -> None: @@ -84,7 +85,7 @@ def test_should_invoke_provider_initialize_function_on_newly_registered_provider set_provider_and_wait(provider) # Then - provider.initialize.assert_called_with(evaluation_context) + provider.initialize.assert_called_with(evaluation_context, domain=None) def test_should_invoke_provider_shutdown_function_once_provider_is_no_longer_in_use(): @@ -179,6 +180,46 @@ def test_should_provide_a_function_to_bind_provider_through_domain(): assert test_client.domain == "test" +def test_should_pass_domain_to_provider_initialize(): + evaluation_context = EvaluationContext("targeting_key", {"attr1": "val1"}) + provider = MagicMock(spec=FeatureProvider) + + set_evaluation_context(evaluation_context) + set_provider_and_wait(provider, domain="test") + + provider.initialize.assert_called_with(evaluation_context, domain="test") + + +def test_legacy_initialize_provider_via_api(): + evaluation_context = EvaluationContext("targeting_key", {"attr1": "val1"}) + provider = LegacyInitProvider() + + set_evaluation_context(evaluation_context) + set_provider_and_wait(provider, domain="test") + + assert provider.initialize_calls == 1 + assert provider.last_evaluation_context == evaluation_context + assert provider_registry.get_provider_status(provider) == ProviderStatus.READY + + client = get_client("test") + assert client.get_boolean_value("flag", True) is True + + +def test_should_reject_domain_scoped_provider_bound_to_second_domain(): + provider = MagicMock(spec=FeatureProvider) + provider.domain_scoped = True + set_provider_and_wait(provider, "foo") + + with pytest.raises(GeneralError) as exc_info: + set_provider(provider, "bar") + + assert ( + exc_info.value.error_message + == "Cannot bind domain-scoped provider to more than one domain" + ) + provider.initialize.assert_called_once() + + def test_should_not_initialize_provider_already_bound_to_another_domain(): # Given provider = MagicMock(spec=FeatureProvider) @@ -426,7 +467,7 @@ def test_set_provider_returns_before_initialization_completes(): provider = MagicMock(spec=FeatureProvider) - def slow_initialize(ctx): + def slow_initialize(ctx, domain=None): init_started.set() init_may_proceed.wait() @@ -446,7 +487,7 @@ def test_provider_status_is_not_ready_during_async_initialization(): init_may_proceed = threading.Event() provider = MagicMock(spec=FeatureProvider) - def slow_initialize(ctx): + def slow_initialize(ctx, domain=None): init_may_proceed.wait() provider.initialize.side_effect = slow_initialize @@ -467,7 +508,7 @@ def test_set_provider_and_wait_blocks_until_initialization_completes(): initialized = threading.Event() provider = MagicMock(spec=FeatureProvider) - def slow_initialize(ctx): + def slow_initialize(ctx, domain=None): initialized.set() provider.initialize.side_effect = slow_initialize @@ -495,7 +536,7 @@ def test_set_provider_swallows_error_and_emits_provider_error_event(): provider = MagicMock(spec=FeatureProvider) error_fired = threading.Event() - def failing_initialize(ctx): + def failing_initialize(ctx, domain=None): raise ProviderFatalError() provider.initialize.side_effect = failing_initialize