arviz.psislw¶
-
arviz.
psislw
(log_weights, reff=1.0)[source]¶ Pareto smoothed importance sampling (PSIS).
- Parameters
- log_weights: array
Array of size (n_observations, n_samples)
- reff: float
relative MCMC efficiency, ess / n
- Returns
- lw_out: array
Smoothed log weights
- kss: array
Pareto tail indices
Notes
If the
log_weights
input is anDataArray
with a dimension named__sample__
(recommended)psislw
will interpret this dimension as samples, and all other dimensions as dimensions of the observed data, looping over them to calculate the psislw of each observation. If no__sample__
dimension is present or the input is a numpy array, the last dimension will be interpreted as__sample__
.References
Vehtari et al. (2015) see https://arxiv.org/abs/1507.02646
Examples
Get Pareto smoothed importance sampling (PSIS) log weights:
In [1]: import arviz as az ...: data = az.load_arviz_data("centered_eight") ...: log_likelihood = data.sample_stats.log_likelihood.stack( ...: __sample__=("chain", "draw") ...: ) ...: az.psislw(-log_likelihood, reff=0.8) ...: Out[1]: (<xarray.DataArray 'log_weights' (school: 8, __sample__: 2000)> array([[-7.35137283, -6.31188687, -7.41236566, ..., -8.46808714, -6.96348725, -7.14365799], [-6.42972359, -7.55712511, -7.20082771, ..., -7.72636166, -7.07301723, -7.48221468], [-7.76454863, -7.76032659, -7.76434977, ..., -7.64950566, -6.94343535, -7.54502347], ..., [-7.3256938 , -7.75981448, -7.70411157, ..., -7.73833836, -7.61736749, -7.74297445], [-7.1045607 , -6.2092645 , -6.43506702, ..., -6.96679107, -8.67364049, -8.42241583], [-7.19870332, -7.47992943, -7.48333094, ..., -7.48056398, -7.64676005, -7.65892997]]) Coordinates: * school (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon' * __sample__ (__sample__) MultiIndex - chain (__sample__) int64 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 - draw (__sample__) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499, <xarray.DataArray 'pareto_shape' (school: 8)> array([0.35683758, 0.32524967, 0.53342172, 0.33519276, 0.25373991, 0.64466503, 0.71238247, 0.28943932]) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon')