diff --git a/pandas-stubs/io/formats/style.pyi b/pandas-stubs/io/formats/style.pyi index b61553603..648f5a64b 100644 --- a/pandas-stubs/io/formats/style.pyi +++ b/pandas-stubs/io/formats/style.pyi @@ -57,6 +57,11 @@ class _SeriesFunc(Protocol): self, series: Series, /, *args: Any, **kwargs: Any ) -> list[Any] | Series: ... +class _SeriesStrFunc(Protocol): + def __call__( + self, series: Series[str], /, *args: Any, **kwargs: Any + ) -> list[str] | Series[str]: ... + class _DataFrameFunc(Protocol): def __call__( self, series: DataFrame, /, *args: Any, **kwargs: Any @@ -277,14 +282,17 @@ class Styler(StylerRenderer): ) -> Styler: ... def apply_index( self, - func: Callable[[Series], list[str] | np_ndarray_str | Series[str]], + func: ( + _SeriesStrFunc + | Callable[[Series], list[str] | np_ndarray_str | Series[str]] + ), axis: Axis = ..., level: Level | list[Level] | None = ..., **kwargs: Any, ) -> Styler: ... def map_index( self, - func: Callable[[Scalar], str | None], + func: _MapCallable | Callable[[Scalar], str | None], axis: Axis = ..., level: Level | list[Level] | None = ..., **kwargs: Any, diff --git a/tests/test_styler.py b/tests/test_styler.py index ba41c5180..e0b0e22c1 100644 --- a/tests/test_styler.py +++ b/tests/test_styler.py @@ -76,6 +76,14 @@ def f1(s: Series) -> Series[str]: check(assert_type(DF.style.apply_index(f1), Styler), Styler) + # GH 1723 + def highlight_odd(index: pd.Series, /, color: str) -> list[str]: + return [f"color: {color}" if x % 2 else "" for x in index] + + check( + assert_type(DF.style.apply_index(highlight_odd, color="purple"), Styler), Styler + ) + def test_map_index() -> None: def f(s: Scalar) -> str | None: @@ -83,6 +91,11 @@ def f(s: Scalar) -> str | None: check(assert_type(DF.style.map_index(f), Styler), Styler) + def f1(s: Scalar, /, color: str) -> str | None: + return f"background-color: {color};" if s == "b" else None + + check(assert_type(DF.style.map_index(f1, color="pink"), Styler), Styler) + def test_background_gradient() -> None: check(assert_type(DF.style.background_gradient(), Styler), Styler)