Skip to content

Commit 6882ed3

Browse files
Merge pull request #187 from maxfischer2781/typefix/neutralcall
Fix async neutral type call type hints
2 parents 846a206 + 6c1255f commit 6882ed3

8 files changed

Lines changed: 91 additions & 18 deletions

File tree

asyncstdlib/_typing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,19 @@
5555
#: Hashable Key
5656
HK = TypeVar("HK", bound=Hashable)
5757

58+
59+
# bool(...)
60+
class SupportsBool(Protocol):
61+
def __bool__(self) -> bool:
62+
raise NotImplementedError
63+
64+
5865
# LT < LT
5966
LT = TypeVar("LT", bound="SupportsLT")
6067

6168

6269
class SupportsLT(Protocol):
63-
def __lt__(self: LT, other: LT) -> bool:
70+
def __lt__(self, __other: Any) -> SupportsBool:
6471
raise NotImplementedError
6572

6673

@@ -69,7 +76,7 @@ def __lt__(self: LT, other: LT) -> bool:
6976

7077

7178
class SupportsAdd(Protocol):
72-
def __add__(self: ADD, other: ADD, /) -> ADD:
79+
def __add__(self, __other: Any, /) -> Any:
7380
raise NotImplementedError
7481

7582

asyncstdlib/builtins.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ async def anext(
5555

5656

5757
def iter(
58-
subject: Union[AnyIterable[T], Callable[[], Awaitable[T]]],
58+
subject: Union[AnyIterable[T], Callable[[], Awaitable[T]], Callable[[], T]],
5959
sentinel: Union[Sentinel, T] = __ITER_DEFAULT,
6060
) -> AsyncIterator[T]:
6161
"""
@@ -84,13 +84,12 @@ def iter(
8484
raise TypeError("iter(v, w): v must be callable")
8585
else:
8686
assert not isinstance(sentinel, Sentinel)
87-
return acallable_iterator(subject, sentinel)
87+
return acallable_iterator(_awaitify(subject), sentinel)
8888

8989

9090
async def acallable_iterator(
9191
subject: Callable[[], Awaitable[T]], sentinel: T
9292
) -> AsyncIterator[T]:
93-
subject = _awaitify(subject)
9493
value = await subject()
9594
while value != sentinel:
9695
yield value
@@ -306,7 +305,7 @@ async def _min_max(
306305
raise ValueError(f"{name}() arg is an empty sequence")
307306
elif key is None:
308307
async for item in item_iter:
309-
if invert ^ (item < best):
308+
if invert ^ bool(item < best):
310309
best = item
311310
else:
312311
key = _awaitify(key)

asyncstdlib/builtins.pyi

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ from typing import Any, AsyncIterator, Awaitable, Callable, overload
22
from typing_extensions import TypeGuard
33
import builtins
44

5-
from ._typing import ADD, AnyIterable, HK, LT, R, T, T1, T2, T3, T4, T5
5+
from ._typing import ADD, AnyIterable, HK, LT, R, T, T1, T2, T3, T4, T5, SupportsLT
66

77
@overload
88
async def anext(iterator: AsyncIterator[T]) -> T: ...
@@ -16,6 +16,10 @@ def iter(
1616
) -> AsyncIterator[T]: ...
1717
@overload
1818
def iter(subject: Callable[[], Awaitable[T]], sentinel: T) -> AsyncIterator[T]: ...
19+
@overload
20+
def iter(subject: Callable[[], T | None], sentinel: None) -> AsyncIterator[T]: ...
21+
@overload
22+
def iter(subject: Callable[[], T], sentinel: T) -> AsyncIterator[T]: ...
1923
async def all(iterable: AnyIterable[Any]) -> bool: ...
2024
async def any(iterable: AnyIterable[Any]) -> bool: ...
2125
@overload
@@ -180,20 +184,42 @@ async def max(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...
180184
@overload
181185
async def max(iterable: AnyIterable[LT], *, key: None = ..., default: T) -> LT | T: ...
182186
@overload
183-
async def max(iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ...) -> T1: ...
187+
async def max(
188+
iterable: AnyIterable[T1], *, key: Callable[[T1], Awaitable[SupportsLT]]
189+
) -> T1: ...
190+
@overload
191+
async def max(
192+
iterable: AnyIterable[T1],
193+
*,
194+
key: Callable[[T1], Awaitable[SupportsLT]],
195+
default: T2,
196+
) -> T1 | T2: ...
197+
@overload
198+
async def max(iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT]) -> T1: ...
184199
@overload
185200
async def max(
186-
iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ..., default: T2
201+
iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT], default: T2
187202
) -> T1 | T2: ...
188203
@overload
189204
async def min(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...
190205
@overload
191206
async def min(iterable: AnyIterable[LT], *, key: None = ..., default: T) -> LT | T: ...
192207
@overload
193-
async def min(iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ...) -> T1: ...
208+
async def min(
209+
iterable: AnyIterable[T1], *, key: Callable[[T1], Awaitable[SupportsLT]]
210+
) -> T1: ...
194211
@overload
195212
async def min(
196-
iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ..., default: T2
213+
iterable: AnyIterable[T1],
214+
*,
215+
key: Callable[[T1], Awaitable[SupportsLT]],
216+
default: T2,
217+
) -> T1 | T2: ...
218+
@overload
219+
async def min(iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT]) -> T1: ...
220+
@overload
221+
async def min(
222+
iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT], default: T2
197223
) -> T1 | T2: ...
198224
@overload
199225
def filter(
@@ -247,5 +273,12 @@ async def sorted(
247273
) -> builtins.list[LT]: ...
248274
@overload
249275
async def sorted(
250-
iterable: AnyIterable[T], *, key: Callable[[T], LT], reverse: bool = ...
276+
iterable: AnyIterable[T],
277+
*,
278+
key: Callable[[T], Awaitable[SupportsLT]],
279+
reverse: bool = ...,
280+
) -> builtins.list[T]: ...
281+
@overload
282+
async def sorted(
283+
iterable: AnyIterable[T], *, key: Callable[[T], SupportsLT], reverse: bool = ...
251284
) -> builtins.list[T]: ...

asyncstdlib/functools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ async def data(self):
257257
if iscoroutinefunction(type_or_getter):
258258
return CachedProperty(type_or_getter)
259259
elif isinstance(type_or_getter, type) and issubclass(
260-
type_or_getter, AsyncContextManager
260+
type_or_getter, AsyncContextManager # pyright: ignore[reportGeneralTypeIssues]
261261
):
262262

263263
def decorator(

asyncstdlib/functools.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ def cached_property(
3333
asynccontextmanager_type: type[AsyncContextManager[Any]], /
3434
) -> Callable[[Callable[[T], Awaitable[R]]], CachedProperty[T, R]]: ...
3535
@overload
36+
async def reduce(
37+
function: Callable[[T1, T2], Awaitable[T1]], iterable: AnyIterable[T2], initial: T1
38+
) -> T1: ...
39+
@overload
40+
async def reduce(
41+
function: Callable[[T, T], Awaitable[T]], iterable: AnyIterable[T]
42+
) -> T: ...
43+
@overload
3644
async def reduce(
3745
function: Callable[[T1, T2], T1], iterable: AnyIterable[T2], initial: T1
3846
) -> T1: ...

asyncstdlib/heapq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async def pull_head(self) -> bool:
9292
return True
9393

9494
def __lt__(self, other: _KeyIter[LT]) -> bool:
95-
return self.reverse ^ (self.head_key < other.head_key)
95+
return self.reverse ^ bool(self.head_key < other.head_key)
9696

9797
def __eq__(self, other: _KeyIter[LT]) -> bool: # type: ignore[override]
9898
return not (self.head_key < other.head_key or other.head_key < self.head_key)
@@ -161,7 +161,7 @@ def __init__(self, key: LT):
161161
self.key = key
162162

163163
def __lt__(self, other: ReverseLT[LT]) -> bool:
164-
return other.key < self.key
164+
return bool(other.key < self.key)
165165

166166

167167
# Python's heapq provides a *min*-heap

typetests/test_builtins.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import TypeVar
2+
from asyncstdlib import builtins
3+
4+
T = TypeVar("T")
5+
6+
7+
def identity(v: T) -> T:
8+
return v
9+
10+
11+
async def async_identity(v: T) -> T:
12+
return v
13+
14+
15+
async def test_min_asyncneutral() -> None:
16+
await builtins.min([1, 2, 3], key=identity)
17+
await builtins.min([1, 2, 3], key=async_identity)

typetests/test_functools.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from asyncstdlib import lru_cache
1+
from asyncstdlib import functools
22

33

4-
@lru_cache()
4+
@functools.lru_cache()
55
async def lru_function(a: int) -> int:
66
return a
77

@@ -16,7 +16,7 @@ class TestLRUMethod:
1616
Test that `lru_cache` works on methods
1717
"""
1818

19-
@lru_cache()
19+
@functools.lru_cache()
2020
async def cached(self, a: int = 0) -> int:
2121
return a
2222

@@ -26,3 +26,12 @@ async def test_implicit_self(self) -> int:
2626
async def test_method_parameters(self) -> int:
2727
await self.cached("wrong parameter type") # type: ignore[arg-type]
2828
return await self.cached(12)
29+
30+
31+
async def aadd(a: int, b: int) -> int:
32+
return a + b
33+
34+
35+
async def test_reduce() -> None:
36+
await functools.reduce(aadd, [1, 2, 3, 4])
37+
await functools.reduce(aadd, [1, 2, 3, 4], initial=1)

0 commit comments

Comments
 (0)