Source code for arviz_plots.plots.loo_interval_plot

"""Predictive intervals plot."""

from collections.abc import Mapping, Sequence
from importlib import import_module
from typing import Any, Literal

import numpy as np
import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_base.validate import (
    validate_dict_argument,
    validate_or_use_rcparam,
    validate_sample_dims,
)
from arviz_stats.loo import loo_expectations

from arviz_plots.plots.ppc_interval_plot import _plot_interval
from arviz_plots.plots.utils import process_group_variables_coords


[docs] def plot_loo_interval( dt, *, var_names=None, filter_vars=None, group="posterior_predictive", coords=None, sample_dims=None, point_estimate=None, ci_kind=None, ci_probs=None, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "trunk", "twig", "observed_markers", "prediction_markers", "xlabel", "ylabel", "title", ], Sequence[str] | bool, ] = None, visuals: Mapping[ Literal[ "trunk", "twig", "observed_markers", "prediction_markers", "xlabel", "ylabel", "title", ], Mapping[str, Any] | bool, ] = None, stats: Mapping[ Literal["trunk", "twig", "point_estimate"], Mapping[str, Any] | xr.Dataset ] = None, **pc_kwargs, ): """Plot LOO posterior predictive intervals with observed data overlaid. Displays observed data as a point and LOO predicted data as a point estimate plus two credible intervals. Parameters ---------- dt : DataTree Input data. It should contain the ``posterior_predictive``, the ``log_likelihood`` and ``observed_data`` groups. var_names : str or list of str, optional One or more variables to be plotted. Prefix the variables by ~ when you want to exclude them from the plot. filter_vars : {None, "like", "regex"}, default=None If None, interpret var_names as the real variables names. If "like", interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. group : str Group to be plotted. Defaults to "posterior_predictive". coords : dict, optional Coordinates of `var_names` to be plotted. sample_dims : str or sequence of hashable, optional Dimensions to reduce unless mapped to an aesthetic. Defaults to ``rcParams["data.sample_dims"]`` point_estimate : {"mean", "median"}, optional Which point estimate to plot for the predictive distribution. Defaults to rcParam ``stats.point_estimate``. ci_probs : (float, float), optional Indicates the probabilities for the inner (twig) and outer (trunk) credible intervals. Defaults to ``(0.5, rcParams["stats.ci_prob"])``. It's assumed that ``ci_probs[0] < ci_probs[1]``, otherwise they are sorted. plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly", "none"}, optional labeller : labeller, optional aes_by_visuals : mapping of {str : sequence of str or False}, optional Mapping of visuals to aesthetics that should use their mapping in `plot_collection` when plotted. visuals : mapping of {str : mapping or bool}, optional Valid keys are: * trunk, twig -> passed to :func:`~arviz_plots.visuals.ci_bound_y` * observed_markers -> passed to :func:`~arviz_plots.visuals.point_y` * prediction_markers -> passed to :func:`~arviz_plots.visuals.point_y` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` defaults to False stats : mapping, optional Valid keys are: * trunk, twig -> passed to loo_expectations for the trunk and twig intervals, respectively * point_estimate -> passed to loo_expectations **pc_kwargs Passed to :class:`arviz_plots.PlotCollection.grid` Returns ------- PlotCollection See Also -------- plot_ppc_interval : Plot prior/posterior predictive intervals and observed data. plot_ppc_dist : Plot 1D marginals for the posterior/prior predictive and observed data. plot_ppc_rootogram : Plot ppc rootogram for discrete (count) data Examples -------- Plot LOO posterior predictive intervals for the radon dataset, with custom styling. .. plot:: :context: close-figs >>> from arviz_base import load_arviz_data >>> import arviz_plots as azp >>> azp.style.use("arviz-variat") >>> data = load_arviz_data("rugby") >>> pc = azp.plot_loo_interval(data) .. minigallery:: plot_loo_interval """ aes_by_visuals = validate_dict_argument(aes_by_visuals, (plot_loo_interval, "aes_by_visuals")) visuals = validate_dict_argument(visuals, (plot_loo_interval, "visuals")) stats = validate_dict_argument(stats, (plot_loo_interval, "stats")) ci_kind = validate_or_use_rcparam(ci_kind, "stats.ci_kind") point_estimate = validate_or_use_rcparam(point_estimate, "stats.point_estimate") if ci_probs is None: rc_ci_prob = rcParams["stats.ci_prob"] ci_probs = (0.5, rc_ci_prob) ci_probs = np.array(ci_probs) if ci_probs.size != 2: raise ValueError("ci_probs must have two elements for twig and trunk intervals.") if np.any(ci_probs < 0) or np.any(ci_probs > 1): raise ValueError("ci_probs must be between 0 and 1.") ci_probs.sort() if backend is None: if plot_collection is None: backend = rcParams["plot.backend"] else: backend = plot_collection.backend if labeller is None: labeller = BaseLabeller() visuals.setdefault("title", False) plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") ds_predictive = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) sample_dims = validate_sample_dims(sample_dims, data=ds_predictive) # Extract observed data if "observed_data" in dt: observed_data = process_group_variables_coords( dt, group="observed_data", var_names=var_names, filter_vars=filter_vars, coords=coords, ) else: observed_data = None probs_trunk = [(1 - ci_probs[1]) / 2, (1 + ci_probs[1]) / 2] probs_twig = [(1 - ci_probs[0]) / 2, (1 + ci_probs[0]) / 2] ci_trunk = ( loo_expectations( dt, kind="quantile", probs=probs_trunk, sample_dims=sample_dims, **stats.get("trunk", {}), )[0] .rename({"quantile": "ci_bound"}) .assign_coords(ci_bound=["lower", "upper"]) ) ci_twig = ( loo_expectations( dt, kind="quantile", probs=probs_twig, sample_dims=sample_dims, **stats.get("twig", {}) )[0] .rename({"quantile": "ci_bound"}) .assign_coords(ci_bound=["lower", "upper"]) ) pe_stats = stats.get("point_estimate", {}) if point_estimate == "median": point, _ = loo_expectations(dt, kind="median", sample_dims=sample_dims, **pe_stats) elif point_estimate == "mean": point, _ = loo_expectations(dt, kind="mean", sample_dims=sample_dims, **pe_stats) else: raise ValueError( f"point_estimate must be 'mean' or 'median', but {point_estimate} was passed." ) return _plot_interval( ds_predictive, ci_trunk, ci_twig, point, observed_data, sample_dims, group, plot_collection, backend, pc_kwargs, plot_bknd, visuals, aes_by_visuals, labeller, )