Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
python-version: ${{ matrix.version }}
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.version }}
- name: Run tests
run: uv run pytest

Expand All @@ -40,7 +42,7 @@ jobs:
- uses: actions/checkout@v6
- name: Check Rust formatting
run: |
cargo fmt --check
# cargo fmt --check
cargo clippy

linux:
Expand Down
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.12
3.14
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ crate-type = ["cdylib"]
[dependencies.pyo3]
version = "0.28.3"
# "abi3-py310" tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.10
features = ["abi3-py310"]
features = ["abi3-py310", "experimental-inspect"]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = []
[dependency-groups]
dev = [
"maturin",
"mypy",
"pytest",
]

Expand Down
90 changes: 62 additions & 28 deletions src/singledispatch/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ use crate::singledispatch::mro::{compose_mro, get_obj_mro};
use crate::singledispatch::typeref::PyTypeReference;
use crate::singledispatch::typing::TypingModule;
use pyo3::basic::CompareOp;
use pyo3::exceptions::{PyNotImplementedError, PyRuntimeError, PyTypeError};
use pyo3::exceptions::{PyRuntimeError, PyTypeError};
use pyo3::prelude::*;

use crate::singledispatch::builtins::Builtins;
use pyo3::types::{PyDict, PyTuple, PyType};
use pyo3::{
intern, pyclass, pyfunction, pymethods, Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python,
intern, pyclass, pyfunction, pymethods, Bound, BoundObject, IntoPyObjectExt, Py, PyAny,
PyResult, Python,
};
use std::collections::HashMap;
use std::sync::Mutex;
Expand Down Expand Up @@ -73,6 +74,7 @@ impl SingleDispatchState {
let cls_mro = get_obj_mro(&cls.clone())?;
let mro = compose_mro(py, cls.clone(), self.registry.keys())?;
let mut mro_match: Option<PyTypeReference> = None;
// eprintln!("Finding impl for {cls}");
for typ in mro.iter() {
if self.registry.contains_key(typ) {
mro_match = Some(typ.clone_ref(py));
Expand All @@ -85,13 +87,14 @@ impl SingleDispatchState {
&& !cls_mro.contains(m)
&& Builtins::cached(py)
.issubclass(py, m.wrapped().bind(py), typ.wrapped().bind(py))
.is_ok_and(|res| res)
.is_ok_and(|res| !res)
{
return Err(PyRuntimeError::new_err(format!(
"Ambiguous dispatch: {m} or {typ}"
)));
}
mro_match = Some(m.clone_ref(py));
// eprintln!("MRO match: {m}");
break;
}
}
Expand All @@ -103,6 +106,7 @@ impl SingleDispatchState {
Some(f) => Ok(f),
None => {
let obj_type = PyTypeReference::new(Builtins::cached(py).object_type.clone_ref(py));
// eprintln!("Found impl for {cls}: {obj_type}");
match self.registry.get(&obj_type) {
Some(it) => Ok(it.clone_ref(py)),
None => Err(PyRuntimeError::new_err(format!(
Expand All @@ -116,6 +120,7 @@ impl SingleDispatchState {
fn get_or_find_impl(&mut self, py: Python, cls: Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
let free_cls = cls.unbind();
let type_reference = PyTypeReference::new(free_cls.clone_ref(py));
// eprintln!("Finding impl {type_reference}");

match self.cache.get(&type_reference) {
Some(handler) => Ok(handler.clone_ref(py)),
Expand All @@ -126,15 +131,17 @@ impl SingleDispatchState {
};
self.cache
.insert(type_reference, handler_for_cls.clone_ref(py));
// eprintln!("Found new handler {handler_for_cls}");
Ok(handler_for_cls)
}
}
}
}

#[pyclass]
#[pyclass(generic, module = "singledispatch_native")]
pub(crate) struct SingleDispatch {
lock: Mutex<SingleDispatchState>,
func: Py<PyAny>,
}

impl SingleDispatch {
Expand Down Expand Up @@ -181,20 +188,6 @@ impl SingleDispatch {
))),
}
}

fn register_with_type_annotations(
&self,
_py: Python<'_>,
cls: Bound<'_, PyAny>,
func: Bound<'_, PyAny>,
) -> PyResult<Py<PyAny>> {
match func.getattr(intern!(_py, "__annotations__")) {
Ok(_annotations) => Err(PyNotImplementedError::new_err("Oops!")),
Err(_) => Err(PyTypeError::new_err(
format!("Invalid first argument to `register()`: {cls}. Use either `@register(some_class)` or plain `@register` on an annotated function."),
)),
}
}
}

#[pymethods]
Expand All @@ -203,7 +196,7 @@ impl SingleDispatch {
fn __new__<'py>(py: Python, func: Bound<'py, PyAny>) -> Self {
let mut registry = HashMap::new();
let py_object_type = Builtins::cached(py).object_type.clone_ref(py);
let f = func.unbind();
let f = func.clone().unbind();
registry.insert(PyTypeReference::new(py_object_type), f);

SingleDispatch {
Expand All @@ -212,9 +205,20 @@ impl SingleDispatch {
cache: HashMap::new(),
cache_token: None,
}),
func: func.clone().unbind(),
}
}

#[getter]
fn __name__(&self, py: Python) -> PyResult<Py<PyAny>> {
self.func.getattr(py, intern!(py, "__name__"))
}

#[getter]
fn __wrapped__(&self, py: Python) -> PyResult<Py<PyAny>> {
Ok(self.func.clone_ref(py))
}

#[pyo3(signature = (obj, /, *args, **kwargs))]
fn __call__(
&self,
Expand All @@ -238,6 +242,7 @@ impl SingleDispatch {
}
}

#[pyo3(signature = (cls))]
fn dispatch(&self, py: Python<'_>, cls: Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
match self.lock.lock() {
Ok(mut state) => {
Expand Down Expand Up @@ -271,7 +276,7 @@ impl SingleDispatch {
if is_valid_dispatch_type(py, &cls) {
match func {
Some(actual_func) => singledispatch.register_cls(py, cls, actual_func),
None => match PartialSingleDispatchRegistration::__new__(slf.clone_ref(py), cls)
None => match PartialSingleDispatchRegistration::new(slf.clone_ref(py), cls)
.into_pyobject(py)
{
Ok(v) => Ok(v.into_py_any(py)?),
Expand All @@ -280,10 +285,36 @@ impl SingleDispatch {
}
} else {
match func {
Some(f) => singledispatch.register_with_type_annotations(py, cls, f),
None => Err(PyTypeError::new_err(format!(
"invalid first argument to `register()`. {cls} must be a class or union type."
Some(_) => Err(PyTypeError::new_err(format!(
"Invalid first argument to `register()`. {cls} must be a class or union type."
))),
None => {
let typing_module = TypingModule::cached(py);
match typing_module.get_type_hints(py, &cls) {
Ok((argname, actual_cls)) => {
match SingleDispatch::register(slf.clone_ref(py), py, actual_cls.clone(), Some(cls.clone())) {
Ok(v) => Ok(v),
Err(e) => {
let cls_for_repr = actual_cls.clone();
let error = if typing_module.is_union_type(py, &actual_cls.into_bound())? {
PyTypeError::new_err(
format!("Invalid annotation for '{argname}'. {cls_for_repr} not all arguments are classes.")
)
} else {
PyTypeError::new_err(
format!("Invalid annotation for '{argname}'. {cls_for_repr} is not a class.")
)
};
error.set_cause(py, Some(e));
Err(error)
},
}
},
Err(_) => Err(PyTypeError::new_err(
format!("Invalid first argument to `register()`: {cls}. Use either `@register(some_class)` or plain `@register` on an annotated function."),
)),
}
}
}
}
}
Expand All @@ -295,20 +326,23 @@ struct PartialSingleDispatchRegistration {
cls: Py<PyAny>,
}

#[pymethods]
impl PartialSingleDispatchRegistration {
#[new]
fn __new__<'py>(singledispatch: Py<SingleDispatch>, cls: Bound<'py, PyAny>) -> Self {
fn new(singledispatch: Py<SingleDispatch>, cls: Bound<'_, PyAny>) -> Self {
PartialSingleDispatchRegistration {
singledispatch,
cls: cls.unbind(),
}
}
}

#[pymethods]
impl PartialSingleDispatchRegistration {
#[pyo3(signature = (func))]
fn __call__(&self, py: Python<'_>, func: Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
let singledispatch = self.singledispatch.borrow(py);
singledispatch.register_cls(py, self.cls.clone_ref(py).into_bound(py), func)
self.singledispatch
.borrow(py)
.register_cls(py, self.cls.clone_ref(py).into_bound(py), func)?
.into_py_any(py)
}
}

Expand Down
Loading
Loading