"""Density Comparison plot."""
import warnings
from ..labels import BaseLabeller
from ..rcparams import rcParams
from ..utils import _var_names, get_coords
from .plot_utils import get_plotting_function
from ..sel_utils import xarray_var_iter, xarray_sel_iter
[docs]def plot_dist_comparison(
data,
kind="latent",
figsize=None,
textsize=None,
var_names=None,
coords=None,
combine_dims=None,
transform=None,
legend=True,
labeller=None,
ax=None,
prior_kwargs=None,
posterior_kwargs=None,
observed_kwargs=None,
backend=None,
backend_kwargs=None,
show=None,
):
"""Plot to compare fitted and unfitted distributions.
The resulting plots will show the compared distributions both on
separate axes (particularly useful when one of them is substantially tighter
than another), and plotted together, so three plots per distribution
Parameters
----------
data : InferenceData object
:class:`arviz.InferenceData` object containing the posterior/prior data.
kind : str
kind of plot to display {"latent", "observed"}, defaults to 'latent'.
"latent" includes {"prior", "posterior"} and "observed" includes
{"observed_data", "prior_predictive", "posterior_predictive"}
figsize : tuple
Figure size. If None it will be defined automatically.
textsize: float
Text size scaling factor for labels, titles and lines. If None it will be
autoscaled based on ``figsize``.
var_names : str, list, list of lists
if str, plot the variable. if list, plot all the variables in list
of all groups. if list of lists, plot the vars of groups in respective lists.
coords : dict
Dictionary mapping dimensions to selected coordinates to be plotted.
Dimensions without a mapping specified will include all coordinates for
that dimension.
combine_dims : set_like of str, optional
List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
See the :ref:`this section <common_combine_dims>` for usage examples.
transform : callable
Function to transform data (defaults to None i.e. the identity function)
legend : bool
Add legend to figure. By default True.
labeller : labeller instance, optional
Class providing the method ``make_pp_label`` to generate the labels in the plot.
Read the :ref:`label_guide` for more details and usage examples.
ax: axes, optional
Matplotlib axes: The ax argument should have shape (nvars, 3), where the
last column is for the combined before/after plots and columns 0 and 1 are
for the before and after plots, respectively.
prior_kwargs : dicts, optional
Additional keywords passed to :func:`arviz.plot_dist` for prior/predictive groups.
posterior_kwargs : dicts, optional
Additional keywords passed to :func:`arviz.plot_dist` for posterior/predictive groups.
observed_kwargs : dicts, optional
Additional keywords passed to :func:`arviz.plot_dist` for observed_data group.
backend: str, optional
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
backend_kwargs: bool, optional
These are kwargs specific to the backend being used, passed to
:func:`matplotlib.pyplot.subplots` or
:func:`bokeh.plotting.figure`. For additional documentation
check the plotting method of the backend.
show : bool, optional
Call backend show function.
Returns
-------
axes : a numpy 2D array of matplotlib axes. Returned object will have shape (nvars, 3),
where the last column is the combined plot and the first columns are the single plots.
See Also
--------
plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
Examples
--------
Plot the prior/posterior plot for specified vars and coords.
.. plot::
:context: close-figs
>>> import arviz as az
>>> data = az.load_arviz_data('rugby')
>>> az.plot_dist_comparison(data, var_names=["defs"], coords={"team" : ["Italy"]})
"""
all_groups = ["prior", "posterior"]
if kind == "observed":
all_groups = ["observed_data", "prior_predictive", "posterior_predictive"]
if coords is None:
coords = {}
if labeller is None:
labeller = BaseLabeller()
datasets = []
groups = []
for group in all_groups:
try:
datasets.append(getattr(data, group))
groups.append(group)
except: # pylint: disable=bare-except
pass
if var_names is None:
var_names = list(datasets[0].data_vars)
if isinstance(var_names, str):
var_names = [var_names]
if isinstance(var_names[0], str):
var_names = [var_names for _ in datasets]
var_names = [_var_names(vars, dataset) for vars, dataset in zip(var_names, datasets)]
if transform is not None:
datasets = [transform(dataset) for dataset in datasets]
datasets = get_coords(datasets, coords)
len_plots = rcParams["plot.max_subplots"] // (len(groups) + 1)
len_plots = len_plots or 1
dc_plotters = [
list(xarray_var_iter(data, var_names=var, combined=True, skip_dims=combine_dims))[
:len_plots
]
for data, var in zip(datasets, var_names)
]
total_plots = sum(
1 for _ in xarray_sel_iter(datasets[0], var_names=var_names[0], combined=True)
) * (len(groups) + 1)
maxplots = len(dc_plotters[0]) * (len(groups) + 1)
if total_plots > rcParams["plot.max_subplots"]:
warnings.warn(
"rcParams['plot.max_subplots'] ({rcParam}) is smaller than the number "
"of subplots to plot ({len_plotters}), generating only {max_plots} "
"plots".format(
rcParam=rcParams["plot.max_subplots"], len_plotters=total_plots, max_plots=maxplots
),
UserWarning,
)
nvars = len(dc_plotters[0])
ngroups = len(groups)
distcomparisonplot_kwargs = dict(
ax=ax,
nvars=nvars,
ngroups=ngroups,
figsize=figsize,
dc_plotters=dc_plotters,
legend=legend,
groups=groups,
textsize=textsize,
labeller=labeller,
prior_kwargs=prior_kwargs,
posterior_kwargs=posterior_kwargs,
observed_kwargs=observed_kwargs,
backend_kwargs=backend_kwargs,
show=show,
)
if backend is None:
backend = rcParams["plot.backend"]
backend = backend.lower()
# TODO: Add backend kwargs
plot = get_plotting_function("plot_dist_comparison", "distcomparisonplot", backend)
axes = plot(**distcomparisonplot_kwargs)
return axes