-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Feat/bubble plot #22403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/bubble plot #22403
Changes from 41 commits
339aa59
33177e0
8e87e24
984d494
9737ca1
69e6662
e46414f
0b6e975
c9cafa1
7112759
15ec9a3
26ecd7f
66e2bcf
0a8c38f
6332e91
24beedf
c80d2c7
a2a1551
67b811f
1a23a6e
44313c1
e054459
04c58fe
d2ff59a
35ede52
bf797d4
a196c22
4d7fa1c
f62085b
e8f461f
2a8d0ac
52dfd1b
0d6cc89
9ffc00c
488ad33
2ceef55
9cd04ac
7620fe8
cd1a636
8b90ceb
cacf942
fb35a6a
6bf9699
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -24,6 +24,8 @@ | |||||
is_integer, | ||||||
is_number, | ||||||
is_hashable, | ||||||
is_numeric_dtype, | ||||||
is_categorical_dtype, | ||||||
is_iterator) | ||||||
from pandas.core.dtypes.generic import ( | ||||||
ABCSeries, ABCDataFrame, ABCPeriodIndex, ABCMultiIndex, ABCIndexClass) | ||||||
|
@@ -861,11 +863,22 @@ def _plot_colorbar(self, ax, **kwds): | |||||
class ScatterPlot(PlanePlot): | ||||||
_kind = 'scatter' | ||||||
|
||||||
def __init__(self, data, x, y, s=None, c=None, **kwargs): | ||||||
def __init__(self, data, x, y, s=None, c=None, size_factor=1, **kwargs): | ||||||
if s is None: | ||||||
# hide the matplotlib default for size, in case we want to change | ||||||
# the handling of this argument later | ||||||
s = 20 | ||||||
# Set default size if no argument is given. | ||||||
s = 20 * size_factor | ||||||
elif is_hashable(s) and s in data.columns: | ||||||
# Handle the case where s is a label of a column of the df. | ||||||
# The data is normalized to 200 * size_factor. | ||||||
self.size_title = s | ||||||
n_bubble_points = 200 | ||||||
size_data = data[s] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens |
||||||
s = self._get_plot_bubbles(size_data, n_bubble_points, size_factor) | ||||||
self.bubble_legend_sizes, self.bubble_legend_labels = ( | ||||||
self._get_legend_bubbles(size_data, | ||||||
n_bubble_points, | ||||||
size_factor) | ||||||
) | ||||||
super(ScatterPlot, self).__init__(data, x, y, s=s, **kwargs) | ||||||
if is_integer(c) and not self.data.columns.holds_integer(): | ||||||
c = self.data.columns[c] | ||||||
|
@@ -874,7 +887,6 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs): | |||||
def _make_plot(self): | ||||||
x, y, c, data = self.x, self.y, self.c, self.data | ||||||
ax = self.axes[0] | ||||||
|
||||||
c_is_column = is_hashable(c) and c in self.data.columns | ||||||
|
||||||
# plot a colorbar only if a colormap is provided or necessary | ||||||
|
@@ -919,6 +931,108 @@ def _make_plot(self): | |||||
ax.errorbar(data[x].values, data[y].values, | ||||||
linestyle='none', **err_kwds) | ||||||
|
||||||
@staticmethod | ||||||
def _get_plot_bubbles(size_data, n_bubble_points=200, size_factor=1): | ||||||
if is_categorical_dtype(size_data): | ||||||
if size_data.cat.ordered: | ||||||
size_data_codes = size_data.cat.codes + 1 | ||||||
s_data_max = size_data_codes.max() | ||||||
s = (n_bubble_points * size_factor | ||||||
* size_data_codes**2 / s_data_max**2) | ||||||
else: | ||||||
raise TypeError( | ||||||
"'s' must be numeric or ordered categorical dtype") | ||||||
elif is_numeric_dtype(size_data): | ||||||
s_data_max = size_data.max() | ||||||
s = n_bubble_points * size_factor * size_data / s_data_max | ||||||
else: | ||||||
raise TypeError("'s' must be numeric or ordered categorical dtype") | ||||||
return s | ||||||
|
||||||
@classmethod | ||||||
def _sci_notation(cls, num): | ||||||
""" | ||||||
Returns mantissa and exponent of the number passed in argument. | ||||||
Example: | ||||||
>>> _sci_notation(89278.8924) | ||||||
(8.9, 4.0) | ||||||
""" | ||||||
scientific_notation = '{:e}'.format(num) | ||||||
regexp = re.compile(r'^([+-]?\d\.\d).*e([+-]\d*)$') | ||||||
mantis, expnt = regexp.search(scientific_notation).groups() | ||||||
return float(mantis), float(expnt) | ||||||
|
||||||
@staticmethod | ||||||
def _get_legend_bubbles(size_data, n_bubble_points=200, size_factor=1): | ||||||
""" | ||||||
Computes and returns appropriate bubble sizes and labels for the | ||||||
legend of a bubble plot. | ||||||
|
||||||
If bubble size represents numerical data, creates 4 bubbles with | ||||||
round values for the labels, the largest of which is close to the | ||||||
maximum of the data. | ||||||
|
||||||
If bubble size represents ordered categorical data, creates one bubble | ||||||
per category in the data. Sizes are determined by category codes. | ||||||
""" | ||||||
if is_categorical_dtype(size_data): | ||||||
if size_data.cat.ordered: | ||||||
size_data_codes = size_data.cat.codes + 1 | ||||||
labels = list(size_data.cat.categories)[::-1] | ||||||
n_categories = len(labels) | ||||||
sizes = ((np.array(range(n_categories)) + 1)**2 | ||||||
* n_bubble_points * size_factor | ||||||
/ size_data_codes.max()**2) | ||||||
sizes = sizes[::-1] | ||||||
else: | ||||||
raise TypeError( | ||||||
"'s' must be numeric or ordered categorical dtype") | ||||||
elif is_numeric_dtype(size_data): | ||||||
s_data_max = size_data.max() | ||||||
coef, expnt = ScatterPlot._sci_notation(s_data_max) | ||||||
labels_catalog = { | ||||||
(9, 10): [10, 5, 2.5, 1], | ||||||
(7, 9): [8, 4, 2, 0.5], | ||||||
(5.5, 7): [6, 3, 1.5, 0.5], | ||||||
(4.5, 5.5): [5, 2, 1, 0.2], | ||||||
(3.5, 4.5): [4, 2, 1, 0.2], | ||||||
(2.5, 3.5): [3, 1, 0.5, 0.2], | ||||||
(1.5, 2.5): [2, 1, 0.5, 0.2], | ||||||
(0, 1.5): [1, 0.5, 0.25, 0.1] | ||||||
} | ||||||
for lower_bound, upper_bound in labels_catalog: | ||||||
if (coef >= lower_bound) and (coef < upper_bound): | ||||||
labels = 10**expnt * np.array(labels_catalog[lower_bound, | ||||||
upper_bound]) | ||||||
sizes = list(n_bubble_points * size_factor | ||||||
* labels / s_data_max) | ||||||
labels = ['{:g}'.format(l) for l in labels] | ||||||
|
||||||
else: | ||||||
raise TypeError("'s' must be numeric or ordered categorical dtype") | ||||||
return (sizes, labels) | ||||||
|
||||||
def _make_legend(self): | ||||||
if hasattr(self, "size_title"): | ||||||
ax = self.axes[0] | ||||||
import matplotlib.legend as legend | ||||||
from matplotlib.collections import CircleCollection | ||||||
sizes, labels = self.bubble_legend_sizes, self.bubble_legend_labels | ||||||
color = self.plt.rcParams['axes.facecolor'], | ||||||
edgecolor = self.plt.rcParams['axes.edgecolor'] | ||||||
bubbles = [] | ||||||
for size in sizes: | ||||||
bubbles.append(CircleCollection(sizes=[size], | ||||||
color=color, | ||||||
edgecolor=edgecolor)) | ||||||
bubble_legend = legend.Legend(ax, | ||||||
handles=bubbles, | ||||||
labels=labels, | ||||||
loc='best') | ||||||
bubble_legend.set_title(self.size_title) | ||||||
ax.add_artist(bubble_legend) | ||||||
super(ScatterPlot, self)._make_legend() | ||||||
|
||||||
|
||||||
class HexBinPlot(PlanePlot): | ||||||
_kind = 'hexbin' | ||||||
|
@@ -3458,7 +3572,7 @@ def pie(self, y=None, **kwds): | |||||
""" | ||||||
return self(kind='pie', y=y, **kwds) | ||||||
|
||||||
def scatter(self, x, y, s=None, c=None, **kwds): | ||||||
def scatter(self, x, y, s=None, c=None, size_factor=1, **kwds): | ||||||
""" | ||||||
Create a scatter plot with varying marker point size and color. | ||||||
|
||||||
|
@@ -3477,7 +3591,7 @@ def scatter(self, x, y, s=None, c=None, **kwds): | |||||
y : int or str | ||||||
The column name or column position to be used as vertical | ||||||
coordinates for each point. | ||||||
s : scalar or array_like, optional | ||||||
s : int, str, scalar or array_like, optional | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
The size of each point. Possible values are: | ||||||
|
||||||
- A single scalar so all points have the same size. | ||||||
|
@@ -3486,6 +3600,12 @@ def scatter(self, x, y, s=None, c=None, **kwds): | |||||
recursively. For instance, when passing [2,14] all points size | ||||||
will be either 2 or 14, alternatively. | ||||||
|
||||||
- .. versionadded:: 0.24.0 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how sphinx will handle this... I would say just do
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||||||
s can now be the name of a column containing numeric or | ||||||
ordered categorical data that will be represented by the size | ||||||
of each point. This turns the scatter plot into a bubble plot. | ||||||
|
||||||
|
||||||
c : str, int or array_like, optional | ||||||
The color of each point. Possible values are: | ||||||
|
||||||
|
@@ -3500,6 +3620,12 @@ def scatter(self, x, y, s=None, c=None, **kwds): | |||||
- A column name or position whose values will be used to color the | ||||||
marker points according to a colormap. | ||||||
|
||||||
size_factor : scalar, optional | ||||||
A multiplication factor to change the size of bubbles | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this only apply when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It applies to all cases, not only when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
.. versionadded:: 0.24.0 | ||||||
|
||||||
|
||||||
**kwds | ||||||
Keyword arguments to pass on to :meth:`pandas.DataFrame.plot`. | ||||||
|
||||||
|
@@ -3537,7 +3663,8 @@ def scatter(self, x, y, s=None, c=None, **kwds): | |||||
... c='species', | ||||||
... colormap='viridis') | ||||||
""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add an example here? I'd say just repeat the previous one with |
||||||
return self(kind='scatter', x=x, y=y, c=c, s=s, **kwds) | ||||||
return self(kind='scatter', x=x, y=y, c=c, s=s, | ||||||
size_factor=size_factor, **kwds) | ||||||
|
||||||
def hexbin(self, x, y, C=None, reduce_C_function=None, gridsize=None, | ||||||
**kwds): | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1237,6 +1237,41 @@ def test_scatter_colors(self): | |
tm.assert_numpy_array_equal(ax.collections[0].get_facecolor()[0], | ||
np.array([1, 1, 1, 1], dtype=np.float64)) | ||
|
||
@pytest.mark.slow | ||
def test_plot_scatter_with_s(self): | ||
data = np.array([[3.1, 4.2, 1.9], | ||
[1.9, 2.8, 3.1], | ||
[5.4, 4.32, 2.0], | ||
[0.4, 3.4, 0.46], | ||
[4.4, 4.9, 0.8], | ||
[2.7, 6.2, 1.49]]) | ||
df = DataFrame(data, | ||
columns=['x', 'y', 'z']) | ||
ax = df.plot.scatter(x='x', y='y', s='z', size_factor=4) | ||
bubbles = ax.collections[0] | ||
bubble_sizes = bubbles.get_sizes() | ||
max_data = df['z'].max() | ||
expected_sizes = 200 * 4 * df['z'].values / max_data | ||
tm.assert_numpy_array_equal(bubble_sizes, expected_sizes) | ||
|
||
@pytest.mark.slow | ||
def test_plot_scatter_with_categorical_s(self): | ||
data = np.array([[3.1, 4.2], | ||
[1.9, 2.8], | ||
[5.4, 4.32], | ||
[0.4, 3.4], | ||
[4.4, 4.9], | ||
[2.7, 6.2]]) | ||
df = DataFrame(data, columns=['x', 'y']) | ||
df['z'] = pd.Categorical(['a', 'b', 'c', 'a', 'b', 'c'], ordered=True) | ||
ax = df.plot.scatter(x='x', y='y', s='z', size_factor=4) | ||
bubbles = ax.collections[0] | ||
bubble_sizes = bubbles.get_sizes() | ||
max_data = df['z'].cat.codes.max() + 1.0 | ||
expected_sizes = (200.0 * 4 * (df['z'].cat.codes.values + 1)**2 | ||
/ max_data**2) | ||
tm.assert_numpy_array_equal(bubble_sizes, expected_sizes) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test that directly calls There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall I replace this test by a test that calls _get_plot_bubbles, or add such a test and keep both? |
||
@pytest.mark.slow | ||
def test_plot_bar(self): | ||
df = DataFrame(randn(6, 4), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you separate the bubble type and size logic from the getting / setting on self. I'd like for testing that the bubble sizes are computed correctly to be easier.
So a standalone function like
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Tom for the review,
I separated the bubble logic into two functions, one for the bubbles in the plot, one for the bubbles in the legend. For the bubbles in the plot, the sizes are passed to the parent class as an argument
s
. For the legend, I stored the sizes and labels as attributes of the scatter plot which are used by the legend building functions. Let me know if that's what you had in mind.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly, I'm interested in ease of testing. I'd like to have a dedicated method that, given an array of values, tells me what the marker size should be for each point.