Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
fail_fast: false
default_language_version:
python: python3
python: python3.14
default_stages:
- pre-commit
- pre-push
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.1.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [0.0.7]

### Added

- A `reset` method for `Settings` to reset settings to their default values.

## [0.0.6]

### Added
Expand Down Expand Up @@ -55,6 +61,7 @@ and this project adheres to [Semantic Versioning][].

- Initial release

[0.0.7]: https://github.com/scverse/scverse-misc/releases/tag/v0.0.7
[0.0.6]: https://github.com/scverse/scverse-misc/releases/tag/v0.0.6
[0.0.5]: https://github.com/scverse/scverse-misc/releases/tag/v0.0.5
[0.0.4]: https://github.com/scverse/scverse-misc/releases/tag/v0.0.4
Expand Down
1 change: 1 addition & 0 deletions docs/api/settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ scverse\_misc.Settings
.. autoclass:: Settings

.. automethod:: override
.. automethod:: reset
60 changes: 56 additions & 4 deletions src/scverse_misc/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager
from types import FunctionType, GenericAlias
from typing import Literal, Self
from typing import Literal, LiteralString, Self

import dotenv
from pydantic.fields import FieldInfo
Expand Down Expand Up @@ -112,20 +112,45 @@ def __init_subclass__(subcls, *, exported_object_name: str, docstring_style: Lit
super().__init_subclass__()

@contextmanager
def override(self, **kwargs: object) -> Generator[None]:
def override(self, **overrides: object) -> Generator[None]:
"""Context manager for local setting overrides.

Subclasses will get a version with a docstring detailing the available parameters.
"""
oldsettings = {argname: getattr(self, argname) for argname in kwargs.keys()}
oldsettings = {argname: getattr(self, argname) for argname in overrides.keys()}
try:
for argname, argval in kwargs.items():
for argname, argval in overrides.items():
setattr(self, argname, argval)
yield
finally:
for argname, argval in reversed(oldsettings.items()):
setattr(self, argname, argval)

def reset(self, *names: LiteralString) -> AbstractContextManager[frozenset[LiteralString]]:
"""Reset passed settings to their default values.

Can be used as a context manager to make the resets temporary.
On `__enter__`, the context manager returns the settings that have been changed.
"""
prev_values = {name: getattr(self, name) for name in names if name in self.model_fields_set}

# since we want to allow using this method imperatively,
# eagerly do the reset here instead of returning a context manager with a lazy `__enter__`.
for name in prev_values:
default = type(self).model_fields[name].get_default()
setattr(self, name, default)
self.model_fields_set.remove(name)

class Cm(AbstractContextManager[frozenset[str]]):
def __enter__(_self) -> frozenset[str]:
return frozenset(prev_values)

def __exit__(_self, *_: object) -> None:
for arg, value in prev_values.items():
setattr(self, arg, value)

return Cm()

@classmethod
def __pydantic_init_subclass__( # type: ignore[override]
subcls: type[Self], *, exported_object_name: str, docstring_style: Literal["google", "numpy", "scverse"]
Expand Down Expand Up @@ -168,6 +193,7 @@ def __pydantic_init_subclass__( # type: ignore[override]
subcls.override = _copy_override( # type: ignore[method-assign,type-var]
subcls, subcls.override, override_doc, return_annotation=AbstractContextManager[None]
)
subcls.reset = _copy_reset(subcls, subcls.reset) # type: ignore[method-assign,type-var]


class CustomRepr(str):
Expand Down Expand Up @@ -206,3 +232,29 @@ def _copy_override[F: FunctionType](cls: type[Settings], func: F, doc: str, retu
)

return copy_func(func, **overrides)


def _copy_reset[F: FunctionType](cls: type[Settings], func: F) -> F:
from ._utils import Overrides

args_t = Literal[tuple(cls.model_fields.keys())] # type: ignore[valid-type]
parameters = [
inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY),
inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL, annotation=args_t),
]
return_annotation = AbstractContextManager[frozenset[args_t]] # type: ignore[valid-type]
overrides = Overrides(
__module__=cls.__module__,
__qualname__=f"{cls.__qualname__}.{func.__name__}",
__signature__=inspect.Signature(parameters, return_annotation=return_annotation),
__annotations__={"args": args_t, "return": return_annotation},
)
if sys.version_info >= (3, 14):
from annotationlib import Format

str_annotations = {n: str(t) for n, t in overrides["__annotations__"].items()}
overrides["__annotate__"] = lambda fmt: (
overrides["__annotations__"] if fmt != Format.STRING else str_annotations
)

return copy_func(func, **overrides)
33 changes: 32 additions & 1 deletion tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import inspect
import sys
from contextlib import nullcontext
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Literal, cast
from typing import TYPE_CHECKING, Annotated, Literal, cast, get_args

import pytest
from pydantic import Field, ValidationError
Expand Down Expand Up @@ -87,13 +89,42 @@ def test_override(settings: DummySettings) -> None:
assert settings.field_bool is True
assert settings.field_bool is False


def test_override_error(settings: DummySettings) -> None:
with pytest.raises(ValidationError):
with settings.override(field_int_range=3, field_no_docstring=1.1):
pass
assert settings.field_no_docstring == 42
assert settings.field_int_range == 1


@pytest.mark.parametrize("temp", [True, False], ids=["temporary", "permanent"])
def test_reset(settings: DummySettings, temp: bool) -> None:
default = settings.field_bool
settings.field_bool = not default
undo_reset = settings.reset("field_bool")
with undo_reset if temp else nullcontext():
assert settings.field_bool is default
assert settings.field_bool is (not default if temp else default)


def test_reset_signature(settings: DummySettings) -> None:
sig = inspect.signature(settings.reset)
assert get_args(sig.parameters["args"].annotation) == ("field_bool", "field_no_docstring", "field_int_range")


@pytest.mark.skipif(sys.version_info < (3, 14), reason="requires annotationlib")
def test_reset_annotations(settings: DummySettings) -> None:
from contextlib import AbstractContextManager

import annotationlib

assert annotationlib.get_annotations(settings.reset) == {
"args": Literal["field_bool", "field_no_docstring", "field_int_range"],
"return": AbstractContextManager[frozenset[Literal["field_bool", "field_no_docstring", "field_int_range"]]],
}


@pytest.mark.parametrize("docstring_style", ["google", "numpy", "scverse"], indirect=True)
def test_docs(docstring_style: Literal["google", "numpy"], settings: DummySettings) -> None:
parser = GoogleDocstring if docstring_style == "google" else NumpyDocstring
Expand Down
Loading