Source code for arviz.plots.densityplot

"""KDE and histogram plots for multiple variables."""
import warnings

from ..data import convert_to_dataset
from ..labels import BaseLabeller
from ..sel_utils import (
    xarray_var_iter,
)
from ..rcparams import rcParams
from ..utils import _var_names
from .plot_utils import default_grid, get_plotting_function


# pylint:disable-msg=too-many-function-args
[docs]def plot_density( data, group="posterior", data_labels=None, var_names=None, filter_vars=None, combine_dims=None, transform=None, hdi_prob=None, point_estimate="auto", colors="cycle", outline=True, hdi_markers="", shade=0.0, bw="default", circular=False, grid=None, figsize=None, textsize=None, labeller=None, ax=None, backend=None, backend_kwargs=None, show=None, ): """Generate KDE plots for continuous variables and histograms for discrete ones. Plots are truncated at their 100*(1-alpha)% highest density intervals. Plots are grouped per variable and colors assigned to models. Parameters ---------- data : Union[Object, Iterator[Object]] Any object that can be converted to an :class:`arviz.InferenceData` object, or an Iterator returning a sequence of such objects. Refer to documentation of :func:`arviz.convert_to_dataset` for details about such objects. group: Optional[str] Specifies which :class:`arviz.InferenceData` group should be plotted. Defaults to 'posterior'. Alternative values include 'prior' and any other strings used as dataset keys in the :class:`arviz.InferenceData`. data_labels : Optional[List[str]] List with names for the datasets passed as "data." Useful when plotting more than one dataset. Must be the same shape as the data parameter. Defaults to None. var_names: Optional[List[str]] List of variables to plot. If multiple datasets are supplied and var_names is not None, will print the same set of variables for each dataset. Defaults to None, which results in all the variables being plotted. filter_vars: {None, "like", "regex"}, optional, default=None If `None` (default), interpret var_names as the real variables names. If "like", interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. A la ``pandas.filter``. combine_dims : set_like of str, optional List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions. See the :ref:`this section <common_combine_dims>` for usage examples. transform : callable Function to transform data (defaults to None i.e. the identity function) hdi_prob : float Probability for the highest density interval. Should be in the interval (0, 1]. Defaults to 0.94. point_estimate : Optional[str] Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None. Defaults to 'auto' i.e. it falls back to default set in ``rcParams``. colors : Optional[Union[List[str],str]] List with valid matplotlib colors, one color per model. Alternative a string can be passed. If the string is `cycle`, it will automatically choose a color per model from matplotlib's cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all models. Defaults to `cycle`. outline : bool Use a line to draw KDEs and histograms. Default to True hdi_markers : str A valid `matplotlib.markers` like 'v', used to indicate the limits of the highest density interval. Defaults to empty string (no marker). shade : Optional[float] Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1 (opaque). Defaults to 0. bw: Optional[float or str] If numeric, indicates the bandwidth and must be positive. If str, indicates the method to estimate the bandwidth and must be one of "scott", "silverman", "isj" or "experimental" when `circular` is False and "taylor" (for now) when `circular` is True. Defaults to "default" which means "experimental" when variable is not circular and "taylor" when it is. circular: Optional[bool] If True, it interprets the values passed are from a circular variable measured in radians and a circular KDE is used. Only valid for 1D KDE. Defaults to False. grid : tuple Number of rows and columns. Defaults to None, the rows and columns are automatically inferred. figsize : Optional[Tuple[int, int]] Figure size. If None it will be defined automatically. textsize: Optional[float] Text size scaling factor for labels, titles and lines. If None it will be autoscaled based on ``figsize``. labeller : labeller instance, optional Class providing the method ``make_label_vert`` to generate the labels in the plot titles. Read the :ref:`label_guide` for more details and usage examples. ax: numpy array-like of matplotlib axes or bokeh figures, optional A 2D array of locations into which to plot the densities. If not supplied, Arviz will create its own array of plot areas (and return it). backend: str, optional Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". backend_kwargs: bool, optional These are kwargs specific to the backend being used, passed to :func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`. For additional documentation check the plotting method of the backend. show : bool, optional Call backend show function. Returns ------- axes : matplotlib axes or bokeh figures See Also -------- plot_dist : Plot distribution as histogram or kernel density estimates. plot_posterior : Plot Posterior densities in the style of John K. Kruschke’s book. Examples -------- Plot default density plot .. plot:: :context: close-figs >>> import arviz as az >>> centered = az.load_arviz_data('centered_eight') >>> non_centered = az.load_arviz_data('non_centered_eight') >>> az.plot_density([centered, non_centered]) Plot variables in a 4x5 grid .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], grid=(4, 5)) Plot subset variables by specifying variable name exactly .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], var_names=["mu"]) Plot a specific `az.InferenceData` group .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], var_names=["mu"], group="prior") Specify highest density interval .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], var_names=["mu"], hdi_prob=.5) Shade plots and/or remove outlines .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], var_names=["mu"], outline=False, shade=.8) Specify binwidth for kernel density estimation .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], var_names=["mu"], bw=.9) """ if not isinstance(data, (list, tuple)): datasets = [convert_to_dataset(data, group=group)] else: datasets = [convert_to_dataset(datum, group=group) for datum in data] if transform is not None: datasets = [transform(dataset) for dataset in datasets] if labeller is None: labeller = BaseLabeller() var_names = _var_names(var_names, datasets, filter_vars) n_data = len(datasets) if data_labels is None: if n_data > 1: data_labels = [f"{idx}" for idx in range(n_data)] else: data_labels = [""] elif len(data_labels) != n_data: raise ValueError( "The number of names for the models ({}) " "does not match the number of models ({})".format(len(data_labels), n_data) ) if hdi_prob is None: hdi_prob = rcParams["stats.hdi_prob"] else: if not 1 >= hdi_prob > 0: raise ValueError("The value of hdi_prob should be in the interval (0, 1]") to_plot = [ list(xarray_var_iter(data, var_names, combined=True, skip_dims=combine_dims)) for data in datasets ] all_labels = [] length_plotters = [] for plotters in to_plot: length_plotters.append(len(plotters)) for var_name, selection, isel, _ in plotters: label = labeller.make_label_vert(var_name, selection, isel) if label not in all_labels: all_labels.append(label) length_plotters = len(all_labels) max_plots = rcParams["plot.max_subplots"] max_plots = length_plotters if max_plots is None else max_plots if length_plotters > max_plots: warnings.warn( "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " "of variables to plot ({len_plotters}) in plot_density, generating only " "{max_plots} plots".format(max_plots=max_plots, len_plotters=length_plotters), UserWarning, ) all_labels = all_labels[:max_plots] to_plot = [ [ (var_name, selection, values) for var_name, selection, isel, values in plotters if labeller.make_label_vert(var_name, selection, isel) in all_labels ] for plotters in to_plot ] length_plotters = max_plots rows, cols = default_grid(length_plotters, grid=grid, max_cols=3) if bw == "default": if circular: bw = "taylor" else: bw = "experimental" plot_density_kwargs = dict( ax=ax, all_labels=all_labels, to_plot=to_plot, colors=colors, bw=bw, circular=circular, figsize=figsize, length_plotters=length_plotters, rows=rows, cols=cols, textsize=textsize, labeller=labeller, hdi_prob=hdi_prob, point_estimate=point_estimate, hdi_markers=hdi_markers, outline=outline, shade=shade, n_data=n_data, data_labels=data_labels, backend_kwargs=backend_kwargs, show=show, ) if backend is None: backend = rcParams["plot.backend"] backend = backend.lower() # TODO: Add backend kwargs plot = get_plotting_function("plot_density", "densityplot", backend) ax = plot(**plot_density_kwargs) return ax