"""Summary plot for model comparison."""
import numpy as np
from ..labels import BaseLabeller
from ..rcparams import rcParams
from .plot_utils import get_plotting_function
[docs]
def plot_compare(
comp_df,
insample_dev=False,
plot_standard_error=True,
plot_ic_diff=False,
order_by_rank=True,
legend=False,
title=True,
figsize=None,
textsize=None,
labeller=None,
plot_kwargs=None,
ax=None,
backend=None,
backend_kwargs=None,
show=None,
):
r"""Summary plot for model comparison.
Models are compared based on their expected log pointwise predictive density (ELPD).
This plot is in the style of the one used in [2]_. Chapter 6 in the first edition
or 7 in the second.
Notes
-----
The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out
cross-validation (LOO) or using the widely applicable information criterion (WAIC).
We recommend LOO in line with the work presented by [1]_.
Parameters
----------
comp_df : pandas.DataFrame
Result of the :func:`arviz.compare` method.
insample_dev : bool, default False
Plot in-sample ELPD, that is the value of the information criteria without the
penalization given by the effective number of parameters (p_loo or p_waic).
plot_standard_error : bool, default True
Plot the standard error of the ELPD.
plot_ic_diff : bool, default False
Plot standard error of the difference in ELPD between each model
and the top-ranked model.
order_by_rank : bool, default True
If True ensure the best model is used as reference.
legend : bool, default False
Add legend to figure.
figsize : (float, float), optional
If `None`, size is (6, num of models) inches.
title : bool, default True
Show a tittle with a description of how to interpret the plot.
textsize : float, optional
Text size scaling factor for labels, titles and lines. If `None` it will be autoscaled based
on `figsize`.
labeller : Labeller, optional
Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
Read the :ref:`label_guide` for more details and usage examples.
plot_kwargs : dict, optional
Optional arguments for plot elements. Currently accepts 'color_ic',
'marker_ic', 'color_insample_dev', 'marker_insample_dev', 'color_dse',
'marker_dse', 'ls_min_ic' 'color_ls_min_ic', 'fontsize'
ax : matplotlib_axes or bokeh_figure, optional
Matplotlib axes or bokeh figure.
backend : {"matplotlib", "bokeh"}, default "matplotlib"
Select plotting backend.
backend_kwargs : bool, optional
These are kwargs specific to the backend being used, passed to
:func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
For additional documentation check the plotting method of the backend.
show : bool, optional
Call backend show function.
Returns
-------
axes : matplotlib_axes or bokeh_figure
See Also
--------
plot_elpd : Plot pointwise elpd differences between two or more models.
compare : Compare models based on PSIS-LOO loo or WAIC waic cross-validation.
loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
waic : Compute the widely applicable information criterion.
References
----------
.. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
cross-validation and WAIC https://arxiv.org/abs/1507.04544
.. [2] McElreath R. (2022). Statistical Rethinking A Bayesian Course with Examples in
R and Stan, Second edition, CRC Press.
Examples
--------
Show default compare plot
.. plot::
:context: close-figs
>>> import arviz as az
>>> model_compare = az.compare({'Centered 8 schools': az.load_arviz_data('centered_eight'),
>>> 'Non-centered 8 schools': az.load_arviz_data('non_centered_eight')})
>>> az.plot_compare(model_compare)
Include the in-sample ELDP
.. plot::
:context: close-figs
>>> az.plot_compare(model_compare, insample_dev=True)
"""
if plot_kwargs is None:
plot_kwargs = {}
if labeller is None:
labeller = BaseLabeller()
yticks_pos, step = np.linspace(0, -1, (comp_df.shape[0] * 2) - 1, retstep=True)
yticks_pos[1::2] = yticks_pos[1::2] + step / 2
labels = [labeller.model_name_to_str(model_name) for model_name in comp_df.index]
if plot_ic_diff:
yticks_labels = [""] * len(yticks_pos)
yticks_labels[0] = labels[0]
yticks_labels[2::2] = labels[1:]
else:
yticks_labels = labels
_information_criterion = ["elpd_loo", "elpd_waic"]
column_index = [c.lower() for c in comp_df.columns]
for information_criterion in _information_criterion:
if information_criterion in column_index:
break
else:
raise ValueError(
"comp_df must contain one of the following "
f"information criterion: {_information_criterion}"
)
if order_by_rank:
comp_df.sort_values(by="rank", inplace=True)
compareplot_kwargs = dict(
ax=ax,
comp_df=comp_df,
legend=legend,
title=title,
figsize=figsize,
plot_ic_diff=plot_ic_diff,
plot_standard_error=plot_standard_error,
insample_dev=insample_dev,
yticks_pos=yticks_pos,
yticks_labels=yticks_labels,
plot_kwargs=plot_kwargs,
information_criterion=information_criterion,
textsize=textsize,
step=step,
backend_kwargs=backend_kwargs,
show=show,
)
if backend is None:
backend = rcParams["plot.backend"]
backend = backend.lower()
# TODO: Add backend kwargs
plot = get_plotting_function("plot_compare", "compareplot", backend)
ax = plot(**compareplot_kwargs)
return ax