Skip to content

Commit 15257ea

Browse files
MarcoGorellijreback
authored andcommitted
ENH: accept a dictionary in plot colors (#31071)
1 parent 35df212 commit 15257ea

File tree

5 files changed

+136
-73
lines changed

5 files changed

+136
-73
lines changed

doc/source/whatsnew/v1.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ I/O
171171
Plotting
172172
^^^^^^^^
173173

174-
-
174+
- :func:`.plot` for line/bar now accepts color by dictonary (:issue:`8193`).
175175
-
176176

177177
Groupby/resample/rolling

pandas/plotting/_core.py

+103-70
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,45 @@ def hist_frame(
385385
"""
386386

387387

388+
_bar_or_line_doc = """
389+
Parameters
390+
----------
391+
x : label or position, optional
392+
Allows plotting of one column versus another. If not specified,
393+
the index of the DataFrame is used.
394+
y : label or position, optional
395+
Allows plotting of one column versus another. If not specified,
396+
all numerical columns are used.
397+
color : str, array_like, or dict, optional
398+
The color for each of the DataFrame's columns. Possible values are:
399+
400+
- A single color string referred to by name, RGB or RGBA code,
401+
for instance 'red' or '#a98d19'.
402+
403+
- A sequence of color strings referred to by name, RGB or RGBA
404+
code, which will be used for each column recursively. For
405+
instance ['green','yellow'] each column's %(kind)s will be filled in
406+
green or yellow, alternatively.
407+
408+
- A dict of the form {column name : color}, so that each column will be
409+
colored accordingly. For example, if your columns are called `a` and
410+
`b`, then passing {'a': 'green', 'b': 'red'} will color %(kind)ss for
411+
column `a` in green and %(kind)ss for column `b` in red.
412+
413+
.. versionadded:: 1.1.0
414+
415+
**kwargs
416+
Additional keyword arguments are documented in
417+
:meth:`DataFrame.plot`.
418+
419+
Returns
420+
-------
421+
matplotlib.axes.Axes or np.ndarray of them
422+
An ndarray is returned with one :class:`matplotlib.axes.Axes`
423+
per column when ``subplots=True``.
424+
"""
425+
426+
388427
@Substitution(backend="")
389428
@Appender(_boxplot_doc)
390429
def boxplot(
@@ -848,31 +887,8 @@ def __call__(self, *args, **kwargs):
848887

849888
__call__.__doc__ = __doc__
850889

851-
def line(self, x=None, y=None, **kwargs):
890+
@Appender(
852891
"""
853-
Plot Series or DataFrame as lines.
854-
855-
This function is useful to plot lines using DataFrame's values
856-
as coordinates.
857-
858-
Parameters
859-
----------
860-
x : int or str, optional
861-
Columns to use for the horizontal axis.
862-
Either the location or the label of the columns to be used.
863-
By default, it will use the DataFrame indices.
864-
y : int, str, or list of them, optional
865-
The values to be plotted.
866-
Either the location or the label of the columns to be used.
867-
By default, it will use the remaining DataFrame numeric columns.
868-
**kwargs
869-
Keyword arguments to pass on to :meth:`DataFrame.plot`.
870-
871-
Returns
872-
-------
873-
:class:`matplotlib.axes.Axes` or :class:`numpy.ndarray`
874-
Return an ndarray when ``subplots=True``.
875-
876892
See Also
877893
--------
878894
matplotlib.pyplot.plot : Plot y versus x as lines and/or markers.
@@ -907,6 +923,16 @@ def line(self, x=None, y=None, **kwargs):
907923
>>> type(axes)
908924
<class 'numpy.ndarray'>
909925
926+
.. plot::
927+
:context: close-figs
928+
929+
Let's repeat the same example, but specifying colors for
930+
each column (in this case, for each animal).
931+
932+
>>> axes = df.plot.line(
933+
... subplots=True, color={"pig": "pink", "horse": "#742802"}
934+
... )
935+
910936
.. plot::
911937
:context: close-figs
912938
@@ -915,36 +941,20 @@ def line(self, x=None, y=None, **kwargs):
915941
916942
>>> lines = df.plot.line(x='pig', y='horse')
917943
"""
918-
return self(kind="line", x=x, y=y, **kwargs)
919-
920-
def bar(self, x=None, y=None, **kwargs):
944+
)
945+
@Substitution(kind="line")
946+
@Appender(_bar_or_line_doc)
947+
def line(self, x=None, y=None, **kwargs):
921948
"""
922-
Vertical bar plot.
923-
924-
A bar plot is a plot that presents categorical data with
925-
rectangular bars with lengths proportional to the values that they
926-
represent. A bar plot shows comparisons among discrete categories. One
927-
axis of the plot shows the specific categories being compared, and the
928-
other axis represents a measured value.
929-
930-
Parameters
931-
----------
932-
x : label or position, optional
933-
Allows plotting of one column versus another. If not specified,
934-
the index of the DataFrame is used.
935-
y : label or position, optional
936-
Allows plotting of one column versus another. If not specified,
937-
all numerical columns are used.
938-
**kwargs
939-
Additional keyword arguments are documented in
940-
:meth:`DataFrame.plot`.
949+
Plot Series or DataFrame as lines.
941950
942-
Returns
943-
-------
944-
matplotlib.axes.Axes or np.ndarray of them
945-
An ndarray is returned with one :class:`matplotlib.axes.Axes`
946-
per column when ``subplots=True``.
951+
This function is useful to plot lines using DataFrame's values
952+
as coordinates.
953+
"""
954+
return self(kind="line", x=x, y=y, **kwargs)
947955

956+
@Appender(
957+
"""
948958
See Also
949959
--------
950960
DataFrame.plot.barh : Horizontal bar plot.
@@ -986,6 +996,17 @@ def bar(self, x=None, y=None, **kwargs):
986996
>>> axes = df.plot.bar(rot=0, subplots=True)
987997
>>> axes[1].legend(loc=2) # doctest: +SKIP
988998
999+
If you don't like the default colours, you can specify how you'd
1000+
like each column to be colored.
1001+
1002+
.. plot::
1003+
:context: close-figs
1004+
1005+
>>> axes = df.plot.bar(
1006+
... rot=0, subplots=True, color={"speed": "red", "lifespan": "green"}
1007+
... )
1008+
>>> axes[1].legend(loc=2) # doctest: +SKIP
1009+
9891010
Plot a single column.
9901011
9911012
.. plot::
@@ -999,32 +1020,24 @@ def bar(self, x=None, y=None, **kwargs):
9991020
:context: close-figs
10001021
10011022
>>> ax = df.plot.bar(x='lifespan', rot=0)
1023+
"""
1024+
)
1025+
@Substitution(kind="bar")
1026+
@Appender(_bar_or_line_doc)
1027+
def bar(self, x=None, y=None, **kwargs):
10021028
"""
1003-
return self(kind="bar", x=x, y=y, **kwargs)
1004-
1005-
def barh(self, x=None, y=None, **kwargs):
1006-
"""
1007-
Make a horizontal bar plot.
1029+
Vertical bar plot.
10081030
1009-
A horizontal bar plot is a plot that presents quantitative data with
1031+
A bar plot is a plot that presents categorical data with
10101032
rectangular bars with lengths proportional to the values that they
10111033
represent. A bar plot shows comparisons among discrete categories. One
10121034
axis of the plot shows the specific categories being compared, and the
10131035
other axis represents a measured value.
1036+
"""
1037+
return self(kind="bar", x=x, y=y, **kwargs)
10141038

1015-
Parameters
1016-
----------
1017-
x : label or position, default DataFrame.index
1018-
Column to be used for categories.
1019-
y : label or position, default All numeric columns in dataframe
1020-
Columns to be plotted from the DataFrame.
1021-
**kwargs
1022-
Keyword arguments to pass on to :meth:`DataFrame.plot`.
1023-
1024-
Returns
1025-
-------
1026-
:class:`matplotlib.axes.Axes` or numpy.ndarray of them
1027-
1039+
@Appender(
1040+
"""
10281041
See Also
10291042
--------
10301043
DataFrame.plot.bar: Vertical bar plot.
@@ -1054,6 +1067,13 @@ def barh(self, x=None, y=None, **kwargs):
10541067
... 'lifespan': lifespan}, index=index)
10551068
>>> ax = df.plot.barh()
10561069
1070+
We can specify colors for each column
1071+
1072+
.. plot::
1073+
:context: close-figs
1074+
1075+
>>> ax = df.plot.barh(color={"speed": "red", "lifespan": "green"})
1076+
10571077
Plot a column of the DataFrame to a horizontal bar plot
10581078
10591079
.. plot::
@@ -1079,6 +1099,19 @@ def barh(self, x=None, y=None, **kwargs):
10791099
>>> df = pd.DataFrame({'speed': speed,
10801100
... 'lifespan': lifespan}, index=index)
10811101
>>> ax = df.plot.barh(x='lifespan')
1102+
"""
1103+
)
1104+
@Substitution(kind="bar")
1105+
@Appender(_bar_or_line_doc)
1106+
def barh(self, x=None, y=None, **kwargs):
1107+
"""
1108+
Make a horizontal bar plot.
1109+
1110+
A horizontal bar plot is a plot that presents quantitative data with
1111+
rectangular bars with lengths proportional to the values that they
1112+
represent. A bar plot shows comparisons among discrete categories. One
1113+
axis of the plot shows the specific categories being compared, and the
1114+
other axis represents a measured value.
10821115
"""
10831116
return self(kind="barh", x=x, y=y, **kwargs)
10841117

pandas/plotting/_matplotlib/core.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,10 @@ def _apply_style_colors(self, colors, kwds, col_num, label):
726726
has_color = "color" in kwds or self.colormap is not None
727727
nocolor_style = style is None or re.match("[a-z]+", style) is None
728728
if (has_color or self.subplots) and nocolor_style:
729-
kwds["color"] = colors[col_num % len(colors)]
729+
if isinstance(colors, dict):
730+
kwds["color"] = colors[label]
731+
else:
732+
kwds["color"] = colors[col_num % len(colors)]
730733
return style, kwds
731734

732735
def _get_colors(self, num_colors=None, color_kwds="color"):
@@ -1347,6 +1350,8 @@ def _make_plot(self):
13471350
kwds = self.kwds.copy()
13481351
if self._is_series:
13491352
kwds["color"] = colors
1353+
elif isinstance(colors, dict):
1354+
kwds["color"] = colors[label]
13501355
else:
13511356
kwds["color"] = colors[i % ncolors]
13521357

pandas/plotting/_matplotlib/style.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ def _get_standard_colors(
2727
warnings.warn(
2828
"'color' and 'colormap' cannot be used simultaneously. Using 'color'"
2929
)
30-
colors = list(color) if is_list_like(color) else color
30+
colors = (
31+
list(color)
32+
if is_list_like(color) and not isinstance(color, dict)
33+
else color
34+
)
3135
else:
3236
if color_type == "default":
3337
# need to call list() on the result to copy so we don't

pandas/tests/plotting/test_misc.py

+21
Original file line numberDiff line numberDiff line change
@@ -406,3 +406,24 @@ def test_get_standard_colors_no_appending(self):
406406
color_list = cm.gnuplot(np.linspace(0, 1, 16))
407407
p = df.A.plot.bar(figsize=(16, 7), color=color_list)
408408
assert p.patches[1].get_facecolor() == p.patches[17].get_facecolor()
409+
410+
@pytest.mark.slow
411+
def test_dictionary_color(self):
412+
# issue-8193
413+
# Test plot color dictionary format
414+
data_files = ["a", "b"]
415+
416+
expected = [(0.5, 0.24, 0.6), (0.3, 0.7, 0.7)]
417+
418+
df1 = DataFrame(np.random.rand(2, 2), columns=data_files)
419+
dic_color = {"b": (0.3, 0.7, 0.7), "a": (0.5, 0.24, 0.6)}
420+
421+
# Bar color test
422+
ax = df1.plot(kind="bar", color=dic_color)
423+
colors = [rect.get_facecolor()[0:-1] for rect in ax.get_children()[0:3:2]]
424+
assert all(color == expected[index] for index, color in enumerate(colors))
425+
426+
# Line color test
427+
ax = df1.plot(kind="line", color=dic_color)
428+
colors = [rect.get_color() for rect in ax.get_lines()[0:2]]
429+
assert all(color == expected[index] for index, color in enumerate(colors))

0 commit comments

Comments
 (0)