arviz.psislw#
- arviz.psislw(log_weights, reff=1.0)[source]#
Pareto smoothed importance sampling (PSIS).
- Parameters:
- log_weights
DataArray
or (…, N) array_like Array of size (n_observations, n_samples)
- reff
float
, default 1 relative MCMC efficiency,
ess / n
- log_weights
- Returns:
See also
loo
Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
Notes
If the
log_weights
input is anDataArray
with a dimension named__sample__
(recommended)psislw
will interpret this dimension as samples, and all other dimensions as dimensions of the observed data, looping over them to calculate the psislw of each observation. If no__sample__
dimension is present or the input is a numpy array, the last dimension will be interpreted as__sample__
.References
Vehtari et al. (2024). Pareto smoothed importance sampling. Journal of Machine Learning Research, 25(72):1-58.
Examples
Get Pareto smoothed importance sampling (PSIS) log weights:
In [1]: import arviz as az ...: data = az.load_arviz_data("non_centered_eight") ...: log_likelihood = data.log_likelihood["obs"].stack( ...: __sample__=["chain", "draw"] ...: ) ...: az.psislw(-log_likelihood, reff=0.8) ...: Out[1]: (<xarray.DataArray 'log_weights' (obs_dim_0: 8, __sample__: 2000)> Size: 128kB array([[-7.534175 , -7.69549252, -7.58716519, ..., -8.0273235 , -7.71899444, -7.62608193], [-7.71657247, -7.82465868, -7.82375075, ..., -7.79389977, -7.57652873, -7.77960084], [-7.6031945 , -7.58032753, -7.64441205, ..., -7.68993465, -7.66796043, -7.64657297], ..., [-7.24915492, -7.68576857, -7.64220523, ..., -7.64873977, -7.76098937, -7.7035236 ], [-7.93253675, -7.99252542, -7.9338617 , ..., -8.05837208, -7.013536 , -7.73175308], [-7.63099257, -7.71489387, -7.64352179, ..., -7.62964354, -7.53068102, -7.67534435]]) Coordinates: * obs_dim_0 (obs_dim_0) int64 64B 0 1 2 3 4 5 6 7 * __sample__ (__sample__) object 16kB MultiIndex * chain (__sample__) int64 16kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 * draw (__sample__) int64 16kB 0 1 2 3 4 5 ... 494 495 496 497 498 499, <xarray.DataArray 'pareto_shape' (obs_dim_0: 8)> Size: 64B array([0.3397522 , 0.78348601, 0.44186589, 0.57020058, 0.40413962, 0.53690878, 0.66410432, 0.5290121 ]) Coordinates: * obs_dim_0 (obs_dim_0) int64 64B 0 1 2 3 4 5 6 7)