Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions asyncstdlib/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,19 @@
#: Hashable Key
HK = TypeVar("HK", bound=Hashable)


# bool(...)
class SupportsBool(Protocol):
def __bool__(self) -> bool:
raise NotImplementedError


# LT < LT
LT = TypeVar("LT", bound="SupportsLT")


class SupportsLT(Protocol):

Check warning

Code scanning / CodeQL

Incomplete ordering Warning

This class implements
__lt__
, but does not implement __le__ or __ge__.
def __lt__(self: LT, other: LT) -> bool:
def __lt__(self, __other: Any) -> SupportsBool:
raise NotImplementedError


Expand All @@ -69,7 +76,7 @@


class SupportsAdd(Protocol):
def __add__(self: ADD, other: ADD, /) -> ADD:
def __add__(self, __other: Any, /) -> Any:
raise NotImplementedError


Expand Down
7 changes: 3 additions & 4 deletions asyncstdlib/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def anext(


def iter(
subject: Union[AnyIterable[T], Callable[[], Awaitable[T]]],
subject: Union[AnyIterable[T], Callable[[], Awaitable[T]], Callable[[], T]],
sentinel: Union[Sentinel, T] = __ITER_DEFAULT,
) -> AsyncIterator[T]:
"""
Expand Down Expand Up @@ -84,13 +84,12 @@ def iter(
raise TypeError("iter(v, w): v must be callable")
else:
assert not isinstance(sentinel, Sentinel)
return acallable_iterator(subject, sentinel)
return acallable_iterator(_awaitify(subject), sentinel)


async def acallable_iterator(
subject: Callable[[], Awaitable[T]], sentinel: T
) -> AsyncIterator[T]:
subject = _awaitify(subject)
value = await subject()
while value != sentinel:
yield value
Expand Down Expand Up @@ -306,7 +305,7 @@ async def _min_max(
raise ValueError(f"{name}() arg is an empty sequence")
elif key is None:
async for item in item_iter:
if invert ^ (item < best):
if invert ^ bool(item < best):
best = item
else:
key = _awaitify(key)
Expand Down
45 changes: 39 additions & 6 deletions asyncstdlib/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from typing import Any, AsyncIterator, Awaitable, Callable, overload
from typing_extensions import TypeGuard
import builtins

from ._typing import ADD, AnyIterable, HK, LT, R, T, T1, T2, T3, T4, T5
from ._typing import ADD, AnyIterable, HK, LT, R, T, T1, T2, T3, T4, T5, SupportsLT

@overload
async def anext(iterator: AsyncIterator[T]) -> T: ...
Expand All @@ -16,6 +16,10 @@ def iter(
) -> AsyncIterator[T]: ...
@overload
def iter(subject: Callable[[], Awaitable[T]], sentinel: T) -> AsyncIterator[T]: ...
@overload
def iter(subject: Callable[[], T | None], sentinel: None) -> AsyncIterator[T]: ...
@overload
def iter(subject: Callable[[], T], sentinel: T) -> AsyncIterator[T]: ...
async def all(iterable: AnyIterable[Any]) -> bool: ...
async def any(iterable: AnyIterable[Any]) -> bool: ...
@overload
Expand Down Expand Up @@ -180,20 +184,42 @@ async def max(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...
@overload
async def max(iterable: AnyIterable[LT], *, key: None = ..., default: T) -> LT | T: ...
@overload
async def max(iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ...) -> T1: ...
async def max(
iterable: AnyIterable[T1], *, key: Callable[[T1], Awaitable[SupportsLT]]
) -> T1: ...
@overload
async def max(
iterable: AnyIterable[T1],
*,
key: Callable[[T1], Awaitable[SupportsLT]],
default: T2,
) -> T1 | T2: ...
@overload
async def max(iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT]) -> T1: ...
@overload
async def max(
iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ..., default: T2
iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT], default: T2
) -> T1 | T2: ...
@overload
async def min(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...
@overload
async def min(iterable: AnyIterable[LT], *, key: None = ..., default: T) -> LT | T: ...
@overload
async def min(iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ...) -> T1: ...
async def min(
iterable: AnyIterable[T1], *, key: Callable[[T1], Awaitable[SupportsLT]]
) -> T1: ...
@overload
async def min(
iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ..., default: T2
iterable: AnyIterable[T1],
*,
key: Callable[[T1], Awaitable[SupportsLT]],
default: T2,
) -> T1 | T2: ...
@overload
async def min(iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT]) -> T1: ...
@overload
async def min(
iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT], default: T2
) -> T1 | T2: ...
@overload
def filter(
Expand Down Expand Up @@ -247,5 +273,12 @@ async def sorted(
) -> builtins.list[LT]: ...
@overload
async def sorted(
iterable: AnyIterable[T], *, key: Callable[[T], LT], reverse: bool = ...
iterable: AnyIterable[T],
*,
key: Callable[[T], Awaitable[SupportsLT]],
reverse: bool = ...,
) -> builtins.list[T]: ...
@overload
async def sorted(
iterable: AnyIterable[T], *, key: Callable[[T], SupportsLT], reverse: bool = ...
) -> builtins.list[T]: ...
2 changes: 1 addition & 1 deletion asyncstdlib/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ async def data(self):
if iscoroutinefunction(type_or_getter):
return CachedProperty(type_or_getter)
elif isinstance(type_or_getter, type) and issubclass(
type_or_getter, AsyncContextManager
type_or_getter, AsyncContextManager # pyright: ignore[reportGeneralTypeIssues]
):

def decorator(
Expand Down
8 changes: 8 additions & 0 deletions asyncstdlib/functools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def cached_property(
asynccontextmanager_type: type[AsyncContextManager[Any]], /
) -> Callable[[Callable[[T], Awaitable[R]]], CachedProperty[T, R]]: ...
@overload
async def reduce(
function: Callable[[T1, T2], Awaitable[T1]], iterable: AnyIterable[T2], initial: T1
) -> T1: ...
@overload
async def reduce(
function: Callable[[T, T], Awaitable[T]], iterable: AnyIterable[T]
) -> T: ...
@overload
async def reduce(
function: Callable[[T1, T2], T1], iterable: AnyIterable[T2], initial: T1
) -> T1: ...
Expand Down
4 changes: 2 additions & 2 deletions asyncstdlib/heapq.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def pull_head(self) -> bool:
return True

def __lt__(self, other: _KeyIter[LT]) -> bool:
return self.reverse ^ (self.head_key < other.head_key)
return self.reverse ^ bool(self.head_key < other.head_key)

def __eq__(self, other: _KeyIter[LT]) -> bool: # type: ignore[override]
return not (self.head_key < other.head_key or other.head_key < self.head_key)
Expand Down Expand Up @@ -161,7 +161,7 @@ def __init__(self, key: LT):
self.key = key

def __lt__(self, other: ReverseLT[LT]) -> bool:
return other.key < self.key
return bool(other.key < self.key)


# Python's heapq provides a *min*-heap
Expand Down
17 changes: 17 additions & 0 deletions typetests/test_builtins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import TypeVar
from asyncstdlib import builtins

T = TypeVar("T")


def identity(v: T) -> T:
return v


async def async_identity(v: T) -> T:
return v


async def test_min_asyncneutral() -> None:
await builtins.min([1, 2, 3], key=identity)
await builtins.min([1, 2, 3], key=async_identity)
15 changes: 12 additions & 3 deletions typetests/test_functools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from asyncstdlib import lru_cache
from asyncstdlib import functools


@lru_cache()
@functools.lru_cache()
async def lru_function(a: int) -> int:
return a

Expand All @@ -16,7 +16,7 @@ class TestLRUMethod:
Test that `lru_cache` works on methods
"""

@lru_cache()
@functools.lru_cache()
async def cached(self, a: int = 0) -> int:
return a

Expand All @@ -26,3 +26,12 @@ async def test_implicit_self(self) -> int:
async def test_method_parameters(self) -> int:
await self.cached("wrong parameter type") # type: ignore[arg-type]
return await self.cached(12)


async def aadd(a: int, b: int) -> int:
return a + b


async def test_reduce() -> None:
await functools.reduce(aadd, [1, 2, 3, 4])
await functools.reduce(aadd, [1, 2, 3, 4], initial=1)