Source code for arviz_plots.plots.loo_pit_plot

"""Plot loo pit."""
import warnings
from collections.abc import Mapping, Sequence
from typing import Any, Literal

import xarray as xr
from arviz_base import convert_to_datatree
from arviz_base.validate import validate_dict_argument
from arviz_stats.loo import loo_pit

from arviz_plots.plots.ecdf_plot import plot_ecdf_pit


[docs] def plot_loo_pit( dt, *, var_names=None, filter_vars=None, # pylint: disable=unused-argument group="posterior_predictive", coords=None, # pylint: disable=unused-argument sample_dims=None, method="pot_c", envelope_prob=None, coverage=False, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "ecdf_lines", "credible_interval", "xlabel", "ylabel", "title", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "ecdf_lines", "credible_interval", "xlabel", "ylabel", "title", "remove_axis", ], Mapping[str, Any] | bool, ] = None, stats: Mapping[Literal["ecdf_pit"], Mapping[str, Any] | xr.Dataset] = None, **pc_kwargs, ): r"""LOO-PIT Δ-ECDF uniformity test. For a calibrated model the LOO Probability Integral Transform (PIT) values, $p(\tilde{y}_i \le y_i \mid y_{-i})$, should be uniformly distributed. Where $y_i$ represents the observed data for index $i$ and $\tilde y_i$ represents the posterior predictive sample at index $i$. $y_{-i}$ indicates we have left out the $i$-th observation. LOO-PIT values are computed using the PSIS-LOO-CV method described in [1]_ and [2]_. This plot shows the empirical cumulative distribution function (ECDF) of the LOO-PIT values. To make the plot easier to interpret, we plot the Δ-ECDF, that is, the difference between the observed ECDF and the expected CDF. The points that contribute the most to deviations from uniformity are computed as described in [3]_ and highlighted in the plot. Alternatively, we can visualize the coverage of the central posterior credible intervals by setting ``coverage=True``. This allows us to assess whether the credible intervals includes the observed values. We can obtain the coverage of the central intervals from the LOO-PIT by replacing the LOO-PIT with two times the absolute difference between the LOO-PIT values and 0.5. For more details on how to interpret this plot, see https://arviz-devs.github.io/EABM/Chapters/Prior_posterior_predictive_checks.html#pit-ecdfs. Parameters ---------- dt : DataTree Input data envelope_prob : float, optional Indicates the probability that should be contained within the envelope. Defaults to ``rcParams["stats.envelope_prob"]``. coverage : bool, optional If True, plot the coverage of the central posterior credible intervals. Defaults to False. var_names : str or list of str, optional One or more variables to be plotted. Currently only one variable is supported. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, “like”, “regex”}, optional, default=None If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. coords : dict, optional Coordinates to plot. sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` CURRENTLY NOT SUPPORTED method : {"pot_c", "prit_c", "piet_c"}, optional Method to use for the uniformity test. Defaults to "pot_c". Check the documentation of :func:`~arviz_plots.plot_ecdf_pit` for more details. envelope_prob : float, optional Indicates the probability threshold to highlight points. Defaults to ``rcParams["stats.envelope_prob"]``. plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `visuals` except for "remove_axis". visuals : mapping of {str : mapping or bool}, optional Valid keys are: * ecdf_lines -> passed to :func:`~arviz_plots.visuals.ecdf_line` * credible_interval -> passed to :func:`~arviz_plots.visuals.fill_between_y` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` * remove_axis -> not passed anywhere, can only be a boolean to indicate whether to call this function. Defaults to ``False`` for plot_loo_pit stats : mapping, optional Valid keys are: * ecdf_pit -> passed to :func:`~xarray.Dataset.azstats.uniformity_test`. Default is ``{"gamma": 0}``. **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.grid` Returns ------- PlotCollection See Also -------- plot_ppc_pit : Predictive check using PIT Δ-ECDF uniformity test. plot_loo_interval : Predictive intervals and observed data. Examples -------- Plot the ecdf-PIT for the crabs hurdle-negative-binomial dataset. .. plot:: :context: close-figs >>> from arviz_plots import plot_loo_pit, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('radon') >>> plot_loo_pit(dt) Plot the coverage for the crabs hurdle-negative-binomial dataset. .. plot:: :context: close-figs >>> plot_loo_pit(dt, coverage=True) .. minigallery:: plot_loo_pit References ---------- .. [1] Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 .. [2] Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html .. [3] Tasso et al. *LOO-PIT predictive model checking* arXiv:2603.02928 (2026). """ if group != "posterior_predictive": raise ValueError(f"Group {group} not supported. Only 'posterior_predictive' is supported.") aes_by_visuals = validate_dict_argument(aes_by_visuals, (plot_loo_pit, "aes_by_visuals")) visuals = validate_dict_argument(visuals, (plot_loo_pit, "visuals")) stats = validate_dict_argument(stats, (plot_loo_pit, "stats")) if sample_dims is None: sample_dims = ["chain", "draw"] else: warnings.warn( "'sample_dims' is currently not supported in plot_loo_pit and will be ignored" ) if method == "envelope": warnings.warn( "Method 'envelope' is not recommended for the LOO-PIT plot." "As it assumes PIT values are independent, which is not the case for LOO-PIT values.", UserWarning, ) if method not in {"envelope", "pot_c", "prit_c", "piet_c"}: raise ValueError( f"Method {method} not supported. Choose from 'envelope', 'pot_c', 'prit_c' or 'piet_c'." ) pareto_pit = method in ["pot_c", "piet_c"] lpv = loo_pit(dt, pareto_pit=pareto_pit) new_dt = convert_to_datatree(lpv, group="loo_pit") visuals.setdefault("ylabel", {}) visuals.setdefault("remove_axis", False) visuals.setdefault("xlabel", {"text": "ETI %" if coverage else "LOO-PIT"}) plot_collection = plot_ecdf_pit( new_dt, var_names=var_names, filter_vars=filter_vars, group="loo_pit", coords=coords, sample_dims=lpv.dims, method=method, envelope_prob=envelope_prob, coverage=coverage, plot_collection=plot_collection, backend=backend, labeller=labeller, aes_by_visuals=aes_by_visuals, visuals=visuals, stats=stats, **pc_kwargs, ) return plot_collection