Source code for arviz.plots.compareplot

"""Summary plot for model comparison."""
import numpy as np

from ..labels import BaseLabeller
from ..rcparams import rcParams
from .plot_utils import get_plotting_function


[docs]def plot_compare( comp_df, insample_dev=True, plot_standard_error=True, plot_ic_diff=True, order_by_rank=True, figsize=None, textsize=None, labeller=None, plot_kwargs=None, ax=None, backend=None, backend_kwargs=None, show=None, ): """ Summary plot for model comparison. This plot is in the style of the one used in the book Statistical Rethinking (Chapter 6) by Richard McElreath. Notes ----- Defaults to comparing Leave-one-out (psis-loo) if present in comp_df column, otherwise compares Widely Applicable Information Criterion (WAIC) Parameters ---------- comp_df : pd.DataFrame Result of the :func:`arviz.compare` method insample_dev : bool, optional Plot in-sample deviance, that is the value of the information criteria without the penalization given by the effective number of parameters (pIC). Defaults to True plot_standard_error : bool, optional Plot the standard error of the information criteria estimate. Defaults to True plot_ic_diff : bool, optional Plot standard error of the difference in information criteria between each model and the top-ranked model. Defaults to True order_by_rank : bool If True (default) ensure the best model is used as reference. figsize : tuple, optional If None, size is (6, num of models) inches textsize: float Text size scaling factor for labels, titles and lines. If None it will be autoscaled based on ``figsize``. labeller : labeller instance, optional Class providing the method ``model_name_to_str`` to generate the labels in the plot. Read the :ref:`label_guide` for more details and usage examples. plot_kwargs : dict, optional Optional arguments for plot elements. Currently accepts 'color_ic', 'marker_ic', 'color_insample_dev', 'marker_insample_dev', 'color_dse', 'marker_dse', 'ls_min_ic' 'color_ls_min_ic', 'fontsize' ax: axes, optional Matplotlib axes or bokeh figures. backend: str, optional Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". backend_kwargs: bool, optional These are kwargs specific to the backend being used, passed to :func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`. show : bool, optional Call backend show function. Returns ------- axes : matplotlib axes or bokeh figures See Also -------- plot_elpd : Plot pointwise elpd differences between two or more models. compare : Compare models based on PSIS-LOO loo or WAIC waic cross-validation. loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV). waic : Compute the widely applicable information criterion. Examples -------- Show default compare plot .. plot:: :context: close-figs >>> import arviz as az >>> model_compare = az.compare({'Centered 8 schools': az.load_arviz_data('centered_eight'), >>> 'Non-centered 8 schools': az.load_arviz_data('non_centered_eight')}) >>> az.plot_compare(model_compare) Plot standard error and information criteria difference only .. plot:: :context: close-figs >>> az.plot_compare(model_compare, insample_dev=False) """ if plot_kwargs is None: plot_kwargs = {} if labeller is None: labeller = BaseLabeller() yticks_pos, step = np.linspace(0, -1, (comp_df.shape[0] * 2) - 1, retstep=True) yticks_pos[1::2] = yticks_pos[1::2] + step / 2 labels = [labeller.model_name_to_str(model_name) for model_name in comp_df.index] if plot_ic_diff: yticks_labels = [""] * len(yticks_pos) yticks_labels[0] = labels[0] yticks_labels[2::2] = labels[1:] else: yticks_labels = labels _information_criterion = ["loo", "waic"] column_index = [c.lower() for c in comp_df.columns] for information_criterion in _information_criterion: if information_criterion in column_index: break else: raise ValueError( "comp_df must contain one of the following" " information criterion: {}".format(_information_criterion) ) if order_by_rank: comp_df.sort_values(by="rank", inplace=True) compareplot_kwargs = dict( ax=ax, comp_df=comp_df, figsize=figsize, plot_ic_diff=plot_ic_diff, plot_standard_error=plot_standard_error, insample_dev=insample_dev, yticks_pos=yticks_pos, yticks_labels=yticks_labels, plot_kwargs=plot_kwargs, information_criterion=information_criterion, textsize=textsize, step=step, backend_kwargs=backend_kwargs, show=show, ) if backend is None: backend = rcParams["plot.backend"] backend = backend.lower() # TODO: Add backend kwargs plot = get_plotting_function("plot_compare", "compareplot", backend) ax = plot(**compareplot_kwargs) return ax