Source code for arviz.plots.densityplot

"""KDE and histogram plots for multiple variables."""

import warnings

from ..data import convert_to_dataset
from ..labels import BaseLabeller
from ..sel_utils import (
    xarray_var_iter,
)
from ..rcparams import rcParams
from ..utils import _var_names
from .plot_utils import default_grid, get_plotting_function


# pylint:disable-msg=too-many-function-args
[docs] def plot_density( data, group="posterior", data_labels=None, var_names=None, filter_vars=None, combine_dims=None, transform=None, hdi_prob=None, point_estimate="auto", colors="cycle", outline=True, hdi_markers="", shade=0.0, bw="default", circular=False, grid=None, figsize=None, textsize=None, labeller=None, ax=None, backend=None, backend_kwargs=None, show=None, ): r"""Generate KDE plots for continuous variables and histograms for discrete ones. Plots are truncated at their 100*(1-alpha)% highest density intervals. Plots are grouped per variable and colors assigned to models. Parameters ---------- data : InferenceData or iterable of InferenceData Any object that can be converted to an :class:`arviz.InferenceData` object, or an Iterator returning a sequence of such objects. Refer to documentation of :func:`arviz.convert_to_dataset` for details. group : {"posterior", "prior"}, default "posterior" Specifies which InferenceData group should be plotted. If "posterior", then the values in `posterior_predictive` group are compared to the ones in `observed_data`, if "prior" then the same comparison happens, but with the values in `prior_predictive` group. data_labels : list of str, default None List with names for the datasets passed as "data." Useful when plotting more than one dataset. Must be the same shape as the data parameter. var_names : list of str, optional List of variables to plot. If multiple datasets are supplied and `var_names` is not None, will print the same set of variables for each dataset. Defaults to None, which results in all the variables being plotted. filter_vars : {None, "like", "regex"}, default None If `None` (default), interpret `var_names` as the real variables names. If "like", interpret `var_names` as substrings of the real variables names. If "regex", interpret `var_names` as regular expressions on the real variables names. See :ref:`this section <common_filter_vars>` for usage examples. combine_dims : set_like of str, optional List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions. See :ref:`this section <common_combine_dims>` for usage examples. transform : callable Function to transform data (defaults to `None` i.e. the identity function). hdi_prob : float, default 0.94 Probability for the highest density interval. Should be in the interval (0, 1]. See :ref:`this section <common_hdi_prob>` for usage examples. point_estimate : str, optional Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None. Defaults to 'auto' i.e. it falls back to default set in ``rcParams``. colors : str or list of str, optional List with valid matplotlib colors, one color per model. Alternative a string can be passed. If the string is `cycle`, it will automatically choose a color per model from matplotlib's cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all models. Defaults to `cycle`. outline : bool, default True Use a line to draw KDEs and histograms. hdi_markers : str A valid `matplotlib.markers` like 'v', used to indicate the limits of the highest density interval. Defaults to empty string (no marker). shade : float, default 0 Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1 (opaque). bw : float or str, optional If numeric, indicates the bandwidth and must be positive. If str, indicates the method to estimate the bandwidth and must be one of "scott", "silverman", "isj" or "experimental" when `circular` is False and "taylor" (for now) when `circular` is True. Defaults to "default" which means "experimental" when variable is not circular and "taylor" when it is. circular : bool, default False If True, it interprets the values passed are from a circular variable measured in radians and a circular KDE is used. Only valid for 1D KDE. grid : tuple, optional Number of rows and columns. Defaults to ``None``, the rows and columns are automatically inferred. See :ref:`this section <common_grid>` for usage examples. figsize : (float, float), optional Figure size. If `None` it will be defined automatically. textsize : float, optional Text size scaling factor for labels, titles and lines. If `None` it will be autoscaled based on `figsize`. labeller : Labeller, optional Class providing the method ``make_label_vert`` to generate the labels in the plot titles. Read the :ref:`label_guide` for more details and usage examples. ax : 2D array-like of matplotlib_axes or bokeh_figure, optional A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create its own array of plot areas (and return it). backend : {"matplotlib", "bokeh"}, default "matplotlib" Select plotting backend. backend_kwargs : dict, optional These are kwargs specific to the backend being used, passed to :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`. For additional documentation check the plotting method of the backend. show : bool, optional Call backend show function. Returns ------- axes : 2D ndarray of matplotlib_axes or bokeh_figure See Also -------- plot_dist : Plot distribution as histogram or kernel density estimates. plot_posterior : Plot Posterior densities in the style of John K. Kruschke's book. Examples -------- Plot default density plot .. plot:: :context: close-figs >>> import arviz as az >>> centered = az.load_arviz_data('centered_eight') >>> non_centered = az.load_arviz_data('non_centered_eight') >>> az.plot_density([centered, non_centered]) Plot variables in a 4x5 grid .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], grid=(4, 5)) Plot subset variables by specifying variable name exactly .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], var_names=["mu"]) Plot a specific `az.InferenceData` group .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], var_names=["mu"], group="prior") Specify highest density interval .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], var_names=["mu"], hdi_prob=.5) Shade plots and/or remove outlines .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], var_names=["mu"], outline=False, shade=.8) Specify binwidth for kernel density estimation .. plot:: :context: close-figs >>> az.plot_density([centered, non_centered], var_names=["mu"], bw=.9) """ if isinstance(data, (list, tuple)): datasets = [convert_to_dataset(datum, group=group) for datum in data] else: datasets = [convert_to_dataset(data, group=group)] if transform is not None: datasets = [transform(dataset) for dataset in datasets] if labeller is None: labeller = BaseLabeller() var_names = _var_names(var_names, datasets, filter_vars) n_data = len(datasets) if data_labels is None: data_labels = [f"{idx}" for idx in range(n_data)] if n_data > 1 else [""] elif len(data_labels) != n_data: raise ValueError( f"The number of names for the models ({len(data_labels)}) " f"does not match the number of models ({n_data})" ) if hdi_prob is None: hdi_prob = rcParams["stats.hdi_prob"] elif not 1 >= hdi_prob > 0: raise ValueError("The value of hdi_prob should be in the interval (0, 1]") to_plot = [ list(xarray_var_iter(data, var_names, combined=True, skip_dims=combine_dims)) for data in datasets ] all_labels = [] length_plotters = [] for plotters in to_plot: length_plotters.append(len(plotters)) for var_name, selection, isel, _ in plotters: label = labeller.make_label_vert(var_name, selection, isel) if label not in all_labels: all_labels.append(label) length_plotters = len(all_labels) max_plots = rcParams["plot.max_subplots"] max_plots = length_plotters if max_plots is None else max_plots if length_plotters > max_plots: warnings.warn( "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " "of variables to plot ({len_plotters}) in plot_density, generating only " "{max_plots} plots".format(max_plots=max_plots, len_plotters=length_plotters), UserWarning, ) all_labels = all_labels[:max_plots] to_plot = [ [ (var_name, selection, values) for var_name, selection, isel, values in plotters if labeller.make_label_vert(var_name, selection, isel) in all_labels ] for plotters in to_plot ] length_plotters = max_plots rows, cols = default_grid(length_plotters, grid=grid, max_cols=3) if bw == "default": bw = "taylor" if circular else "experimental" plot_density_kwargs = dict( ax=ax, all_labels=all_labels, to_plot=to_plot, colors=colors, bw=bw, circular=circular, figsize=figsize, length_plotters=length_plotters, rows=rows, cols=cols, textsize=textsize, labeller=labeller, hdi_prob=hdi_prob, point_estimate=point_estimate, hdi_markers=hdi_markers, outline=outline, shade=shade, n_data=n_data, data_labels=data_labels, backend_kwargs=backend_kwargs, show=show, ) if backend is None: backend = rcParams["plot.backend"] backend = backend.lower() # TODO: Add backend kwargs plot = get_plotting_function("plot_density", "densityplot", backend) ax = plot(**plot_density_kwargs) return ax