diff --git a/src/figure.js b/src/figure.js index f14884e7b6..dbc85e6334 100644 --- a/src/figure.js +++ b/src/figure.js @@ -1,12 +1,10 @@ -// Wrap the plot in a figure with a caption, if desired. -export function figureWrap(svg, {width}, caption) { - if (caption == null) return svg; +// Wrap the plot in a figure with a caption and legends, if desired. +export function figure(decorations, {width}) { + decorations = decorations.filter(d => d instanceof Node); + if (decorations.length === 1) return decorations[0]; const figure = document.createElement("figure"); figure.style = `max-width: ${width}px`; - figure.appendChild(svg); - const figcaption = document.createElement("figcaption"); - figcaption.appendChild(caption instanceof Node ? caption : document.createTextNode(caption)); - figure.appendChild(figcaption); + for (const d of decorations) figure.appendChild(d); return figure; } diff --git a/src/legends.js b/src/legends.js index 43994734a1..ab1394b7a9 100644 --- a/src/legends.js +++ b/src/legends.js @@ -1,5 +1,31 @@ -import {legendColor} from "./legends/color.js"; +import {Scale} from "./scales.js"; +import {legendRamp} from "./legends/ramp.js"; +import {legendSwatches} from "./legends/swatches.js"; export function legend({color, ...options}) { - if (color) return legendColor({...color, ...options}); + const type = color ? "color" : "unknown"; + const scale = Scale(type, undefined, color); + options = {...color, ...options}; + let legend = options.legend; + if (type === "color") { + if (legend === undefined || legend === true) { + legend = ["ordinal", "categorical"].includes(scale.type) ? "swatches" : "ramp"; + } + if (legend === "swatches") { + legend = legendSwatches; + } else if (legend === "ramp") { + legend = legendRamp; + } + } + + if (typeof legend !== "function") { + throw new Error(`unknown legend type ${legend}`); + } + + // todo: remove scale.scale, add scale.apply and scale.invert? + return legend({...scale, ...options}); +} + +export function exposeLegends(scale) { + return (type, options = {}) => legend({[type]: scale(type), ...options}); } diff --git a/src/legends/color.js b/src/legends/color.js deleted file mode 100644 index 0fbe274cfc..0000000000 --- a/src/legends/color.js +++ /dev/null @@ -1,16 +0,0 @@ -import {Scale} from "../scales.js"; -import {legendRamp} from "./ramp.js"; -import {legendSwatches} from "./swatches.js"; - -export function legendColor({legend, ...options}) { - const scale = Scale("color", undefined, options); - if (legend === undefined) legend = scale.type === "ordinal" || scale.type === "categorical" ? "swatches" : "ramp"; - switch (legend) { - case "swatches": - return legendSwatches({...scale, ...options}); - case "ramp": - return legendRamp({...scale, ...options}); - default: - throw new Error(`unknown legend type ${legend}`); - } -} diff --git a/src/plot.js b/src/plot.js index 9e9805c6c4..6391ec3031 100644 --- a/src/plot.js +++ b/src/plot.js @@ -1,8 +1,8 @@ import {create} from "d3"; import {Axes, autoAxisTicks, autoScaleLabels} from "./axes.js"; import {facets} from "./facet.js"; -import {figureWrap} from "./figure.js"; -import { legend } from "./legends.js"; +import {figure} from "./figure.js"; +import {exposeLegends} from "./legends.js"; import {markify} from "./mark.js"; import {Scales, autoScaleRange, applyScales, exposeScales, isOrdinalScale} from "./scales.js"; import {filterStyles, maybeClassName, offset} from "./style.js"; @@ -109,11 +109,19 @@ export function plot(options = {}) { if (node != null) svg.appendChild(node); } - // Wrap the plot in a figure with a caption, if desired. - const figure = figureWrap(svg, dimensions, caption); - figure.scale = exposeScales(scaleDescriptors); - figure.legend = (type, options = {}) => legend({[type]: figure.scale(type), ...options}); - return figure; + // Wrap the plot in a figure with a caption and legends, if desired. + const decorations = [svg]; + const scale = exposeScales(scaleDescriptors); + const legend = exposeLegends(scale); + if (options.color?.extra?.legend) { + decorations.unshift(legend("color", options.color.extra)); + } + if (caption != null) { + const figcaption = document.createElement("figcaption"); + figcaption.appendChild(caption instanceof Node ? caption : document.createTextNode(caption)); + decorations.push(figcaption); + } + return Object.assign(figure(decorations, dimensions), {scale, legend}); } function Dimensions( diff --git a/src/scales/diverging.js b/src/scales/diverging.js index 081868f394..8b454c5f66 100644 --- a/src/scales/diverging.js +++ b/src/scales/diverging.js @@ -23,7 +23,8 @@ function ScaleD(key, scale, transform, channels, { range, symmetric = true, interpolate = registry.get(key) === color ? (scheme == null && range !== undefined ? interpolateRgb : quantitativeScheme(scheme !== undefined ? scheme : "rdbu")) : interpolateNumber, - reverse + reverse, + ...extra }) { pivot = +pivot; let [min, max] = domain; @@ -61,7 +62,7 @@ function ScaleD(key, scale, transform, channels, { scale.domain([min, pivot, max]).unknown(unknown).interpolator(interpolate); if (clamp) scale.clamp(clamp); if (nice) scale.nice(nice); - return {type, interpolate, scale}; + return {type, interpolate, scale, extra}; } export function ScaleDiverging(key, channels, options) { diff --git a/src/scales/ordinal.js b/src/scales/ordinal.js index fcd7ced28a..71648ee619 100644 --- a/src/scales/ordinal.js +++ b/src/scales/ordinal.js @@ -8,7 +8,8 @@ export function ScaleO(scale, channels, { type, domain = inferDomain(channels), range, - reverse + reverse, + ...extra }) { if (type === "categorical") type = "ordinal"; // shorthand for color schemes if (reverse) domain = reverseof(domain); @@ -18,7 +19,7 @@ export function ScaleO(scale, channels, { if (typeof range === "function") range = range(domain); scale.range(range); } - return {type, domain, range, scale}; + return {type, domain, range, scale, extra}; } export function ScaleOrdinal(key, channels, { diff --git a/src/scales/quantitative.js b/src/scales/quantitative.js index de831962af..dfabeb3f39 100644 --- a/src/scales/quantitative.js +++ b/src/scales/quantitative.js @@ -58,7 +58,8 @@ export function ScaleQ(key, scale, channels, { scheme, range = registry.get(key) === radius ? inferRadialRange(channels, domain) : registry.get(key) === opacity ? unit : undefined, interpolate = registry.get(key) === color ? (scheme == null && range !== undefined ? interpolateRgb : quantitativeScheme(scheme !== undefined ? scheme : type === "cyclical" ? "rainbow" : "turbo")) : round ? interpolateRound : interpolateNumber, - reverse + reverse, + ...extra }) { if (type === "cyclical" || type === "sequential") type = "linear"; // shorthand for color schemes reverse = !!reverse; @@ -104,7 +105,7 @@ export function ScaleQ(key, scale, channels, { if (nice) scale.nice(nice === true ? undefined : nice); if (range !== undefined) scale.range(range); if (clamp) scale.clamp(clamp); - return {type, domain, range, scale, interpolate}; + return {type, domain, range, scale, interpolate, extra}; } export function ScaleLinear(key, channels, options) {