-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathcontainer.py
More file actions
134 lines (110 loc) · 5.26 KB
/
container.py
File metadata and controls
134 lines (110 loc) · 5.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import threading
import typing
import typing_extensions
from modern_di import errors, types
from modern_di.group import Group
from modern_di.providers.abstract import AbstractProvider
from modern_di.providers.container_provider import container_provider
from modern_di.registries.cache_registry import CacheRegistry
from modern_di.registries.context_registry import ContextRegistry
from modern_di.registries.overrides_registry import OverridesRegistry
from modern_di.registries.providers_registry import ProvidersRegistry
from modern_di.scope import Scope
class Container:
__slots__ = (
"cache_registry",
"context_registry",
"lock",
"overrides_registry",
"parent_container",
"providers_registry",
"scope",
)
def __init__(
self,
scope: Scope = Scope.APP,
parent_container: typing.Optional["typing_extensions.Self"] = None,
context: dict[type[typing.Any], typing.Any] | None = None,
groups: list[type[Group]] | None = None,
use_lock: bool = True,
) -> None:
self.lock = threading.Lock() if use_lock else None
self.scope = scope
self.parent_container = parent_container
self.cache_registry = CacheRegistry()
self.context_registry = ContextRegistry(context=context or {})
self.providers_registry: ProvidersRegistry
self.overrides_registry: OverridesRegistry
if parent_container:
self.providers_registry = parent_container.providers_registry
self.overrides_registry = parent_container.overrides_registry
else:
self.providers_registry = ProvidersRegistry()
self.providers_registry.register(type(self), container_provider)
self.overrides_registry = OverridesRegistry()
if groups:
for one_group in groups:
self.providers_registry.add_providers(*one_group.get_providers())
def build_child_container(
self, context: dict[type[typing.Any], typing.Any] | None = None, scope: Scope | None = None
) -> "typing_extensions.Self":
if scope and scope <= self.scope:
raise RuntimeError(
errors.CONTAINER_SCOPE_IS_LOWER_ERROR.format(
parent_scope=self.scope.name,
child_scope=scope.name,
allowed_scopes=[x.name for x in Scope if x > self.scope],
)
)
if not scope:
try:
scope = self.scope.__class__(self.scope.value + 1)
except ValueError as exc:
raise RuntimeError(
errors.CONTAINER_MAX_SCOPE_REACHED_ERROR.format(parent_scope=self.scope.name)
) from exc
return self.__class__(scope=scope, parent_container=self, context=context)
def find_container(self, scope: Scope) -> "typing_extensions.Self":
container = self
if container.scope < scope:
raise RuntimeError(
errors.CONTAINER_NOT_INITIALIZED_SCOPE_ERROR.format(
provider_scope=scope.name, container_scope=self.scope.name
)
)
while container.scope > scope and container.parent_container:
container = container.parent_container
if container.scope != scope:
raise RuntimeError(errors.CONTAINER_SCOPE_IS_SKIPPED_ERROR.format(provider_scope=scope.name))
return container
def resolve(self, dependency_type: type[types.T]) -> types.T:
provider = self.providers_registry.find_provider(dependency_type)
if not provider:
raise RuntimeError(errors.CONTAINER_MISSING_PROVIDER_ERROR.format(provider_type=dependency_type))
return self.resolve_provider(provider)
def resolve_provider(self, provider: "AbstractProvider[types.T]") -> types.T:
if (override := self.overrides_registry.fetch_override(provider.provider_id)) is not types.UNSET:
return typing.cast(types.T, override)
return typing.cast(types.T, provider.resolve(self))
def validate_provider(self, provider: "AbstractProvider[types.T]") -> types.T:
return typing.cast(types.T, provider.validate(self))
async def close_async(self) -> None:
if not self.parent_container:
self.overrides_registry.reset_override()
await self.cache_registry.close_async()
def close_sync(self) -> None:
if not self.parent_container:
self.overrides_registry.reset_override()
self.cache_registry.close_sync()
def override(self, provider: AbstractProvider[types.T], override_object: object) -> None:
self.overrides_registry.override(provider.provider_id, override_object)
def reset_override(self, provider: AbstractProvider[types.T] | None = None) -> None:
self.overrides_registry.reset_override(provider.provider_id if provider else None)
def set_context(self, context_type: type[types.T], obj: types.T) -> None:
self.context_registry.set_context(context_type, obj)
def __deepcopy__(self, *_: object, **__: object) -> "typing_extensions.Self":
"""Prevent cloning object."""
return self
def __copy__(self, *_: object, **__: object) -> "typing_extensions.Self":
"""Prevent cloning object."""
return self