Source code for arviz_plots.plots.ppc_dist_pit_plot

"""Predictive check using densities and PIT Δ-ECDFs."""

from collections.abc import Mapping, Sequence
from typing import Any, Literal

import xarray as xr
from arviz_base.labels import BaseLabeller
from arviz_base.validate import validate_dict_argument, validate_or_use_rcparam

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.ecdf_plot import plot_ecdf_pit
from arviz_plots.plots.utils import filter_aes, get_visual_kwargs, set_grid_layout
from arviz_plots.plots.utils_ppc import get_ppc_pit, get_suspicious_mask_ds, prepare_ppc_dist_data
from arviz_plots.visuals import trace_rug


[docs] def plot_ppc_dist_pit( dt, *, var_names=None, filter_vars=None, group="posterior_predictive", coords=None, sample_dims=None, kind=None, num_samples=50, method="pot_c", envelope_prob=None, coverage=False, plot_collection=None, backend=None, labeller=None, aes_by_visuals: Mapping[ Literal[ "predictive_dist", "observed_dist", "ecdf_lines", "credible_interval", "suspicious_points", "p_value_text", "title", ], Sequence[str], ] = None, visuals: Mapping[ Literal[ "predictive_dist", "observed_dist", "ecdf_lines", "credible_interval", "suspicious_points", "p_value_text", "xlabel_dist", "xlabel_pit", "ylabel", "title", "remove_axis", ], Mapping[str, Any] | bool, ] = None, stats: Mapping[ Literal["predictive_dist", "observed_dist", "ecdf_pit"], Mapping[str, Any] | xr.Dataset ] = None, **pc_kwargs, ): """1D marginals for the predictive distribution and PIT Δ-ECDF. The left column shows 1D marginals for the posterior predictive distribution overlaid on the observed data, identical to :func:`~arviz_plots.plot_ppc_dist`. The right column shows the empirical CDF (ECDF) of the PIT values minus the expected CDF, identical to :func:`~arviz_plots.plot_ppc_pit`. Suspicious observations are computed from the uniformity test and they are highlighted in both columns, either as rug marks at y=0 in the dist column or as points in ECDF for the PIT column. The suspicious observations are the ones that contribute the most to deviations from uniformity. Parameters ---------- dt : DataTree Input data with ``posterior_predictive`` and ``observed_data`` groups. var_names : str or list of str, optional Variables to plot. filter_vars : {None, "like", "regex"}, optional group : str, Group to be plotted. Defaults to "posterior_predictive". It could also be "prior_predictive". coords : dict, optional sample_dims : str or sequence of hashable, optional Defaults to ``rcParams["data.sample_dims"]``. kind : {"auto", "kde", "hist", "ecdf", "dot"}, optional Density kind for the dist column. Defaults to ``rcParams["plot.density_kind"]``. num_samples : int, default 50 Number of predictive draws to overlay in the dist column. method : {"pot_c", "prit_c", "piet_c", "envelope"}, default "pot_c" Uniformity-test method for the PIT column. envelope_prob : float, optional Probability inside the simultaneous envelope. Defaults to ``rcParams["stats.envelope_prob"]``. coverage : bool, default False If True, replace PIT with ``2|PIT - 0.5|`` to assess ETI coverage. plot_collection : PlotCollection, optional backend : {"matplotlib", "bokeh", "plotly"}, optional labeller : labeller, optional aes_by_visuals : mapping, optional Valid keys: ``predictive_dist``, ``observed_dist``, ``ecdf_lines``, ``credible_interval``, ``suspicious_points``, ``p_value_text``, ``title``. visuals : mapping, optional Valid keys: * predictive_dist -> density lines for predictive draws * observed_dist -> density line for observed data * ecdf_lines -> passed to :func:`~arviz_plots.visuals.ecdf_line` * credible_interval -> only when ``method="envelope"`` * suspicious_points -> passed to :func:`~arviz_plots.visuals.scatter_xy` * p_value_text -> passed to :func:`~arviz_plots.visuals.annotate_xy` * xlabel_dist -> x-axis label for the dist column * xlabel_pit -> x-axis label for the PIT column * ylabel -> y-axis label for the PIT column * title -> passed to :func:`~arviz_plots.visuals.labelled_title` * remove_axis -> set to ``False`` to skip axis removal stats : mapping, optional Valid keys: ``predictive_dist``, ``observed_dist``, ``ecdf_pit``. **pc_kwargs Passed to :class:`~arviz_plots.PlotCollection.grid`. Returns ------- PlotCollection See Also -------- plot_ppc_dist : Predictive density check only. plot_ppc_pit : PIT Δ-ECDF check only. Examples -------- .. plot:: :context: close-figs >>> from arviz_plots import plot_ppc_dist_pit, style >>> style.use("arviz-variat") >>> from arviz_base import load_arviz_data >>> dt = load_arviz_data('radon') >>> plot_ppc_dist_pit(dt) .. minigallery:: plot_ppc_dist_pit """ envelope_prob = validate_or_use_rcparam(envelope_prob, "stats.envelope_prob") aes_by_visuals = validate_dict_argument(aes_by_visuals, (plot_ppc_dist_pit, "aes_by_visuals")) visuals = validate_dict_argument(visuals, (plot_ppc_dist_pit, "visuals")) stats = validate_dict_argument(stats, (plot_ppc_dist_pit, "stats")) gamma = stats.get("ecdf_pit", {}).get("gamma", 0) alpha = 1 - envelope_prob if method not in {"envelope", "pot_c", "prit_c", "piet_c"}: raise ValueError( f"Method {method!r} not supported. " "Choose from 'envelope', 'pot_c', 'prit_c' or 'piet_c'." ) plot_bknd, pp_dims, sample_dims, predictive_dist, predictive_dist_sub, observed_dist = ( prepare_ppc_dist_data( dt, var_names=var_names, filter_vars=filter_vars, group=group, coords=coords, sample_dims=sample_dims, kind=kind, num_samples=num_samples, plot_collection=plot_collection, backend=backend, stats=stats, require_observed=True, ) ) pit_dt = get_ppc_pit(predictive_dist, observed_dist, sample_dims, coverage, method) pit_dims = pit_dt.ecdf_pit.dims if plot_collection is None: pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy() pc_kwargs["figure_kwargs"].setdefault("sharex", False) pc_kwargs["figure_kwargs"].setdefault("sharey", False) pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() pc_kwargs["aes"].setdefault("overlay_ppc", ["sample"]) pc_kwargs.setdefault("cols", ["column"]) pc_kwargs.setdefault("rows", "__variable__") pc_kwargs = set_grid_layout(pc_kwargs, plot_bknd, predictive_dist_sub, num_cols=2) plot_collection = PlotCollection.grid( predictive_dist_sub.expand_dims(column=2).assign_coords(column=["dist", "pit"]), backend=backend, **pc_kwargs, ) if labeller is None: labeller = BaseLabeller() # predictive density pred_density_kwargs = get_visual_kwargs(visuals, "predictive_dist") if pred_density_kwargs is not False: pred_density_kwargs.setdefault("alpha", 0.3) pred_visuals = { "dist": pred_density_kwargs, "credible_interval": False, "point_estimate": False, "point_estimate_text": False, "title": visuals.get("title", {}), "rug": False, "remove_axis": visuals.get("remove_axis", {}), } pred_aes_by_visuals = { k.replace("predictive_", ""): v for k, v in aes_by_visuals.items() if k != "observed_dist" } plot_collection.coords = {"column": "dist"} plot_collection = plot_dist( predictive_dist_sub, group=group, sample_dims=pp_dims, kind=kind, visuals=pred_visuals, aes_by_visuals=pred_aes_by_visuals, pc_kwargs=pc_kwargs, plot_collection=plot_collection, stats={"dist": stats.get("predictive_dist", {})}, ) plot_collection.coords = None plot_collection.rename_visuals(dist="predictive_dist") # observed density observed_density_kwargs = get_visual_kwargs( visuals, "observed_dist", False if group == "prior_predictive" else None ) if observed_density_kwargs is not False: observed_density_kwargs.setdefault("color", "B1") observed_visuals = { "dist": observed_density_kwargs, "credible_interval": False, "point_estimate": False, "point_estimate_text": False, "title": False, "rug": False, "remove_axis": False, } obs_aes_by_visuals = ( {"dist": aes_by_visuals["observed_dist"]} if "observed_dist" in aes_by_visuals else {} ) plot_collection.coords = {"column": "dist"} plot_collection = plot_dist( observed_dist, group="observed_data", sample_dims=pp_dims, kind=kind, visuals=observed_visuals, aes_by_visuals=obs_aes_by_visuals, plot_collection=plot_collection, stats={"dist": stats.get("observed_dist", {})}, ) plot_collection.coords = None plot_collection.rename_visuals(dist="observed_dist") # Plot suspicious obs from uniformity test rug_kwargs = get_visual_kwargs(visuals, "suspicious_points") if rug_kwargs is not False and method != "envelope": suspicious_mask_ds = get_suspicious_mask_ds(observed_dist, pit_dt, alpha, gamma, method) rug_kwargs.setdefault("color", "C1") rug_kwargs.setdefault("marker", "|") _, _, rug_ignore = filter_aes( plot_collection, aes_by_visuals, "suspicious_points", sample_dims ) plot_collection.map( trace_rug, "suspicious_points", data=observed_dist, mask=suspicious_mask_ds, ignore_aes=rug_ignore, xname=False, y=0, coords={"column": "dist"}, **rug_kwargs, ) pit_visuals = { "ylabel": visuals.get("ylabel", {}), "remove_axis": False, "xlabel": visuals.get("xlabel_pit", {"text": "ETI %" if coverage else "PIT"}), "title": visuals.get("title", {}), } for key in ("ecdf_lines", "credible_interval", "suspicious_points", "p_value_text"): if key in visuals: pit_visuals[key] = visuals[key] pit_aes_by_visuals = { k: v for k, v in aes_by_visuals.items() if k in ("ecdf_lines", "credible_interval", "suspicious_points", "p_value_text") } plot_collection.coords = {"column": "pit"} plot_collection = plot_ecdf_pit( pit_dt, var_names=var_names, filter_vars=filter_vars, group="ecdf_pit", coords=coords, sample_dims=pit_dims, method=method, envelope_prob=envelope_prob, coverage=coverage, plot_collection=plot_collection, backend=backend, labeller=labeller, aes_by_visuals=pit_aes_by_visuals, visuals=pit_visuals, stats={"ecdf_pit": stats.get("ecdf_pit", {})}, ) plot_collection.coords = None return plot_collection