Source code for arviz.stats.stats_refitting

"""Stats functions that require refitting the model."""

import logging
import warnings

import numpy as np

from .stats import loo
from .stats_utils import logsumexp as _logsumexp

__all__ = ["reloo"]

_log = logging.getLogger(__name__)


[docs] def reloo(wrapper, loo_orig=None, k_thresh=0.7, scale=None, verbose=True): """Recalculate exact Leave-One-Out cross validation refitting where the approximation fails. ``az.loo`` estimates the values of Leave-One-Out (LOO) cross validation using Pareto Smoothed Importance Sampling (PSIS) to approximate its value. PSIS works well when the posterior and the posterior_i (excluding observation i from the data used to fit) are similar. In some cases, there are highly influential observations for which PSIS cannot approximate the LOO-CV, and a warning of a large Pareto shape is sent by ArviZ. This cases typically have a handful of bad or very bad Pareto shapes and a majority of good or ok shapes. Therefore, this may not indicate that the model is not robust enough nor that these observations are inherently bad, only that PSIS cannot approximate LOO-CV correctly. Thus, we can use PSIS for all observations where the Pareto shape is below a threshold and refit the model to perform exact cross validation for the handful of observations where PSIS cannot be used. This approach allows to properly approximate LOO-CV with only a handful of refits, which in most cases is still much less computationally expensive than exact LOO-CV, which needs one refit per observation. Parameters ---------- wrapper: SamplingWrapper-like Class (preferably a subclass of ``az.SamplingWrapper``, see :ref:`wrappers_api` for details) implementing the methods described in the SamplingWrapper docs. This allows ArviZ to call **any** sampling backend (like PyStan or emcee) using always the same syntax. loo_orig : ELPDData, optional ELPDData instance with pointwise loo results. The pareto_k attribute will be checked for values above the threshold. k_thresh : float, optional Pareto shape threshold. Each pareto shape value above ``k_thresh`` will trigger a refit excluding that observation. scale : str, optional Only taken into account when loo_orig is None. See ``az.loo`` for valid options. Returns ------- ELPDData ELPDData instance containing the PSIS approximation where possible and the exact LOO-CV result where PSIS failed. The Pareto shape of the observations where exact LOO-CV was performed is artificially set to 0, but as PSIS is not performed, it should be ignored. Notes ----- It is strongly recommended to first compute ``az.loo`` on the inference results to confirm that the number of values above the threshold is small enough. Otherwise, prohibitive computation time may be needed to perform all required refits. As an extreme case, artificially assigning all ``pareto_k`` values to something larger than the threshold would make ``reloo`` perform the whole exact LOO-CV. This is not generally recommended nor intended, however, if needed, this function can be used to achieve the result. Warnings -------- Sampling wrappers are an experimental feature in a very early stage. Please use them with caution. """ required_methods = ("sel_observations", "sample", "get_inference_data", "log_likelihood__i") not_implemented = wrapper.check_implemented_methods(required_methods) if not_implemented: raise TypeError( "Passed wrapper instance does not implement all methods required for reloo " f"to work. Check the documentation of SamplingWrapper. {not_implemented} must be " "implemented and were not found." ) if loo_orig is None: loo_orig = loo(wrapper.idata_orig, pointwise=True, scale=scale) loo_refitted = loo_orig.copy() khats = loo_refitted.pareto_k loo_i = loo_refitted.loo_i scale = loo_orig.scale if scale.lower() == "deviance": scale_value = -2 elif scale.lower() == "log": scale_value = 1 elif scale.lower() == "negative_log": scale_value = -1 lppd_orig = loo_orig.p_loo + loo_orig.elpd_loo / scale_value n_data_points = loo_orig.n_data_points if verbose: warnings.warn("reloo is an experimental and untested feature", UserWarning) if np.any(khats > k_thresh): for idx in np.argwhere(khats.values > k_thresh): if verbose: _log.info("Refitting model excluding observation %d", idx) new_obs, excluded_obs = wrapper.sel_observations(idx) fit = wrapper.sample(new_obs) idata_idx = wrapper.get_inference_data(fit) log_like_idx = wrapper.log_likelihood__i(excluded_obs, idata_idx).values.flatten() loo_lppd_idx = scale_value * _logsumexp(log_like_idx, b_inv=len(log_like_idx)) khats[idx] = 0 loo_i[idx] = loo_lppd_idx loo_refitted.elpd_loo = loo_i.values.sum() loo_refitted.se = (n_data_points * np.var(loo_i.values)) ** 0.5 loo_refitted.p_loo = lppd_orig - loo_refitted.elpd_loo / scale_value return loo_refitted else: _log.info("No problematic observations") return loo_orig