Source code for arviz.plots.bpvplot

"""Bayesian p-value Posterior/Prior predictive plot."""

import numpy as np

from ..labels import BaseLabeller
from ..rcparams import rcParams
from ..utils import _var_names
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
from ..sel_utils import xarray_var_iter

[docs] def plot_bpv( data, kind="u_value", t_stat="median", bpv=True, plot_mean=True, reference="analytical", mse=False, n_ref=100, hdi_prob=0.94, color="C0", grid=None, figsize=None, textsize=None, labeller=None, data_pairs=None, var_names=None, filter_vars=None, coords=None, flatten=None, flatten_pp=None, ax=None, backend=None, plot_ref_kwargs=None, backend_kwargs=None, group="posterior", show=None, ): r"""Plot Bayesian p-value for observed data and Posterior/Prior predictive. Parameters ---------- data : InferenceData :class:`arviz.InferenceData` object containing the observed and posterior/prior predictive data. kind : {"u_value", "p_value", "t_stat"}, default "u_value" Specify the kind of plot: * The ``kind="p_value"`` computes :math:`p := p(y* \leq y | y)`. This is the probability of the data y being larger or equal than the predicted data y*. The ideal value is 0.5 (half the predictions below and half above the data). * The ``kind="u_value"`` argument computes :math:`p_i := p(y_i* \leq y_i | y)`. i.e. like a p_value but per observation :math:`y_i`. This is also known as marginal p_value. The ideal distribution is uniform. This is similar to the LOO-PIT calculation/plot, the difference is than in LOO-pit plot we compute :math:`pi = p(y_i* r \leq y_i | y_{-i} )`, where :math:`y_{-i}`, is all other data except :math:`y_i`. * The ``kind="t_stat"`` argument computes :math:`:= p(T(y)* \leq T(y) | y)` where T is any test statistic. See ``t_stat`` argument below for details of available options. t_stat : str, float, or callable, default "median" Test statistics to compute from the observations and predictive distributions. Allowed strings are “mean”, “median” or “std”. Alternative a quantile can be passed as a float (or str) in the interval (0, 1). Finally a user defined function is also acepted, see examples section for details. bpv : bool, default True If True add the Bayesian p_value to the legend when ``kind = t_stat``. plot_mean : bool, default True Whether or not to plot the mean test statistic. reference : {"analytical", "samples", None}, default "analytical" How to compute the distributions used as reference for ``kind=u_values`` or ``kind=p_values``. Use `None` to not plot any reference. mse : bool, default False Show scaled mean square error between uniform distribution and marginal p_value distribution. n_ref : int, default 100 Number of reference distributions to sample when ``reference=samples``. hdi_prob : float, optional Probability for the highest density interval for the analytical reference distribution when ``kind=u_values``. Should be in the interval (0, 1]. Defaults to the rcParam ``stats.hdi_prob``. See :ref:`this section <common_hdi_prob>` for usage examples. color : str, optional Matplotlib color grid : tuple, optional Number of rows and columns. By default, the rows and columns are automatically inferred. See :ref:`this section <common_grid>` for usage examples. figsize : (float, float), optional Figure size. If None it will be defined automatically. textsize : float, optional Text size scaling factor for labels, titles and lines. If None it will be autoscaled based on `figsize`. data_pairs : dict, optional Dictionary containing relations between observed data and posterior/prior predictive data. Dictionary structure: - key = data var_name - value = posterior/prior predictive var_name For example, ``data_pairs = {'y' : 'y_hat'}`` If None, it will assume that the observed data and the posterior/prior predictive data have the same variable name. Labeller : Labeller, optional Class providing the method ``make_pp_label`` to generate the labels in the plot titles. Read the :ref:`label_guide` for more details and usage examples. var_names : list of str, optional Variables to be plotted. If `None` all variable are plotted. Prefix the variables by ``~`` when you want to exclude them from the plot. See the :ref:`this section <common_var_names>` for usage examples. See :ref:`this section <common_var_names>` for usage examples. filter_vars : {None, "like", "regex"}, default None If `None` (default), interpret `var_names` as the real variables names. If "like", interpret `var_names` as substrings of the real variables names. If "regex", interpret `var_names` as regular expressions on the real variables names. See :ref:`this section <common_filter_vars>` for usage examples. coords : dict, optional Dictionary mapping dimensions to selected coordinates to be plotted. Dimensions without a mapping specified will include all coordinates for that dimension. Defaults to including all coordinates for all dimensions if None. See :ref:`this section <common_coords>` for usage examples. flatten : list, optional List of dimensions to flatten in observed_data. Only flattens across the coordinates specified in the coords argument. Defaults to flattening all of the dimensions. flatten_pp : list, optional List of dimensions to flatten in posterior_predictive/prior_predictive. Only flattens across the coordinates specified in the coords argument. Defaults to flattening all of the dimensions. Dimensions should match flatten excluding dimensions for data_pairs parameters. If `flatten` is defined and `flatten_pp` is None, then ``flatten_pp=flatten``. legend : bool, default True Add legend to figure. ax : 2D array-like of matplotlib_axes or bokeh_figure, optional A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create its own array of plot areas (and return it). backend : str, optional Select plotting backend {"matplotlib", "bokeh"}. Default "matplotlib". plot_ref_kwargs : dict, optional Extra keyword arguments to control how reference is represented. Passed to :meth:`matplotlib.axes.Axes.plot` or :meth:`matplotlib.axes.Axes.axhspan` (when ``kind=u_value`` and ``reference=analytical``). backend_kwargs : bool, optional These are kwargs specific to the backend being used, passed to :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`. For additional documentation check the plotting method of the backend. group : {"posterior", "prior"}, default "posterior" Specifies which InferenceData group should be plotted. If "posterior", then the values in `posterior_predictive` group are compared to the ones in `observed_data`, if "prior" then the same comparison happens, but with the values in `prior_predictive` group. show : bool, optional Call backend show function. Returns ------- axes : 2D ndarray of matplotlib_axes or bokeh_figure See Also -------- plot_ppc : Plot for posterior/prior predictive checks. plot_loo_pit : Plot Leave-One-Out probability integral transformation (PIT) predictive checks. plot_dist_comparison : Plot to compare fitted and unfitted distributions. References ---------- * Gelman et al. (2013) see pages 151-153 for details Notes ----- Discrete data is smoothed before computing either p-values or u-values using the function :func:`~arviz.smooth_data` Examples -------- Plot Bayesian p_values. .. plot:: :context: close-figs >>> import arviz as az >>> data = az.load_arviz_data("regression1d") >>> az.plot_bpv(data, kind="p_value") Plot custom test statistic comparison. .. plot:: :context: close-figs >>> import arviz as az >>> data = az.load_arviz_data("regression1d") >>> az.plot_bpv(data, kind="t_stat", t_stat=lambda x:np.percentile(x, q=50, axis=-1)) """ if group not in ("posterior", "prior"): raise TypeError("`group` argument must be either `posterior` or `prior`") for groups in (f"{group}_predictive", "observed_data"): if not hasattr(data, groups): raise TypeError(f'`data` argument must have the group "{groups}"') if kind.lower() not in ("t_stat", "u_value", "p_value"): raise TypeError("`kind` argument must be either `t_stat`, `u_value`, or `p_value`") if reference is not None and reference.lower() not in ("analytical", "samples"): raise TypeError("`reference` argument must be either `analytical`, `samples`, or `None`") if hdi_prob is None: hdi_prob = rcParams["stats.hdi_prob"] elif not 1 >= hdi_prob > 0: raise ValueError("The value of hdi_prob should be in the interval (0, 1]") if data_pairs is None: data_pairs = {} if labeller is None: labeller = BaseLabeller() if backend is None: backend = rcParams["plot.backend"] backend = backend.lower() observed = data.observed_data if group == "posterior": predictive_dataset = data.posterior_predictive elif group == "prior": predictive_dataset = data.prior_predictive if var_names is None: var_names = list(observed.data_vars) var_names = _var_names(var_names, observed, filter_vars) pp_var_names = [data_pairs.get(var, var) for var in var_names] pp_var_names = _var_names(pp_var_names, predictive_dataset, filter_vars) if flatten_pp is None: if flatten is None: flatten_pp = list(predictive_dataset.dims) else: flatten_pp = flatten if flatten is None: flatten = list(observed.dims) if coords is None: coords = {} total_pp_samples = predictive_dataset.sizes["chain"] * predictive_dataset.sizes["draw"] for key in coords.keys(): coords[key] = np.where(np.in1d(observed[key], coords[key]))[0] obs_plotters = filter_plotters_list( list( xarray_var_iter( observed.isel(coords), skip_dims=set(flatten), var_names=var_names, combined=True ) ), "plot_t_stats", ) length_plotters = len(obs_plotters) pp_plotters = [ tup for _, tup in zip( range(length_plotters), xarray_var_iter( predictive_dataset.isel(coords), var_names=pp_var_names, skip_dims=set(flatten_pp), combined=True, ), ) ] rows, cols = default_grid(length_plotters, grid=grid) bpvplot_kwargs = dict( ax=ax, length_plotters=length_plotters, rows=rows, cols=cols, obs_plotters=obs_plotters, pp_plotters=pp_plotters, total_pp_samples=total_pp_samples, kind=kind, bpv=bpv, t_stat=t_stat, reference=reference, mse=mse, n_ref=n_ref, hdi_prob=hdi_prob, plot_mean=plot_mean, color=color, figsize=figsize, textsize=textsize, labeller=labeller, plot_ref_kwargs=plot_ref_kwargs, backend_kwargs=backend_kwargs, show=show, ) # TODO: Add backend kwargs plot = get_plotting_function("plot_bpv", "bpvplot", backend) axes = plot(**bpvplot_kwargs) return axes