diff --git a/MANIFEST.in b/MANIFEST.in index df74e1c..8ae9c24 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,5 @@ include *.txt *.rst *.cfg *.py *.ini *.toml *.yaml exclude .installed.cfg -recursive-include reg *.py +recursive-include reg *.py py.typed recursive-include tests *.py -recursive-include doc *.rst Makefile *.py *.bat -include .coveragerc \ No newline at end of file +recursive-include doc *.rst Makefile *.py *.bat \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f105828..191629d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,10 +37,18 @@ test = ["pytest >= 8", "pytest-env", "sphinx"] docs = ["sphinx"] coverage = ["pytest-cov"] lint = ["black", "flake8", "flake8-pyproject"] +# NOTE: sphinx 9.1 introduced PEP-695 syntax into the code-base, so we +# can't type check in pre-3.12 mode this way, which we need to, +# as long as we still support these versions. +mypy = ["mypy", "pytest", "sphinx < 9.1.0"] +pyright = ["pyright", "pytest", "sphinx < 9.1.0"] [tool.setuptools.packages] find = {} +[tool.setuptools.package-data] +reg = ["py.typed"] + [tool.setuptools.dynamic] readme = {file = ["README.rst", "CHANGES.txt"]} @@ -51,7 +59,7 @@ addopts = ["-vv"] env = ["RUN_ENV=test"] [tool.coverage.run] -omit = ["reg/tests/*"] +omit = ["reg/tests/*", "reg/types.py"] source = ["reg"] [tool.coverage.report] @@ -59,9 +67,18 @@ show_missing = true [tool.flake8] show-source = true -ignore = ["E203", "E731", "W503"] +ignore = ["E203", "E301", "E501", "E704", "E731", "W503"] max-line-length = 88 +[tool.mypy] +python_version = "3.10" +strict = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = "reg.tests.fixtures.*" +disallow_untyped_defs = false + [tool.tox] requires = ["tox>=4"] env_list = [ @@ -75,11 +92,13 @@ env_list = [ "pre-commit", "docs", "perf", + "mypy", + "pyright" ] skip_missing_interpreters = true [tool.tox.gh.python] -"3.10" = ["py310", "perf"] +"3.10" = ["py310", "mypy", "pyright", "perf"] "3.11" = ["py311"] "3.12" = ["py312"] "3.13" = ["py313"] @@ -119,3 +138,25 @@ extras = [] commands = [ ["python", "{toxinidir}/tox_perf.py"], ] + +[tool.tox.env.mypy] +base_python = ["python3"] +extras = ["mypy"] +commands = [ + ["mypy", "-p", "reg", "--python-version", "3.10"], + ["mypy", "-p", "reg", "--python-version", "3.11"], + ["mypy", "-p", "reg", "--python-version", "3.12"], + ["mypy", "-p", "reg", "--python-version", "3.13"], + ["mypy", "-p", "reg", "--python-version", "3.14"], +] + +[tool.tox.env.pyright] +base_python = ["python3"] +extras = ["pyright"] +commands = [ + ["pyright", "reg", "--pythonversion", "3.10"], + ["pyright", "reg", "--pythonversion", "3.11"], + ["pyright", "reg", "--pythonversion", "3.12"], + ["pyright", "reg", "--pythonversion", "3.13"], + ["pyright", "reg", "--pythonversion", "3.14"], +] diff --git a/reg/__init__.py b/reg/__init__.py index cb4eb46..30f4205 100644 --- a/reg/__init__.py +++ b/reg/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa from .dispatch import dispatch, Dispatch, LookupEntry from .context import ( dispatch_method, @@ -17,3 +16,23 @@ match_class, ) from .cache import DictCachingKeyLookup, LruCachingKeyLookup + +__all__ = ( + "ClassIndex", + "DictCachingKeyLookup", + "Dispatch", + "DispatchMethod", + "KeyIndex", + "LookupEntry", + "LruCachingKeyLookup", + "Predicate", + "RegistrationError", + "arginfo", + "clean_dispatch_methods", + "dispatch", + "dispatch_method", + "match_class", + "match_instance", + "match_key", + "methodify", +) diff --git a/reg/arginfo.py b/reg/arginfo.py index 42c3ee7..456e918 100644 --- a/reg/arginfo.py +++ b/reg/arginfo.py @@ -1,21 +1,42 @@ +from __future__ import annotations + import inspect import sys +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from collections.abc import Callable + from .types import ArgInfo + if sys.version_info < (3, 14): - def get_signature(callable): # pragma: no cover + def get_signature( + callable: Callable[..., Any], + ) -> inspect.Signature: # pragma: no cover """A compatibility wrapper for `inspect.signature`.""" return inspect.signature(callable) else: from annotationlib import Format # pragma: no cover - def get_signature(callable): # pragma: no cover + def get_signature( + callable: Callable[..., Any], + ) -> inspect.Signature: # pragma: no cover """A compatibility wrapper for `inspect.signature`.""" return inspect.signature(callable, annotation_format=Format.FORWARDREF) -def arginfo(callable): +# NOTE: This no-op decorator lets type checkers know about the extra +# attributes we add to the arginfo callable +def _coerce_to_arginfo( + f: Callable[[Callable[..., Any]], inspect.FullArgSpec | None], +) -> ArgInfo: + return cast("ArgInfo", f) + + +@_coerce_to_arginfo +def arginfo(callable: Callable[..., Any]) -> inspect.FullArgSpec | None: """Get information about the arguments of a callable. Returns a :class:`inspect.FullArgSpec` object as for @@ -43,10 +64,11 @@ def arginfo(callable): except KeyError: # Try to get __call__ function from the cache. try: - return arginfo._cache[callable.__call__] + return arginfo._cache[callable.__call__] # type: ignore except (AttributeError, KeyError): pass + cache_key: Callable[..., Any] if inspect.isfunction(callable): cache_key = callable elif inspect.ismethod(callable): @@ -63,7 +85,7 @@ def arginfo(callable): # Since arbitrary callable objects may not be hashable # we instead retrieve their call method, which should be try: - cache_key = callable.__call__ + cache_key = callable.__call__ # type: ignore except AttributeError: return None @@ -111,17 +133,17 @@ def arginfo(callable): return result -def is_cached(callable): +def is_cached(callable: Callable[..., Any]) -> bool: if callable in arginfo._cache: return True - return callable.__call__ in arginfo._cache + return callable.__call__ in arginfo._cache # type: ignore arginfo._cache = {} arginfo.is_cached = is_cached -def fake_empty_init(): +def fake_empty_init() -> None: pass # pragma: nocoverage @@ -129,4 +151,4 @@ class Dummy: pass -WRAPPER_DESCRIPTOR = Dummy.__init__ +WRAPPER_DESCRIPTOR: object = Dummy.__init__ diff --git a/reg/cache.py b/reg/cache.py index 18d706b..fc1a224 100644 --- a/reg/cache.py +++ b/reg/cache.py @@ -1,18 +1,35 @@ -from repoze.lru import lru_cache +from __future__ import annotations +from repoze.lru import lru_cache # type: ignore +from typing import TYPE_CHECKING, Any, Generic -class Cache(dict): +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from typing_extensions import TypeVar + from .types import KeyLookup + + _ValueT = TypeVar("_ValueT", default=Callable[..., Any]) +else: + from typing import TypeVar + + _ValueT = TypeVar("_ValueT") + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +class Cache(dict[_KT, _VT]): """A dict to cache a function.""" - def __init__(self, func): + def __init__(self, func: Callable[[_KT], _VT]) -> None: self.func = func - def __missing__(self, key): + def __missing__(self, key: _KT) -> _VT: self[key] = result = self.func(key) return result -class DictCachingKeyLookup: +class DictCachingKeyLookup(Generic[_ValueT]): """A key lookup that caches. Implements the read-only API of :class:`reg.PredicateRegistry` using @@ -28,14 +45,32 @@ class DictCachingKeyLookup: """ - def __init__(self, key_lookup): + def __init__(self, key_lookup: KeyLookup[_ValueT]) -> None: self.key_lookup = key_lookup - self.component = Cache(key_lookup.component).__getitem__ - self.fallback = Cache(key_lookup.fallback).__getitem__ - self.all = Cache(lambda key: list(key_lookup.all(key))).__getitem__ + self.component = Cache(key_lookup.component).__getitem__ # type: ignore + self.fallback = Cache(key_lookup.fallback).__getitem__ # type: ignore + + def _all(key: Sequence[Any]) -> list[_ValueT]: + return list(key_lookup.all(key)) + + self.all = Cache(_all).__getitem__ # type: ignore + + if TYPE_CHECKING: + # NOTE: For pyright's sake we declare these callable instance attributes + # as methods, even though they're not, since pyright does not seem + # to be able to match protocols against them. mypy can deal with + # it just fine + def component(self, key: Sequence[Any], /) -> _ValueT | None: + raise NotImplementedError + + def fallback(self, key: Sequence[Any], /) -> _ValueT | None: + raise NotImplementedError + + def all(self, key: Sequence[Any], /) -> list[_ValueT]: + raise NotImplementedError -class LruCachingKeyLookup: +class LruCachingKeyLookup(Generic[_ValueT]): """A key lookup that caches. Implements the read-only API of :class:`reg.PredicateRegistry`, using @@ -57,12 +92,26 @@ class LruCachingKeyLookup: def __init__( self, - key_lookup, - component_cache_size, - all_cache_size, - fallback_cache_size, - ): + key_lookup: KeyLookup[_ValueT], + component_cache_size: int, + all_cache_size: int, + fallback_cache_size: int, + ) -> None: self.key_lookup = key_lookup - self.component = lru_cache(component_cache_size)(key_lookup.component) - self.fallback = lru_cache(fallback_cache_size)(key_lookup.fallback) - self.all = lru_cache(all_cache_size)(lambda key: list(key_lookup.all(key))) + self.component = lru_cache(component_cache_size)(key_lookup.component) # type: ignore + self.fallback = lru_cache(fallback_cache_size)(key_lookup.fallback) # type: ignore + self.all = lru_cache(all_cache_size)(lambda key: list(key_lookup.all(key))) # type: ignore + + if TYPE_CHECKING: + # NOTE: For pyright's sake we declare these callable instance attributes + # as methods, even though they're not, since pyright does not seem + # to be able to match protocols against them. mypy can deal with + # it just fine + def component(self, key: Sequence[Any], /) -> _ValueT | None: + raise NotImplementedError + + def fallback(self, key: Sequence[Any], /) -> _ValueT | None: + raise NotImplementedError + + def all(self, key: Sequence[Any], /) -> list[_ValueT]: + raise NotImplementedError diff --git a/reg/context.py b/reg/context.py index 96bfaa7..3aabf92 100644 --- a/reg/context.py +++ b/reg/context.py @@ -1,10 +1,36 @@ +from __future__ import annotations + import inspect from types import MethodType -from .dispatch import dispatch, Dispatch, format_signature, execute +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + Generic, + NoReturn as Never, + ParamSpec, + TypeVar, + cast, + overload, +) +from .dispatch import dispatch, Dispatch, format_signature, execute, identity from .arginfo import arginfo +if TYPE_CHECKING: + from collections.abc import Callable + from .dispatch import LookupEntry + from .predicate import Predicate, PredicateRegistry + from .types import BoundDispatchMethodCall, DispatchMethodCall, KeyLookup + +_T = TypeVar("_T") +_T1 = TypeVar("_T1") +_R = TypeVar("_R") +_R1 = TypeVar("_R1") +_P = ParamSpec("_P") +_P1 = ParamSpec("_P1") + -class dispatch_method(dispatch): +class dispatch_method(dispatch, Generic[_P, _T, _R]): """Decorator to make a method on a context class dispatch. This takes the predicates to dispatch on as zero or more parameters. @@ -28,16 +54,48 @@ class in which this decorator is used. It is invoked the first """ - def __init__(self, *predicates, **kw): - self.first_invocation_hook = kw.pop("first_invocation_hook", lambda x: None) - super().__init__(*predicates, **kw) - self._cache = {} - - def __call__(self, callable): + callable: Callable[Concatenate[_T, _P], _R] + + def __init__( + # NOTE: When we first create this object, we don't know yet what kind + # of object and callable we're binding to, so to avoid unknown + # type errors we solve the type vars here and return a new + # type from `__call__`. It would be more robust if this was + # a factory for `dispatch_method` decorators instead, but it's + # a little difficult to justify changing at this point. + self: dispatch_method[Any, Any, Any], + *predicates: str | Predicate, + first_invocation_hook: Callable[[Any], object] = lambda x: None, + get_key_lookup: Callable[[PredicateRegistry], KeyLookup] = identity, + # NOTE: We keep allowing arbitrary keyword arguments at runtime + # for now, but type checkers should emit an error for these. + **kw: Never, + ) -> None: + self.first_invocation_hook = first_invocation_hook + super().__init__(*predicates, get_key_lookup=get_key_lookup, **kw) + self._cache: dict[type[_T] | None, DispatchMethodCall[_P, _T, _R]] = {} + + # NOTE: This needs to be able to modify the bound type vars, so we + # use different type vars, with a separate scope. + def __call__( # type: ignore[override] + self: dispatch_method[Any, Any, Any], + callable: Callable[Concatenate[_T1, _P1], _R1], + ) -> dispatch_method[_P1, _T1, _R1]: self.callable = callable return self - def __get__(self, obj, type=None): + @overload + def __get__( + self, obj: _T, type: type[_T] | None = None + ) -> BoundDispatchMethodCall[_P, _T, _R]: ... + @overload + def __get__( + self, obj: None, type: type[_T] | None = None + ) -> DispatchMethodCall[_P, _T, _R]: ... + + def __get__( + self, obj: _T | None, type: type[_T] | None = None + ) -> BoundDispatchMethodCall[_P, _T, _R] | DispatchMethodCall[_P, _T, _R]: # we get the method from the cache # this guarantees that we distinguish between dispatches # on a per class basis, and on the name of the method @@ -63,7 +121,7 @@ def __get__(self, obj, type=None): self.first_invocation_hook(obj) # if we access the instance, we simulate binding it - bound = MethodType(dispatch, obj) + bound = cast("BoundDispatchMethodCall[_P, _T, _R]", MethodType(dispatch, obj)) # we store it on the instance, so that next time we # access this, we do not hit the descriptor anymore # but return the bound dispatch function directly @@ -71,18 +129,37 @@ def __get__(self, obj, type=None): return bound -class DispatchMethod(Dispatch): - def by_args(self, *args, **kw): +class DispatchMethod(Dispatch[Concatenate[_T, _P], _R], Generic[_P, _T, _R]): + call: DispatchMethodCall[ + _P, _T, _R + ] # pyright: ignore[reportIncompatibleVariableOverride] + + def by_args(self, *args: _P.args, **kw: _P.kwargs) -> LookupEntry[Callable[Concatenate[_T, _P], _R]]: # type: ignore[override] """Lookup an implementation by invocation arguments. :param args: positional arguments used in invocation. :param kw: named arguments used in invocation. :returns: a :class:`reg.LookupEntry`. """ - return super().by_args(None, *args, **kw) + return super().by_args(None, *args, **kw) # type: ignore[arg-type] + + +@overload +def methodify( + func: Callable[_P, _T], selfname: None = None +) -> Callable[Concatenate[Any, _P], _T]: ... + + +# NOTE: For this overload we no longer know what the signature will look +# like since it might be concatenated or not, so we have to discard +# it if we want to preserve the gradual guarantee +@overload +def methodify(func: Callable[..., _T], selfname: str) -> Callable[..., _T]: ... -def methodify(func, selfname=None): +def methodify( + func: Callable[_P, _T], selfname: str | None = None +) -> Callable[Concatenate[Any, _P], _T] | Callable[..., _T]: """Turn a function into a method, if needed. If ``selfname`` is not specified, wrap the function so that it @@ -125,10 +202,10 @@ def methodify(func, selfname=None): code_source = code_template.format( signature=format_signature(args), selfname=selfname or "_" ) - return execute(code_source, _func=func)["wrapper"] + return execute(code_source, _func=func)["wrapper"] # type: ignore[no-any-return] -def clean_dispatch_methods(cls): +def clean_dispatch_methods(cls: type[object]) -> None: """For a given class clean all dispatch methods. This resets their registry to the original state using @@ -139,4 +216,4 @@ def clean_dispatch_methods(cls): for name in dir(cls): attr = getattr(cls, name) if inspect.isfunction(attr) and hasattr(attr, "clean"): - attr.clean() + attr.clean() # pyright: ignore[reportFunctionMemberAccess] diff --git a/reg/dispatch.py b/reg/dispatch.py index 1e169d2..7497002 100644 --- a/reg/dispatch.py +++ b/reg/dispatch.py @@ -1,10 +1,36 @@ +from __future__ import annotations + from functools import partial, wraps -from collections import namedtuple +from typing import ( + TYPE_CHECKING, + Any, + Generic, + NamedTuple, + NoReturn as Never, + ParamSpec, + TypeVar, + cast, + overload, +) from .predicate import match_instance from .predicate import PredicateRegistry from .arginfo import arginfo from .error import RegistrationError +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from inspect import FullArgSpec + from .predicate import Predicate + from .types import DispatchCall, GetKeyLookup, KeyLookup + +_T = TypeVar("_T") +_F = TypeVar("_F", bound="Callable[..., Any]") +_P = ParamSpec("_P") + + +def identity(registry: PredicateRegistry) -> PredicateRegistry: + return registry + class dispatch: """Decorator to make a function dispatch based on its arguments. @@ -29,50 +55,56 @@ class dispatch: """ - def __init__(self, *predicates, **kw): + def __init__( + self, + *predicates: str | Predicate, + get_key_lookup: GetKeyLookup = identity, + # NOTE: We keep allowing arbitrary keyword arguments at runtime + # for now, but type checkers should emit an error for these. + **kw: Never, + ) -> None: self.predicates = [self._make_predicate(predicate) for predicate in predicates] - self.get_key_lookup = kw.pop("get_key_lookup", identity) + self.get_key_lookup = get_key_lookup - def _make_predicate(self, predicate): + def _make_predicate(self, predicate: str | Predicate) -> Predicate: if isinstance(predicate, str): return match_instance(predicate) return predicate - def __call__(self, callable): + def __call__(self, callable: Callable[_P, _T]) -> DispatchCall[_P, _T]: return Dispatch(self.predicates, callable, self.get_key_lookup).call -def identity(registry): - return registry +class _LookupEntry(NamedTuple): + lookup: KeyLookup + key: tuple[Any, ...] -class LookupEntry(namedtuple("LookupEntry", "lookup key")): +class LookupEntry(_LookupEntry, Generic[_F]): """The dispatch data associated to a key.""" - __slots__ = () - @property - def component(self): + def component(self) -> _F | None: """The function to dispatch to, excluding fallbacks.""" - return self.lookup.component(self.key) + return cast("_F | None", self.lookup.component(self.key)) @property - def fallback(self): + def fallback(self) -> _F | None: """The approriate fallback implementation.""" - return self.lookup.fallback(self.key) + return cast("_F | None", self.lookup.fallback(self.key)) @property - def matches(self): + def matches(self) -> Iterable[_F]: """An iterator over all the compatible implementations.""" - return self.lookup.all(self.key) + return cast("Iterable[_F]", self.lookup.all(self.key)) @property - def all_matches(self): + def all_matches(self) -> list[_F]: """The list of all compatible implementations.""" return list(self.matches) -class Dispatch: +class Dispatch(Generic[_P, _T]): """Dispatch function. You can register implementations based on particular predicates. The @@ -92,14 +124,19 @@ class Dispatch: :class:`reg.LruCachingKeyLookup`) to make it more efficient. """ - def __init__(self, predicates, callable, get_key_lookup): + def __init__( + self, + predicates: list[Predicate], + callable: Callable[_P, _T], + get_key_lookup: GetKeyLookup, + ) -> None: self.wrapped_func = callable self.get_key_lookup = get_key_lookup self._original_predicates = predicates self._define_call() self._register_predicates(predicates) - def _register_predicates(self, predicates): + def _register_predicates(self, predicates: list[Predicate]) -> None: self.registry = PredicateRegistry(*predicates) self.predicates = predicates self.call.key_lookup = self.key_lookup = self.get_key_lookup(self.registry) @@ -110,10 +147,14 @@ def _register_predicates(self, predicates): ) self._predicate_key.__globals__.update( _registry_key=self.registry.key, - _return_type=partial(LookupEntry, self.key_lookup), + _return_type=partial(LookupEntry["Callable[_P, _T]"], self.key_lookup), ) - def _define_call(self): + # tell type checkers about these auto-generated functions + call: DispatchCall[_P, _T] + _predicate_key: Callable[..., LookupEntry[Callable[_P, _T]]] + + def _define_call(self) -> None: # We build the generic function on the fly. Its definition # requires the signature of the wrapped function and the # arguments needed by the registered predicates @@ -127,6 +168,7 @@ def call({signature}): """ args = arginfo(self.wrapped_func) + assert args is not None signature = format_signature(args) predicate_args = ", ".join("{0}={0}".format(x) for x in args.args) code_source = code_template.format( @@ -134,14 +176,17 @@ def call({signature}): ) # We now compile call to byte-code: - self.call = call = wraps(self.wrapped_func)( - execute( - code_source, - _registry_key=None, - _component_lookup=None, - _fallback_lookup=None, - _fallback=self.wrapped_func, - )["call"] + self.call = call = cast( + "DispatchCall[_P, _T]", + wraps(self.wrapped_func)( + execute( + code_source, + _registry_key=None, + _component_lookup=None, + _fallback_lookup=None, + _fallback=self.wrapped_func, + )["call"] + ), ) # We copy over the defaults from the wrapped function. @@ -163,7 +208,7 @@ def call({signature}): _return_type=None, )["predicate_key"] - def clean(self): + def clean(self) -> None: """Clean up implementations and added predicates. This restores the dispatch function to its original state, @@ -172,7 +217,7 @@ def clean(self): """ self._register_predicates(self._original_predicates) - def add_predicates(self, predicates): + def add_predicates(self, predicates: list[Predicate]) -> None: """Add new predicates. Extend the predicates used by this predicates. This can be @@ -184,7 +229,16 @@ def add_predicates(self, predicates): """ self._register_predicates(self.predicates + predicates) - def register(self, func=None, **key_dict): + @overload + def register(self, func: Callable[_P, _T], **key_dict: Any) -> Callable[_P, _T]: ... + @overload + def register( + self, func: None = None, **key_dict: Any + ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... + + def register( + self, func: Callable[_P, _T] | None = None, **key_dict: Any + ) -> Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]: """Register an implementation. If ``func`` is not specified, this method can be used as a @@ -207,7 +261,7 @@ def register(self, func=None, **key_dict): self.registry.register(predicate_key, func) return func - def by_args(self, *args, **kw): + def by_args(self, *args: _P.args, **kw: _P.kwargs) -> LookupEntry[Callable[_P, _T]]: """Lookup an implementation by invocation arguments. :param args: positional arguments used in invocation. @@ -216,7 +270,7 @@ def by_args(self, *args, **kw): """ return self._predicate_key(*args, **kw) - def by_predicates(self, **predicate_values): + def by_predicates(self, **predicate_values: Any) -> LookupEntry[Callable[_P, _T]]: """Lookup an implementation by predicate values. :param predicate_values: the values of the predicates to lookup. @@ -228,20 +282,22 @@ def by_predicates(self, **predicate_values): ) -def validate_signature(f, dispatch): +def validate_signature(f: Callable[..., Any], dispatch: Callable[..., Any]) -> None: f_arginfo = arginfo(f) if f_arginfo is None: raise RegistrationError( "Cannot register non-callable for dispatch " "%r: %r" % (dispatch, f) ) - if not same_signature(arginfo(dispatch), f_arginfo): + d_arginfo = arginfo(dispatch) + assert d_arginfo is not None + if not same_signature(d_arginfo, f_arginfo): raise RegistrationError( "Signature of callable dispatched to (%r) " "not that of dispatch (%r)" % (f, dispatch) ) -def format_signature(args): +def format_signature(args: FullArgSpec) -> str: return ", ".join( args.args + (["*" + args.varargs] if args.varargs else []) @@ -249,7 +305,7 @@ def format_signature(args): ) -def same_signature(a, b): +def same_signature(a: FullArgSpec, b: FullArgSpec) -> bool: """Check whether a arginfo and b arginfo are the same signature. Actual names of arguments may differ. Default arguments may be @@ -260,7 +316,7 @@ def same_signature(a, b): return len(a_args) == len(b_args) and a.varargs == b.varargs and a.varkw == b.varkw -def execute(code_source, **namespace): +def execute(code_source: str, **namespace: Any) -> dict[str, Any]: """Execute code in a namespace, returning the namespace.""" code_object = compile(code_source, f"", "exec") exec(code_object, namespace) diff --git a/reg/predicate.py b/reg/predicate.py index ca8f418..1b25edd 100644 --- a/reg/predicate.py +++ b/reg/predicate.py @@ -1,11 +1,24 @@ +from __future__ import annotations + import inspect from operator import itemgetter from itertools import product +from typing import TYPE_CHECKING, Any, Generic from .error import RegistrationError +if TYPE_CHECKING: + from collections.abc import Callable, Iterator, Sequence + from typing_extensions import TypeVar + + _ValueT = TypeVar("_ValueT", default=Callable[..., Any]) +else: + from typing import TypeVar + + _ValueT = TypeVar("_ValueT") + -class Predicate: +class Predicate(Generic[_ValueT]): """A dispatch predicate. :param name: name used to identify the predicate when specifying @@ -25,21 +38,36 @@ class Predicate: """ - def __init__(self, name, index, get_key=None, fallback=None, default=None): + def __init__( + self, + name: str, + index: Callable[[_ValueT | None], KeyIndex[_ValueT]], + # FIXME: This maybe shouldn't be optional, considering + # PredicateRegistry.key will crash if it got a + # Predicate without a get_key. + get_key: Callable[[dict[str, Any]], Any] | None = None, + fallback: _ValueT | None = None, + default: Any | None = None, + ) -> None: self.name = name self.index = index self.fallback = fallback self.get_key = get_key self.default = default - def create_index(self): + def create_index(self) -> KeyIndex[_ValueT]: return self.index(self.fallback) - def key_by_predicate_name(self, d): + def key_by_predicate_name(self, d: dict[str, Any]) -> Any | None: return d.get(self.name, self.default) -def match_key(name, func=None, fallback=None, default=None): +def match_key( + name: str, + func: Callable[..., Any] | None = None, + fallback: Any | None = None, + default: Any | None = None, +) -> Predicate[Any]: """Predicate that returns a value used for dispatching. :name: predicate name. @@ -54,6 +82,7 @@ def match_key(name, func=None, fallback=None, default=None): :returns: a :class:`Predicate`. """ + get_key: Callable[[dict[str, Any]], Any] if func is None: get_key = itemgetter(name) else: @@ -61,7 +90,12 @@ def match_key(name, func=None, fallback=None, default=None): return Predicate(name, KeyIndex, get_key, fallback, default) -def match_instance(name, func=None, fallback=None, default=None): +def match_instance( + name: str, + func: Callable[..., Any] | None = None, + fallback: Any | None = None, + default: Any | None = None, +) -> Predicate[Any]: """Predicate that returns an instance whose class is used for dispatching. :name: predicate name. @@ -82,7 +116,12 @@ def match_instance(name, func=None, fallback=None, default=None): return Predicate(name, ClassIndex, get_key, fallback, default) -def match_class(name, func=None, fallback=None, default=None): +def match_class( + name: str, + func: Callable[..., Any] | None = None, + fallback: Any | None = None, + default: Any | None = None, +) -> Predicate[Any]: """Predicate that returns a class used for dispatching. :name: predicate name. @@ -96,6 +135,7 @@ def match_class(name, func=None, fallback=None, default=None): :returns: a :class:`Predicate`. """ + get_key: Callable[[dict[str, Any]], Any] if func is None: get_key = itemgetter(name) else: @@ -103,17 +143,17 @@ def match_class(name, func=None, fallback=None, default=None): return Predicate(name, ClassIndex, get_key, fallback, default) -_emptyset = frozenset() +_emptyset: frozenset[Any] = frozenset() -class KeyIndex(dict): - def __init__(self, fallback=None): +class KeyIndex(dict[Any, set[_ValueT]]): + def __init__(self, fallback: _ValueT | None = None) -> None: self.fallback = fallback - def __missing__(self, key): + def __missing__(self, key: Any) -> frozenset[Any]: return _emptyset - def permutations(self, key): + def permutations(self, key: Any) -> Iterator[Any]: """Permutations for a simple immutable key. There is only a single permutation: the key itself. @@ -121,42 +161,37 @@ def permutations(self, key): yield key -class ClassIndex(KeyIndex): - def permutations(self, key): +class ClassIndex(KeyIndex[_ValueT]): + def permutations(self, key: type[Any]) -> Iterator[type[Any]]: """Permutations for class key. - Returns class and its base classes in mro order. If a classic - class in Python 2, smuggle in ``object`` as the base class - anyway to make lookups consistent. + Returns class and its base classes in mro order. """ - for class_ in inspect.getmro(key): - yield class_ - if class_ is not object: - yield object # pragma: no cover + yield from inspect.getmro(key) -class PredicateRegistry: - def __init__(self, *predicates): - self.known_keys = set() - self.known_values = set() +class PredicateRegistry(Generic[_ValueT]): + def __init__(self, *predicates: Predicate[_ValueT]) -> None: + self.known_keys: set[Any] = set() + self.known_values: set[_ValueT] = set() self.predicates = predicates self.indexes = [predicate.create_index() for predicate in predicates] key_getters = [p.get_key for p in predicates] if len(predicates) == 0: - self.key = lambda **kw: () + self.key = lambda **kw: () # type: ignore elif len(predicates) == 1: (p,) = key_getters - self.key = lambda **kw: (p(kw),) + self.key = lambda **kw: (p(kw),) # type: ignore elif len(predicates) == 2: p, q = key_getters - self.key = lambda **kw: (p(kw), q(kw)) + self.key = lambda **kw: (p(kw), q(kw)) # type: ignore elif len(predicates) == 3: p, q, r = key_getters - self.key = lambda **kw: (p(kw), q(kw), r(kw)) + self.key = lambda **kw: (p(kw), q(kw), r(kw)) # type: ignore else: - self.key = lambda **kw: tuple([p(kw) for p in key_getters]) + self.key = lambda **kw: tuple([p(kw) for p in key_getters]) # type: ignore - def register(self, key, value): + def register(self, key: Any, value: _ValueT) -> None: if key in self.known_keys: raise RegistrationError(f"Already have registration for key: {key}") for index, key_item in zip(self.indexes, key): @@ -164,7 +199,7 @@ def register(self, key, value): self.known_keys.add(key) self.known_values.add(value) - def get(self, keys): + def get(self, keys: Sequence[Any]) -> set[_ValueT]: # do an intersection of all sets that result from index lookup # this code is a bit convoluted for performance reasons. sets = (index[key] for index, key in zip(self.indexes, keys)) @@ -172,12 +207,12 @@ def get(self, keys): # this returns the known values if there are no indexes at all return next(sets, self.known_values).intersection(*sets) - def permutations(self, keys): + def permutations(self, keys: Sequence[Any]) -> Iterator[tuple[Any, ...]]: return product( *(index.permutations(key) for index, key in zip(self.indexes, keys)) ) - def key(self, **kw): + def key(self, **kw: Any) -> tuple[Any, ...]: # type: ignore[empty-body] """Construct a dispatch key from the arguments of a generic function. :param kw: a dictionary with the arguments passed to a generic @@ -187,7 +222,7 @@ def key(self, **kw): """ # Overwritten by init - def key_dict_to_predicate_key(self, d): + def key_dict_to_predicate_key(self, d: dict[str, Any]) -> tuple[Any, ...]: """Construct a dispatch key from predicate values. Uses ``name`` and ``default`` attributes of predicates to @@ -200,10 +235,10 @@ def key_dict_to_predicate_key(self, d): """ return tuple([p.key_by_predicate_name(d) for p in self.predicates]) - def component(self, keys): + def component(self, keys: Sequence[Any]) -> _ValueT | None: return next(self.all(keys), None) - def fallback(self, keys): + def fallback(self, keys: Sequence[Any]) -> _ValueT | None: result = None for index, key in zip(self.indexes, keys): for k in index.permutations(key): @@ -221,7 +256,8 @@ def fallback(self, keys): # match if not result: return index.fallback + return None - def all(self, key): + def all(self, key: Sequence[Any]) -> Iterator[_ValueT]: for p in self.permutations(key): yield from self.get(p) diff --git a/reg/py.typed b/reg/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/reg/tests/test_arginfo.py b/reg/tests/test_arginfo.py index 8dc665c..02e3b10 100644 --- a/reg/tests/test_arginfo.py +++ b/reg/tests/test_arginfo.py @@ -1,13 +1,19 @@ +from __future__ import annotations + import pytest +from typing import Any, TYPE_CHECKING from ..arginfo import arginfo +if TYPE_CHECKING: + from collections.abc import Callable + -def func_no_args(): +def func_no_args() -> None: pass class ObjNoArgs: - def __call__(self): + def __call__(self) -> None: pass @@ -15,7 +21,7 @@ def __call__(self): class MethodNoArgs: - def method(self): + def method(self) -> None: pass @@ -24,13 +30,13 @@ def method(self): class StaticMethodNoArgs: @staticmethod - def method(): + def method() -> None: pass class ClassMethodNoArgs: @classmethod - def method(cls): + def method(cls) -> None: pass @@ -39,7 +45,7 @@ class ClassNoInit: class ClassNoArgs: - def __init__(self): + def __init__(self) -> None: pass @@ -48,7 +54,7 @@ class ClassicNoInit: class ClassicNoArgs: - def __init__(self): + def __init__(self) -> None: pass @@ -86,20 +92,21 @@ class ClassicInheritedNoArgs(ClassicNoArgs): ClassicInheritedNoArgs, ], ) -def test_arginfo_no_args(callable): +def test_arginfo_no_args(callable: Callable[[], Any]) -> None: info = arginfo(callable) + assert info is not None assert info.args == [] assert info.varargs is None assert info.varkw is None assert info.defaults is None -def func_args(a): +def func_args(a: int) -> None: pass class ObjArgs: - def __call__(self, a): + def __call__(self, a: int) -> None: pass @@ -107,7 +114,7 @@ def __call__(self, a): class MethodArgs: - def method(self, a): + def method(self, a: int) -> None: pass @@ -116,23 +123,23 @@ def method(self, a): class StaticMethodArgs: @staticmethod - def method(a): + def method(a: int) -> None: pass class ClassMethodArgs: @classmethod - def method(cls, a): + def method(cls, a: int) -> None: pass class ClassArgs: - def __init__(self, a): + def __init__(self, a: int) -> None: pass class ClassicArgs: - def __init__(self, a): + def __init__(self, a: int) -> None: pass @@ -158,20 +165,21 @@ class ClassicInheritedArgs(ClassicArgs): ClassicInheritedArgs, ], ) -def test_arginfo_args(callable): +def test_arginfo_args(callable: Callable[[int], Any]) -> None: info = arginfo(callable) + assert info is not None assert info.args == ["a"] assert info.varargs is None assert info.varkw is None assert info.defaults is None -def func_varargs(*args): +def func_varargs(*args: int) -> None: pass class ObjVarargs: - def __call__(self, *args): + def __call__(self, *args: int) -> None: pass @@ -179,7 +187,7 @@ def __call__(self, *args): class MethodVarargs: - def method(self, *args): + def method(self, *args: int) -> None: pass @@ -187,12 +195,12 @@ def method(self, *args): class ClassVarargs: - def __init__(self, *args): + def __init__(self, *args: int) -> None: pass class ClassicVarargs: - def __init__(self, *args): + def __init__(self, *args: int) -> None: pass @@ -216,20 +224,21 @@ class ClassicInheritedVarargs(ClassicVarargs): ClassicInheritedVarargs, ], ) -def test_arginfo_varargs(callable): +def test_arginfo_varargs(callable: Callable[..., Any]) -> None: info = arginfo(callable) + assert info is not None assert info.args == [] assert info.varargs == "args" assert info.varkw is None assert info.defaults is None -def func_keywords(**kw): +def func_keywords(**kw: int) -> None: pass class ObjKeywords: - def __call__(self, **kw): + def __call__(self, **kw: int) -> None: pass @@ -237,7 +246,7 @@ def __call__(self, **kw): class MethodKeywords: - def method(self, **kw): + def method(self, **kw: int) -> None: pass @@ -245,12 +254,12 @@ def method(self, **kw): class ClassKeywords: - def __init__(self, **kw): + def __init__(self, **kw: int) -> None: pass class ClassicKeywords: - def __init__(self, **kw): + def __init__(self, **kw: int) -> None: pass @@ -274,20 +283,21 @@ class ClassicInheritedKeywords(ClassicKeywords): ClassicInheritedKeywords, ], ) -def test_arginfo_keywords(callable): +def test_arginfo_keywords(callable: Callable[..., Any]) -> None: info = arginfo(callable) + assert info is not None assert info.args == [] assert info.varargs is None assert info.varkw == "kw" assert info.defaults is None -def func_defaults(a=1): +def func_defaults(a: int = 1) -> None: pass class ObjDefaults: - def __call__(self, a=1): + def __call__(self, a: int = 1) -> None: pass @@ -295,7 +305,7 @@ def __call__(self, a=1): class MethodDefaults: - def method(self, a=1): + def method(self, a: int = 1) -> None: pass @@ -303,12 +313,12 @@ def method(self, a=1): class ClassDefaults: - def __init__(self, a=1): + def __init__(self, a: int = 1) -> None: pass class ClassicDefaults: - def __init__(self, a=1): + def __init__(self, a: int = 1) -> None: pass @@ -332,20 +342,21 @@ class ClassicInheritedDefaults(ClassicDefaults): ClassicInheritedDefaults, ], ) -def test_arginfo_defaults(callable): +def test_arginfo_defaults(callable: Callable[[int], Any]) -> None: info = arginfo(callable) + assert info is not None assert info.args == ["a"] assert info.varargs is None assert info.varkw is None assert info.defaults == (1,) -def func_kwonlydefaults(*, a=1): +def func_kwonlydefaults(*, a: int = 1) -> None: pass class ObjKwOnlyDefaults: - def __call__(self, *, a=1): + def __call__(self, *, a: int = 1) -> None: pass @@ -353,7 +364,7 @@ def __call__(self, *, a=1): class MethodKwOnlyDefaults: - def method(self, *, a=1): + def method(self, *, a: int = 1) -> None: pass @@ -361,7 +372,7 @@ def method(self, *, a=1): class ClassKwOnlyDefaults: - def __init__(self, *, a=1): + def __init__(self, *, a: int = 1) -> None: pass @@ -379,8 +390,9 @@ class InheritedKwOnlyDefaults(ClassKwOnlyDefaults): InheritedKwOnlyDefaults, ], ) -def test_arginfo_kwonlydefaults(callable): +def test_arginfo_kwonlydefaults(callable: Callable[..., Any]) -> None: info = arginfo(callable) + assert info is not None assert not info.args assert info.varargs is None assert info.varkw is None @@ -428,27 +440,29 @@ class InheritedAnnotations(ClassAnnotations): InheritedAnnotations, ], ) -def test_arginfo_annotations(callable): +def test_arginfo_annotations(callable: Callable[[int], Any]) -> None: info = arginfo(callable) + assert info is not None assert info.args == ["a"] assert info.varargs is None assert info.varkw is None assert info.defaults is None - assert info.annotations == {"a": int, "return": None} + assert info.annotations == {"a": "int", "return": "None"} # Information on builtin functions is not reported. These can # still be called with mapply, but only using positional arguments. -def test_arginfo_builtin(): +def test_arginfo_builtin() -> None: info = arginfo(int) + assert info is not None assert info.args == [] assert info.varargs is None assert info.varkw is None assert info.defaults is None -def test_arginfo_cache(): - def foo(a): +def test_arginfo_cache() -> None: + def foo(a: object) -> None: pass assert not arginfo.is_cached(foo) @@ -456,9 +470,9 @@ def foo(a): assert arginfo.is_cached(foo) -def test_arginfo_cache_callable(): +def test_arginfo_cache_callable() -> None: class Foo: - def __call__(self): + def __call__(self) -> None: pass foo = Foo() diff --git a/reg/tests/test_classdispatch.py b/reg/tests/test_classdispatch.py index 3e0f06b..8acec87 100644 --- a/reg/tests/test_classdispatch.py +++ b/reg/tests/test_classdispatch.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from typing import Any from ..dispatch import dispatch from ..predicate import match_class @@ -11,21 +14,21 @@ class SpecialClass: class Foo: - def __repr__(self): + def __repr__(self) -> str: return "" class Bar: - def __repr__(self): + def __repr__(self) -> str: return "" -def test_dispatch_basic(): +def test_dispatch_basic() -> None: @dispatch(match_class("cls")) - def something(cls): + def something(cls: type[Any]) -> str: raise NotImplementedError() - def something_for_object(cls): + def something_for_object(cls: type[object]) -> str: return "Something for %s" % cls something.register(something_for_object, cls=object) @@ -36,15 +39,15 @@ def something_for_object(cls): assert something.by_args(DemoClass).all_matches == [something_for_object] -def test_classdispatch_multidispatch(): +def test_classdispatch_multidispatch() -> None: @dispatch(match_class("cls"), "other") - def something(cls, other): + def something(cls: type[Any], other: Any) -> str: raise NotImplementedError() - def something_for_object_and_object(cls, other): + def something_for_object_and_object(cls: type[object], other: object) -> str: return "Something, other is object: %s" % other - def something_for_object_and_foo(cls, other): + def something_for_object_and_foo(cls: type[object], other: Foo) -> str: return "Something, other is Foo: %s" % other something.register(something_for_object_and_object, cls=object, other=object) @@ -57,12 +60,12 @@ def something_for_object_and_foo(cls, other): assert something(DemoClass, Foo()) == ("Something, other is Foo: ") -def test_classdispatch_extra_arguments(): +def test_classdispatch_extra_arguments() -> None: @dispatch(match_class("cls")) - def something(cls, extra): + def something(cls: type[Any], extra: str) -> str: raise NotImplementedError() - def something_for_object(cls, extra): + def something_for_object(cls: type[object], extra: str) -> str: return "Extra: %s" % extra something.register(something_for_object, cls=object) @@ -70,12 +73,12 @@ def something_for_object(cls, extra): assert something(DemoClass, "foo") == "Extra: foo" -def test_classdispatch_no_arguments(): +def test_classdispatch_no_arguments() -> None: @dispatch() - def something(): + def something() -> str: raise NotImplementedError() - def something_impl(): + def something_impl() -> str: return "Something!" something.register(something_impl) @@ -83,15 +86,15 @@ def something_impl(): assert something() == "Something!" -def test_classdispatch_override(): +def test_classdispatch_override() -> None: @dispatch(match_class("cls")) - def something(cls): + def something(cls: type[Any]) -> str: raise NotImplementedError() - def something_for_object(cls): + def something_for_object(cls: type[object]) -> str: return "Something for %s" % cls - def something_for_special(cls): + def something_for_special(cls: type[SpecialClass]) -> str: return "Special for %s" % cls something.register(something_for_object, cls=object) @@ -100,9 +103,9 @@ def something_for_special(cls): assert something(SpecialClass) == (f"Special for ") -def test_classdispatch_fallback(): +def test_classdispatch_fallback() -> None: @dispatch() - def something(cls): + def something(cls: type[Any]) -> str: return "Fallback" assert something(DemoClass) == "Fallback" diff --git a/reg/tests/test_dispatch.py b/reg/tests/test_dispatch.py index b17efc1..020507d 100644 --- a/reg/tests/test_dispatch.py +++ b/reg/tests/test_dispatch.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import pytest +from typing import Any from ..predicate import match_instance, match_key, match_class from ..dispatch import dispatch from ..error import RegistrationError @@ -21,23 +24,23 @@ class Beta(IBeta): pass -def test_dispatch_argname(): +def test_dispatch_argname() -> None: @dispatch("obj") - def foo(obj): + def foo(obj: Any) -> str | None: pass - def for_bar(obj): + def for_bar(obj: Bar) -> str: return obj.method() - def for_qux(obj): + def for_qux(obj: Qux) -> str: return obj.method() class Bar: - def method(self): + def method(self) -> str: return "bar's method" class Qux: - def method(self): + def method(self) -> str: return "qux's method" foo.register(for_bar, obj=Bar) @@ -47,23 +50,23 @@ def method(self): assert foo(Qux()) == "qux's method" -def test_dispatch_match_instance(): +def test_dispatch_match_instance() -> None: @dispatch(match_instance("obj")) - def foo(obj): + def foo(obj: Any) -> str | None: pass - def for_bar(obj): + def for_bar(obj: Bar) -> str: return obj.method() - def for_qux(obj): + def for_qux(obj: Qux) -> str: return obj.method() class Bar: - def method(self): + def method(self) -> str: return "bar's method" class Qux: - def method(self): + def method(self) -> str: return "qux's method" foo.register(for_bar, obj=Bar) @@ -73,12 +76,12 @@ def method(self): assert foo(Qux()) == "qux's method" -def test_dispatch_no_arguments(): +def test_dispatch_no_arguments() -> None: @dispatch() - def foo(): + def foo() -> str | None: pass - def special_foo(): + def special_foo() -> str: return "special" foo.register(special_foo) @@ -89,7 +92,7 @@ def special_foo(): assert foo.by_args().fallback is None -def test_all(): +def test_all() -> None: class Base: pass @@ -97,13 +100,13 @@ class Sub(Base): pass @dispatch("obj") - def target(obj): + def target(obj: Any) -> None: pass - def registered_for_sub(obj): + def registered_for_sub(obj: Sub) -> None: pass - def registered_for_base(obj): + def registered_for_base(obj: Base) -> None: pass target.register(registered_for_sub, obj=Sub) @@ -119,7 +122,7 @@ def registered_for_base(obj): assert target.by_args(base).all_matches == [registered_for_base] -def test_all_by_keys(): +def test_all_by_keys() -> None: class Base: pass @@ -127,13 +130,13 @@ class Sub(Base): pass @dispatch("obj") - def target(obj): + def target(obj: Any) -> None: pass - def registered_for_sub(obj): + def registered_for_sub(obj: Sub) -> None: pass - def registered_for_base(obj): + def registered_for_base(obj: Base) -> None: pass target.register(registered_for_sub, obj=Sub) @@ -146,36 +149,36 @@ def registered_for_base(obj): assert target.by_predicates(obj=Base).all_matches == [registered_for_base] -def test_component_no_source(): +def test_component_no_source() -> None: @dispatch() - def target(): + def target() -> None: pass - def foo(): + def foo() -> None: pass target.register(foo) assert target.by_args().component is foo -def test_component_no_source_key_dict(): +def test_component_no_source_key_dict() -> None: @dispatch() - def target(): + def target() -> None: pass - def foo(): + def foo() -> None: pass target.register(foo) assert target.by_predicates().component is foo -def test_component_one_source(): +def test_component_one_source() -> None: @dispatch("obj") - def target(obj): + def target(obj: Any) -> None: pass - def foo(obj): + def foo(obj: Alpha) -> None: pass target.register(foo, obj=Alpha) @@ -184,12 +187,12 @@ def foo(obj): assert target.by_args(alpha).component is foo -def test_component_one_source_key_dict(): +def test_component_one_source_key_dict() -> None: @dispatch("obj") - def target(obj): + def target(obj: Any) -> None: pass - def foo(obj): + def foo(obj: Alpha) -> None: pass target.register(foo, obj=Alpha) @@ -197,12 +200,12 @@ def foo(obj): assert target.by_predicates(obj=Alpha).component is foo -def test_component_two_sources(): +def test_component_two_sources() -> None: @dispatch("a", "b") - def target(a, b): + def target(a: Any, b: Any) -> None: pass - def foo(a, b): + def foo(a: IAlpha, b: IBeta) -> None: pass target.register(foo, a=IAlpha, b=IBeta) @@ -212,7 +215,7 @@ def foo(a, b): assert target.by_args(alpha, beta).component is foo -def test_component_inheritance(): +def test_component_inheritance() -> None: class Gamma: pass @@ -220,10 +223,10 @@ class Delta(Gamma): pass @dispatch("obj") - def target(obj): + def target(obj: Any) -> None: pass - def foo(obj): + def foo(obj: Gamma) -> None: pass target.register(foo, obj=Gamma) @@ -233,7 +236,7 @@ def foo(obj): assert target.by_args(delta).component is foo -def test_component_inheritance_old_style_class(): +def test_component_inheritance_old_style_class() -> None: class Gamma: pass @@ -241,10 +244,10 @@ class Delta(Gamma): pass @dispatch("obj") - def target(obj): + def target(obj: Any) -> None: pass - def foo(obj): + def foo(obj: Gamma) -> None: pass target.register(foo, obj=Gamma) @@ -258,14 +261,14 @@ def foo(obj): assert target.by_args(delta).component is foo -def test_call_no_source(): +def test_call_no_source() -> None: foo = object() @dispatch() - def target(): + def target() -> object: pass - def factory(): + def factory() -> object: return foo target.register(factory) @@ -273,15 +276,15 @@ def factory(): assert target() is foo -def test_call_one_source(): +def test_call_one_source() -> None: @dispatch("obj") - def target(obj): + def target(obj: Any) -> str | None: pass - def foo(obj): + def foo(obj: IAlpha) -> str: return "foo" - def bar(obj): + def bar(obj: IBeta) -> str: return "bar" target.register(foo, obj=IAlpha) @@ -291,15 +294,15 @@ def bar(obj): assert target(Beta()) == "bar" -def test_call_two_sources(): +def test_call_two_sources() -> None: @dispatch("a", "b") - def target(a, b): + def target(a: Any, b: Any) -> str | None: pass - def foo(a, b): + def foo(a: IAlpha, b: IBeta) -> str: return "foo" - def bar(a, b): + def bar(a: IBeta, b: IAlpha) -> str: return "bar" target.register(foo, a=IAlpha, b=IBeta) @@ -311,119 +314,119 @@ def bar(a, b): assert target(beta, alpha) == "bar" -def test_component_not_found_no_sources(): +def test_component_not_found_no_sources() -> None: @dispatch() - def target(): + def target() -> None: pass assert target.by_args().component is None -def test_call_not_found_no_sources(): +def test_call_not_found_no_sources() -> None: @dispatch() - def target(): + def target() -> str: return "default" assert target() == "default" -def test_component_not_found_one_source(): +def test_component_not_found_one_source() -> None: @dispatch("obj") - def target(obj): + def target(obj: str) -> None: pass assert target.by_args("dummy").component is None -def test_call_not_found_one_source(): +def test_call_not_found_one_source() -> None: @dispatch("obj") - def target(obj): + def target(obj: str) -> str: return "default: %s" % obj assert target("dummy") == "default: dummy" -def test_component_not_found_two_sources(): +def test_component_not_found_two_sources() -> None: @dispatch("a", "b") - def target(a, b): + def target(a: str, b: str) -> None: pass assert target.by_args("dummy", "dummy").component is None -def test_call_not_found_two_sources(): +def test_call_not_found_two_sources() -> None: @dispatch("a", "b") - def target(a, b): + def target(a: str, b: str) -> str: return f"a: {a} b: {b}" assert target("dummy1", "dummy2") == "a: dummy1 b: dummy2" -def test_wrong_callable_registered(): +def test_wrong_callable_registered() -> None: @dispatch("obj") - def target(obj): + def target(obj: Any) -> Any: pass - def callable(a, b): + def callable(a: Any, b: Any) -> Any: pass with pytest.raises(RegistrationError): - target.register(callable, a=Alpha) + target.register(callable, a=Alpha) # type: ignore -def test_non_callable_registered(): +def test_non_callable_registered() -> None: @dispatch("obj") - def target(obj): + def target(obj: Any) -> None: pass non_callable = 42 with pytest.raises(RegistrationError): - target.register(non_callable, a=Alpha) + target.register(non_callable, a=Alpha) # type: ignore -def test_call_with_no_args_while_arg_expected(): +def test_call_with_no_args_while_arg_expected() -> None: @dispatch("obj") - def target(obj): + def target(obj: Any) -> str | None: pass - def specific(obj): + def specific(obj: Alpha) -> str: return "specific" target.register(specific, obj=Alpha) # we are not allowed to call target without arguments with pytest.raises(TypeError): - target() + target() # type: ignore with pytest.raises(TypeError): - target.by_args().component + target.by_args().component # type: ignore -def test_call_with_wrong_args(): +def test_call_with_wrong_args() -> None: @dispatch("obj") - def target(obj): + def target(obj: Any) -> str | None: pass - def specific(obj): + def specific(obj: Alpha) -> str: return "specific" target.register(specific, obj=Alpha) # we are not allowed to call target without arguments with pytest.raises(TypeError): - target(wrong=1) + target(wrong=1) # type: ignore with pytest.raises(TypeError): - target.by_args(wrong=1) + target.by_args(wrong=1) # type: ignore -def test_extra_arg_for_call(): +def test_extra_arg_for_call() -> None: @dispatch("obj") - def target(obj, extra): + def target(obj: Any, extra: str) -> str: return "General: %s" % extra - def specific(obj, extra): + def specific(obj: Alpha, extra: str) -> str: return "Specific: %s" % extra target.register(specific, obj=Alpha) @@ -437,15 +440,15 @@ def specific(obj, extra): assert target(beta, "allowed") == "General: allowed" -def test_fallback_to_fallback(): - def fallback(obj): +def test_fallback_to_fallback() -> None: + def fallback(obj: Any) -> str: return "fallback!" @dispatch(match_instance("obj", fallback=fallback)) - def target(obj): + def target(obj: Any) -> str: return "not the fallback we want" - def specific_target(obj): + def specific_target(obj: Alpha) -> str: return "specific" target.register(specific_target, obj=Alpha) @@ -459,12 +462,12 @@ def specific_target(obj): assert target.by_args(Alpha()).fallback is None -def test_fallback_to_dispatch(): +def test_fallback_to_dispatch() -> None: @dispatch("obj") - def target(obj): + def target(obj: Any) -> str: return "fallback" - def specific_target(obj): + def specific_target(obj: Alpha) -> str: return "specific" target.register(specific_target, obj=Alpha) @@ -475,15 +478,15 @@ def specific_target(obj): assert target.by_args(beta).fallback is None -def test_calling_twice(): +def test_calling_twice() -> None: @dispatch("obj") - def target(obj): + def target(obj: Any) -> str: return "fallback" - def a(obj): + def a(obj: Alpha) -> str: return "a" - def b(obj): + def b(obj: Beta) -> str: return "b" target.register(a, obj=Alpha) @@ -493,12 +496,12 @@ def b(obj): assert target(Beta()) == "b" -def test_different_defaults_in_specific_non_dispatch_arg(): +def test_different_defaults_in_specific_non_dispatch_arg() -> None: @dispatch("obj") - def target(obj, blah="default"): + def target(obj: Any, blah: str = "default") -> str: return "fallback: %s" % blah - def a(obj, blah="default 2"): + def a(obj: Any, blah: str = "default 2") -> str: return "a: %s" % blah target.register(a, obj=Alpha) @@ -506,12 +509,12 @@ def a(obj, blah="default 2"): assert target(Alpha()) == "a: default" -def test_different_defaults_in_specific_dispatch_arg(): +def test_different_defaults_in_specific_dispatch_arg() -> None: @dispatch(match_key("key")) - def target(key="default"): + def target(key: str = "default") -> str: return "fallback: %s" % key - def a(key="default 2"): + def a(key: str = "default 2") -> str: return "a: %s" % key target.register(a, key="foo") @@ -521,12 +524,12 @@ def a(key="default 2"): assert target() == "fallback: default" -def test_different_defaults_in_specific_dispatch_arg_causes_dispatch(): +def test_different_defaults_in_specific_dispatch_arg_causes_dispatch() -> None: @dispatch(match_key("key")) - def target(key="foo"): + def target(key: str = "foo") -> str: return "fallback: %s" % key - def a(key="default 2"): + def a(key: str = "default 2") -> str: return "a: %s" % key target.register(a, key="foo") @@ -536,33 +539,38 @@ def a(key="default 2"): assert target() == "a: foo" -def test_add_predicates_no_defaults(): +def test_add_predicates_no_defaults() -> None: class Foo: pass class FooSub(Foo): pass + class Request: + def __init__(self, name: str, request_method: str) -> None: + self.name = name + self.request_method = request_method + @dispatch() - def view(self, request): + def view(self: Any, request: Request) -> str: raise NotImplementedError() - def get_model(self, request): + def get_model(self: Any, request: Request) -> Any: return self - def get_name(self, request): + def get_name(self: Any, request: Request) -> str: return request.name - def get_request_method(self, request): + def get_request_method(self: Any, request: Request) -> str: return request.request_method - def model_fallback(self, request): + def model_fallback(self: Any, request: Request) -> Any: return "Model fallback" - def name_fallback(self, request): + def name_fallback(self: Any, request: Request) -> str: return "Name fallback" - def request_method_fallback(self, request): + def request_method_fallback(self: Any, request: Request) -> str: return "Request method fallback" view.add_predicates( @@ -573,24 +581,19 @@ def request_method_fallback(self, request): ] ) - def foo_default(self, request): + def foo_default(self: Foo, request: Request) -> str: return "foo default" - def foo_post(self, request): + def foo_post(self: Foo, request: Request) -> str: return "foo default post" - def foo_edit(self, request): + def foo_edit(self: Foo, request: Request) -> str: return "foo edit" view.register(foo_default, model=Foo, name="", request_method="GET") view.register(foo_post, model=Foo, name="", request_method="POST") view.register(foo_edit, model=Foo, name="edit", request_method="POST") - class Request: - def __init__(self, name, request_method): - self.name = name - self.request_method = request_method - assert view(Foo(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("edit", "POST")) == "foo edit" @@ -604,33 +607,38 @@ class Bar: assert view(FooSub(), Request("dummy", "GET")) == "Name fallback" -def test_dispatch_external_predicates(): +def test_dispatch_external_predicates() -> None: class Foo: pass class FooSub(Foo): pass + class Request: + def __init__(self, name: str, request_method: str) -> None: + self.name = name + self.request_method = request_method + @dispatch() - def view(self, request): + def view(self: Any, request: Request) -> str: raise NotImplementedError() - def get_model(self, request): + def get_model(self: Any, request: Request) -> Any: return self - def get_name(self, request): + def get_name(self: Any, request: Request) -> str: return request.name - def get_request_method(self, request): + def get_request_method(self: Any, request: Request) -> str: return request.request_method - def model_fallback(self, request): + def model_fallback(self: Any, request: Request) -> str: return "Model fallback" - def name_fallback(self, request): + def name_fallback(self: Any, request: Request) -> str: return "Name fallback" - def request_method_fallback(self, request): + def request_method_fallback(self: Any, request: Request) -> str: return "Request method fallback" view.add_predicates( @@ -641,24 +649,19 @@ def request_method_fallback(self, request): ] ) - def foo_default(self, request): + def foo_default(self: Foo, request: Request) -> str: return "foo default" - def foo_post(self, request): + def foo_post(self: Foo, request: Request) -> str: return "foo default post" - def foo_edit(self, request): + def foo_edit(self: Foo, request: Request) -> str: return "foo edit" view.register(foo_default, model=Foo, name="", request_method="GET") view.register(foo_post, model=Foo, name="", request_method="POST") view.register(foo_edit, model=Foo, name="edit", request_method="POST") - class Request: - def __init__(self, name, request_method): - self.name = name - self.request_method = request_method - assert view(Foo(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("edit", "POST")) == "foo edit" @@ -673,33 +676,38 @@ class Bar: assert view.by_args(Bar(), Request("", "GET")).fallback is model_fallback -def test_dispatch_predicates_register_defaults(): +def test_dispatch_predicates_register_defaults() -> None: class Foo: pass class FooSub(Foo): pass + class Request: + def __init__(self, name: str, request_method: str) -> None: + self.name = name + self.request_method = request_method + @dispatch() - def view(self, request): + def view(self: Any, request: Request) -> str: raise NotImplementedError() - def get_model(self, request): + def get_model(self: Any, request: Request) -> Any: return self - def get_name(self, request): + def get_name(self: Any, request: Request) -> str: return request.name - def get_request_method(self, request): + def get_request_method(self: Any, request: Request) -> str: return request.request_method - def model_fallback(self, request): + def model_fallback(self: Any, request: Request) -> Any: return "Model fallback" - def name_fallback(self, request): + def name_fallback(self: Any, request: Request) -> str: return "Name fallback" - def request_method_fallback(self, request): + def request_method_fallback(self: Any, request: Request) -> str: return "Request method fallback" view.add_predicates( @@ -715,24 +723,19 @@ def request_method_fallback(self, request): ] ) - def foo_default(self, request): + def foo_default(self: Foo, request: Request) -> str: return "foo default" - def foo_post(self, request): + def foo_post(self: Foo, request: Request) -> str: return "foo default post" - def foo_edit(self, request): + def foo_edit(self: Foo, request: Request) -> str: return "foo edit" view.register(foo_default, model=Foo) view.register(foo_post, model=Foo, request_method="POST") view.register(foo_edit, model=Foo, name="edit", request_method="POST") - class Request: - def __init__(self, name, request_method): - self.name = name - self.request_method = request_method - assert view(Foo(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("edit", "POST")) == "foo edit" @@ -746,24 +749,24 @@ class Bar: assert view(FooSub(), Request("dummy", "GET")) == "Name fallback" -def test_key_dict_to_predicate_key(): +def test_key_dict_to_predicate_key() -> None: @dispatch( match_key("foo", default="default foo"), match_key("bar", default="default bar"), ) - def view(self, request): + def view(self: Any, request: Any) -> Any: raise NotImplementedError() assert view.by_predicates(foo="FOO", bar="BAR").key == ("FOO", "BAR") assert view.by_predicates().key == ("default foo", "default bar") -def test_key_dict_to_predicate_key_unknown_keys(): +def test_key_dict_to_predicate_key_unknown_keys() -> None: @dispatch( match_key("foo", default="default foo"), match_key("bar", default="default bar"), ) - def view(self, request): + def view(self: Any, request: Any) -> Any: raise NotImplementedError() # unknown keys are just ignored @@ -773,7 +776,7 @@ def view(self, request): ) -def test_register_dispatch_key_dict(): +def test_register_dispatch_key_dict() -> None: class Foo: pass @@ -781,25 +784,25 @@ class FooSub(Foo): pass @dispatch() - def view(self, request): + def view(self: Any, request: Any) -> Any: raise NotImplementedError() - def get_model(self, request): + def get_model(self: Any, request: Any) -> Any: return self - def get_name(self, request): + def get_name(self: Any, request: Any) -> Any: return request.name - def get_request_method(self, request): + def get_request_method(self: Any, request: Any) -> Any: return request.request_method - def model_fallback(self, request): + def model_fallback(self: Any, request: Any) -> Any: return "Model fallback" - def name_fallback(self, request): + def name_fallback(self: Any, request: Any) -> Any: return "Name fallback" - def request_method_fallback(self, request): + def request_method_fallback(self: Any, request: Any) -> Any: return "Request method fallback" view.add_predicates( @@ -818,35 +821,35 @@ def request_method_fallback(self, request): assert view.by_predicates().key == (None, "", "GET") -def test_fallback_should_already_use_subset(): +def test_fallback_should_already_use_subset() -> None: class Request: - def __init__(self, name, request_method, body_obj): + def __init__(self, name: str, request_method: str, body_obj: Any) -> None: self.name = name self.request_method = request_method self.body_obj = body_obj - def get_model(self, request): + def get_model(self: Any, request: Request) -> Any: return self - def get_name(self, request): + def get_name(self: Any, request: Request) -> str: return request.name - def get_request_method(self, request): + def get_request_method(self: Any, request: Request) -> str: return request.request_method - def get_body_model(self, request): + def get_body_model(self: Any, request: Request) -> Any: return request.body_obj.__class__ - def model_fallback(self, request): + def model_fallback(self: Any, request: Request) -> Any: return "Model fallback" - def name_fallback(self, request): + def name_fallback(self: Any, request: Request) -> str: return "Name fallback" - def request_method_fallback(self, request): + def request_method_fallback(self: Any, request: Request) -> str: return "Request method fallback" - def body_model_fallback(self, request): + def body_model_fallback(self: Any, request: Request) -> Any: return "Body model fallback" @dispatch( @@ -860,10 +863,10 @@ def body_model_fallback(self, request): ), match_class("body_model", get_body_model, body_model_fallback, default=object), ) - def view(self, request): + def view(self: Any, request: Request) -> str: return "view fallback" - def exception_view(self, request): + def exception_view(self: Exception, request: Request) -> str: return "exception view" view.register(exception_view, model=Exception) @@ -877,7 +880,7 @@ class Item: class Item2: pass - def collection_add(self, request): + def collection_add(self: Collection, request: Request) -> str: return "collection add" view.register( @@ -900,12 +903,12 @@ def collection_add(self, request): ) -def test_dispatch_missing_argument(): +def test_dispatch_missing_argument() -> None: @dispatch("obj") - def foo(obj): + def foo(obj: object) -> Any: pass - def for_bar(obj): + def for_bar(obj: object) -> Any: return "for bar" class Bar: @@ -914,18 +917,18 @@ class Bar: foo.register(for_bar, obj=Bar) with pytest.raises(TypeError): - assert foo() + assert foo() # type: ignore -def test_register_dispatch_predicates_twice(): +def test_register_dispatch_predicates_twice() -> None: @dispatch() - def foo(a, b): + def foo(a: Any, b: Any) -> Any: pass - def for_bar(a, b): + def for_bar(a: Any, b: Any) -> Any: return "for bar" - def for_qux(a, b): + def for_qux(a: Any, b: Any) -> Any: return "for qux" class Bar: @@ -943,31 +946,31 @@ class Qux: assert foo(Qux(), Qux()) == "for qux" -def test_dict_to_predicate_key_for_no_dispatch(): +def test_dict_to_predicate_key_for_no_dispatch() -> None: @dispatch() - def foo(): + def foo() -> None: pass assert foo.by_predicates().key == () -def test_dispatch_clean(): +def test_dispatch_clean() -> None: @dispatch("obj") - def foo(obj): + def foo(obj: Any) -> str: return "default" - def for_bar(obj): + def for_bar(obj: Bar) -> str: return obj.method() - def for_qux(obj): + def for_qux(obj: Qux) -> str: return obj.method() class Bar: - def method(self): + def method(self) -> str: return "bar's method" class Qux: - def method(self): + def method(self) -> str: return "qux's method" foo.register(for_bar, obj=Bar) @@ -982,23 +985,23 @@ def method(self): assert foo(Qux()) == "default" -def test_dispatch_clean_add_predicates(): +def test_dispatch_clean_add_predicates() -> None: @dispatch() - def foo(obj): + def foo(obj: Any) -> str: return "default" - def for_bar(obj): + def for_bar(obj: Bar) -> str: return obj.method() - def for_qux(obj): + def for_qux(obj: Qux) -> str: return obj.method() class Bar: - def method(self): + def method(self) -> str: return "bar's method" class Qux: - def method(self): + def method(self) -> str: return "qux's method" foo.add_predicates([match_instance("obj")]) @@ -1017,9 +1020,9 @@ def method(self): foo.register(for_qux) -def test_dispatch_introspection(): +def test_dispatch_introspection() -> None: @dispatch("obj") - def foo(obj): + def foo(obj: object) -> str: "return the foo of an object." return "default" @@ -1028,25 +1031,25 @@ def foo(obj): assert foo.__module__ == __name__ -def test_dispatch_argname_with_decorator(): +def test_dispatch_argname_with_decorator() -> None: @dispatch("obj") - def foo(obj): + def foo(obj: Any) -> Any: pass class Bar: - def method(self): + def method(self) -> str: return "bar's method" class Qux: - def method(self): + def method(self) -> str: return "qux's method" @foo.register(obj=Bar) - def for_bar(obj): + def for_bar(obj: Bar) -> str: return obj.method() @foo.register(obj=Qux) - def for_qux(obj): + def for_qux(obj: Qux) -> str: return obj.method() assert foo(Bar()) == "bar's method" @@ -1056,9 +1059,9 @@ def for_qux(obj): assert foo(Qux()) == for_qux(Qux()) -def test_component_lookup_before_call_and_no_registrations(): +def test_component_lookup_before_call_and_no_registrations() -> None: @dispatch("obj") - def foo(obj): + def foo(obj: Any) -> Any: pass class Bar: @@ -1067,46 +1070,46 @@ class Bar: assert foo.by_args(Bar()).component is None -def test_predicate_key_too_few_arguments_gives_typeerror(): +def test_predicate_key_too_few_arguments_gives_typeerror() -> None: @dispatch("obj") - def foo(obj): + def foo(obj: Any) -> Any: pass - def for_bar(obj): + def for_bar(obj: Any) -> Any: return obj.method() - def for_qux(obj): + def for_qux(obj: Any) -> Any: return obj.method() with pytest.raises(TypeError): - assert foo.by_args() + assert foo.by_args() # type: ignore -def test_predicate_key_too_many_arguments_gives_typeerror(): +def test_predicate_key_too_many_arguments_gives_typeerror() -> None: @dispatch("obj") - def foo(obj): + def foo(obj: Any) -> Any: pass - def for_bar(obj): + def for_bar(obj: Any) -> Any: return obj.method() - def for_qux(obj): + def for_qux(obj: Any) -> Any: return obj.method() with pytest.raises(TypeError): - assert foo.by_args(1, 2) + assert foo.by_args(1, 2) # type: ignore -def test_predicate_key_wrong_keyword_argument_gives_typeerror(): +def test_predicate_key_wrong_keyword_argument_gives_typeerror() -> None: @dispatch("obj") - def foo(obj): + def foo(obj: Any) -> Any: pass - def for_bar(obj): + def for_bar(obj: Any) -> Any: return obj.method() - def for_qux(obj): + def for_qux(obj: Any) -> Any: return obj.method() with pytest.raises(TypeError): - assert foo.by_args(wrong=1) + assert foo.by_args(wrong=1) # type: ignore diff --git a/reg/tests/test_dispatch_method.py b/reg/tests/test_dispatch_method.py index 8ef4780..7c7f4fc 100644 --- a/reg/tests/test_dispatch_method.py +++ b/reg/tests/test_dispatch_method.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from types import FunctionType +from typing import TYPE_CHECKING, Any import pytest +from ..dispatch import dispatch from ..context import ( - dispatch, dispatch_method, methodify, clean_dispatch_methods, @@ -9,14 +12,17 @@ from ..predicate import match_instance from ..error import RegistrationError +if TYPE_CHECKING: + from collections.abc import Callable + -def test_dispatch_method_explicit_fallback(): - def obj_fallback(self, obj): +def test_dispatch_method_explicit_fallback() -> None: + def obj_fallback(self: Foo, obj: object) -> str: return "Obj fallback" class Foo: @dispatch_method(match_instance("obj", fallback=obj_fallback)) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: @@ -37,10 +43,10 @@ class Beta: assert foo.bar(None) == "Obj fallback" -def test_dispatch_method_without_fallback(): +def test_dispatch_method_without_fallback() -> None: class Foo: @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: @@ -61,10 +67,10 @@ class Beta: assert foo.bar(None) == "default" -def test_dispatch_method_string_predicates(): +def test_dispatch_method_string_predicates() -> None: class Foo: @dispatch_method("obj") - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: @@ -85,10 +91,10 @@ class Beta: assert foo.bar(None) == "default" -def test_dispatch_method_add_predicates(): +def test_dispatch_method_add_predicates() -> None: class Foo: @dispatch_method() - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" Foo.bar.add_predicates([match_instance("obj")]) @@ -111,10 +117,10 @@ class Beta: assert foo.bar(None) == "default" -def test_dispatch_method_register_function(): +def test_dispatch_method_register_function() -> None: class Foo: @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: @@ -135,56 +141,56 @@ class Beta: assert foo.bar(None) == "default" -def test_dispatch_method_register_function_wrong_signature_too_long(): +def test_dispatch_method_register_function_wrong_signature_too_long() -> None: class Foo: @dispatch_method("obj") - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: pass with pytest.raises(RegistrationError): - Foo.bar.register(methodify(lambda obj, extra: "Alpha"), obj=Alpha) + Foo.bar.register(methodify(lambda obj, extra: "Alpha"), obj=Alpha) # type: ignore -def test_dispatch_method_register_function_wrong_signature_too_short(): +def test_dispatch_method_register_function_wrong_signature_too_short() -> None: class Foo: @dispatch_method("obj") - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: pass with pytest.raises(RegistrationError): - Foo.bar.register(methodify(lambda: "Alpha"), obj=Alpha) + Foo.bar.register(methodify(lambda: "Alpha"), obj=Alpha) # type: ignore -def test_dispatch_method_register_non_callable(): +def test_dispatch_method_register_non_callable() -> None: class Foo: @dispatch_method("obj") - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: pass with pytest.raises(RegistrationError): - Foo.bar.register("cannot call this", obj=Alpha) + Foo.bar.register("cannot call this", obj=Alpha) # type: ignore -def test_dispatch_method_methodify_non_callable(): +def test_dispatch_method_methodify_non_callable() -> None: with pytest.raises(TypeError): - methodify("cannot call this") + methodify("cannot call this") # type: ignore -def test_dispatch_method_register_auto(): +def test_dispatch_method_register_auto() -> None: class Foo: x = "X" @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: @@ -205,10 +211,10 @@ class Beta: assert foo.bar(None) == "default" -def test_dispatch_method_class_method_accessed_first(): +def test_dispatch_method_class_method_accessed_first() -> None: class Foo: @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: @@ -227,13 +233,13 @@ class Beta: assert foo.bar(None) == "default" -def test_dispatch_method_accesses_instance(): +def test_dispatch_method_accesses_instance() -> None: class Foo: - def __init__(self, x): + def __init__(self, x: str) -> None: self.x = x @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default %s" % self.x class Alpha: @@ -252,10 +258,10 @@ class Beta: assert foo.bar(None) == "default hello" -def test_dispatch_method_inheritance_register_on_subclass(): +def test_dispatch_method_inheritance_register_on_subclass() -> None: class Foo: @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Sub(Foo): @@ -279,10 +285,10 @@ class Beta: assert sub.bar(None) == "default" -def test_dispatch_method_inheritance_separation(): +def test_dispatch_method_inheritance_separation() -> None: class Foo: @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Sub(Foo): @@ -315,14 +321,14 @@ class Beta: assert sub.bar(None) == "default" -def test_dispatch_method_inheritance_separation_multiple(): +def test_dispatch_method_inheritance_separation_multiple() -> None: class Foo: @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "bar default" @dispatch_method(match_instance("obj")) - def qux(self, obj): + def qux(self, obj: Any) -> str: return "qux default" class Sub(Foo): @@ -364,13 +370,13 @@ class Beta: assert sub.qux(None) == "qux default" -def test_dispatch_method_api_available(): - def obj_fallback(self, obj): +def test_dispatch_method_api_available() -> None: + def obj_fallback(self: Any, obj: Any) -> str: return "Obj fallback" class Foo: @dispatch_method(match_instance("obj", fallback=obj_fallback)) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: @@ -381,10 +387,10 @@ class Beta: foo = Foo() - def alpha_func(self, obj): + def alpha_func(self: Foo, obj: Alpha) -> str: return "Alpha" - def beta_func(self, obj): + def beta_func(self: Foo, obj: Beta) -> str: return "Beta" Foo.bar.register(alpha_func, obj=Alpha) @@ -400,10 +406,10 @@ def beta_func(self, obj): assert foo.bar.by_args(None).all_matches == [] -def test_dispatch_method_with_register_function_value(): +def test_dispatch_method_with_register_function_value() -> None: class Foo: @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: @@ -416,10 +422,10 @@ class Beta: assert foo.bar(Alpha()) == "default" - def alpha_func(obj): + def alpha_func(obj: Alpha) -> str: return "Alpha" - def beta_func(obj): + def beta_func(obj: Beta) -> str: return "Beta" Foo.bar.register(methodify(alpha_func), obj=Alpha) @@ -428,10 +434,10 @@ def beta_func(obj): assert unmethodify(foo.bar.by_args(Alpha()).component) is alpha_func -def test_dispatch_method_with_register_auto_value(): +def test_dispatch_method_with_register_auto_value() -> None: class Foo: @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: @@ -444,10 +450,10 @@ class Beta: assert foo.bar(Alpha()) == "default" - def alpha_func(obj): + def alpha_func(obj: Alpha) -> str: return "Alpha" - def beta_func(app, obj): + def beta_func(app: Foo, obj: Beta) -> str: return "Beta" Foo.bar.register(methodify(alpha_func, "app"), obj=Alpha) @@ -459,11 +465,11 @@ def beta_func(app, obj): assert foo.bar.by_args(Beta()).component is beta_func -def test_install_method(): +def test_install_method() -> None: class Target: - pass + m: Callable[..., str] - def f(self, a): + def f(self: Target, a: str) -> str: return a Target.m = f @@ -473,11 +479,11 @@ def f(self, a): assert t.m("A") == "A" -def test_install_auto_method_function_no_app_arg(): +def test_install_auto_method_function_no_app_arg() -> None: class Target: - pass + m: Callable[..., str] - def f(a): + def f(a: str) -> str: return a Target.m = methodify(f, "app") @@ -488,11 +494,11 @@ def f(a): assert unmethodify(t.m) is f -def test_install_auto_method_function_app_arg(): +def test_install_auto_method_function_app_arg() -> None: class Target: - pass + m: Callable[..., str] - def g(app, a): + def g(app: Target, a: str) -> str: assert isinstance(app, Target) return a @@ -503,12 +509,12 @@ def g(app, a): assert unmethodify(t.m) is g -def test_install_auto_method_method_no_app_arg(): +def test_install_auto_method_method_no_app_arg() -> None: class Target: - pass + m: Callable[..., str] class Foo: - def f(self, a): + def f(self, a: str) -> str: return a f = Foo().f @@ -521,12 +527,12 @@ def f(self, a): assert unmethodify(t.m) is f -def test_install_auto_method_method_app_arg(): +def test_install_auto_method_method_app_arg() -> None: class Target: - pass + m: Callable[..., str] class Bar: - def g(self, app, a): + def g(self, app: Target, a: str) -> str: assert isinstance(app, Target) return a @@ -540,12 +546,12 @@ def g(self, app, a): assert unmethodify(t.m) is g -def test_install_instance_method(): +def test_install_instance_method() -> None: class Target: - pass + m: Callable[..., str] class Bar: - def g(self, a): + def g(self, a: str) -> str: assert isinstance(self, Bar) return a @@ -559,10 +565,10 @@ def g(self, a): assert unmethodify(t.m) is g -def test_dispatch_method_introspection(): +def test_dispatch_method_introspection() -> None: class Foo: @dispatch_method("obj") - def bar(self, obj): + def bar(self, obj: Any) -> str: "Return the bar of an object." return "default" @@ -571,10 +577,10 @@ def bar(self, obj): assert Foo.bar.__module__ == __name__ -def test_dispatch_method_clean(): +def test_dispatch_method_clean() -> None: class Foo: @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Qux(Foo): @@ -609,10 +615,10 @@ class Beta: assert qux.bar(Alpha()) == "Qux Alpha" -def test_clean_dispatch_methods(): +def test_clean_dispatch_methods() -> None: class Foo: @dispatch_method(match_instance("obj")) - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Qux(Foo): @@ -646,10 +652,10 @@ class Beta: assert qux.bar(Alpha()) == "Qux Alpha" -def test_replacing_with_normal_method(): +def test_replacing_with_normal_method() -> None: class Foo: @dispatch_method("obj") - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class Alpha: @@ -664,10 +670,12 @@ class Beta: # Simply using Foo.bar wouldn't have worked here, as it would # invoke the descriptor: - assert isinstance(Foo.bar, FunctionType) + if not TYPE_CHECKING: + # NOTE: mypy does not like the following two statements + assert isinstance(Foo.bar, FunctionType) - # We now replace the descriptor with the actual unbound method: - Foo.bar = Foo.bar + # We now replace the descriptor with the actual unbound method: + Foo.bar = Foo.bar # Now the descriptor is gone assert isinstance(vars(Foo)["bar"], FunctionType) @@ -682,10 +690,10 @@ class Beta: assert foo.bar(None) == "default" -def test_replacing_with_normal_method_and_its_effect_on_inheritance(): +def test_replacing_with_normal_method_and_its_effect_on_inheritance_1() -> None: class Foo: @dispatch_method("obj") - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class SubFoo(Foo): @@ -712,7 +720,7 @@ class Beta: assert subfoo.bar(None) == "default" # We now replace the descriptor with the actual unbound method: - Foo.bar = Foo.bar + Foo.bar = Foo.bar # type: ignore # Now the descriptor is gone assert isinstance(vars(Foo)["bar"], FunctionType) @@ -729,49 +737,70 @@ class Beta: assert subfoo.bar(Beta()) == "Beta" assert subfoo.bar(None) == "default" + +def test_replacing_with_normal_method_and_its_effect_on_inheritance_2() -> None: # This is exactly the same behavior we'd get by using dispatch # instead of dispatch_method: - del Foo, SubFoo class Foo: @dispatch("obj") - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class SubFoo(Foo): pass + class Alpha: + pass + + class Beta: + pass + # Foo and SubFoo share the same registry: Foo.bar.register(obj=Alpha)(lambda self, obj: "Alpha") SubFoo.bar.register(obj=Beta)(lambda self, obj: "Beta") - foo = Foo() - assert foo.bar(Alpha()) == "Alpha" - assert foo.bar(Beta()) == "Beta" - assert foo.bar(None) == "default" - - subfoo = SubFoo() - assert subfoo.bar(Alpha()) == "Alpha" - assert subfoo.bar(Beta()) == "Beta" - assert subfoo.bar(None) == "default" - + # NOTE: Using dispatch instead of dispatch_method makes type checkers + # not understand, that the first argument can be omitted, it's + # technically a solveable problem to a degree, but it takes a + # a lot of extra descriptor protocols and complex overloads to + # make work. It's probably not worth making worth at the moment + # since you can just use dispatch_method + if not TYPE_CHECKING: + foo = Foo() + assert foo.bar(Alpha()) == "Alpha" + assert foo.bar(Beta()) == "Beta" + assert foo.bar(None) == "default" + + subfoo = SubFoo() + assert subfoo.bar(Alpha()) == "Alpha" + assert subfoo.bar(Beta()) == "Beta" + assert subfoo.bar(None) == "default" + + +def test_replacing_with_normal_method_and_its_effect_on_inheritance_3() -> None: # Now we start again, and do the replacement for both subclass and # parent class, in this order: - del Foo, SubFoo class Foo: @dispatch_method("obj") - def bar(self, obj): + def bar(self, obj: Any) -> str: return "default" class SubFoo(Foo): pass + class Alpha: + pass + + class Beta: + pass + Foo.bar.register(obj=Alpha)(lambda self, obj: "Alpha") Foo.bar.register(obj=Beta)(lambda self, obj: "Beta") - SubFoo.bar = SubFoo.bar - Foo.bar = Foo.bar + SubFoo.bar = SubFoo.bar # type: ignore + Foo.bar = Foo.bar # type: ignore # This has kept two separate registries: foo = Foo() @@ -785,7 +814,7 @@ class SubFoo(Foo): assert subfoo.bar(None) == "default" -def unmethodify(func): +def unmethodify(func: Any) -> Any: """Reverses methodify operation. Given an object that is returned from a call to diff --git a/reg/tests/test_docgen.py b/reg/tests/test_docgen.py index 6754b05..b899c06 100644 --- a/reg/tests/test_docgen.py +++ b/reg/tests/test_docgen.py @@ -1,15 +1,19 @@ +from __future__ import annotations + import pydoc +import pytest import sys +from pathlib import Path from sphinx.application import Sphinx from .fixtures.module import Foo, foo -def rstrip_lines(s): +def rstrip_lines(s: str) -> str: "Delete trailing spaces from each line in s." return "\n".join(line.rstrip() for line in s.splitlines()) -def test_dispatch_method_class_help(capsys): +def test_dispatch_method_class_help(capsys: pytest.CaptureFixture[str]) -> None: pydoc.help(Foo) out, err = capsys.readouterr() assert ( @@ -43,7 +47,7 @@ class Foo({builtins}.object) ) -def test_dispatch_method_help(capsys): +def test_dispatch_method_help(capsys: pytest.CaptureFixture[str]) -> None: pydoc.help(Foo.bar) out, err = capsys.readouterr() assert rstrip_lines(out) == """\ @@ -54,7 +58,7 @@ def test_dispatch_method_help(capsys): """ -def test_dispatch_help(capsys): +def test_dispatch_help(capsys: pytest.CaptureFixture[str]) -> None: pydoc.help(foo) out, err = capsys.readouterr() assert rstrip_lines(out) == """\ @@ -65,10 +69,10 @@ def test_dispatch_help(capsys): """ -def test_autodoc(tmpdir): - root = str(tmpdir) - tmpdir.join("conf.py").write("extensions = ['sphinx.ext.autodoc']\n") - tmpdir.join("contents.rst").write( +def test_autodoc(tmp_path: Path) -> None: + root = str(tmp_path) + (tmp_path / "conf.py").write_text("extensions = ['sphinx.ext.autodoc']\n") + (tmp_path / "contents.rst").write_text( ".. automodule:: reg.tests.fixtures.module\n" " :members:\n" ) # status=None makes Sphinx completely quiet, in case you run @@ -76,7 +80,7 @@ def test_autodoc(tmpdir): # remove it. app = Sphinx(root, root, root + "/build", root, "text", status=None) app.build() - assert tmpdir.join("build/contents.txt").read() == """\ + assert (tmp_path / "build/contents.txt").read_text() == """\ Sample module for testing autodoc. class reg.tests.fixtures.module.Foo diff --git a/reg/tests/test_predicate.py b/reg/tests/test_predicate.py index 3dad15e..0de7a4a 100644 --- a/reg/tests/test_predicate.py +++ b/reg/tests/test_predicate.py @@ -9,12 +9,12 @@ import pytest -def test_key_index_permutations(): +def test_key_index_permutations() -> None: i = KeyIndex() assert list(i.permutations("GET")) == ["GET"] -def test_class_index_permutations(): +def test_class_index_permutations() -> None: class Foo: pass @@ -31,7 +31,7 @@ class Qux: assert list(i.permutations(Qux)) == [Qux, object] -def test_multi_class_predicate_permutations(): +def test_multi_class_predicate_permutations() -> None: class ABase: pass @@ -59,7 +59,7 @@ class BSub(BBase): ] -def test_multi_key_predicate_permutations(): +def test_multi_key_predicate_permutations() -> None: i = PredicateRegistry( match_key("a"), match_key("b"), @@ -69,7 +69,7 @@ def test_multi_key_predicate_permutations(): assert list(i.permutations(["A", "B", "C"])) == [("A", "B", "C")] -def test_registry_single_key_predicate(): +def test_registry_single_key_predicate() -> None: r = PredicateRegistry(match_key("a")) r.register(("A",), "A value") @@ -80,7 +80,7 @@ def test_registry_single_key_predicate(): assert list(r.all(("B",))) == [] -def test_registry_single_class_predicate(): +def test_registry_single_class_predicate() -> None: r = PredicateRegistry(match_instance("a")) class Foo: @@ -102,7 +102,7 @@ class Qux: assert list(r.all((Qux,))) == [] -def test_registry_single_classic_class_predicate(): +def test_registry_single_classic_class_predicate() -> None: r = PredicateRegistry(match_instance("a")) class Foo: @@ -124,7 +124,7 @@ class Qux: assert list(r.all((Qux,))) == [] -def test_registry_single_class_predicate_also_sub(): +def test_registry_single_class_predicate_also_sub() -> None: r = PredicateRegistry(match_instance("a")) class Foo: @@ -147,7 +147,7 @@ class Qux: assert list(r.all((Qux,))) == [] -def test_registry_multi_class_predicate(): +def test_registry_multi_class_predicate() -> None: r = PredicateRegistry( match_instance("a"), match_instance("b"), @@ -182,7 +182,7 @@ class BB(B): assert list(r.all((object, B))) == [] -def test_registry_multi_mixed_predicate_class_key(): +def test_registry_multi_mixed_predicate_class_key() -> None: r = PredicateRegistry( match_instance("a"), match_key("b"), @@ -212,7 +212,7 @@ class Unknown: assert list(r.all((Unknown, "B"))) == [] -def test_registry_multi_mixed_predicate_key_class(): +def test_registry_multi_mixed_predicate_key_class() -> None: r = PredicateRegistry( match_key("a"), match_instance("b"), @@ -240,20 +240,21 @@ class Unknown: assert list(r.all(("unknown", Unknown))) == [] -def test_single_predicate_get_key(): - def get_key(foo): +def test_single_predicate_get_key() -> None: + def get_key(foo: str) -> str: return foo p = match_key("a", get_key) + assert p.get_key is not None assert p.get_key({"foo": "value"}) == "value" -def test_multi_predicate_get_key(): - def a_key(**d): +def test_multi_predicate_get_key() -> None: + def a_key(**d: str) -> str: return d["a"] - def b_key(**d): + def b_key(**d: str) -> str: return d["b"] p = PredicateRegistry(match_key("a", a_key), match_key("b", b_key)) @@ -261,7 +262,7 @@ def b_key(**d): assert p.key(a="A", b="B") == ("A", "B") -def test_single_predicate_fallback(): +def test_single_predicate_fallback() -> None: r = PredicateRegistry(match_key("a", fallback="fallback")) r.register(("A",), "A value") @@ -271,7 +272,7 @@ def test_single_predicate_fallback(): assert r.fallback(("B",)) == "fallback" -def test_multi_predicate_fallback(): +def test_multi_predicate_fallback() -> None: r = PredicateRegistry( match_key("a", fallback="fallback1"), match_key("b", fallback="fallback2"), @@ -290,7 +291,7 @@ def test_multi_predicate_fallback(): assert list(r.all(("C", "B"))) == [] -def test_predicate_self_request(): +def test_predicate_self_request() -> None: m = PredicateRegistry(match_key("a"), match_key("b", fallback="registered for all")) m.register(("foo", "POST"), "registered for post") @@ -304,14 +305,14 @@ def test_predicate_self_request(): # XXX using an incomplete key returns undefined results -def test_predicate_duplicate_key(): +def test_predicate_duplicate_key() -> None: m = PredicateRegistry(match_key("a"), match_key("b", fallback="registered for all")) m.register(("foo", "POST"), "registered for post") with pytest.raises(RegistrationError): m.register(("foo", "POST"), "registered again") -def test_name_request_method_body_model_registered_for_base(): +def test_name_request_method_body_model_registered_for_base() -> None: m = PredicateRegistry( match_key("name", fallback="name fallback"), match_key("request_method", fallback="request_method fallback"), @@ -336,7 +337,7 @@ class Bar(Foo): assert m.component(("foo", "POST", Bar)) == "post foo" -def test_name_request_method_body_model_registered_for_base_and_sub(): +def test_name_request_method_body_model_registered_for_base_and_sub() -> None: m = PredicateRegistry( match_key("name", fallback="name fallback"), match_key("request", fallback="request_method fallback"), @@ -365,14 +366,14 @@ class Bar(Foo): assert m.component(("foo", "POST", Bar)) == "post bar" -def test_key_by_predicate_name(): +def test_key_by_predicate_name() -> None: p = match_key("foo", default="default") assert p.key_by_predicate_name({"foo": "value"}) == "value" assert p.key_by_predicate_name({}) == "default" -def test_multi_key_by_predicate_name(): +def test_multi_key_by_predicate_name() -> None: p = PredicateRegistry( match_key("foo", default="default foo"), match_key("bar", default="default bar"), @@ -384,7 +385,7 @@ def test_multi_key_by_predicate_name(): assert p.key_dict_to_predicate_key({}) == ("default foo", "default bar") -def test_nameless_predicate_key(): +def test_nameless_predicate_key() -> None: p = match_key("a") assert p.key_by_predicate_name({}) is None diff --git a/reg/tests/test_registry.py b/reg/tests/test_registry.py index 6cd5bca..4150f08 100644 --- a/reg/tests/test_registry.py +++ b/reg/tests/test_registry.py @@ -1,44 +1,61 @@ +from __future__ import annotations + +import pytest +from typing import Any, ParamSpec, TypeVar, TYPE_CHECKING from ..predicate import PredicateRegistry, match_instance, match_key from ..cache import DictCachingKeyLookup, LruCachingKeyLookup from ..error import RegistrationError from ..dispatch import dispatch -import pytest + +if TYPE_CHECKING: + from collections.abc import Callable + from ..types import DispatchCall + + _T = TypeVar("_T") + _P = ParamSpec("_P") -def register_value(generic, key, value): +def register_value( + generic: DispatchCall[_P, _T], key: Any, value: Callable[_P, _T] +) -> None: """Low-level function that directly uses the internal registry of the generic function to register an implementation. """ - generic.register.__self__.registry.register(key, value) + generic.register.__self__.registry.register(key, value) # type: ignore[attr-defined] -def test_registry(): +def test_registry() -> None: class Foo: pass class FooSub(Foo): pass + class Request: + def __init__(self, name: str, request_method: str) -> None: + self.name = name + self.request_method = request_method + @dispatch() - def view(self, request): + def view(self: Any, request: Request) -> str: raise NotImplementedError() - def get_model(self, request): + def get_model(self: Any, request: Request) -> Any: return self - def get_name(self, request): + def get_name(self: Any, request: Request) -> str: return request.name - def get_request_method(self, request): + def get_request_method(self: Any, request: Request) -> str: return request.request_method - def model_fallback(self, request): + def model_fallback(self: Any, request: Request) -> str: return "Model fallback" - def name_fallback(self, request): + def name_fallback(self: Any, request: Request) -> str: return "Name fallback" - def request_method_fallback(self, request): + def request_method_fallback(self: Any, request: Request) -> str: return "Request method fallback" view.add_predicates( @@ -49,13 +66,13 @@ def request_method_fallback(self, request): ] ) - def foo_default(self, request): + def foo_default(self: Foo, request: Request) -> str: return "foo default" - def foo_post(self, request): + def foo_post(self: Foo, request: Request) -> str: return "foo default post" - def foo_edit(self, request): + def foo_edit(self: Foo, request: Request) -> str: return "foo edit" register_value(view, (Foo, "", "GET"), foo_default) @@ -69,11 +86,6 @@ def foo_edit(self, request): assert key_lookup.component((FooSub, "", "GET")) is foo_default assert key_lookup.component((FooSub, "", "POST")) is foo_post - class Request: - def __init__(self, name, request_method): - self.name = name - self.request_method = request_method - assert view(Foo(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("edit", "POST")) == "foo edit" @@ -87,7 +99,7 @@ class Bar: assert view(FooSub(), Request("dummy", "GET")) == "Name fallback" -def test_predicate_registry_class_lookup(): +def test_predicate_registry_class_lookup() -> None: reg = PredicateRegistry(match_instance("obj")) class Document: @@ -114,7 +126,7 @@ class Other: assert reg.component((Other,)) is None -def test_predicate_registry_target_find_specific(): +def test_predicate_registry_target_find_specific() -> None: reg = PredicateRegistry(match_instance("obj")) reg2 = PredicateRegistry(match_instance("obj")) @@ -124,12 +136,6 @@ class Document: class SpecialDocument(Document): pass - def linecount(obj): - pass - - def special_linecount(obj): - pass - reg.register((Document,), "line count") reg2.register((Document,), "special line count") @@ -140,8 +146,8 @@ def special_linecount(obj): assert reg2.component((SpecialDocument,)) == "special line count" -def test_registry_no_sources(): - reg = PredicateRegistry() +def test_registry_no_sources() -> None: + reg = PredicateRegistry[str]() class Animal: pass @@ -150,7 +156,7 @@ class Animal: assert reg.component(()) == "elephant" -def test_register_twice_with_predicate(): +def test_register_twice_with_predicate() -> None: reg = PredicateRegistry(match_instance("obj")) class Document: @@ -161,40 +167,45 @@ class Document: reg.register((Document,), "another line count") -def test_register_twice_without_predicates(): - reg = PredicateRegistry() +def test_register_twice_without_predicates() -> None: + reg = PredicateRegistry[str]() reg.register((), "once") with pytest.raises(RegistrationError): reg.register((), "twice") -def test_dict_caching_registry(): +def test_dict_caching_registry() -> None: class Foo: pass class FooSub(Foo): pass - def get_model(self, request): + class Request: + def __init__(self, name: str, request_method: str) -> None: + self.name = name + self.request_method = request_method + + def get_model(self: Any, request: Request) -> Any: return self - def get_name(self, request): + def get_name(self: Any, request: Request) -> str: return request.name - def get_request_method(self, request): + def get_request_method(self: Any, request: Request) -> str: return request.request_method - def model_fallback(self, request): + def model_fallback(self: Any, request: Request) -> str: return "Model fallback" - def name_fallback(self, request): + def name_fallback(self: Any, request: Request) -> str: return "Name fallback" - def request_method_fallback(self, request): + def request_method_fallback(self: Any, request: Request) -> str: return "Request method fallback" - def get_caching_key_lookup(r): + def get_caching_key_lookup(r: PredicateRegistry) -> DictCachingKeyLookup: return DictCachingKeyLookup(r) @dispatch( @@ -203,27 +214,22 @@ def get_caching_key_lookup(r): match_key("request_method", get_request_method, request_method_fallback), get_key_lookup=get_caching_key_lookup, ) - def view(self, request): + def view(self: Any, request: Request) -> str: raise NotImplementedError() - def foo_default(self, request): + def foo_default(self: Foo, request: Request) -> str: return "foo default" - def foo_post(self, request): + def foo_post(self: Foo, request: Request) -> str: return "foo default post" - def foo_edit(self, request): + def foo_edit(self: Foo, request: Request) -> str: return "foo edit" register_value(view, (Foo, "", "GET"), foo_default) register_value(view, (Foo, "", "POST"), foo_post) register_value(view, (Foo, "edit", "POST"), foo_edit) - class Request: - def __init__(self, name, request_method): - self.name = name - self.request_method = request_method - assert view(Foo(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("edit", "POST")) == "foo edit" @@ -234,9 +240,9 @@ def __init__(self, name, request_method): ) # use a bit of inside knowledge to check the cache is filled - assert view.key_lookup.component.__self__.get((Foo, "", "GET")) is not None - assert view.key_lookup.component.__self__.get((FooSub, "", "GET")) is not None - assert view.key_lookup.component.__self__.get((FooSub, "edit", "POST")) is not None + assert view.key_lookup.component.__self__.get((Foo, "", "GET")) is not None # type: ignore[attr-defined] + assert view.key_lookup.component.__self__.get((FooSub, "", "GET")) is not None # type: ignore[attr-defined] + assert view.key_lookup.component.__self__.get((FooSub, "edit", "POST")) is not None # type: ignore[attr-defined] # now let's do this again. this time things come from the component cache assert view(Foo(), Request("", "GET")) == "foo default" @@ -246,7 +252,7 @@ def __init__(self, name, request_method): key_lookup = view.key_lookup # prime and check the all cache assert view.by_args(Foo(), Request("", "GET")).all_matches == [foo_default] - assert key_lookup.all.__self__.get((Foo, "", "GET")) is not None + assert key_lookup.all.__self__.get((Foo, "", "GET")) is not None # type: ignore[attr-defined] # should be coming from cache now assert view.by_args(Foo(), Request("", "GET")).all_matches == [foo_default] @@ -259,7 +265,7 @@ class Bar: assert view(FooSub(), Request("dummy", "GET")) == "Name fallback" # fallbacks get cached too - assert key_lookup.fallback.__self__.get((Bar, "", "GET")) is model_fallback + assert key_lookup.fallback.__self__.get((Bar, "", "GET")) is model_fallback # type: ignore[attr-defined] # these come from the fallback cache now assert view(Bar(), Request("", "GET")) == "Model fallback" @@ -268,32 +274,37 @@ class Bar: assert view(FooSub(), Request("dummy", "GET")) == "Name fallback" -def test_lru_caching_registry(): +def test_lru_caching_registry() -> None: class Foo: pass class FooSub(Foo): pass - def get_model(self, request): + class Request: + def __init__(self, name: str, request_method: str) -> None: + self.name = name + self.request_method = request_method + + def get_model(self: Any, request: Request) -> Any: return self - def get_name(self, request): + def get_name(self: Any, request: Request) -> str: return request.name - def get_request_method(self, request): + def get_request_method(self: Any, request: Request) -> str: return request.request_method - def model_fallback(self, request): + def model_fallback(self: Any, request: Request) -> str: return "Model fallback" - def name_fallback(self, request): + def name_fallback(self: Any, request: Request) -> str: return "Name fallback" - def request_method_fallback(self, request): + def request_method_fallback(self: Any, request: Request) -> str: return "Request method fallback" - def get_caching_key_lookup(r): + def get_caching_key_lookup(r: PredicateRegistry) -> LruCachingKeyLookup: return LruCachingKeyLookup(r, 100, 100, 100) @dispatch( @@ -302,27 +313,22 @@ def get_caching_key_lookup(r): match_key("request_method", get_request_method, request_method_fallback), get_key_lookup=get_caching_key_lookup, ) - def view(self, request): + def view(self: Any, request: Request) -> str: raise NotImplementedError() - def foo_default(self, request): + def foo_default(self: Foo, request: Request) -> str: return "foo default" - def foo_post(self, request): + def foo_post(self: Foo, request: Request) -> str: return "foo default post" - def foo_edit(self, request): + def foo_edit(self: Foo, request: Request) -> str: return "foo edit" register_value(view, (Foo, "", "GET"), foo_default) register_value(view, (Foo, "", "POST"), foo_post) register_value(view, (Foo, "edit", "POST"), foo_edit) - class Request: - def __init__(self, name, request_method): - self.name = name - self.request_method = request_method - assert view(Foo(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("edit", "POST")) == "foo edit" @@ -333,6 +339,7 @@ def __init__(self, name, request_method): ) # use a bit of inside knowledge to check the cache is filled + assert view.key_lookup.component.__closure__ is not None component_cache = view.key_lookup.component.__closure__[0].cell_contents assert component_cache.get(((Foo, "", "GET"),)) is not None assert component_cache.get(((FooSub, "", "GET"),)) is not None @@ -343,6 +350,7 @@ def __init__(self, name, request_method): assert view(FooSub(), Request("", "GET")) == "foo default" assert view(FooSub(), Request("edit", "POST")) == "foo edit" + assert view.key_lookup.all.__closure__ is not None all_cache = view.key_lookup.all.__closure__[0].cell_contents # prime and check the all cache assert view.by_args(Foo(), Request("", "GET")).all_matches == [foo_default] @@ -359,6 +367,7 @@ class Bar: assert view(FooSub(), Request("dummy", "GET")) == "Name fallback" # fallbacks get cached too + assert view.key_lookup.fallback.__closure__ is not None fallback_cache = view.key_lookup.fallback.__closure__[0].cell_contents assert fallback_cache.get(((Bar, "", "GET"),)) is model_fallback diff --git a/reg/types.py b/reg/types.py new file mode 100644 index 0000000..3f1efae --- /dev/null +++ b/reg/types.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + ParamSpec, + Protocol, + TypeAlias, + overload, +) + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Sequence + from inspect import FullArgSpec + from typing_extensions import TypeVar + from .dispatch import LookupEntry + from .predicate import Predicate, PredicateRegistry + + _ValueT = TypeVar("_ValueT", covariant=True, default=Callable[..., Any]) +else: + from typing import TypeVar + + _ValueT = TypeVar("_ValueT", covariant=True) + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_R = TypeVar("_R") +_P = ParamSpec("_P") + + +# NOTE: Dispatch.call serves as a proxy for most of the public interface +# of Dispatch. We model this using a protocol, since there isn't +# a proper class for this for this object. We duplicate the docs +# so language servers can give better hints. +class DispatchCall(Protocol[_P, _T]): + __name__: str + __qualname__: str + __defaults__: tuple[Any, ...] | None + __globals__: dict[str, Any] + wrapped_func: Callable[_P, _T] + get_key_lookup: GetKeyLookup + key_lookup: KeyLookup + + def clean(self) -> None: + """Clean up implementations and added predicates. + + This restores the dispatch function to its original state, + removing registered implementations and predicates added + using :meth:`reg.Dispatch.add_predicates`. + """ + raise NotImplementedError + + def add_predicates(self, predicates: list[Predicate]) -> None: + """Add new predicates. + + Extend the predicates used by this predicates. This can be + used to add predicates that are configured during startup time. + + Note that this clears up any registered implementations. + + :param predicates: a list of predicates to add. + """ + raise NotImplementedError + + @overload + def register(self, func: Callable[_P, _T], **key_dict: Any) -> Callable[_P, _T]: ... + @overload + def register( + self, func: None = None, **key_dict: Any + ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + raise NotImplementedError + + def register( + self, func: Callable[_P, _T] | None = None, **key_dict: Any + ) -> Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """Register an implementation. + + If ``func`` is not specified, this method can be used as a + decorator and the decorated function will be used as the + actual ``func`` argument. + + :param func: a function that implements behavior for this + dispatch function. It needs to have the same signature as + the original dispatch function. If this is a + :class:`reg.DispatchMethod`, then this means it needs to + take a first context argument. + :param key_dict: keyword arguments describing the registration, + with as keys predicate name and as values predicate values. + :returns: ``func``. + """ + raise NotImplementedError + + def by_args(self, *args: _P.args, **kw: _P.kwargs) -> LookupEntry[Callable[_P, _T]]: + """Lookup an implementation by invocation arguments. + + :param args: positional arguments used in invocation. + :param kw: named arguments used in invocation. + :returns: a :class:`reg.LookupEntry`. + """ + raise NotImplementedError + + def by_predicates(self, **predicate_values: Any) -> LookupEntry[Callable[_P, _T]]: + """Lookup an implementation by predicate values. + + :param predicate_values: the values of the predicates to lookup. + :returns: a :class:`reg.LookupEntry`. + """ + raise NotImplementedError + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: + raise NotImplementedError + + +class DispatchMethodCall(DispatchCall[Concatenate[_T, _P], _R], Protocol[_P, _T, _R]): + def by_args(self, *args: _P.args, **kw: _P.kwargs) -> LookupEntry[Callable[Concatenate[_T, _P], _R]]: # type: ignore[override] + """Lookup an implementation by invocation arguments. + + :param args: positional arguments used in invocation. + :param kw: named arguments used in invocation. + :returns: a :class:`reg.LookupEntry`. + """ + raise NotImplementedError + + +# NOTE: This is so we can expose the original DispatchMethod through __func__ +# and have the correct signature for `__call__`. +class BoundDispatchMethodCall(DispatchMethodCall[_P, _T, _R], Protocol[_P, _T, _R]): + @property + def __self__(self) -> _T: + raise NotImplementedError + + @property + def __func__(self) -> DispatchMethodCall[_P, _T, _R]: + raise NotImplementedError + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: # type: ignore[override] + raise NotImplementedError + + +# NOTE: arginfo serves as a proxy for the is_cached function, we duplicate +# the docs for better language server support. +class ArgInfo(Protocol): + __name__: str + __qualname__: str + _cache: dict[Callable[..., Any], FullArgSpec] + is_cached: Callable[[Callable[..., Any]], bool] + + def __call__(self, callable: Callable[..., Any]) -> FullArgSpec | None: + """Get information about the arguments of a callable. + + Returns a :class:`inspect.FullArgSpec` object as for + :func:`inspect.getfullargspec`. + + :func:`inspect.getfullargspec` returns information about the arguments + of a function. arginfo also works for classes and instances with a + __call__ defined. Unlike getfullargspec, arginfo treats bound methods + like functions, so that the self argument is not reported. Another + difference is the handling of decorated functions. This will return + the original signature, rather than the signature of the wrapper, if + wrapped via :func:`functools.wraps`. + + arginfo returns ``None`` if given something that is not callable. + + arginfo caches previous calls (except for instances with a + __call__), making calling it repeatedly cheap. + + This was originally inspired by the pytest.core varnames() function, + but has been completely rewritten to handle class constructors, + also show other getarginfo() information, and for readability. + """ + + +class KeyLookup(Protocol[_ValueT]): + def component(self, key: Sequence[Any], /) -> _ValueT | None: + raise NotImplementedError + + def fallback(self, key: Sequence[Any], /) -> _ValueT | None: + raise NotImplementedError + + def all(self, key: Sequence[Any], /) -> Iterable[_ValueT]: + raise NotImplementedError + + +GetKeyLookup: TypeAlias = Callable[[PredicateRegistry[_ValueT]], KeyLookup[_ValueT]]