Refitting PyStan (2.x) models with ArviZ (and xarray)#
ArviZ is backend agnostic and therefore does not sample directly. In order to take advantage of algorithms that require refitting models several times, ArviZ uses SamplingWrapper
s to convert the API of the sampling backend to a common set of functions. Hence, functions like Leave Future Out Cross Validation can be used in ArviZ independently of the sampling backend used.
Below there is an example of SamplingWrapper
usage for PyStan (2.x).
import arviz as az
import pystan
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import xarray as xr
For the example, we will use a linear regression model.
np.random.seed(26)
xdata = np.linspace(0, 50, 100)
b0, b1, sigma = -2, 1, 3
ydata = np.random.normal(loc=b1 * xdata + b0, scale=sigma)
plt.plot(xdata, ydata)
[<matplotlib.lines.Line2D at 0x7eff3db05e80>]
Now we will write the Stan code, keeping in mind to include only the array shapes as parameters.
refit_lr_code = """
data {
// Define data for fitting
int<lower=0> N;
vector[N] x;
vector[N] y;
}
parameters {
real b0;
real b1;
real<lower=0> sigma_e;
}
model {
b0 ~ normal(0, 10);
b1 ~ normal(0, 10);
sigma_e ~ normal(0, 10);
for (i in 1:N) {
y[i] ~ normal(b0 + b1 * x[i], sigma_e); // use only data for fitting
}
}
generated quantities {
vector[N] y_hat;
for (i in 1:N) {
// pointwise log likelihood will be calculated outside Stan,
// posterior predictive however will be generated here, there are
// no restrictions on adding more generated quantities
y_hat[i] = normal_rng(b0 + b1 * x[i], sigma_e);
}
}
"""
sm = pystan.StanModel(model_code=refit_lr_code)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_2cdc9d1f1db425bb7186f919c45c9b36 NOW.
data_dict = {
"N": len(ydata),
"y": ydata,
"x": xdata,
}
sample_kwargs = {"iter": 1000, "chains": 4}
fit = sm.sampling(data=data_dict, **sample_kwargs)
We have defined a dictionary sample_kwargs
that will be passed to the SamplingWrapper
in order to make sure that all
refits use the same sampler parameters. We follow the same pattern with from_pystan()
.
dims = {"y": ["time"], "x": ["time"], "y_hat": ["time"]}
idata_kwargs = {
"posterior_predictive": ["y_hat"],
"observed_data": "y",
"constant_data": "x",
"dims": dims,
}
idata = az.from_pystan(posterior=fit, **idata_kwargs)
We are now missing the log_likelihood
group because we have not used the log_likelihood
argument in idata_kwargs
. We are doing this to ease the job of the sampling wrapper. Instead of going out of our way to get Stan to calculate the pointwise log likelihood values for each refit and for the excluded observation at every refit, we will compromise and manually write a function to calculate the pointwise log likelihood.
Even though it is not ideal to lose part of the straight out of the box capabilities of PyStan-ArviZ integration, this should generally not be a problem. We are basically moving the pointwise log likelihood calculation from the Stan code to the Python code, in both cases, we need to manually write the function to calculate the pointwise log likelihood.
Moreover, the Python computation could even be written to be compatible with Dask. Thus it will work even in cases where the large number of observations makes it impossible to store pointwise log likelihood values (with shape n_samples * n_observations
) in memory.
def calculate_log_lik(x, y, b0, b1, sigma_e):
mu = b0 + b1 * x
return stats.norm(mu, sigma_e).logpdf(y)
This function should work for any shape of the input arrays as long as their shapes are compatible and can broadcast. There is no need to loop over each draw in order to calculate the pointwise log likelihood using scalars.
Therefore, we can use xr.apply_ufunc
to handle the broadcasting and preserve the dimension names:
log_lik = xr.apply_ufunc(
calculate_log_lik,
idata.constant_data["x"],
idata.observed_data["y"],
idata.posterior["b0"],
idata.posterior["b1"],
idata.posterior["sigma_e"],
)
idata.add_groups(log_likelihood=log_lik)
The first argument is the function, followed by as many positional arguments as needed by the function, 5 in our case. As this case does not have many different dimensions nor combinations of these, we do not need to use any extra kwargs passed to xarray.apply_ufunc
.
We are now passing the arguments to calculate_log_lik
initially as xarray.DataArray
s. What is happening here behind the scenes is that apply_ufunc()
is broadcasting and aligning the dimensions of all the DataArrays involved and afterwards passing NumPy arrays to calculate_log_lik
. Everything works automagically.
Now let’s see what happens if we were to pass the arrays directly to calculate_log_lik
instead:
calculate_log_lik(
idata.constant_data["x"].values,
idata.observed_data["y"].values,
idata.posterior["b0"].values,
idata.posterior["b1"].values,
idata.posterior["sigma_e"].values
)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-10-fc2d553bde92> in <module>
----> 1 calculate_log_lik(
2 idata.constant_data["x"].values,
3 idata.observed_data["y"].values,
4 idata.posterior["b0"].values,
5 idata.posterior["b1"].values,
<ipython-input-8-e6777d985e1f> in calculate_log_lik(x, y, b0, b1, sigma_e)
1 def calculate_log_lik(x, y, b0, b1, sigma_e):
----> 2 mu = b0 + b1 * x
3 return stats.norm(mu, sigma_e).logpdf(y)
ValueError: operands could not be broadcast together with shapes (4,500) (100,)
If you are still curious about the magic of xarray and apply_ufunc
, you can also try to modify the dims
used to generate the InferenceData
a couple of cells before:
dims = {"y": ["time"], "x": ["time"]}
What happens to the result if you use a different name for the dimension of x
?
idata
-
- chain: 4
- draw: 500
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
- b0(chain, draw)float64-1.71 -3.234 ... -3.213 -2.435
array([[-1.70976257, -3.23425216, -2.25351782, ..., -2.56913775, -2.08880786, -1.89536428], [-2.80265651, -2.75145537, -2.19537235, ..., -2.50119745, -2.4800164 , -2.67361341], [-2.76579393, -2.76807446, -1.55760542, ..., -1.53763999, -1.71889824, -2.52665986], [-1.9179084 , -1.90488131, -3.21904738, ..., -1.51565781, -3.21331217, -2.43539996]])
- b1(chain, draw)float641.007 1.036 1.025 ... 1.046 1.025
array([[1.00722956, 1.03629913, 1.02533173, ..., 1.02414785, 1.00990258, 1.00257576], [1.02843502, 1.01101049, 1.02637939, ..., 1.0221316 , 1.03456897, 1.03830894], [1.03302425, 1.03091501, 0.98509302, ..., 0.98669291, 0.97865203, 1.04242641], [0.9949438 , 1.00678403, 1.04593325, ..., 0.98998275, 1.04561366, 1.02479964]])
- sigma_e(chain, draw)float643.135 3.333 2.613 ... 3.116 3.664
array([[3.13518404, 3.33292261, 2.61334128, ..., 2.75948526, 2.51749634, 3.02298999], [2.91211116, 3.00750097, 2.6990428 , ..., 2.73295424, 2.7586741 , 2.69014456], [2.61478251, 2.58839126, 2.96722067, ..., 3.01217268, 3.09647796, 2.90083022], [3.19894428, 3.17015489, 2.79776047, ..., 2.70463523, 3.11588558, 3.66370081]])
- created_at :
- 2020-10-06T01:27:54.846477
- arviz_version :
- 0.10.0
- inference_library :
- pystan
- inference_library_version :
- 2.19.1.1
- args :
- [{'random_seed': '345470392', 'chain_id': 0, 'init': 'random', 'init_radius': 2.0, 'enable_random_init': False, 'append_samples': False, 'method': 'SAMPLING', 'iter': 1000, 'warmup': 500, 'thin': 1, 'refresh': 100, 'test_grad': False, 'sampler_t': 'NUTS', 'control': {'adapt_engaged': True, 'adapt_gamma': 0.05, 'adapt_delta': 0.8, 'adapt_kappa': 0.75, 'adapt_init_buffer': 75, 'adapt_term_buffer': 50, 'adapt_window': 25, 'adapt_t0': 10.0, 'stepsize': 1.0, 'stepsize_jitter': 0.0}}, {'random_seed': '345470392', 'chain_id': 1, 'init': 'random', 'init_radius': 2.0, 'enable_random_init': False, 'append_samples': False, 'method': 'SAMPLING', 'iter': 1000, 'warmup': 500, 'thin': 1, 'refresh': 100, 'test_grad': False, 'sampler_t': 'NUTS', 'control': {'adapt_engaged': True, 'adapt_gamma': 0.05, 'adapt_delta': 0.8, 'adapt_kappa': 0.75, 'adapt_init_buffer': 75, 'adapt_term_buffer': 50, 'adapt_window': 25, 'adapt_t0': 10.0, 'stepsize': 1.0, 'stepsize_jitter': 0.0}}, {'random_seed': '345470392', 'chain_id': 2, 'init': 'random', 'init_radius': 2.0, 'enable_random_init': False, 'append_samples': False, 'method': 'SAMPLING', 'iter': 1000, 'warmup': 500, 'thin': 1, 'refresh': 100, 'test_grad': False, 'sampler_t': 'NUTS', 'control': {'adapt_engaged': True, 'adapt_gamma': 0.05, 'adapt_delta': 0.8, 'adapt_kappa': 0.75, 'adapt_init_buffer': 75, 'adapt_term_buffer': 50, 'adapt_window': 25, 'adapt_t0': 10.0, 'stepsize': 1.0, 'stepsize_jitter': 0.0}}, {'random_seed': '345470392', 'chain_id': 3, 'init': 'random', 'init_radius': 2.0, 'enable_random_init': False, 'append_samples': False, 'method': 'SAMPLING', 'iter': 1000, 'warmup': 500, 'thin': 1, 'refresh': 100, 'test_grad': False, 'sampler_t': 'NUTS', 'control': {'adapt_engaged': True, 'adapt_gamma': 0.05, 'adapt_delta': 0.8, 'adapt_kappa': 0.75, 'adapt_init_buffer': 75, 'adapt_term_buffer': 50, 'adapt_window': 25, 'adapt_t0': 10.0, 'stepsize': 1.0, 'stepsize_jitter': 0.0}}]
- inits :
- [[1.7126106914526407, -0.030748403931205592, 0.22247692354028922, 1.563010620085344, 1.736125399192665, 1.747477672816779, 1.3619421724977934, 1.862291818270751, 1.823006344878086, 1.9510697680488691, 1.7802003879055686, 1.9723026966582544, 1.5794589603977573, 1.423259411648393, 1.1533007326197822, 1.665203635965765, 1.2206841139058222, 1.5134883035255693, 1.7157023710552108, 0.9881839736406081, 1.069917520625811, 1.6653581146389125, 1.409099714503847, 1.4401322196838116, 1.6619755510164553, 1.4123525809413187, 1.230305712619689, 0.9892577439052912, 1.1956685234367415, 1.290803397086368, 1.6556721281980755, 1.0359510222207802, 1.439786579369818, 1.1183112282535932, 1.7343475002773099, 1.1046589804210964, 1.2283133412643932, 1.3207407493167025, 1.218904874415506, 1.1003995539499523, 0.9975735945253834, 1.350742576917557, 1.2532881850639583, 1.1823273075743246, 1.1012409075612908, 1.2628103800322803, 1.0147299582505995, 0.8589847106505416, 0.8898727253796739, 0.9782458546698332, 1.2665383438938882, 1.228723618730904, 0.5637394005411227, 0.7582721484070517, 1.0776734106206127, 0.7136245342532639, 0.9501103607543755, 1.237519347530787, 1.0065130041924768, 0.6463745867889956, 0.9206908332482945, 0.7868266683971612, 0.9435344545136928, 0.7803438724119967, 1.0043123973930805, 0.7690345998354766, 0.8673992960057956, 0.5130348495323778, 0.9577832508776849, 0.8487364496512848, 0.56420484463849, 0.3808267388243632, 0.8432494355018948, 0.7778493238482694, 0.4678642119028372, 0.94579232925843, 0.40229430763676355, 0.6920895344930726, 0.5109142229354574, 0.33922752258717415, 0.31074683835556594, 0.25225379978818274, 0.689448177734008, 0.684476788564254, 0.2715522715334655, 0.4946368245978605, 0.9866808184224851, 0.4169624690140418, 0.6718100078505944, 0.5023736247617323, 0.6529272673144919, 0.15714348325402075, 0.5642690591022588, 0.3123409660119498, 0.5021046802979686, 0.12687982165999667, -0.08504371711810915, 0.20288571183597803, 0.31733394346830934, 0.11187690328850451, -0.3845369228583567, 0.30810916044917214, 0.39284771909908267], [-1.1134838907791371, -1.1238437558759764, 0.9262276644304522, -1.0001364048783417, -2.782522071985624, -3.7351937289534938, -2.83382509828959, -3.9792167629245547, -5.325575986287659, -3.9719222630911326, -4.9701162633268146, -4.889873484536508, -8.098215724972203, -6.243735913970589, -6.789380005195758, -8.584457382718497, -7.914841940120913, -10.035717691640423, -9.846791491398221, -10.640525911555425, -10.449218679505659, -10.172584281939635, -12.019409036276185, -13.549679369471825, -13.164603583641847, -12.383514257856618, -13.782400280097933, -16.546919393640355, -15.30576145363258, -14.948107811268722, -14.70816502361977, -15.127552271660006, -17.832674554385495, -16.775446861311107, -18.661588011977177, -18.36835186425152, -20.394695344524703, -21.439750398706323, -20.389732246228064, -22.17381654540745, -22.641342007687257, -21.621532944156733, -23.751810424861787, -24.9088896091435, -24.44293953585543, -23.331151799198576, -25.799087598211663, -25.72683651993046, -25.48519331614153, -27.093526224961874, -27.007351803525772, -27.701731039425788, -28.554065742695233, -29.48530243664214, -28.855743050674327, -30.933087043072334, -31.680025416945885, -30.820205556693764, -33.27033854652422, -33.47350243177385, -34.186533699921384, -33.16471915066973, -35.323053269050135, -35.03193254130218, -35.49445872914028, -35.91028138752765, -37.55910680289756, -37.15440219323971, -38.86861034335689, -38.54267862730015, -39.8214291646721, -40.0671002094611, -41.08407230470809, -39.88130992708778, -42.14392256076823, -42.646971743334106, -44.230222897142525, -43.459515365024444, -44.63644938096393, -44.68353023478125, -43.33238239646698, -44.386574144691686, -47.47950487798273, -47.11019966930678, -45.64532647072011, -47.29140500752328, -47.8741527359583, -46.830155827750204, -49.97442249976532, -48.90904201933886, -48.687968198431754, -50.54786356802271, -53.90826980794699, -50.80527241272121, -51.467416347151236, -54.79437858746194, -53.13239027349298, -55.006846225468095, -54.82657088035839, -55.907371006058746, -56.79163455597591, -56.53305517174383, -56.40714707088171], [0.7591034533841987, 1.9826051995642704, 0.16793431858092658, 0.805312735131347, 1.6118763010848847, 2.946838444483057, 3.986470083405372, 4.662304025186486, 5.663232521610033, 6.47764386550193, 8.063336371592063, 8.973610823364023, 9.683577887532342, 10.777672929044504, 11.483516345441842, 12.669825782382464, 13.81617825420758, 14.727078095609041, 15.653286560423263, 16.85465470427706, 17.987686798235003, 18.441188207454093, 19.75095010792246, 20.9754254345274, 21.85535971129151, 22.533179271441444, 23.50227798608833, 25.00778874809557, 25.614353580458648, 27.172987133980325, 27.658779705317993, 28.954196590003786, 30.15981298135688, 30.463482542940092, 31.689433708885595, 32.524210496649246, 33.749313054299236, 34.83697661146893, 35.751436298022114, 36.60985428849034, 37.66136044211557, 38.55513192121144, 39.75641664041576, 40.969800417605995, 41.914429503403355, 42.793784152363386, 43.72267631681878, 44.881413319736005, 45.78253954987451, 47.03942682509191, 47.95070882098185, 48.80238646913053, 49.95797756092825, 50.80882028984611, 51.8271754072294, 52.588018832719314, 54.10576353161209, 54.95576050804802, 55.89297305055758, 56.93988171758835, 58.02306474601656, 58.77058441411522, 60.16084890417911, 60.777300374837594, 61.91383765149841, 62.767107949227004, 63.944942003216745, 64.73034265551848, 65.87382466221956, 66.6603136060696, 67.73754512574801, 69.16076749335465, 69.8769846512889, 70.7057558927324, 71.68907293809872, 72.94513751725212, 73.82240665566322, 74.80419979945611, 75.72663080825859, 76.83568810191117, 78.08416880540659, 78.66852993263308, 79.56930634230059, 80.90007304107715, 81.78419498059216, 82.95122544359755, 84.0829806921834, 85.08294160279907, 85.76766354702576, 86.86449445647895, 87.65391529214156, 88.92242615016434, 89.52614263110287, 90.89244320699623, 91.86356426165837, 92.94479767697617, 93.90322653999661, 94.91265847220282, 95.91238715716095, 96.80811913044283, 97.73116607328464, 98.94966159285758, 99.87780404837258], [0.15253633126529165, -0.36145359607646665, 2.9280508124252953, 0.44250791094032554, -3.6799518142844327, -2.637575384337925, 3.119163033609511, -5.903861852643543, -2.4637571292352485, -4.816611487752867, -0.0077577799217010135, -5.0189534846237605, -1.8634726123276462, -7.709913469136537, -1.5653880871383996, 2.28507271712304, 1.0967970812627015, -4.334732711225383, -2.360323504043661, -0.39052505975241747, -5.616006050348081, -1.4596375055706772, -2.69918271254674, 0.9850547221829933, -1.2274417251958054, -5.560564712256568, -7.5781701998946165, -2.428212736177295, -4.857753561980811, -2.0866409143312015, -7.691703594910914, -4.218549670102667, -10.925917855564652, -8.81457774471066, -6.350847179057185, -5.868837216402981, -0.31264571876556335, -9.164780377262067, -5.120628415520505, -9.278690325085478, -10.99975092471847, -5.957421945498301, -5.224847538211541, -5.591515757241417, -8.243928421143757, -2.0404183877048476, -5.084809273978762, -9.757553467949597, -9.63713618833015, -7.866345955920552, -13.292801501220811, -8.026299189765593, -12.22354699393636, -8.06460770280281, -11.70071071659354, -9.60974123905669, -5.95337086054841, -7.032222067826643, -10.075801891488775, -11.180074524884002, -16.74247715518144, -10.661417484399925, -8.879564959344641, -11.329560330559646, -6.365956064783927, -12.107391549879763, -8.917615099302498, -8.58178416233676, -9.0983027670198, -16.252306312272093, -14.010782067805266, -14.997204857624105, -12.531442981775685, -17.467706632832485, -9.682956913321522, -15.941717677222895, -12.981704646961063, -15.907159314915198, -12.057734353093792, -9.604060313033159, -14.966573074038134, -18.42544143538557, -13.595692084709363, -13.676939354896357, -15.83130940868651, -16.91040917415477, -17.010379505766938, -17.573009441087734, -15.007376906661394, -17.692655622589996, -14.476798257946067, -17.838639229747464, -14.25081695011485, -18.05170466422207, -17.966436001117373, -11.091651588821101, -17.31954264519986, -22.554457294516784, -17.64570979361735, -19.822745317065376, -16.65596740508466, -19.8783069256713, -16.612644786945985]]
- step_size :
- [0.376451, 0.417677, 0.382013, 0.282232]
- metric :
- ['diag_e', 'diag_e', 'diag_e', 'diag_e']
- inv_metric :
- [[0.360859, 0.000475137, 0.00443736], [0.34879, 0.000404285, 0.00467239], [0.280667, 0.000337829, 0.00543657], [0.425118, 0.000556571, 0.00522525]]
- adaptation_info :
- ['# Adaptation terminated\n# Step size = 0.376451\n# Diagonal elements of inverse mass matrix:\n# 0.360859, 0.000475137, 0.00443736\n', '# Adaptation terminated\n# Step size = 0.417677\n# Diagonal elements of inverse mass matrix:\n# 0.34879, 0.000404285, 0.00467239\n', '# Adaptation terminated\n# Step size = 0.382013\n# Diagonal elements of inverse mass matrix:\n# 0.280667, 0.000337829, 0.00543657\n', '# Adaptation terminated\n# Step size = 0.282232\n# Diagonal elements of inverse mass matrix:\n# 0.425118, 0.000556571, 0.00522525\n']
- stan_code :
- data { // Define data for fitting int<lower=0> N; vector[N] x; vector[N] y; } parameters { real b0; real b1; real<lower=0> sigma_e; } model { b0 ~ normal(0, 10); b1 ~ normal(0, 10); sigma_e ~ normal(0, 10); for (i in 1:N) { y[i] ~ normal(b0 + b1 * x[i], sigma_e); // use only data for fitting } } generated quantities { vector[N] y_hat; for (i in 1:N) { // pointwise log likelihood will be calculated outside stan, // posterior predictive however will be generated here, there are // no restrictions on adding more generated quantities y_hat[i] = normal_rng(b0 + b1 * x[i], sigma_e); } }
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 Data variables: b0 (chain, draw) float64 -1.71 -3.234 -2.254 ... -1.516 -3.213 -2.435 b1 (chain, draw) float64 1.007 1.036 1.025 1.028 ... 0.99 1.046 1.025 sigma_e (chain, draw) float64 3.135 3.333 2.613 2.644 ... 2.705 3.116 3.664 Attributes: created_at: 2020-10-06T01:27:54.846477 arviz_version: 0.10.0 inference_library: pystan inference_library_version: 2.19.1.1 args: [{'random_seed': '345470392', 'chain_id': 0, ... inits: [[1.7126106914526407, -0.030748403931205592, ... step_size: [0.376451, 0.417677, 0.382013, 0.282232] metric: ['diag_e', 'diag_e', 'diag_e', 'diag_e'] inv_metric: [[0.360859, 0.000475137, 0.00443736], [0.3487... adaptation_info: ['# Adaptation terminated\n# Step size = 0.37... stan_code: \ndata {\n // Define data for fitting\n int...
xarray.Dataset -
- chain: 4
- draw: 500
- time: 100
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
- time(time)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- y_hat(chain, draw, time)float64-6.143 3.58 -2.542 ... 50.53 52.38
array([[[-6.14324716, 3.57962417, -2.54207984, ..., 44.80279953, 52.17649124, 46.30120482], [-2.35686662, -5.589217 , -0.36932621, ..., 41.00136624, 48.08109852, 46.33960453], [-1.66857314, -1.74229818, 1.77193804, ..., 49.12058088, 47.91325588, 49.82294993], ..., [-5.09563187, -3.88868228, -0.40436719, ..., 47.2786351 , 45.9107155 , 49.63458312], [-4.33567606, -1.35079225, 0.93241619, ..., 43.8019844 , 52.00584094, 50.67218395], [-3.33230934, -1.92816671, 4.29220822, ..., 46.937629 , 48.17816884, 48.07725946]], [[ 3.3925083 , 3.29887865, -4.26851217, ..., 48.04336146, 47.67624246, 49.86028495], [-4.08820847, -4.8530133 , -2.61032323, ..., 47.11625882, 48.06495934, 49.81138569], [-1.02823027, -3.15138034, 0.49948274, ..., 49.5596293 , 49.29807624, 50.37768098], ... [ 6.99634784, -3.99291677, -1.69006997, ..., 38.37266713, 45.88751643, 47.62275585], [ 1.12457707, -4.5825232 , 5.08083248, ..., 49.50040743, 47.97067205, 45.40190614], [-2.9970034 , -3.09521293, -2.44704281, ..., 48.1932144 , 50.14285014, 46.37605659]], [[-4.56410819, -3.1154288 , 0.72723618, ..., 45.02077023, 45.83328218, 52.53996288], [-5.28682627, -2.21783154, 0.91731377, ..., 47.56635459, 41.44493902, 47.95766167], [-4.18208721, -2.10287726, -2.38028365, ..., 41.44783494, 50.07066811, 50.47616347], ..., [ 1.50073219, -2.13533975, -3.55674687, ..., 50.63988104, 47.53302654, 46.90014858], [-2.50882429, -2.88749118, -2.35065312, ..., 50.75031412, 45.17683213, 55.36951043], [-6.21022215, 0.56425813, 2.95538064, ..., 44.45020634, 50.52837543, 52.38130319]]])
- created_at :
- 2020-10-06T01:27:54.856579
- arviz_version :
- 0.10.0
- inference_library :
- pystan
- inference_library_version :
- 2.19.1.1
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, time: 100) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * time (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 Data variables: y_hat (chain, draw, time) float64 -6.143 3.58 -2.542 ... 50.53 52.38 Attributes: created_at: 2020-10-06T01:27:54.856579 arviz_version: 0.10.0 inference_library: pystan inference_library_version: 2.19.1.1
xarray.Dataset -
- chain: 4
- draw: 500
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
- accept_stat(chain, draw)float640.9181 0.9607 ... 0.9131 0.8482
array([[0.91812179, 0.96067939, 0.98916694, ..., 0.90434277, 0.96825801, 0.99901529], [0.97601426, 0.28979372, 1. , ..., 0.96886751, 0.8477975 , 0.97603435], [0.70064989, 0.9978169 , 0.96396176, ..., 0.99683543, 0.79800424, 0.98639612], [0.95074447, 0.99994436, 0.99786912, ..., 0.98993924, 0.91313736, 0.84818474]])
- stepsize(chain, draw)float640.3765 0.3765 ... 0.2822 0.2822
array([[0.37645063, 0.37645063, 0.37645063, ..., 0.37645063, 0.37645063, 0.37645063], [0.41767651, 0.41767651, 0.41767651, ..., 0.41767651, 0.41767651, 0.41767651], [0.38201258, 0.38201258, 0.38201258, ..., 0.38201258, 0.38201258, 0.38201258], [0.28223165, 0.28223165, 0.28223165, ..., 0.28223165, 0.28223165, 0.28223165]])
- treedepth(chain, draw)int643 3 4 3 2 1 3 4 ... 2 3 3 4 4 1 4 4
array([[3, 3, 4, ..., 3, 3, 3], [2, 2, 2, ..., 2, 4, 1], [2, 3, 3, ..., 3, 2, 4], [3, 3, 4, ..., 1, 4, 4]])
- n_leapfrog(chain, draw)int647 11 15 15 3 3 ... 7 15 15 3 15 15
array([[ 7, 11, 15, ..., 11, 7, 7], [ 3, 3, 7, ..., 3, 15, 3], [ 3, 7, 15, ..., 15, 3, 15], [ 7, 7, 15, ..., 3, 15, 15]])
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
- energy(chain, draw)float64159.8 160.4 160.1 ... 157.6 162.1
array([[159.77452278, 160.39501628, 160.07809243, ..., 156.59300635, 157.4439019 , 157.54976232], [155.82344174, 160.46219142, 157.56073606, ..., 155.50982852, 156.62434929, 157.10646235], [158.98498532, 156.65856438, 159.15846073, ..., 157.18681133, 157.89767293, 158.39235644], [163.19043666, 156.76008373, 157.04852783, ..., 156.69939028, 157.60374161, 162.11484674]])
- lp(chain, draw)float64-156.2 -157.9 ... -156.5 -159.7
array([[-156.21111326, -157.88066104, -156.62156933, ..., -155.20137833, -157.07660797, -155.27960499], [-155.26793518, -156.92001921, -156.36895333, ..., -155.25730088, -156.03898312, -156.3164537 ], [-156.3873743 , -156.59553119, -156.11815325, ..., -156.05395265, -157.5559844 , -156.76946534], [-156.37672426, -155.84329924, -156.34586725, ..., -156.47280077, -156.46022752, -159.65714163]])
- created_at :
- 2020-10-06T01:27:54.851944
- arviz_version :
- 0.10.0
- inference_library :
- pystan
- inference_library_version :
- 2.19.1.1
- args :
- [{'random_seed': '345470392', 'chain_id': 0, 'init': 'random', 'init_radius': 2.0, 'enable_random_init': False, 'append_samples': False, 'method': 'SAMPLING', 'iter': 1000, 'warmup': 500, 'thin': 1, 'refresh': 100, 'test_grad': False, 'sampler_t': 'NUTS', 'control': {'adapt_engaged': True, 'adapt_gamma': 0.05, 'adapt_delta': 0.8, 'adapt_kappa': 0.75, 'adapt_init_buffer': 75, 'adapt_term_buffer': 50, 'adapt_window': 25, 'adapt_t0': 10.0, 'stepsize': 1.0, 'stepsize_jitter': 0.0}}, {'random_seed': '345470392', 'chain_id': 1, 'init': 'random', 'init_radius': 2.0, 'enable_random_init': False, 'append_samples': False, 'method': 'SAMPLING', 'iter': 1000, 'warmup': 500, 'thin': 1, 'refresh': 100, 'test_grad': False, 'sampler_t': 'NUTS', 'control': {'adapt_engaged': True, 'adapt_gamma': 0.05, 'adapt_delta': 0.8, 'adapt_kappa': 0.75, 'adapt_init_buffer': 75, 'adapt_term_buffer': 50, 'adapt_window': 25, 'adapt_t0': 10.0, 'stepsize': 1.0, 'stepsize_jitter': 0.0}}, {'random_seed': '345470392', 'chain_id': 2, 'init': 'random', 'init_radius': 2.0, 'enable_random_init': False, 'append_samples': False, 'method': 'SAMPLING', 'iter': 1000, 'warmup': 500, 'thin': 1, 'refresh': 100, 'test_grad': False, 'sampler_t': 'NUTS', 'control': {'adapt_engaged': True, 'adapt_gamma': 0.05, 'adapt_delta': 0.8, 'adapt_kappa': 0.75, 'adapt_init_buffer': 75, 'adapt_term_buffer': 50, 'adapt_window': 25, 'adapt_t0': 10.0, 'stepsize': 1.0, 'stepsize_jitter': 0.0}}, {'random_seed': '345470392', 'chain_id': 3, 'init': 'random', 'init_radius': 2.0, 'enable_random_init': False, 'append_samples': False, 'method': 'SAMPLING', 'iter': 1000, 'warmup': 500, 'thin': 1, 'refresh': 100, 'test_grad': False, 'sampler_t': 'NUTS', 'control': {'adapt_engaged': True, 'adapt_gamma': 0.05, 'adapt_delta': 0.8, 'adapt_kappa': 0.75, 'adapt_init_buffer': 75, 'adapt_term_buffer': 50, 'adapt_window': 25, 'adapt_t0': 10.0, 'stepsize': 1.0, 'stepsize_jitter': 0.0}}]
- inits :
- [[1.7126106914526407, -0.030748403931205592, 0.22247692354028922, 1.563010620085344, 1.736125399192665, 1.747477672816779, 1.3619421724977934, 1.862291818270751, 1.823006344878086, 1.9510697680488691, 1.7802003879055686, 1.9723026966582544, 1.5794589603977573, 1.423259411648393, 1.1533007326197822, 1.665203635965765, 1.2206841139058222, 1.5134883035255693, 1.7157023710552108, 0.9881839736406081, 1.069917520625811, 1.6653581146389125, 1.409099714503847, 1.4401322196838116, 1.6619755510164553, 1.4123525809413187, 1.230305712619689, 0.9892577439052912, 1.1956685234367415, 1.290803397086368, 1.6556721281980755, 1.0359510222207802, 1.439786579369818, 1.1183112282535932, 1.7343475002773099, 1.1046589804210964, 1.2283133412643932, 1.3207407493167025, 1.218904874415506, 1.1003995539499523, 0.9975735945253834, 1.350742576917557, 1.2532881850639583, 1.1823273075743246, 1.1012409075612908, 1.2628103800322803, 1.0147299582505995, 0.8589847106505416, 0.8898727253796739, 0.9782458546698332, 1.2665383438938882, 1.228723618730904, 0.5637394005411227, 0.7582721484070517, 1.0776734106206127, 0.7136245342532639, 0.9501103607543755, 1.237519347530787, 1.0065130041924768, 0.6463745867889956, 0.9206908332482945, 0.7868266683971612, 0.9435344545136928, 0.7803438724119967, 1.0043123973930805, 0.7690345998354766, 0.8673992960057956, 0.5130348495323778, 0.9577832508776849, 0.8487364496512848, 0.56420484463849, 0.3808267388243632, 0.8432494355018948, 0.7778493238482694, 0.4678642119028372, 0.94579232925843, 0.40229430763676355, 0.6920895344930726, 0.5109142229354574, 0.33922752258717415, 0.31074683835556594, 0.25225379978818274, 0.689448177734008, 0.684476788564254, 0.2715522715334655, 0.4946368245978605, 0.9866808184224851, 0.4169624690140418, 0.6718100078505944, 0.5023736247617323, 0.6529272673144919, 0.15714348325402075, 0.5642690591022588, 0.3123409660119498, 0.5021046802979686, 0.12687982165999667, -0.08504371711810915, 0.20288571183597803, 0.31733394346830934, 0.11187690328850451, -0.3845369228583567, 0.30810916044917214, 0.39284771909908267], [-1.1134838907791371, -1.1238437558759764, 0.9262276644304522, -1.0001364048783417, -2.782522071985624, -3.7351937289534938, -2.83382509828959, -3.9792167629245547, -5.325575986287659, -3.9719222630911326, -4.9701162633268146, -4.889873484536508, -8.098215724972203, -6.243735913970589, -6.789380005195758, -8.584457382718497, -7.914841940120913, -10.035717691640423, -9.846791491398221, -10.640525911555425, -10.449218679505659, -10.172584281939635, -12.019409036276185, -13.549679369471825, -13.164603583641847, -12.383514257856618, -13.782400280097933, -16.546919393640355, -15.30576145363258, -14.948107811268722, -14.70816502361977, -15.127552271660006, -17.832674554385495, -16.775446861311107, -18.661588011977177, -18.36835186425152, -20.394695344524703, -21.439750398706323, -20.389732246228064, -22.17381654540745, -22.641342007687257, -21.621532944156733, -23.751810424861787, -24.9088896091435, -24.44293953585543, -23.331151799198576, -25.799087598211663, -25.72683651993046, -25.48519331614153, -27.093526224961874, -27.007351803525772, -27.701731039425788, -28.554065742695233, -29.48530243664214, -28.855743050674327, -30.933087043072334, -31.680025416945885, -30.820205556693764, -33.27033854652422, -33.47350243177385, -34.186533699921384, -33.16471915066973, -35.323053269050135, -35.03193254130218, -35.49445872914028, -35.91028138752765, -37.55910680289756, -37.15440219323971, -38.86861034335689, -38.54267862730015, -39.8214291646721, -40.0671002094611, -41.08407230470809, -39.88130992708778, -42.14392256076823, -42.646971743334106, -44.230222897142525, -43.459515365024444, -44.63644938096393, -44.68353023478125, -43.33238239646698, -44.386574144691686, -47.47950487798273, -47.11019966930678, -45.64532647072011, -47.29140500752328, -47.8741527359583, -46.830155827750204, -49.97442249976532, -48.90904201933886, -48.687968198431754, -50.54786356802271, -53.90826980794699, -50.80527241272121, -51.467416347151236, -54.79437858746194, -53.13239027349298, -55.006846225468095, -54.82657088035839, -55.907371006058746, -56.79163455597591, -56.53305517174383, -56.40714707088171], [0.7591034533841987, 1.9826051995642704, 0.16793431858092658, 0.805312735131347, 1.6118763010848847, 2.946838444483057, 3.986470083405372, 4.662304025186486, 5.663232521610033, 6.47764386550193, 8.063336371592063, 8.973610823364023, 9.683577887532342, 10.777672929044504, 11.483516345441842, 12.669825782382464, 13.81617825420758, 14.727078095609041, 15.653286560423263, 16.85465470427706, 17.987686798235003, 18.441188207454093, 19.75095010792246, 20.9754254345274, 21.85535971129151, 22.533179271441444, 23.50227798608833, 25.00778874809557, 25.614353580458648, 27.172987133980325, 27.658779705317993, 28.954196590003786, 30.15981298135688, 30.463482542940092, 31.689433708885595, 32.524210496649246, 33.749313054299236, 34.83697661146893, 35.751436298022114, 36.60985428849034, 37.66136044211557, 38.55513192121144, 39.75641664041576, 40.969800417605995, 41.914429503403355, 42.793784152363386, 43.72267631681878, 44.881413319736005, 45.78253954987451, 47.03942682509191, 47.95070882098185, 48.80238646913053, 49.95797756092825, 50.80882028984611, 51.8271754072294, 52.588018832719314, 54.10576353161209, 54.95576050804802, 55.89297305055758, 56.93988171758835, 58.02306474601656, 58.77058441411522, 60.16084890417911, 60.777300374837594, 61.91383765149841, 62.767107949227004, 63.944942003216745, 64.73034265551848, 65.87382466221956, 66.6603136060696, 67.73754512574801, 69.16076749335465, 69.8769846512889, 70.7057558927324, 71.68907293809872, 72.94513751725212, 73.82240665566322, 74.80419979945611, 75.72663080825859, 76.83568810191117, 78.08416880540659, 78.66852993263308, 79.56930634230059, 80.90007304107715, 81.78419498059216, 82.95122544359755, 84.0829806921834, 85.08294160279907, 85.76766354702576, 86.86449445647895, 87.65391529214156, 88.92242615016434, 89.52614263110287, 90.89244320699623, 91.86356426165837, 92.94479767697617, 93.90322653999661, 94.91265847220282, 95.91238715716095, 96.80811913044283, 97.73116607328464, 98.94966159285758, 99.87780404837258], [0.15253633126529165, -0.36145359607646665, 2.9280508124252953, 0.44250791094032554, -3.6799518142844327, -2.637575384337925, 3.119163033609511, -5.903861852643543, -2.4637571292352485, -4.816611487752867, -0.0077577799217010135, -5.0189534846237605, -1.8634726123276462, -7.709913469136537, -1.5653880871383996, 2.28507271712304, 1.0967970812627015, -4.334732711225383, -2.360323504043661, -0.39052505975241747, -5.616006050348081, -1.4596375055706772, -2.69918271254674, 0.9850547221829933, -1.2274417251958054, -5.560564712256568, -7.5781701998946165, -2.428212736177295, -4.857753561980811, -2.0866409143312015, -7.691703594910914, -4.218549670102667, -10.925917855564652, -8.81457774471066, -6.350847179057185, -5.868837216402981, -0.31264571876556335, -9.164780377262067, -5.120628415520505, -9.278690325085478, -10.99975092471847, -5.957421945498301, -5.224847538211541, -5.591515757241417, -8.243928421143757, -2.0404183877048476, -5.084809273978762, -9.757553467949597, -9.63713618833015, -7.866345955920552, -13.292801501220811, -8.026299189765593, -12.22354699393636, -8.06460770280281, -11.70071071659354, -9.60974123905669, -5.95337086054841, -7.032222067826643, -10.075801891488775, -11.180074524884002, -16.74247715518144, -10.661417484399925, -8.879564959344641, -11.329560330559646, -6.365956064783927, -12.107391549879763, -8.917615099302498, -8.58178416233676, -9.0983027670198, -16.252306312272093, -14.010782067805266, -14.997204857624105, -12.531442981775685, -17.467706632832485, -9.682956913321522, -15.941717677222895, -12.981704646961063, -15.907159314915198, -12.057734353093792, -9.604060313033159, -14.966573074038134, -18.42544143538557, -13.595692084709363, -13.676939354896357, -15.83130940868651, -16.91040917415477, -17.010379505766938, -17.573009441087734, -15.007376906661394, -17.692655622589996, -14.476798257946067, -17.838639229747464, -14.25081695011485, -18.05170466422207, -17.966436001117373, -11.091651588821101, -17.31954264519986, -22.554457294516784, -17.64570979361735, -19.822745317065376, -16.65596740508466, -19.8783069256713, -16.612644786945985]]
- step_size :
- [0.376451, 0.417677, 0.382013, 0.282232]
- metric :
- ['diag_e', 'diag_e', 'diag_e', 'diag_e']
- inv_metric :
- [[0.360859, 0.000475137, 0.00443736], [0.34879, 0.000404285, 0.00467239], [0.280667, 0.000337829, 0.00543657], [0.425118, 0.000556571, 0.00522525]]
- adaptation_info :
- ['# Adaptation terminated\n# Step size = 0.376451\n# Diagonal elements of inverse mass matrix:\n# 0.360859, 0.000475137, 0.00443736\n', '# Adaptation terminated\n# Step size = 0.417677\n# Diagonal elements of inverse mass matrix:\n# 0.34879, 0.000404285, 0.00467239\n', '# Adaptation terminated\n# Step size = 0.382013\n# Diagonal elements of inverse mass matrix:\n# 0.280667, 0.000337829, 0.00543657\n', '# Adaptation terminated\n# Step size = 0.282232\n# Diagonal elements of inverse mass matrix:\n# 0.425118, 0.000556571, 0.00522525\n']
- stan_code :
- data { // Define data for fitting int<lower=0> N; vector[N] x; vector[N] y; } parameters { real b0; real b1; real<lower=0> sigma_e; } model { b0 ~ normal(0, 10); b1 ~ normal(0, 10); sigma_e ~ normal(0, 10); for (i in 1:N) { y[i] ~ normal(b0 + b1 * x[i], sigma_e); // use only data for fitting } } generated quantities { vector[N] y_hat; for (i in 1:N) { // pointwise log likelihood will be calculated outside stan, // posterior predictive however will be generated here, there are // no restrictions on adding more generated quantities y_hat[i] = normal_rng(b0 + b1 * x[i], sigma_e); } }
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: accept_stat (chain, draw) float64 0.9181 0.9607 0.9892 ... 0.9131 0.8482 stepsize (chain, draw) float64 0.3765 0.3765 0.3765 ... 0.2822 0.2822 treedepth (chain, draw) int64 3 3 4 3 2 1 3 4 4 3 ... 1 3 2 3 3 4 4 1 4 4 n_leapfrog (chain, draw) int64 7 11 15 15 3 3 7 15 ... 11 7 15 15 3 15 15 diverging (chain, draw) bool False False False ... False False False energy (chain, draw) float64 159.8 160.4 160.1 ... 156.7 157.6 162.1 lp (chain, draw) float64 -156.2 -157.9 -156.6 ... -156.5 -159.7 Attributes: created_at: 2020-10-06T01:27:54.851944 arviz_version: 0.10.0 inference_library: pystan inference_library_version: 2.19.1.1 args: [{'random_seed': '345470392', 'chain_id': 0, ... inits: [[1.7126106914526407, -0.030748403931205592, ... step_size: [0.376451, 0.417677, 0.382013, 0.282232] metric: ['diag_e', 'diag_e', 'diag_e', 'diag_e'] inv_metric: [[0.360859, 0.000475137, 0.00443736], [0.3487... adaptation_info: ['# Adaptation terminated\n# Step size = 0.37... stan_code: \ndata {\n // Define data for fitting\n int...
xarray.Dataset -
- time: 100
- time(time)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- y(time)float64-1.412 -7.319 1.151 ... 48.52 46.03
array([-1.41202037, -7.3186587 , 1.15145745, 1.50159596, -0.66638434, 1.340505 , 0.94309618, -3.74339279, -0.34243761, 4.41332204, 3.52852595, -0.38735502, 6.92937569, 2.17738437, 2.29506712, 2.479826 , 4.32780469, 14.8236344 , 8.58376674, 13.86029246, 8.30748541, 9.42697384, 6.20130931, 12.74674786, 14.49628457, 12.37415809, 13.04994867, 12.27711609, 13.04606435, 15.07724923, 16.25148031, 10.18710661, 12.24013837, 13.51964002, 9.40025182, 20.11401051, 19.57963549, 17.26609676, 16.39293544, 15.4848222 , 19.07510716, 19.94747454, 18.05554676, 18.95703705, 21.97194249, 18.55533794, 21.38972486, 17.64270549, 22.38207915, 20.23227438, 22.3752402 , 22.69176278, 25.10069955, 22.29368553, 25.38288326, 27.6663142 , 26.60546597, 20.45069871, 27.89511126, 27.62673933, 25.45690863, 25.41379887, 28.68450485, 36.54353412, 25.65553597, 29.01507728, 30.97776362, 35.17952383, 31.07761309, 38.35764652, 32.82119153, 32.72583667, 34.15217468, 34.91746821, 39.98665155, 32.85510289, 35.76383771, 37.94409775, 37.02228539, 37.94068802, 42.21713708, 36.9942534 , 36.55582315, 36.82877733, 42.81862081, 37.00939662, 41.67250008, 37.28144053, 44.59191824, 44.05883374, 42.87139157, 47.30850894, 48.57268519, 46.25413295, 51.7918344 , 48.79686829, 51.28945751, 48.48599342, 48.52212075, 46.03052542])
- created_at :
- 2020-10-06T01:27:54.841473
- arviz_version :
- 0.10.0
- inference_library :
- pystan
- inference_library_version :
- 2.19.1.1
<xarray.Dataset> Dimensions: (time: 100) Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 Data variables: y (time) float64 -1.412 -7.319 1.151 1.502 ... 48.49 48.52 46.03 Attributes: created_at: 2020-10-06T01:27:54.841473 arviz_version: 0.10.0 inference_library: pystan inference_library_version: 2.19.1.1
xarray.Dataset -
- time: 100
- time(time)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- x(time)float640.0 0.5051 1.01 ... 49.49 50.0
array([ 0. , 0.50505051, 1.01010101, 1.51515152, 2.02020202, 2.52525253, 3.03030303, 3.53535354, 4.04040404, 4.54545455, 5.05050505, 5.55555556, 6.06060606, 6.56565657, 7.07070707, 7.57575758, 8.08080808, 8.58585859, 9.09090909, 9.5959596 , 10.1010101 , 10.60606061, 11.11111111, 11.61616162, 12.12121212, 12.62626263, 13.13131313, 13.63636364, 14.14141414, 14.64646465, 15.15151515, 15.65656566, 16.16161616, 16.66666667, 17.17171717, 17.67676768, 18.18181818, 18.68686869, 19.19191919, 19.6969697 , 20.2020202 , 20.70707071, 21.21212121, 21.71717172, 22.22222222, 22.72727273, 23.23232323, 23.73737374, 24.24242424, 24.74747475, 25.25252525, 25.75757576, 26.26262626, 26.76767677, 27.27272727, 27.77777778, 28.28282828, 28.78787879, 29.29292929, 29.7979798 , 30.3030303 , 30.80808081, 31.31313131, 31.81818182, 32.32323232, 32.82828283, 33.33333333, 33.83838384, 34.34343434, 34.84848485, 35.35353535, 35.85858586, 36.36363636, 36.86868687, 37.37373737, 37.87878788, 38.38383838, 38.88888889, 39.39393939, 39.8989899 , 40.4040404 , 40.90909091, 41.41414141, 41.91919192, 42.42424242, 42.92929293, 43.43434343, 43.93939394, 44.44444444, 44.94949495, 45.45454545, 45.95959596, 46.46464646, 46.96969697, 47.47474747, 47.97979798, 48.48484848, 48.98989899, 49.49494949, 50. ])
- created_at :
- 2020-10-06T01:27:54.842743
- arviz_version :
- 0.10.0
- inference_library :
- pystan
- inference_library_version :
- 2.19.1.1
<xarray.Dataset> Dimensions: (time: 100) Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 Data variables: x (time) float64 0.0 0.5051 1.01 1.515 ... 48.48 48.99 49.49 50.0 Attributes: created_at: 2020-10-06T01:27:54.842743 arviz_version: 0.10.0 inference_library: pystan inference_library_version: 2.19.1.1
xarray.Dataset -
- chain: 4
- draw: 500
- time: 100
- time(time)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
- x(time, chain, draw)float64-2.066 -2.272 ... -2.53 -2.504
array([[[-2.06613587, -2.27224873, -1.93141031, ..., -2.02189885, -1.87833916, -2.03796724], [-2.10183686, -2.11922292, -1.9539534 , ..., -2.00373665, -2.0086278 , -2.01849919], [-2.01414603, -2.00720986, -2.00776791, ..., -2.02246979, -2.0541148 , -2.05775894], [-2.09426385, -2.08480429, -2.15634047, ..., -1.91463974, -2.22255144, -2.2564248 ]], [[-3.9653583 , -3.07845137, -4.16154214, ..., -3.75536919, -4.44141338, -3.94895912], [-3.48275917, -3.44536726, -4.09639708, ..., -3.82873285, -3.82205147, -3.754851 ], [-3.76334136, -3.78926207, -4.23100701, ..., -4.19449868, -3.98581789, -3.66467314], [-3.78446115, -3.81766954, -3.31583137, ..., -4.62938265, -3.16109007, -3.30396731]], [[-2.23455997, -2.62459603, -2.29054081, ..., -2.40774334, -2.23107081, -2.25157079], ... [-2.21667577, -2.12953375, -1.9615624 , ..., -2.06821034, -2.06706173, -2.23653893]], [[-2.06893725, -2.13251227, -1.87962108, ..., -1.94454713, -1.87310459, -2.05976964], [-1.9983384 , -2.10417813, -1.91230973, ..., -1.93687083, -1.93641759, -1.91116912], [-1.88195507, -1.87522023, -2.1059045 , ..., -2.10408599, -2.21866659, -2.00165327], [-2.15157242, -2.09040809, -1.94780521, ..., -1.987641 , -2.05546711, -2.21947147]], [[-2.41112142, -2.41551375, -2.53082343, ..., -2.38050058, -2.28750115, -2.29069769], [-2.38288638, -2.19294608, -2.56848047, ..., -2.36814641, -2.61401326, -2.62103004], [-2.47616385, -2.43319128, -2.16428596, ..., -2.19356027, -2.12220576, -2.73873861], [-2.23984825, -2.36019642, -2.54084624, ..., -2.17460308, -2.53040691, -2.5040681 ]]])
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, time: 100) Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 Data variables: x (time, chain, draw) float64 -2.066 -2.272 -1.931 ... -2.53 -2.504
xarray.Dataset
We will create a subclass of SamplingWrapper
. Therefore, instead of having to implement all functions required by reloo()
we only have to implement sel_observations()
(we are cloning sample()
and get_inference_data()
from the SamplingWrapper
in order to use apply_ufunc
instead of assuming the log likelihood is calculated within Stan).
Let’s check the 2 outputs of sel_observations
.
data__i
is a dictionary because it is an argument ofsample
which will pass it as is tomodel.sampling
.data_ex
is a list because it is an argument tolog_likelihood__i
which will pass it as*data_ex
toapply_ufunc
.
More on data_ex
and apply_ufunc
integration is given below.
class LinearRegressionWrapper(az.SamplingWrapper):
def sel_observations(self, idx):
xdata = self.idata_orig.constant_data["x"]
ydata = self.idata_orig.observed_data["y"]
mask = np.isin(np.arange(len(xdata)), idx)
data__i = {"x": xdata[~mask], "y": ydata[~mask], "N": len(ydata[~mask])}
data_ex = [ary[mask] for ary in (xdata, ydata)]
return data__i, data_ex
def sample(self, modified_observed_data):
#Cloned from PyStan2SamplingWrapper.
fit = self.model.sampling(data=modified_observed_data, **self.sample_kwargs)
return fit
def get_inference_data(self, fit):
# Cloned from PyStan2SamplingWrapper.
idata = az.from_pystan(posterior=fit, **self.idata_kwargs)
return idata
loo_orig = az.loo(idata, pointwise=True)
loo_orig
Computed from 2000 by 100 log-likelihood matrix
Estimate SE
elpd_loo -250.79 7.12
p_loo 2.95 -
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.5] (good) 100 100.0%
(0.5, 0.7] (ok) 0 0.0%
(0.7, 1] (bad) 0 0.0%
(1, Inf) (very bad) 0 0.0%
In this case, the Leave-One-Out Cross Validation (LOO-CV) approximation using Pareto Smoothed Importance Sampling (PSIS) works for all observations, so we will use modify loo_orig
in order to make reloo()
believe that PSIS failed for some observations. This will also serve as a validation of our wrapper, as the PSIS LOO-CV already returned the correct value.
loo_orig.pareto_k[[13, 42, 56, 73]] = np.array([0.8, 1.2, 2.6, 0.9])
We initialize our sampling wrapper. Let’s stop and analyze each of the arguments.
We use the log_lik_fun
and posterior_vars
argument to tell the wrapper how to call apply_ufunc()
. log_lik_fun
is the function to be called, which is then called with the following positional arguments:
log_lik_fun(*data_ex, *[idata__i.posterior[var_name] for var_name in posterior_vars]
where data_ex
is the second element returned by sel_observations
and idata__i
is the InferenceData
object result of get_inference_data
which contains the fit on the subsetted data. We have generated data_ex
to be a tuple of DataArrays so it plays nicely with this call signature.
We use idata_orig
as a starting point, and mostly as a source of observed and constant data which is then subsetted in sel_observations
.
Finally, sample_kwargs
and idata_kwargs
are used to make sure all refits and corresponding InferenceData
are generated with the same properties.
pystan_wrapper = LinearRegressionWrapper(
sm,
log_lik_fun=calculate_log_lik,
posterior_vars=("b0", "b1", "sigma_e"),
idata_orig=idata,
sample_kwargs=sample_kwargs,
idata_kwargs=idata_kwargs
)
And eventually, we can use this wrapper to call az.reloo
, and compare the results with the PSIS LOO-CV results.
loo_relooed = az.reloo(pystan_wrapper, loo_orig=loo_orig)
/home/oriol/miniconda3/envs/arviz/lib/python3.8/site-packages/arviz/stats/stats_refitting.py:99: UserWarning: reloo is an experimental and untested feature
warnings.warn("reloo is an experimental and untested feature", UserWarning)
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 13
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 13
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 42
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 42
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 56
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 56
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 73
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 73
loo_relooed
Computed from 2000 by 100 log-likelihood matrix
Estimate SE
elpd_loo -250.79 7.12
p_loo 2.95 -
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.5] (good) 100 100.0%
(0.5, 0.7] (ok) 0 0.0%
(0.7, 1] (bad) 0 0.0%
(1, Inf) (very bad) 0 0.0%
loo_orig
Computed from 2000 by 100 log-likelihood matrix
Estimate SE
elpd_loo -250.79 7.12
p_loo 2.95 -
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.5] (good) 96 96.0%
(0.5, 0.7] (ok) 0 0.0%
(0.7, 1] (bad) 2 2.0%
(1, Inf) (very bad) 2 2.0%