Skip to content
This repository was archived by the owner on Jun 3, 2024. It is now read-only.

Support implicit dataframe argument. #87

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions plotly_express/_chart_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def scatter(
data_frame,
data_frame=None,
x=None,
y=None,
color=None,
Expand Down Expand Up @@ -56,7 +56,7 @@ def scatter(


def density_contour(
data_frame,
data_frame=None,
x=None,
y=None,
z=None,
Expand Down Expand Up @@ -166,7 +166,7 @@ def density_heatmap(


def line(
data_frame,
data_frame=None,
x=None,
y=None,
line_group=None,
Expand Down Expand Up @@ -210,7 +210,7 @@ def line(


def area(
data_frame,
data_frame=None,
x=None,
y=None,
line_group=None,
Expand Down Expand Up @@ -254,7 +254,7 @@ def area(


def bar(
data_frame,
data_frame=None,
x=None,
y=None,
color=None,
Expand Down Expand Up @@ -303,7 +303,7 @@ def bar(


def histogram(
data_frame,
data_frame=None,
x=None,
y=None,
color=None,
Expand Down Expand Up @@ -360,7 +360,7 @@ def histogram(


def violin(
data_frame,
data_frame=None,
x=None,
y=None,
color=None,
Expand Down Expand Up @@ -410,7 +410,7 @@ def violin(


def box(
data_frame,
data_frame=None,
x=None,
y=None,
color=None,
Expand Down Expand Up @@ -504,7 +504,7 @@ def strip(


def scatter_3d(
data_frame,
data_frame=None,
x=None,
y=None,
z=None,
Expand Down Expand Up @@ -554,7 +554,7 @@ def scatter_3d(


def line_3d(
data_frame,
data_frame=None,
x=None,
y=None,
z=None,
Expand Down Expand Up @@ -599,7 +599,7 @@ def line_3d(


def scatter_ternary(
data_frame,
data_frame=None,
a=None,
b=None,
c=None,
Expand Down Expand Up @@ -637,7 +637,7 @@ def scatter_ternary(


def line_ternary(
data_frame,
data_frame=None,
a=None,
b=None,
c=None,
Expand Down Expand Up @@ -671,7 +671,7 @@ def line_ternary(


def scatter_polar(
data_frame,
data_frame=None,
r=None,
theta=None,
color=None,
Expand Down Expand Up @@ -714,7 +714,7 @@ def scatter_polar(


def line_polar(
data_frame,
data_frame=None,
r=None,
theta=None,
color=None,
Expand Down Expand Up @@ -753,7 +753,7 @@ def line_polar(


def bar_polar(
data_frame,
data_frame=None,
r=None,
theta=None,
color=None,
Expand Down Expand Up @@ -790,7 +790,7 @@ def bar_polar(


def choropleth(
data_frame,
data_frame=None,
lat=None,
lon=None,
locations=None,
Expand Down Expand Up @@ -829,7 +829,7 @@ def choropleth(


def scatter_geo(
data_frame,
data_frame=None,
lat=None,
lon=None,
locations=None,
Expand Down Expand Up @@ -872,7 +872,7 @@ def scatter_geo(


def line_geo(
data_frame,
data_frame=None,
lat=None,
lon=None,
locations=None,
Expand Down Expand Up @@ -913,7 +913,7 @@ def line_geo(


def scatter_mapbox(
data_frame,
data_frame=None,
lat=None,
lon=None,
color=None,
Expand Down Expand Up @@ -948,7 +948,7 @@ def scatter_mapbox(


def line_mapbox(
data_frame,
data_frame=None,
lat=None,
lon=None,
color=None,
Expand Down Expand Up @@ -978,7 +978,7 @@ def line_mapbox(


def scatter_matrix(
data_frame,
data_frame=None,
dimensions=None,
color=None,
symbol=None,
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def scatter_matrix(


def parallel_coordinates(
data_frame,
data_frame=None,
dimensions=None,
color=None,
labels={},
Expand All @@ -1039,7 +1039,7 @@ def parallel_coordinates(


def parallel_categories(
data_frame,
data_frame=None,
dimensions=None,
color=None,
labels={},
Expand Down
32 changes: 27 additions & 5 deletions plotly_express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,28 @@ def apply_default_cascade(args):
if args["color_discrete_sequence"] is None:
args["color_discrete_sequence"] = qualitative.Plotly

def has_value(collection, key):
return collection.get(key, None) is not None

def build_dataframe(args, attrables):
"""
Constructs an implicit dataframe and modifies `args` in-place.

`attrables` is a list of keys into `args`, all of whose corresponding
values are converted into columns of a dataframe.

Used to be support calls to plotting function that elide a dataframe argument;
for example `scatter(x=[1,2], y=[3,4])`.
"""
data_frame_columns = {}
for field in attrables:
if not has_value(args, field):
continue
data_frame_columns[field] = args[field]
# This sets the label of an attribute to be the name of the attribute.
args[field] = field
args["data_frame"] = pandas.DataFrame(data_frame_columns)
return args

def infer_config(args, constructor, trace_patch):
attrables = (
Expand All @@ -669,11 +691,12 @@ def infer_config(args, constructor, trace_patch):
)
array_attrables = ["dimensions", "hover_data"]
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]

all_attrables = attrables + group_attrables + ["color"]
if not has_value(args, "data_frame"):
build_dataframe(args, all_attrables)
df_columns = args["data_frame"].columns

for attr in attrables + group_attrables + ["color"]:
if attr in args and args[attr] is not None:
for attr in all_attrables:
if has_value(args, attr):
maybe_col_list = [args[attr]] if attr not in array_attrables else args[attr]
for maybe_col in maybe_col_list:
try:
Expand Down Expand Up @@ -790,7 +813,6 @@ def get_orderings(args, grouper, grouped):

return orders, group_names


def make_figure(args, constructor, trace_patch={}, layout_patch={}):
apply_default_cascade(args)
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
Expand Down