Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ python:
install:
- pip install coverage
- pip install --upgrade pytest pytest-benchmark
- pip install pytypes

script:
- |
Expand Down
6 changes: 4 additions & 2 deletions multipledispatch/conflict.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from .utils import _toposort, groupby
from pytypes import is_subtype


class AmbiguityWarning(Warning):
pass


def supercedes(a, b):
""" A is consistent and strictly more specific than B """
return len(a) == len(b) and all(map(issubclass, a, b))
return len(a) == len(b) and all(map(is_subtype, a, b))


def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """
return (len(a) == len(b) and
all(issubclass(aa, bb) or issubclass(bb, aa)
all(is_subtype(aa, bb) or is_subtype(bb, aa)
for aa, bb in zip(a, b)))


Expand Down
20 changes: 13 additions & 7 deletions multipledispatch/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import inspect
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
from .utils import expand_tuples
import itertools as itl

import itertools as itl
import pytypes
import typing

class MDNotImplementedError(NotImplementedError):
""" A NotImplementedError for multiple dispatch """
Expand Down Expand Up @@ -111,7 +113,7 @@ def get_func_params(cls, func):

@classmethod
def get_func_annotations(cls, func):
""" get annotations of function positional paremeters
""" get annotations of function positional parameters
"""
params = cls.get_func_params(func)
if params:
Expand All @@ -135,13 +137,17 @@ def add(self, signature, func, on_ambiguity=ambiguity_warn):
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D.add((typing.Optional[str], ), lambda x: x)

>>> D(1, 2)
3
>>> D(1, 2.0)
>>> D('1', 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
NotImplementedError: Could not find signature for add: <str, float>
>>> D('s')
's'
>>> D(None)

When ``add`` detects a warning it calls the ``on_ambiguity`` callback
with a dispatcher/itself, and a set of ambiguous type signature pairs
Expand All @@ -154,7 +160,7 @@ def add(self, signature, func, on_ambiguity=ambiguity_warn):
signature = annotations

# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
if any(isinstance(typ, tuple) or pytypes.is_Union(typ) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func, on_ambiguity)
return
Expand Down Expand Up @@ -182,7 +188,7 @@ def reorder(self, on_ambiguity=ambiguity_warn):
_unresolved_dispatchers.add(self)

def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
types = tuple([pytypes.deep_type(arg) for arg in args])
try:
func = self._cache[types]
except KeyError:
Expand Down Expand Up @@ -244,7 +250,7 @@ def dispatch(self, *types):
def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if len(signature) == n and all(map(issubclass, types, signature)):
if len(signature) == n and all(map(pytypes.is_subtype, types, signature)):
result = self.funcs[signature]
yield result

Expand Down
16 changes: 16 additions & 0 deletions multipledispatch/tests/test_dispatcher_3only.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from multipledispatch import dispatch
from multipledispatch.dispatcher import Dispatcher
import typing


def test_function_annotation_register():
Expand All @@ -30,8 +31,23 @@ def inc(x: int):
def inc(x: float):
return x - 1

@dispatch()
def inc(x: typing.Optional[str]):
return x

@dispatch()
def inc(x: typing.List[int]):
return x[0] * 4

@dispatch()
def inc(x: typing.List[str]):
return x[0] + 'b'

assert inc(1) == 2
assert inc(1.0) == 0.0
assert inc('a') == 'a'
assert inc([8]) == 32
assert inc(['a']) == 'ab'


def test_function_annotation_dispatch_custom_namespace():
Expand Down
22 changes: 17 additions & 5 deletions multipledispatch/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@

import pytypes
import typing


def raises(err, lamda):
try:
lamda()
Expand All @@ -14,15 +19,22 @@ def expand_tuples(L):

>>> expand_tuples([1, 2])
[(1, 2)]

>>> expand_tuples([1, typing.Optional[str]]) #doctest: +ELLIPSIS
[(1, <... 'str'>), (1, <... 'NoneType'>)]
"""
if not L:
return [()]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]
if pytypes.is_Union(L[0]):
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in pytypes.get_Union_params(L[0])]
elif not pytypes.is_of_type(L[0], tuple):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somehow this doesn't actually get hit when L[0] is not a tuple:

(Pdb) p L[0]
<type 'numpy.dtype'>
(Pdb) p pytypes.is_of_type(L[0], tuple)
True

This breaks importing datashape.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The root cause of this is actually more concerning than this one bug. Why doesn't pytypes.is_of_type work correctly for type objects?

rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]


# Taken from theano/theano/gof/sched.py
Expand Down