arviz.psislw#

arviz.psislw(log_weights, reff=1.0)[source]#

Pareto smoothed importance sampling (PSIS).

Parameters:
log_weightsDataArray or (…, N) array_like

Array of size (n_observations, n_samples)

refffloat, default 1

relative MCMC efficiency, ess / n

Returns:
lw_outDataArray or (…, N) ndarray

Smoothed, truncated and normalized log weights.

kssDataArray or (…) ndarray

Estimates of the shape parameter k of the generalized Pareto distribution.

See also

loo

Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).

Notes

If the log_weights input is an DataArray with a dimension named __sample__ (recommended) psislw will interpret this dimension as samples, and all other dimensions as dimensions of the observed data, looping over them to calculate the psislw of each observation. If no __sample__ dimension is present or the input is a numpy array, the last dimension will be interpreted as __sample__.

References

  • Vehtari et al. (2024). Pareto smoothed importance sampling. Journal of Machine Learning Research, 25(72):1-58.

Examples

Get Pareto smoothed importance sampling (PSIS) log weights:

In [1]: import arviz as az
   ...: data = az.load_arviz_data("non_centered_eight")
   ...: log_likelihood = data.log_likelihood["obs"].stack(
   ...:     __sample__=["chain", "draw"]
   ...: )
   ...: az.psislw(-log_likelihood, reff=0.8)
   ...: 
Out[1]: 
(<xarray.DataArray 'log_weights' (obs_dim_0: 8, __sample__: 2000)> Size: 128kB
 array([[-7.534175  , -7.69549252, -7.58716519, ..., -8.0273235 ,
         -7.71899444, -7.62608193],
        [-7.71657247, -7.82465868, -7.82375075, ..., -7.79389977,
         -7.57652873, -7.77960084],
        [-7.6031945 , -7.58032753, -7.64441205, ..., -7.68993465,
         -7.66796043, -7.64657297],
        ...,
        [-7.24915492, -7.68576857, -7.64220523, ..., -7.64873977,
         -7.76098937, -7.7035236 ],
        [-7.93253675, -7.99252542, -7.9338617 , ..., -8.05837208,
         -7.013536  , -7.73175308],
        [-7.63099257, -7.71489387, -7.64352179, ..., -7.62964354,
         -7.53068102, -7.67534435]])
 Coordinates:
   * obs_dim_0   (obs_dim_0) int64 64B 0 1 2 3 4 5 6 7
   * __sample__  (__sample__) object 16kB MultiIndex
   * chain       (__sample__) int64 16kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3
   * draw        (__sample__) int64 16kB 0 1 2 3 4 5 ... 494 495 496 497 498 499,
 <xarray.DataArray 'pareto_shape' (obs_dim_0: 8)> Size: 64B
 array([0.3397522 , 0.78348601, 0.44186589, 0.57020058, 0.40413962,
        0.53690878, 0.66410432, 0.5290121 ])
 Coordinates:
   * obs_dim_0  (obs_dim_0) int64 64B 0 1 2 3 4 5 6 7)