Skip to content

seaborn: fix and complete seaborn.regression #11043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions stubs/seaborn/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
seaborn._core.scales.(Pipeline|TransFuncs) # aliases defined in `if TYPE_CHECKING` block
seaborn.external.docscrape.ClassDoc.__init__ # stubtest doesn't like ABC class as default value
seaborn.external.docscrape.NumpyDocString.__str__ # weird signature

seaborn(\.regression)?\.lmplot # the `data` argument is required but it defaults to `None` at runtime
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a shame we have to allow list the whole function and can't only list the offending parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100 changes: 78 additions & 22 deletions stubs/seaborn/seaborn/regression.pyi
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from _typeshed import Incomplete
from collections.abc import Iterable
from typing import Any
from typing_extensions import Literal
from collections.abc import Callable, Iterable
from typing import Any, overload
from typing_extensions import Literal, TypeAlias

import pandas as pd
from matplotlib.axes import Axes
from matplotlib.typing import ColorType
from numpy.typing import NDArray

from .axisgrid import FacetGrid
from .utils import _Palette, _Seed

__all__ = ["lmplot", "regplot", "residplot"]

_Vector: TypeAlias = list[Incomplete] | pd.Series[Incomplete] | pd.Index[Incomplete] | NDArray[Incomplete]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_Vector is defined according to this function. All parameters annotated with _Vector are handled by this function


def lmplot(
data: Incomplete | None = None,
data: pd.DataFrame,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*,
x: str | None = None,
y: str | None = None,
Expand All @@ -25,15 +28,15 @@ def lmplot(
height: float = 5,
aspect: float = 1,
markers: str = "o",
sharex: bool | Literal["col", "row"] | None = None,
sharey: bool | Literal["col", "row"] | None = None,
sharex: bool | Literal["col", "row"] | None = None, # deprecated
sharey: bool | Literal["col", "row"] | None = None, # deprecated
Comment on lines +31 to +32
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/mwaskom/seaborn/blob/d4b8de81f74a153bb3c15af760d956d32aad980b/seaborn/regression.py#L601-L603. I'll handle the deprecations of seaborn using typing_extensions.deprecated in another PR.

hue_order: Iterable[str] | None = None,
col_order: Iterable[str] | None = None,
row_order: Iterable[str] | None = None,
legend: bool = True,
legend_out: Incomplete | None = None,
x_estimator: Incomplete | None = None,
x_bins: Incomplete | None = None,
legend_out: bool | None = None, # deprecated
x_estimator: Callable[[Incomplete], Incomplete] | None = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how to type the argument and return type of this callable. It is documented as "callable that maps vector -> scalar", something like numpy.mean but I think it will require Higher-Kinded TypeVars and a lot of overloads to get it right (unless I am missing something trivial about it)

x_bins: int | _Vector | None = None,
x_ci: Literal["ci", "sd"] | int | None = "ci",
scatter: bool = True,
fit_reg: bool = True,
Expand All @@ -55,27 +58,61 @@ def lmplot(
line_kws: dict[str, Any] | None = None,
facet_kws: dict[str, Any] | None = None,
) -> FacetGrid: ...
@overload
def regplot(
data: pd.DataFrame | None = None,
data: None = None,
*,
x: Incomplete | None = None,
y: Incomplete | None = None,
x_estimator: Incomplete | None = None,
x_bins: Incomplete | None = None,
x: _Vector | None = None,
y: _Vector | None = None,
x_estimator: Callable[[Incomplete], Incomplete] | None = None,
x_bins: int | _Vector | None = None,
x_ci: Literal["ci", "sd"] | int | None = "ci",
scatter: bool = True,
fit_reg: bool = True,
ci: int | None = 95,
n_boot: int = 1000,
units: str | None = None,
units: _Vector | None = None,
seed: _Seed | None = None,
order: int = 1,
logistic: bool = False,
lowess: bool = False,
robust: bool = False,
logx: bool = False,
x_partial: str | None = None,
y_partial: str | None = None,
x_partial: _Vector | None = None,
y_partial: _Vector | None = None,
truncate: bool = True,
dropna: bool = True,
x_jitter: float | None = None,
y_jitter: float | None = None,
label: str | None = None,
color: ColorType | None = None,
marker: str = "o",
scatter_kws: dict[str, Any] | None = None,
line_kws: dict[str, Any] | None = None,
ax: Axes | None = None,
) -> Axes: ...
@overload
def regplot(
data: pd.DataFrame,
*,
x: str | _Vector | None = None,
y: str | _Vector | None = None,
Comment on lines +98 to +99
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x, y and the like can only be str if data is provided as a data frame, otherwise they must be vectors: https://seaborn.pydata.org/generated/seaborn.regplot.html

x_estimator: Callable[[Incomplete], Incomplete] | None = None,
x_bins: int | _Vector | None = None,
x_ci: Literal["ci", "sd"] | int | None = "ci",
scatter: bool = True,
fit_reg: bool = True,
ci: int | None = 95,
n_boot: int = 1000,
units: str | _Vector | None = None,
seed: _Seed | None = None,
order: int = 1,
logistic: bool = False,
lowess: bool = False,
robust: bool = False,
logx: bool = False,
x_partial: str | _Vector | None = None,
y_partial: str | _Vector | None = None,
truncate: bool = True,
dropna: bool = True,
x_jitter: float | None = None,
Expand All @@ -87,13 +124,32 @@ def regplot(
line_kws: dict[str, Any] | None = None,
ax: Axes | None = None,
) -> Axes: ...
@overload
def residplot(
data: None = None,
*,
x: _Vector | None = None,
y: _Vector | None = None,
x_partial: _Vector | None = None,
y_partial: _Vector | None = None,
lowess: bool = False,
order: int = 1,
robust: bool = False,
dropna: bool = True,
label: str | None = None,
color: ColorType | None = None,
scatter_kws: dict[str, Any] | None = None,
line_kws: dict[str, Any] | None = None,
ax: Axes | None = None,
) -> Axes: ...
@overload
def residplot(
data: Incomplete | None = None,
data: pd.DataFrame,
*,
x: Incomplete | None = None,
y: Incomplete | None = None,
x_partial: Incomplete | None = None,
y_partial: Incomplete | None = None,
x: str | _Vector | None = None,
y: str | _Vector | None = None,
x_partial: str | _Vector | None = None,
y_partial: str | _Vector | None = None,
lowess: bool = False,
order: int = 1,
robust: bool = False,
Expand Down