Source code for arviz.plots.elpdplot

"""Plot pointwise elpd estimations of inference data."""

import numpy as np

from ..rcparams import rcParams
from ..stats import _calculate_ics
from ..utils import get_coords
from .plot_utils import format_coords_as_labels, get_plotting_function


[docs] def plot_elpd( compare_dict, color="C0", xlabels=False, figsize=None, textsize=None, coords=None, legend=False, threshold=None, ax=None, ic=None, scale=None, var_name=None, plot_kwargs=None, backend=None, backend_kwargs=None, show=None, ): r"""Plot pointwise elpd differences between two or more models. Pointwise model comparison based on their expected log pointwise predictive density (ELPD). Notes ----- The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out cross-validation (LOO) or using the widely applicable information criterion (WAIC). We recommend LOO in line with the work presented by [1]_. Parameters ---------- compare_dict : mapping of {str : ELPDData or InferenceData} A dictionary mapping the model name to the object containing inference data or the result of :func:`arviz.loo` or :func:`arviz.waic` functions. Refer to :func:`arviz.convert_to_inference_data` for details on possible dict items. color : str or array_like, default "C0" Colors of the scatter plot. If color is a str all dots will have the same color. If it is the size of the observations, each dot will have the specified color. Otherwise, it will be interpreted as a list of the dims to be used for the color code. xlabels : bool, default False Use coords as xticklabels. figsize : (float, float), optional If `None`, size is (8 + numvars, 8 + numvars). textsize : float, optional Text size for labels. If `None` it will be autoscaled based on `figsize`. coords : mapping, optional Coordinates of points to plot. **All** values are used for computation, but only a subset can be plotted for convenience. See :ref:`this section <common_coords>` for usage examples. legend : bool, default False Include a legend to the plot. Only taken into account when color argument is a dim name. threshold : float, optional If some elpd difference is larger than ``threshold * elpd.std()``, show its label. If `None`, no observations will be highlighted. ic : str, optional Information Criterion ("loo" for PSIS-LOO, "waic" for WAIC) used to compare models. Defaults to ``rcParams["stats.information_criterion"]``. Only taken into account when input is :class:`arviz.InferenceData`. scale : str, optional Scale argument passed to :func:`arviz.loo` or :func:`arviz.waic`, see their docs for details. Only taken into account when values in ``compare_dict`` are :class:`arviz.InferenceData`. var_name : str, optional Argument passed to to :func:`arviz.loo` or :func:`arviz.waic`, see their docs for details. Only taken into account when values in ``compare_dict`` are :class:`arviz.InferenceData`. plot_kwargs : dicts, optional Additional keywords passed to :meth:`matplotlib.axes.Axes.scatter`. ax : axes, optional :class:`matplotlib.axes.Axes` or :class:`bokeh.plotting.Figure`. backend : {"matplotlib", "bokeh"}, default "matplotlib" Select plotting backend. backend_kwargs : dict, optional These are kwargs specific to the backend being used, passed to :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`. For additional documentation check the plotting method of the backend. show : bool, optional Call backend show function. Returns ------- axes : matplotlib_axes or bokeh_figure See Also -------- plot_compare : Summary plot for model comparison. loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV). waic : Compute the widely applicable information criterion. References ---------- .. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC https://arxiv.org/abs/1507.04544 Examples -------- Compare pointwise PSIS-LOO for centered and non centered models of the 8-schools problem using matplotlib. .. plot:: :context: close-figs >>> import arviz as az >>> idata1 = az.load_arviz_data("centered_eight") >>> idata2 = az.load_arviz_data("non_centered_eight") >>> az.plot_elpd( >>> {"centered model": idata1, "non centered model": idata2}, >>> xlabels=True >>> ) .. bokeh-plot:: :source-position: above import arviz as az idata1 = az.load_arviz_data("centered_eight") idata2 = az.load_arviz_data("non_centered_eight") az.plot_elpd( {"centered model": idata1, "non centered model": idata2}, backend="bokeh" ) """ try: (compare_dict, _, ic) = _calculate_ics(compare_dict, scale=scale, ic=ic, var_name=var_name) except Exception as e: raise e.__class__("Encountered error in ic computation of plot_elpd.") from e if backend is None: backend = rcParams["plot.backend"] backend = backend.lower() numvars = len(compare_dict) models = list(compare_dict.keys()) if coords is None: coords = {} pointwise_data = [get_coords(compare_dict[model][f"{ic}_i"], coords) for model in models] xdata = np.arange(pointwise_data[0].size) coord_labels = format_coords_as_labels(pointwise_data[0]) if xlabels else None if numvars < 2: raise ValueError("Number of models to compare must be 2 or greater.") elpd_plot_kwargs = dict( ax=ax, models=models, pointwise_data=pointwise_data, numvars=numvars, figsize=figsize, textsize=textsize, plot_kwargs=plot_kwargs, xlabels=xlabels, coord_labels=coord_labels, xdata=xdata, threshold=threshold, legend=legend, color=color, backend_kwargs=backend_kwargs, show=show, ) plot = get_plotting_function("plot_elpd", "elpdplot", backend) ax = plot(**elpd_plot_kwargs) return ax