arviz.psislw#
- arviz.psislw(log_weights, reff=1.0)[source]#
Pareto smoothed importance sampling (PSIS).
- Parameters
- log_weights: array
Array of size (n_observations, n_samples)
- reff: float
relative MCMC efficiency,
ess / n
- Returns
- lw_out: array
Smoothed log weights
- kss: array
Pareto tail indices
See also
loo
Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
Notes
If the
log_weights
input is anDataArray
with a dimension named__sample__
(recommended)psislw
will interpret this dimension as samples, and all other dimensions as dimensions of the observed data, looping over them to calculate the psislw of each observation. If no__sample__
dimension is present or the input is a numpy array, the last dimension will be interpreted as__sample__
.References
Vehtari et al. (2015) see https://arxiv.org/abs/1507.02646
Examples
Get Pareto smoothed importance sampling (PSIS) log weights:
In [1]: import arviz as az ...: data = az.load_arviz_data("centered_eight") ...: log_likelihood = data.sample_stats.log_likelihood.stack( ...: __sample__=("chain", "draw") ...: ) ...: az.psislw(-log_likelihood, reff=0.8) ...: Out[1]: (<xarray.DataArray 'log_weights' (school: 8, __sample__: 2000)> array([[-7.34995066, -6.32418143, -7.41094349, ..., -8.46666497, -6.96363409, -7.14223582], [-6.44840009, -7.55679071, -7.18955445, ..., -7.72602726, -7.07920345, -7.48188027], [-7.76466076, -7.76043872, -7.7644619 , ..., -7.64961779, -6.9420515 , -7.5451356 ], ..., [-7.32584608, -7.76025117, -7.70454826, ..., -7.73877505, -7.61780418, -7.74341114], [-7.10727871, -6.20091776, -6.42891471, ..., -6.96829636, -8.6763585 , -8.42513384], [-7.19682912, -7.48004722, -7.48344873, ..., -7.48068176, -7.64687783, -7.65904775]]) Coordinates: * school (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon' * __sample__ (__sample__) MultiIndex - chain (__sample__) int64 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 - draw (__sample__) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499, <xarray.DataArray 'pareto_shape' (school: 8)> array([0.35683758, 0.32524967, 0.53342172, 0.33519276, 0.25373991, 0.64466503, 0.71238247, 0.28943932]) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon')