Skip to content
16 changes: 14 additions & 2 deletions openfeature/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def attach(

def detach(self) -> None: ...

def initialize(self, evaluation_context: EvaluationContext) -> None: ...
def initialize(
self,
evaluation_context: EvaluationContext,
domain: str | None = None,
) -> None: ...

def shutdown(self) -> None: ...

Expand Down Expand Up @@ -140,9 +144,17 @@ def detach(self) -> None:
if hasattr(self, "_on_emit"):
del self._on_emit

def initialize(self, evaluation_context: EvaluationContext) -> None:
def initialize(
self,
evaluation_context: EvaluationContext,
domain: str | None = None,
) -> None:
pass

@property
def domain_scoped(self) -> bool:
return False

def shutdown(self) -> None:
pass

Expand Down
82 changes: 74 additions & 8 deletions openfeature/provider/_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
import threading
from collections.abc import Callable

from openfeature._event_support import run_handlers_for_provider
from openfeature.evaluation_context import EvaluationContext, get_evaluation_context
Expand All @@ -11,6 +13,36 @@
from openfeature.provider.no_op_provider import NoOpProvider


def _is_domain_scoped(provider: FeatureProvider) -> bool:
return getattr(provider, "domain_scoped", False) is True


def _callable_accepts_domain(callable_obj: Callable[..., object]) -> bool:
try:
signature = inspect.signature(callable_obj)
except (TypeError, ValueError):
return False
return "domain" in signature.parameters or any(
param.kind == inspect.Parameter.VAR_KEYWORD
for param in signature.parameters.values()
)


def _initialize_accepts_domain(provider: FeatureProvider) -> bool:
return _callable_accepts_domain(provider.initialize)


def _call_initialize(
provider: FeatureProvider,
evaluation_context: EvaluationContext,
domain: str | None,
) -> None:
if _initialize_accepts_domain(provider):
provider.initialize(evaluation_context, domain=domain)
else:
provider.initialize(evaluation_context)


class ProviderRegistry:
_default_provider: FeatureProvider
_providers: dict[str, FeatureProvider]
Expand All @@ -36,17 +68,20 @@ def set_provider(
old_provider: FeatureProvider | None = None
needs_init = False
with self._lock:
self._reject_domain_scoped_rebind(provider, domain)
old_provider = self._providers.get(domain)
self._providers[domain] = provider
already_bound = provider is self._default_provider or any(
p is provider for d, p in self._providers.items() if d != domain
)
self._providers[domain] = provider
if not already_bound:
needs_init = True
self._provider_status[provider] = ProviderStatus.NOT_READY

if needs_init:
self._initialize_provider(provider, wait_for_init=wait_for_init)
self._initialize_provider(
provider, domain=domain, wait_for_init=wait_for_init
)

# old-provider shutdown is always async so a hanging shutdown() cannot
# block set_provider.
Expand All @@ -67,6 +102,7 @@ def set_default_provider(
old_provider: FeatureProvider | None = None
needs_init = False
with self._lock:
self._reject_domain_scoped_rebind(provider, None)
old_provider = self._default_provider
self._default_provider = provider
if (
Expand All @@ -77,7 +113,9 @@ def set_default_provider(
self._provider_status[provider] = ProviderStatus.NOT_READY

if needs_init:
self._initialize_provider(provider, wait_for_init=wait_for_init)
self._initialize_provider(
provider, domain=None, wait_for_init=wait_for_init
)

if old_provider is not None and old_provider is not provider:
self._shutdown_if_unused(old_provider)
Expand All @@ -104,8 +142,32 @@ def shutdown(self) -> None:
def _get_evaluation_context(self) -> EvaluationContext:
return get_evaluation_context()

def _provider_bindings(self, provider: FeatureProvider) -> list[str | None]:
bindings: list[str | None] = []
if provider is self._default_provider:
bindings.append(None)
bindings.extend(d for d, p in self._providers.items() if p is provider)
return bindings

def _reject_domain_scoped_rebind(
self, provider: FeatureProvider, domain: str | None
) -> None:
if not _is_domain_scoped(provider):
return
bindings = self._provider_bindings(provider)
if bindings and domain not in bindings:
raise GeneralError(
error_message=(
"Cannot bind domain-scoped provider to more than one domain"
)
)

def _initialize_provider(
self, provider: FeatureProvider, wait_for_init: bool
self,
provider: FeatureProvider,
*,
domain: str | None,
wait_for_init: bool,
) -> None:
provider.attach(self.dispatch_event)
if not hasattr(provider, "initialize"):
Expand All @@ -115,22 +177,26 @@ def _initialize_provider(
)
return
if wait_for_init:
self._run_initialize(provider, raise_on_error=True)
self._run_initialize(provider, domain=domain, raise_on_error=True)
return

thread = threading.Thread(
target=self._run_initialize,
args=(provider,),
kwargs={"raise_on_error": False},
kwargs={"domain": domain, "raise_on_error": False},
daemon=True,
)
thread.start()

def _run_initialize(
self, provider: FeatureProvider, raise_on_error: bool = False
self,
provider: FeatureProvider,
*,
domain: str | None,
raise_on_error: bool = False,
) -> None:
try:
provider.initialize(self._get_evaluation_context())
_call_initialize(provider, self._get_evaluation_context(), domain)
# stale init: provider was replaced/shut down during initialize(); drop event.
# Check active registration, not _provider_status, since replaced providers
# remain in _provider_status until async shutdown pops them.
Expand Down
15 changes: 15 additions & 0 deletions tests/legacy_init_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from openfeature.evaluation_context import EvaluationContext
from openfeature.provider.no_op_provider import NoOpProvider


class LegacyInitProvider(NoOpProvider):
"""Provider mirroring contrib overrides: initialize(context) without domain."""

def __init__(self) -> None:
super().__init__()
self.initialize_calls = 0
self.last_evaluation_context: EvaluationContext | None = None

def initialize(self, evaluation_context: EvaluationContext) -> None:
self.initialize_calls += 1
self.last_evaluation_context = evaluation_context
Loading
Loading