Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pandas-stubs/_libs/interval.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ VALID_CLOSED: frozenset[str]

_OrderableScalarT = TypeVar("_OrderableScalarT", bound=int | float)
_OrderableTimesT = TypeVar("_OrderableTimesT", bound=Timestamp | Timedelta)
_OrderableT = TypeVar("_OrderableT", bound=int | float | Timestamp | Timedelta)
_OrderableT = TypeVar(
"_OrderableT", bound=int | float | Timestamp | Timedelta, default=Any
)

@type_check_only
class _LengthDescriptor:
Expand Down
64 changes: 38 additions & 26 deletions pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ from pandas.tseries.offsets import (
P = ParamSpec("P")

HashableT = TypeVar("HashableT", bound=Hashable)
HashableT0 = TypeVar("HashableT0", bound=Hashable, default=Any)
HashableT1 = TypeVar("HashableT1", bound=Hashable)
HashableT2 = TypeVar("HashableT2", bound=Hashable)
HashableT3 = TypeVar("HashableT3", bound=Hashable)
Expand Down Expand Up @@ -774,7 +775,7 @@ XMLParsers: TypeAlias = Literal["lxml", "etree"]
HTMLFlavors: TypeAlias = Literal["lxml", "html5lib", "bs4"]

# Interval closed type
IntervalT = TypeVar("IntervalT", bound=Interval)
IntervalT = TypeVar("IntervalT", bound=Interval, default=Interval)
IntervalLeftRight: TypeAlias = Literal["left", "right"]
IntervalClosedType: TypeAlias = IntervalLeftRight | Literal["both", "neither"]

Expand Down Expand Up @@ -872,7 +873,11 @@ ExcelWriterMergeCells: TypeAlias = bool | Literal["columns"]

# read_csv: usecols
UsecolsArgType: TypeAlias = (
SequenceNotStr[Hashable] | range | AnyArrayLike | Callable[[HashableT], bool] | None
SequenceNotStr[Hashable]
| range
| AnyArrayLike
| Callable[[HashableT0], bool]
| None
)

# maintain the sub-type of any hashable sequence
Expand Down Expand Up @@ -918,6 +923,7 @@ PyArrowNotStrDtypeArg: TypeAlias = (
StrLike: TypeAlias = str | np.str_

ScalarT = TypeVar("ScalarT", bound=Scalar)
ScalarT0 = TypeVar("ScalarT0", bound=Scalar, default=Scalar)
# Refine the definitions below in 3.9 to use the specialized type.
np_num: TypeAlias = np.bool | np.integer | np.floating | np.complexfloating
np_ndarray_intp: TypeAlias = npt.NDArray[np.intp]
Expand Down Expand Up @@ -959,7 +965,10 @@ np_1darray_dt: TypeAlias = np_1darray[np.datetime64]
np_1darray_td: TypeAlias = np_1darray[np.timedelta64]
np_2darray: TypeAlias = np.ndarray[tuple[int, int], np.dtype[GenericT]]

NDArrayT = TypeVar("NDArrayT", bound=np.ndarray)
if sys.version_info >= (3, 11):
NDArrayT = TypeVar("NDArrayT", bound=np.ndarray)
else:
NDArrayT = TypeVar("NDArrayT", bound=np.ndarray[Any, Any])

DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic])
KeysArgType: TypeAlias = Any
Expand All @@ -969,7 +978,7 @@ ListLikeExceptSeriesAndStr: TypeAlias = (
)
ListLikeU: TypeAlias = Sequence[Any] | np_1darray | Series | Index
ListLikeHashable: TypeAlias = (
MutableSequence[HashableT] | np_1darray | tuple[HashableT, ...] | range
MutableSequence[HashableT0] | np_1darray | tuple[HashableT0, ...] | range
)

class SupportsDType(Protocol[GenericT_co]):
Expand Down Expand Up @@ -1010,8 +1019,9 @@ SeriesDType: TypeAlias = (
| datetime.datetime # includes pd.Timestamp
| datetime.timedelta # includes pd.Timedelta
)
S0 = TypeVar("S0", bound=SeriesDType, default=Any)
S1 = TypeVar("S1", bound=SeriesDType, default=Any)
# Like S1, but without `default=Any`.
# Like S0 and S1, but without `default=Any`.
S2 = TypeVar("S2", bound=SeriesDType)
S2_contra = TypeVar("S2_contra", bound=SeriesDType, contravariant=True)
S2_NDT_contra = TypeVar(
Expand Down Expand Up @@ -1045,14 +1055,14 @@ IndexingInt: TypeAlias = (
)

# AxesData is used for data for Index
AxesData: TypeAlias = Mapping[S3, Any] | Axes | KeysView[S3]
AxesData: TypeAlias = Mapping[S0, Any] | Axes | KeysView[S0]

# Any plain Python or numpy function
Function: TypeAlias = np.ufunc | Callable[..., Any]
# Use a distinct HashableT in shared types to avoid conflicts with
# shared HashableT and HashableT#. This one can be used if the identical
# type is need in a function that uses GroupByObjectNonScalar
_HashableTa = TypeVar("_HashableTa", bound=Hashable)
_HashableTa = TypeVar("_HashableTa", bound=Hashable, default=Any)
if TYPE_CHECKING: # noqa: PYI002
ByT = TypeVar(
"ByT",
Expand All @@ -1070,7 +1080,7 @@ if TYPE_CHECKING: # noqa: PYI002
| Scalar
| Period
| Interval[int | float | Timestamp | Timedelta]
| tuple,
| tuple[Any, ...],
)
# Use a distinct SeriesByT when using groupby with Series of known dtype.
# Essentially, an intersection between Series S1 TypeVar, and ByT TypeVar
Expand All @@ -1088,21 +1098,23 @@ if TYPE_CHECKING: # noqa: PYI002
| Period
| Interval[int | float | Timestamp | Timedelta],
)
GroupByObjectNonScalar: TypeAlias = (
tuple[_HashableTa, ...]
| list[_HashableTa]
| Function
| list[Function]
| list[Series]
| np_ndarray
| list[np_ndarray]
| Mapping[Label, Any]
| list[Mapping[Label, Any]]
| list[Index]
| Grouper
| list[Grouper]
)
GroupByObject: TypeAlias = Scalar | Index | GroupByObjectNonScalar | Series
GroupByObjectNonScalar: TypeAlias = (
tuple[_HashableTa, ...]
| list[_HashableTa]
| Function
| list[Function]
| list[Series]
| np_ndarray
| list[np_ndarray]
| Mapping[Label, Any]
| list[Mapping[Label, Any]]
| list[Index]
| Grouper
| list[Grouper]
)
GroupByObject: TypeAlias = (
Scalar | Index | GroupByObjectNonScalar[_HashableTa] | Series
)

StataDateFormat: TypeAlias = Literal[
"tc",
Expand All @@ -1125,10 +1137,10 @@ StataDateFormat: TypeAlias = Literal[
# `DataFrame.replace` also accepts mappings of these.
ReplaceValue: TypeAlias = (
Scalar
| Pattern
| Pattern[Any]
| NAType
| Sequence[Scalar | Pattern]
| Mapping[HashableT, ScalarT]
| Sequence[Scalar | Pattern[Any]]
| Mapping[HashableT0, ScalarT0]
| Series
| None
)
Expand Down
92 changes: 51 additions & 41 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ from pandas.core.indexing import (
_LocIndexer,
)
from pandas.core.reshape.pivot import (
_PivotAggFunc,
_PivotAggFuncTypes,
_PivotTableColumnsTypes,
_PivotTableIndexTypes,
_PivotTableValuesTypes,
Expand Down Expand Up @@ -178,7 +178,7 @@ from pandas.plotting import PlotAccessor
from pandas.plotting._core import _BoxPlotT

_T_MUTABLE_MAPPING_co = TypeVar(
"_T_MUTABLE_MAPPING_co", bound=MutableMapping, covariant=True
"_T_MUTABLE_MAPPING_co", bound=MutableMapping[Any, Any], covariant=True
)

class _iLocIndexerFrame(_iLocIndexer, Generic[_T]):
Expand Down Expand Up @@ -361,28 +361,23 @@ class _AtIndexerFrame(_AtIndexer):
),
) -> None: ...

# With mypy 1.14.1 and python 3.12, the second overload needs a type-ignore statement
if sys.version_info >= (3, 12):
class _GetItemHack:
@overload
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
class _GetItemHack:
@overload
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
# With mypy 1.14.1 and python 3.12, the second overload needs a type-ignore statement
if sys.version_info >= (3, 12):
@overload
def __getitem__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
self, key: Iterable[Hashable] | slice
) -> Self: ...
@overload
def __getitem__(self, key: Hashable) -> Series: ...

else:
class _GetItemHack:
@overload
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
else:
@overload
def __getitem__( # pyright: ignore[reportOverlappingOverload]
self, key: Iterable[Hashable] | slice
) -> Self: ...
@overload
def __getitem__(self, key: Hashable) -> Series: ...

@overload
def __getitem__(self, key: Hashable) -> Series: ...

_AstypeArgExt: TypeAlias = (
AstypeArg
Expand Down Expand Up @@ -484,7 +479,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
self,
orient: str = ...,
*,
into: type[defaultdict],
into: type[defaultdict[Any, Any]],
index: Literal[True] = True,
) -> Never: ...
@overload
Expand All @@ -500,7 +495,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
self,
orient: Literal["records"],
*,
into: type[dict] = ...,
into: type[dict[Any, Any]] = ...,
index: Literal[True] = True,
) -> list[dict[Hashable, Any]]: ...
@overload
Expand All @@ -516,23 +511,23 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
self,
orient: Literal["index"],
*,
into: OrderedDict | type[OrderedDict],
into: OrderedDict[Any, Any] | type[OrderedDict[Any, Any]],
index: Literal[True] = True,
) -> OrderedDict[Hashable, dict[Hashable, Any]]: ...
@overload
def to_dict(
self,
orient: Literal["index"],
*,
into: type[MutableMapping],
into: type[MutableMapping[Any, Any]],
index: Literal[True] = True,
) -> MutableMapping[Hashable, dict[Hashable, Any]]: ...
@overload
def to_dict(
self,
orient: Literal["index"],
*,
into: type[dict] = ...,
into: type[dict[Any, Any]] = ...,
index: Literal[True] = True,
) -> dict[Hashable, dict[Hashable, Any]]: ...
@overload
Expand All @@ -548,23 +543,23 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
self,
orient: Literal["dict", "list", "series"] = ...,
*,
into: type[dict] = ...,
into: type[dict[Any, Any]] = ...,
index: Literal[True] = True,
) -> dict[Hashable, Any]: ...
@overload
def to_dict(
self,
orient: Literal["split", "tight"],
*,
into: MutableMapping[Any, Any] | type[MutableMapping],
into: MutableMapping[Any, Any] | type[MutableMapping[Any, Any]],
index: bool = ...,
) -> MutableMapping[str, list[Any]]: ...
@overload
def to_dict(
self,
orient: Literal["split", "tight"],
*,
into: type[dict] = ...,
into: type[dict[Any, Any]] = ...,
index: bool = ...,
) -> dict[str, list[Any]]: ...
@classmethod
Expand All @@ -583,16 +578,29 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
coerce_float: bool = False,
nrows: int | None = None,
) -> Self: ...
def to_records(
self,
index: _bool = True,
column_dtypes: (
_str | npt.DTypeLike | Mapping[HashableT1, npt.DTypeLike] | None
) = None,
index_dtypes: (
_str | npt.DTypeLike | Mapping[HashableT2, npt.DTypeLike] | None
) = None,
) -> np.recarray: ...
if sys.version_info >= (3, 11):
def to_records(
self,
index: _bool = True,
column_dtypes: (
_str | npt.DTypeLike | Mapping[HashableT1, npt.DTypeLike] | None
) = None,
index_dtypes: (
_str | npt.DTypeLike | Mapping[HashableT2, npt.DTypeLike] | None
) = None,
) -> np.recarray: ...
else:
def to_records(
self,
index: _bool = True,
column_dtypes: (
_str | npt.DTypeLike | Mapping[HashableT1, npt.DTypeLike] | None
) = None,
index_dtypes: (
_str | npt.DTypeLike | Mapping[HashableT2, npt.DTypeLike] | None
) = None,
) -> np.recarray[Any, Any]: ...

@overload
def to_stata(
self,
Expand Down Expand Up @@ -1381,7 +1389,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
dropna: _bool = ...,
) -> DataFrameGroupBy[Period, Literal[False]]: ...
@overload
def groupby( # pyright: ignore reportOverlappingOverload
def groupby(
self,
by: IntervalIndex[IntervalT],
level: IndexLabel | None = ...,
Expand All @@ -1394,7 +1402,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
@overload
def groupby(
self,
by: IntervalIndex[IntervalT],
by: IntervalIndex,
level: IndexLabel | None = ...,
as_index: Literal[False] = False,
sort: _bool = ...,
Expand Down Expand Up @@ -1480,9 +1488,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
values: _PivotTableValuesTypes = None,
index: _PivotTableIndexTypes = None,
columns: _PivotTableColumnsTypes = None,
aggfunc: (
_PivotAggFunc | Sequence[_PivotAggFunc] | Mapping[Hashable, _PivotAggFunc]
) = "mean",
aggfunc: _PivotAggFuncTypes[Scalar] = "mean",
fill_value: Scalar | None = None,
margins: _bool = False,
dropna: _bool = True,
Expand Down Expand Up @@ -2863,8 +2869,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
def __rfloordiv__(
self, other: float | DataFrame | Series[int] | Series[float] | Sequence[float]
) -> Self: ...
def __truediv__(self, other: float | DataFrame | Series | Sequence) -> Self: ...
def __rtruediv__(self, other: float | DataFrame | Series | Sequence) -> Self: ...
def __truediv__(
self, other: float | DataFrame | Series | Sequence[Any]
) -> Self: ...
def __rtruediv__(
self, other: float | DataFrame | Series | Sequence[Any]
) -> Self: ...
@final
def __bool__(self) -> NoReturn: ...

Expand Down
8 changes: 4 additions & 4 deletions pandas-stubs/core/groupby/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
random_state: RandomState | None = ...,
) -> NDFrameT: ...

_GroupByT = TypeVar("_GroupByT", bound=GroupBy)
_GroupByT = TypeVar("_GroupByT", bound=GroupBy[Any])

# GroupByPlot does not really inherit from PlotAccessor but it delegates
# to it using __call__ and __getattr__. We lie here to avoid repeating the
Expand Down Expand Up @@ -383,15 +383,15 @@ class BaseGroupBy(SelectionMixin[NDFrameT], GroupByIndexingMixin):
@final
def __iter__(self) -> Iterator[tuple[Hashable, NDFrameT]]: ...
@overload
def __getitem__(self: BaseGroupBy[DataFrame], key: Scalar) -> generic.SeriesGroupBy: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
def __getitem__(self: BaseGroupBy[DataFrame], key: Scalar) -> generic.SeriesGroupBy[Any, Any]: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
@overload
def __getitem__(
self: BaseGroupBy[DataFrame], key: Iterable[Hashable]
) -> generic.DataFrameGroupBy: ...
) -> generic.DataFrameGroupBy[Any, Any]: ...
@overload
def __getitem__(
self: BaseGroupBy[Series[S1]],
idx: list[str] | Index | Series[S1] | MaskType | tuple[Hashable | slice, ...],
) -> generic.SeriesGroupBy: ...
) -> generic.SeriesGroupBy[Any, Any]: ...
@overload
def __getitem__(self: BaseGroupBy[Series[S1]], idx: Scalar) -> S1: ...
Loading