Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/_numtype/_dtype.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,33 @@ __all__ = [
_T = TypeVar("_T")
_ScalarT = TypeVar("_ScalarT", bound=np.generic)
_ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, covariant=True, default=Any)
_DTypeT = TypeVar("_DTypeT", bound=np.dtype)
_DTypeT_co = TypeVar("_DTypeT_co", bound=np.dtype, covariant=True, default=np.dtype)

@type_check_only
class _HasDType(Protocol[_DTypeT_co]):
class _HasDTypeOld(Protocol[_DTypeT_co]):
@property
def dtype(self) -> _DTypeT_co: ...

@type_check_only
class _HasDTypeOf(Protocol[_ScalarT_co]):
class _HasDTypeNew(Protocol[_DTypeT_co]):
@property
def __numpy_dtype__(self) -> _DTypeT_co: ...

_HasDType = TypeAliasType("_HasDType", _HasDTypeNew[_DTypeT] | _HasDTypeOld[_DTypeT], type_params=(_DTypeT,))

@type_check_only
class _HasDTypeOldOf(Protocol[_ScalarT_co]):
@property
def dtype(self) -> np.dtype[_ScalarT_co]: ...

@type_check_only
class _HasDTypeNewOf(Protocol[_ScalarT_co]):
@property
def __numpy_dtype__(self) -> np.dtype[_ScalarT_co]: ...

_HasDTypeOf = TypeAliasType("_HasDTypeOf", _HasDTypeNewOf[_ScalarT] | _HasDTypeOldOf[_ScalarT], type_params=(_ScalarT,))

_ToDType = TypeAliasType(
"_ToDType", type[_ScalarT] | np.dtype[_ScalarT] | _HasDTypeOf[_ScalarT], type_params=(_ScalarT,)
)
Expand Down
18 changes: 13 additions & 5 deletions src/numpy-stubs/_typing/_dtype_like.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from _typeshed import Incomplete
from collections.abc import Sequence
from typing import Any, Protocol, Required, TypeAlias, TypedDict, runtime_checkable
from typing import Any, Protocol, Required, TypeAlias, TypedDict, type_check_only
from typing_extensions import TypeVar

import numpy as np
Expand All @@ -21,11 +21,13 @@ from ._char_codes import (
)

_ScalarT = TypeVar("_ScalarT", bound=np.generic)
_DTypeT = TypeVar("_DTypeT", bound=np.dtype)
_DTypeT_co = TypeVar("_DTypeT_co", covariant=True, bound=np.dtype)

# TODO(jorenham): Actually annotate this
_DTypeLikeNested: TypeAlias = Incomplete

@type_check_only
class _DTypeDict(TypedDict, total=False):
names: Required[Sequence[str]]
formats: Required[Sequence[_DTypeLikeNested]]
Expand All @@ -36,11 +38,17 @@ class _DTypeDict(TypedDict, total=False):
# but `titles` can in principle accept any object
titles: Sequence[Any]

# A protocol for anything with the dtype attribute
@runtime_checkable
class _SupportsDType(Protocol[_DTypeT_co]):
@type_check_only
class _HasDTypeLegacy(Protocol[_DTypeT_co]):
@property
def dtype(self) -> _DTypeT_co: ...
def dtype(self, /) -> _DTypeT_co: ...

@type_check_only
class _HasDType(Protocol[_DTypeT_co]):
@property
def __numpy_dtype__(self, /) -> _DTypeT_co: ...

_SupportsDType: TypeAlias = _HasDType[_DTypeT] | _HasDTypeLegacy[_DTypeT]

# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic`
_DTypeLike: TypeAlias = type[_ScalarT] | np.dtype[_ScalarT] | _SupportsDType[np.dtype[_ScalarT]]
Expand Down