diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d764382..07dc5fe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ fail_fast: false default_language_version: - python: python3 + python: python3.14 default_stages: - pre-commit - pre-push diff --git a/CHANGELOG.md b/CHANGELOG.md index 94b6758..824decc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.1.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +## [0.0.7] + +### Added + +- A `reset` method for `Settings` to reset settings to their default values. + ## [0.0.6] ### Added @@ -55,6 +61,7 @@ and this project adheres to [Semantic Versioning][]. - Initial release +[0.0.7]: https://github.com/scverse/scverse-misc/releases/tag/v0.0.7 [0.0.6]: https://github.com/scverse/scverse-misc/releases/tag/v0.0.6 [0.0.5]: https://github.com/scverse/scverse-misc/releases/tag/v0.0.5 [0.0.4]: https://github.com/scverse/scverse-misc/releases/tag/v0.0.4 diff --git a/docs/api/settings.rst b/docs/api/settings.rst index d667d04..b2dfc57 100644 --- a/docs/api/settings.rst +++ b/docs/api/settings.rst @@ -6,3 +6,4 @@ scverse\_misc.Settings .. autoclass:: Settings .. automethod:: override + .. automethod:: reset diff --git a/src/scverse_misc/_settings.py b/src/scverse_misc/_settings.py index c9c94b2..dc67d54 100644 --- a/src/scverse_misc/_settings.py +++ b/src/scverse_misc/_settings.py @@ -7,7 +7,7 @@ from collections.abc import Generator from contextlib import AbstractContextManager, contextmanager from types import FunctionType, GenericAlias -from typing import Literal, Self +from typing import Literal, LiteralString, Self import dotenv from pydantic.fields import FieldInfo @@ -112,20 +112,45 @@ def __init_subclass__(subcls, *, exported_object_name: str, docstring_style: Lit super().__init_subclass__() @contextmanager - def override(self, **kwargs: object) -> Generator[None]: + def override(self, **overrides: object) -> Generator[None]: """Context manager for local setting overrides. Subclasses will get a version with a docstring detailing the available parameters. """ - oldsettings = {argname: getattr(self, argname) for argname in kwargs.keys()} + oldsettings = {argname: getattr(self, argname) for argname in overrides.keys()} try: - for argname, argval in kwargs.items(): + for argname, argval in overrides.items(): setattr(self, argname, argval) yield finally: for argname, argval in reversed(oldsettings.items()): setattr(self, argname, argval) + def reset(self, *names: LiteralString) -> AbstractContextManager[frozenset[LiteralString]]: + """Reset passed settings to their default values. + + Can be used as a context manager to make the resets temporary. + On `__enter__`, the context manager returns the settings that have been changed. + """ + prev_values = {name: getattr(self, name) for name in names if name in self.model_fields_set} + + # since we want to allow using this method imperatively, + # eagerly do the reset here instead of returning a context manager with a lazy `__enter__`. + for name in prev_values: + default = type(self).model_fields[name].get_default() + setattr(self, name, default) + self.model_fields_set.remove(name) + + class Cm(AbstractContextManager[frozenset[str]]): + def __enter__(_self) -> frozenset[str]: + return frozenset(prev_values) + + def __exit__(_self, *_: object) -> None: + for arg, value in prev_values.items(): + setattr(self, arg, value) + + return Cm() + @classmethod def __pydantic_init_subclass__( # type: ignore[override] subcls: type[Self], *, exported_object_name: str, docstring_style: Literal["google", "numpy", "scverse"] @@ -168,6 +193,7 @@ def __pydantic_init_subclass__( # type: ignore[override] subcls.override = _copy_override( # type: ignore[method-assign,type-var] subcls, subcls.override, override_doc, return_annotation=AbstractContextManager[None] ) + subcls.reset = _copy_reset(subcls, subcls.reset) # type: ignore[method-assign,type-var] class CustomRepr(str): @@ -206,3 +232,29 @@ def _copy_override[F: FunctionType](cls: type[Settings], func: F, doc: str, retu ) return copy_func(func, **overrides) + + +def _copy_reset[F: FunctionType](cls: type[Settings], func: F) -> F: + from ._utils import Overrides + + args_t = Literal[tuple(cls.model_fields.keys())] # type: ignore[valid-type] + parameters = [ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY), + inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL, annotation=args_t), + ] + return_annotation = AbstractContextManager[frozenset[args_t]] # type: ignore[valid-type] + overrides = Overrides( + __module__=cls.__module__, + __qualname__=f"{cls.__qualname__}.{func.__name__}", + __signature__=inspect.Signature(parameters, return_annotation=return_annotation), + __annotations__={"args": args_t, "return": return_annotation}, + ) + if sys.version_info >= (3, 14): + from annotationlib import Format + + str_annotations = {n: str(t) for n, t in overrides["__annotations__"].items()} + overrides["__annotate__"] = lambda fmt: ( + overrides["__annotations__"] if fmt != Format.STRING else str_annotations + ) + + return copy_func(func, **overrides) diff --git a/tests/test_settings.py b/tests/test_settings.py index 197caf0..e92bc60 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,8 +1,10 @@ from __future__ import annotations import inspect +import sys +from contextlib import nullcontext from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Literal, cast +from typing import TYPE_CHECKING, Annotated, Literal, cast, get_args import pytest from pydantic import Field, ValidationError @@ -87,6 +89,8 @@ def test_override(settings: DummySettings) -> None: assert settings.field_bool is True assert settings.field_bool is False + +def test_override_error(settings: DummySettings) -> None: with pytest.raises(ValidationError): with settings.override(field_int_range=3, field_no_docstring=1.1): pass @@ -94,6 +98,33 @@ def test_override(settings: DummySettings) -> None: assert settings.field_int_range == 1 +@pytest.mark.parametrize("temp", [True, False], ids=["temporary", "permanent"]) +def test_reset(settings: DummySettings, temp: bool) -> None: + default = settings.field_bool + settings.field_bool = not default + undo_reset = settings.reset("field_bool") + with undo_reset if temp else nullcontext(): + assert settings.field_bool is default + assert settings.field_bool is (not default if temp else default) + + +def test_reset_signature(settings: DummySettings) -> None: + sig = inspect.signature(settings.reset) + assert get_args(sig.parameters["args"].annotation) == ("field_bool", "field_no_docstring", "field_int_range") + + +@pytest.mark.skipif(sys.version_info < (3, 14), reason="requires annotationlib") +def test_reset_annotations(settings: DummySettings) -> None: + from contextlib import AbstractContextManager + + import annotationlib + + assert annotationlib.get_annotations(settings.reset) == { + "args": Literal["field_bool", "field_no_docstring", "field_int_range"], + "return": AbstractContextManager[frozenset[Literal["field_bool", "field_no_docstring", "field_int_range"]]], + } + + @pytest.mark.parametrize("docstring_style", ["google", "numpy", "scverse"], indirect=True) def test_docs(docstring_style: Literal["google", "numpy"], settings: DummySettings) -> None: parser = GoogleDocstring if docstring_style == "google" else NumpyDocstring