Refitting PyMC3 models with ArviZ (and xarray)¶
ArviZ is backend agnostic and therefore does not sample directly. In order to take advantage of algorithms that require refitting models several times, ArviZ uses SamplingWrappers
to convert the API of the sampling backend to a common set of functions. Hence, functions like Leave Future Out Cross Validation can be used in ArviZ independently of the sampling backend used.
Below there is one example of SamplingWrapper
usage for PyMC3.
Before starting, it is important to note that PyMC3 cannot modify the shapes of the input data using the same compiled model. Thus, each refitting will require a recompilation of the model.
import arviz as az
import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import xarray as xr
For the example we will use a linear regression.
np.random.seed(26)
xdata = np.linspace(0, 50, 100)
b0, b1, sigma = -2, 1, 3
ydata = np.random.normal(loc=b1 * xdata + b0, scale=sigma)
plt.plot(xdata, ydata);
Now we will write the PyMC3 model, keeping in mind that 1) data must be modifiable (both x
and y
) and 2) the model must be recompiled in order to be refitted with the modified data. We therefore have to create a function that recompiles the model when it’s called. Luckily for us, compilation in PyMC3 is generally quite fast.
def compile_linreg_model(xdata, ydata):
with pm.Model() as model:
x = pm.Data("x", xdata)
b0 = pm.Normal("b0", 0, 10)
b1 = pm.Normal("b1", 0, 10)
sigma_e = pm.HalfNormal("sigma_e", 10)
y = pm.Normal("y", b0 + b1 * x, sigma_e, observed=ydata)
return model
sample_kwargs = {"draws": 500, "tune": 500, "chains": 4}
with compile_linreg_model(xdata, ydata) as linreg_model:
trace = pm.sample(**sample_kwargs)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma_e, b1, b0]
Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 3 seconds.
The acceptance probability does not match the target. It is 0.917825834816141, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8850799498280131, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8818306765045102, but should be close to 0.8. Try to increase the number of tuning steps.
We have defined a dictionary sample_kwargs
that will be passed to the SamplingWrapper
in order to make sure that all refits use the same sampler parameters.
We follow the same pattern with az.from_pymc3
.
Note however, how coords
are not set. This is done to prevent errors due to coordinates and values shapes being incompatible during refits. Otherwise we’d have to handle subsetting of the coordinate values even though the refits are never used outside the refitting functions such as reloo
.
We also exclude the model
because the model
, like the trace
, is different for every refit. This may seem counterintuitive or even plain wrong, but we have to remember that the pm.Model
object contains information like the observed data.
dims = {"y": ["time"], "x": ["time"]}
idata_kwargs = {
"dims": dims,
"log_likelihood": False,
}
idata = az.from_pymc3(trace, model=linreg_model, **idata_kwargs)
idata
-
- chain: 4
- draw: 500
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
- b0(chain, draw)float64-2.545 -2.677 ... -2.545 -2.89
array([[-2.54545049, -2.67741862, -2.08761456, ..., -3.09110505, -2.77156713, -1.82933111], [-2.93703463, -3.08086967, -2.67962367, ..., -2.89842373, -2.65133177, -2.49119487], [-3.04911494, -3.04911494, -3.40933706, ..., -2.3201444 , -1.36054491, -1.38022529], [-2.48907267, -1.4442294 , -1.15349324, ..., -3.02343449, -2.54474655, -2.89029106]])
- b1(chain, draw)float641.016 1.02 0.9981 ... 1.023 1.012
array([[1.01622524, 1.02045325, 0.99806651, ..., 1.03522927, 1.04665622, 1.00653981], [1.0456625 , 1.04366705, 1.03454016, ..., 1.00894854, 1.01218477, 1.01874484], [1.05923249, 1.05923249, 1.03547951, ..., 1.01666033, 0.98471165, 0.9845449 ], [1.02269411, 0.98692974, 0.97328788, ..., 1.03426398, 1.02267379, 1.01200097]])
- sigma_e(chain, draw)float642.772 2.993 2.867 ... 2.838 3.039
array([[2.77234909, 2.9932014 , 2.86667421, ..., 3.10750173, 3.22347532, 2.87163973], [3.21155286, 3.32651749, 2.71789539, ..., 2.92579964, 3.04690959, 2.90988686], [3.00019757, 3.00019757, 2.84787176, ..., 2.78722998, 2.75353244, 2.75335359], [2.79974824, 2.82024413, 2.6589687 , ..., 2.80048563, 2.83832588, 3.03888294]])
- created_at :
- 2020-10-06T00:56:59.863225
- arviz_version :
- 0.10.0
- inference_library :
- pymc3
- inference_library_version :
- 3.9.3
- sampling_time :
- 3.375001907348633
- tuning_steps :
- 500
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 Data variables: b0 (chain, draw) float64 -2.545 -2.677 -2.088 ... -3.023 -2.545 -2.89 b1 (chain, draw) float64 1.016 1.02 0.9981 1.022 ... 1.034 1.023 1.012 sigma_e (chain, draw) float64 2.772 2.993 2.867 2.833 ... 2.8 2.838 3.039 Attributes: created_at: 2020-10-06T00:56:59.863225 arviz_version: 0.10.0 inference_library: pymc3 inference_library_version: 3.9.3 sampling_time: 3.375001907348633 tuning_steps: 500
xarray.Dataset -
- chain: 4
- draw: 500
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
- step_size_bar(chain, draw)float640.4844 0.4844 ... 0.5218 0.5218
array([[0.48441936, 0.48441936, 0.48441936, ..., 0.48441936, 0.48441936, 0.48441936], [0.48473255, 0.48473255, 0.48473255, ..., 0.48473255, 0.48473255, 0.48473255], [0.51576623, 0.51576623, 0.51576623, ..., 0.51576623, 0.51576623, 0.51576623], [0.52181729, 0.52181729, 0.52181729, ..., 0.52181729, 0.52181729, 0.52181729]])
- step_size(chain, draw)float640.6579 0.6579 ... 0.5572 0.5572
array([[0.65787679, 0.65787679, 0.65787679, ..., 0.65787679, 0.65787679, 0.65787679], [0.44637199, 0.44637199, 0.44637199, ..., 0.44637199, 0.44637199, 0.44637199], [0.48842612, 0.48842612, 0.48842612, ..., 0.48842612, 0.48842612, 0.48842612], [0.55721811, 0.55721811, 0.55721811, ..., 0.55721811, 0.55721811, 0.55721811]])
- energy_error(chain, draw)float640.1089 -0.01555 ... -0.1793 1.522
array([[ 0.108865 , -0.0155471 , 0.07155788, ..., 0.04296501, 0.15742812, -0.33332631], [-0.02959823, -0.00236713, -0.09980043, ..., 1.42175158, -1.087928 , -0.45451202], [ 0.04905213, 0. , 0.19387839, ..., 0.00296038, 0.07487397, -0.00271838], [-0.21919158, 0.06004846, -0.05758929, ..., 0.17061896, -0.17928858, 1.52185661]])
- process_time_diff(chain, draw)float640.0005479 0.001 ... 0.0017 0.001683
array([[0.00054787, 0.00099997, 0.00185223, ..., 0.00099977, 0.00054254, 0.00100001], [0.00084784, 0.00108633, 0.0020409 , ..., 0.00052323, 0.00051445, 0.00103059], [0.00201625, 0.00059325, 0.00093283, ..., 0.00102077, 0.0011406 , 0.00054941], [0.0010606 , 0.00102543, 0.00055379, ..., 0.00172862, 0.00169956, 0.00168284]])
- max_energy_error(chain, draw)float640.1089 -0.1235 ... -0.1793 2.335
array([[ 0.108865 , -0.12354286, -0.14185415, ..., 0.09598741, 0.53482382, -0.33332631], [-0.16844692, 0.2400077 , 0.11562474, ..., 1.42175158, -1.59470083, -0.5532092 ], [-0.36648334, 1.47743899, -0.46971553, ..., 0.05945543, 0.27683305, 0.12141349], [ 0.64090009, 0.1046269 , 1.42605221, ..., 0.17061896, -0.17928858, 2.33507852]])
- perf_counter_start(chain, draw)float643.84e+04 3.84e+04 ... 3.84e+04
array([[38398.32515598, 38398.32584412, 38398.32698033, ..., 38398.92606505, 38398.92720005, 38398.92787772], [38397.30786484, 38397.30896574, 38397.31030007, ..., 38397.94424016, 38397.94489865, 38397.94554733], [38397.31111097, 38397.31335491, 38397.31413613, ..., 38397.95092883, 38397.95580027, 38397.95710175], [38397.10391414, 38397.10512732, 38397.10629472, ..., 38397.77145488, 38397.77339404, 38397.77530902]])
- depth(chain, draw)int642 3 4 4 4 4 3 3 ... 3 3 3 2 3 3 3 3
array([[2, 3, 4, ..., 3, 2, 3], [2, 2, 3, ..., 2, 2, 3], [4, 2, 2, ..., 3, 3, 2], [3, 3, 2, ..., 3, 3, 3]])
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
- perf_counter_diff(chain, draw)float640.0005474 0.0009994 ... 0.001682
array([[0.00054738, 0.00099937, 0.00185167, ..., 0.00099918, 0.00054206, 0.00099944], [0.00084672, 0.00108562, 0.00204043, ..., 0.00052271, 0.00051389, 0.0010299 ], [0.00201547, 0.00059255, 0.00093194, ..., 0.00101994, 0.00114038, 0.00054889], [0.00106005, 0.00102457, 0.00055308, ..., 0.00172803, 0.00169889, 0.00168203]])
- energy(chain, draw)float64256.5 256.4 256.8 ... 257.8 260.0
array([[256.53901717, 256.41152341, 256.81453713, ..., 259.09599776, 259.21878784, 258.71725041], [257.79794939, 259.30051126, 258.76059782, ..., 259.99556199, 259.3207446 , 257.320887 ], [259.358463 , 260.9352891 , 259.32214557, ..., 256.03515402, 259.49267325, 258.08906042], [257.4111622 , 258.72059479, 261.61061224, ..., 257.04211674, 257.7631332 , 260.04492158]])
- mean_tree_accept(chain, draw)float640.935 1.0 0.9774 ... 0.9993 0.4627
array([[0.93504528, 1. , 0.97735748, ..., 0.96845534, 0.81157022, 0.96331078], [1. , 0.89644474, 0.96742349, ..., 0.24129101, 1. , 0.99619178], [0.94804989, 0.40516219, 0.96041165, ..., 0.97067921, 0.84005685, 0.96174938], [0.85091472, 0.93431431, 0.75282966, ..., 0.91894568, 0.99930385, 0.46272935]])
- lp(chain, draw)float64-256.3 -256.3 ... -255.9 -258.6
array([[-256.30865316, -256.26391717, -256.53366773, ..., -257.10543575, -258.10563783, -256.1625731 ], [-257.6338782 , -258.21964153, -256.62948676, ..., -259.38536792, -257.10688257, -255.8261813 ], [-258.56978045, -258.56978045, -258.66976508, ..., -255.86326273, -257.55053431, -257.52311936], [-255.88563267, -257.0566285 , -259.34080164, ..., -256.73582417, -255.85048569, -258.61354926]])
- tree_size(chain, draw)float643.0 7.0 15.0 11.0 ... 7.0 7.0 7.0
array([[ 3., 7., 15., ..., 7., 3., 7.], [ 3., 3., 7., ..., 3., 3., 7.], [15., 3., 3., ..., 7., 7., 3.], [ 7., 7., 3., ..., 7., 7., 7.]])
- created_at :
- 2020-10-06T00:56:59.868643
- arviz_version :
- 0.10.0
- inference_library :
- pymc3
- inference_library_version :
- 3.9.3
- sampling_time :
- 3.375001907348633
- tuning_steps :
- 500
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: step_size_bar (chain, draw) float64 0.4844 0.4844 ... 0.5218 0.5218 step_size (chain, draw) float64 0.6579 0.6579 ... 0.5572 0.5572 energy_error (chain, draw) float64 0.1089 -0.01555 ... -0.1793 1.522 process_time_diff (chain, draw) float64 0.0005479 0.001 ... 0.001683 max_energy_error (chain, draw) float64 0.1089 -0.1235 ... -0.1793 2.335 perf_counter_start (chain, draw) float64 3.84e+04 3.84e+04 ... 3.84e+04 depth (chain, draw) int64 2 3 4 4 4 4 3 3 ... 3 3 3 2 3 3 3 3 diverging (chain, draw) bool False False False ... False False perf_counter_diff (chain, draw) float64 0.0005474 0.0009994 ... 0.001682 energy (chain, draw) float64 256.5 256.4 256.8 ... 257.8 260.0 mean_tree_accept (chain, draw) float64 0.935 1.0 0.9774 ... 0.9993 0.4627 lp (chain, draw) float64 -256.3 -256.3 ... -255.9 -258.6 tree_size (chain, draw) float64 3.0 7.0 15.0 11.0 ... 7.0 7.0 7.0 Attributes: created_at: 2020-10-06T00:56:59.868643 arviz_version: 0.10.0 inference_library: pymc3 inference_library_version: 3.9.3 sampling_time: 3.375001907348633 tuning_steps: 500
xarray.Dataset -
- time: 100
- time(time)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- y(time)float64-1.412 -7.319 1.151 ... 48.52 46.03
array([-1.41202037, -7.3186587 , 1.15145745, 1.50159596, -0.66638434, 1.340505 , 0.94309618, -3.74339279, -0.34243761, 4.41332204, 3.52852595, -0.38735502, 6.92937569, 2.17738437, 2.29506712, 2.479826 , 4.32780469, 14.8236344 , 8.58376674, 13.86029246, 8.30748541, 9.42697384, 6.20130931, 12.74674786, 14.49628457, 12.37415809, 13.04994867, 12.27711609, 13.04606435, 15.07724923, 16.25148031, 10.18710661, 12.24013837, 13.51964002, 9.40025182, 20.11401051, 19.57963549, 17.26609676, 16.39293544, 15.4848222 , 19.07510716, 19.94747454, 18.05554676, 18.95703705, 21.97194249, 18.55533794, 21.38972486, 17.64270549, 22.38207915, 20.23227438, 22.3752402 , 22.69176278, 25.10069955, 22.29368553, 25.38288326, 27.6663142 , 26.60546597, 20.45069871, 27.89511126, 27.62673933, 25.45690863, 25.41379887, 28.68450485, 36.54353412, 25.65553597, 29.01507728, 30.97776362, 35.17952383, 31.07761309, 38.35764652, 32.82119153, 32.72583667, 34.15217468, 34.91746821, 39.98665155, 32.85510289, 35.76383771, 37.94409775, 37.02228539, 37.94068802, 42.21713708, 36.9942534 , 36.55582315, 36.82877733, 42.81862081, 37.00939662, 41.67250008, 37.28144053, 44.59191824, 44.05883374, 42.87139157, 47.30850894, 48.57268519, 46.25413295, 51.7918344 , 48.79686829, 51.28945751, 48.48599342, 48.52212075, 46.03052542])
- created_at :
- 2020-10-06T00:56:59.872278
- arviz_version :
- 0.10.0
- inference_library :
- pymc3
- inference_library_version :
- 3.9.3
<xarray.Dataset> Dimensions: (time: 100) Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 Data variables: y (time) float64 -1.412 -7.319 1.151 1.502 ... 48.49 48.52 46.03 Attributes: created_at: 2020-10-06T00:56:59.872278 arviz_version: 0.10.0 inference_library: pymc3 inference_library_version: 3.9.3
xarray.Dataset -
- time: 100
- time(time)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- x(time)float640.0 0.5051 1.01 ... 49.49 50.0
array([ 0. , 0.50505051, 1.01010101, 1.51515152, 2.02020202, 2.52525253, 3.03030303, 3.53535354, 4.04040404, 4.54545455, 5.05050505, 5.55555556, 6.06060606, 6.56565657, 7.07070707, 7.57575758, 8.08080808, 8.58585859, 9.09090909, 9.5959596 , 10.1010101 , 10.60606061, 11.11111111, 11.61616162, 12.12121212, 12.62626263, 13.13131313, 13.63636364, 14.14141414, 14.64646465, 15.15151515, 15.65656566, 16.16161616, 16.66666667, 17.17171717, 17.67676768, 18.18181818, 18.68686869, 19.19191919, 19.6969697 , 20.2020202 , 20.70707071, 21.21212121, 21.71717172, 22.22222222, 22.72727273, 23.23232323, 23.73737374, 24.24242424, 24.74747475, 25.25252525, 25.75757576, 26.26262626, 26.76767677, 27.27272727, 27.77777778, 28.28282828, 28.78787879, 29.29292929, 29.7979798 , 30.3030303 , 30.80808081, 31.31313131, 31.81818182, 32.32323232, 32.82828283, 33.33333333, 33.83838384, 34.34343434, 34.84848485, 35.35353535, 35.85858586, 36.36363636, 36.86868687, 37.37373737, 37.87878788, 38.38383838, 38.88888889, 39.39393939, 39.8989899 , 40.4040404 , 40.90909091, 41.41414141, 41.91919192, 42.42424242, 42.92929293, 43.43434343, 43.93939394, 44.44444444, 44.94949495, 45.45454545, 45.95959596, 46.46464646, 46.96969697, 47.47474747, 47.97979798, 48.48484848, 48.98989899, 49.49494949, 50. ])
- created_at :
- 2020-10-06T00:56:59.872872
- arviz_version :
- 0.10.0
- inference_library :
- pymc3
- inference_library_version :
- 3.9.3
<xarray.Dataset> Dimensions: (time: 100) Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 Data variables: x (time) float64 0.0 0.5051 1.01 1.515 ... 48.48 48.99 49.49 50.0 Attributes: created_at: 2020-10-06T00:56:59.872872 arviz_version: 0.10.0 inference_library: pymc3 inference_library_version: 3.9.3
xarray.Dataset
We are now missing the log_likelihood
group due to setting log_likelihood=False
in idata_kwargs
. We are doing this to ease the job of the sampling wrapper. Instead of going out of our way to get PyMC3 to calculate the pointwise log likelihood values for each refit and for the excluded observation at every refit, we will compromise and manually write a function to calculate the pointwise log likelihood.
Even though it is not ideal to lose part of the straight out of the box capabilities of PyMC3, this should generally not be a problem. In fact, other PPLs such as Stan always require writing the pointwise log likelihood values manually (either within the Stan code or in Python). Moreover, computing the pointwise log likelihood in Python using xarray will be more efficient in computational terms than the automatic extraction from PyMC3.
It could even be written to be compatible with Dask. Thus it will work even in cases where the large number of observations makes it impossible to store pointwise log likelihood values (with shape n_samples * n_observations
) in memory.
def calculate_log_lik(x, y, b0, b1, sigma_e):
mu = b0 + b1 * x
return stats.norm(mu, sigma_e).logpdf(y)
This function should work for any shape of the input arrays as long as their shapes are compatible and can broadcast. There is no need to loop over each draw in order to calculate the pointwise log likelihood using scalars.
Therefore, we can use xr.apply_ufunc
to handle the broadasting and preserve the dimension names:
log_lik = xr.apply_ufunc(
calculate_log_lik,
idata.constant_data["x"],
idata.observed_data["y"],
idata.posterior["b0"],
idata.posterior["b1"],
idata.posterior["sigma_e"],
)
idata.add_groups(log_likelihood=log_lik)
The first argument is the function, followed by as many positional arguments as needed by the function, 5 in our case. As this case does not have many different dimensions nor combinations of these, we do not need to use any extra kwargs passed to xr.apply_ufunc
.
We are now passing the arguments to calculate_log_lik
initially as xr.DataArrays
. What is happening here behind the scenes is that xr.apply_ufunc
is broadcasting and aligning the dimensions of all the DataArrays involved and afterwards passing numpy arrays to calculate_log_lik
. Everything works automagically.
Now let’s see what happens if we were to pass the arrays directly to calculate_log_lik
instead:
calculate_log_lik(
idata.constant_data["x"].values,
idata.observed_data["y"].values,
idata.posterior["b0"].values,
idata.posterior["b1"].values,
idata.posterior["sigma_e"].values
)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-10-fc2d553bde92> in <module>
----> 1 calculate_log_lik(
2 idata.constant_data["x"].values,
3 idata.observed_data["y"].values,
4 idata.posterior["b0"].values,
5 idata.posterior["b1"].values,
<ipython-input-8-e6777d985e1f> in calculate_log_lik(x, y, b0, b1, sigma_e)
1 def calculate_log_lik(x, y, b0, b1, sigma_e):
----> 2 mu = b0 + b1 * x
3 return stats.norm(mu, sigma_e).logpdf(y)
ValueError: operands could not be broadcast together with shapes (4,500) (100,)
If you are still curious about the magic of xarray and xr.apply_ufunc
, you can also try to modify the dims
used to generate the InferenceData a couple cells before:
dims = {"y": ["time"], "x": ["time"]}
What happens to the result if you use a different name for the dimension of x
?
idata
-
- chain: 4
- draw: 500
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
- b0(chain, draw)float64-2.545 -2.677 ... -2.545 -2.89
array([[-2.54545049, -2.67741862, -2.08761456, ..., -3.09110505, -2.77156713, -1.82933111], [-2.93703463, -3.08086967, -2.67962367, ..., -2.89842373, -2.65133177, -2.49119487], [-3.04911494, -3.04911494, -3.40933706, ..., -2.3201444 , -1.36054491, -1.38022529], [-2.48907267, -1.4442294 , -1.15349324, ..., -3.02343449, -2.54474655, -2.89029106]])
- b1(chain, draw)float641.016 1.02 0.9981 ... 1.023 1.012
array([[1.01622524, 1.02045325, 0.99806651, ..., 1.03522927, 1.04665622, 1.00653981], [1.0456625 , 1.04366705, 1.03454016, ..., 1.00894854, 1.01218477, 1.01874484], [1.05923249, 1.05923249, 1.03547951, ..., 1.01666033, 0.98471165, 0.9845449 ], [1.02269411, 0.98692974, 0.97328788, ..., 1.03426398, 1.02267379, 1.01200097]])
- sigma_e(chain, draw)float642.772 2.993 2.867 ... 2.838 3.039
array([[2.77234909, 2.9932014 , 2.86667421, ..., 3.10750173, 3.22347532, 2.87163973], [3.21155286, 3.32651749, 2.71789539, ..., 2.92579964, 3.04690959, 2.90988686], [3.00019757, 3.00019757, 2.84787176, ..., 2.78722998, 2.75353244, 2.75335359], [2.79974824, 2.82024413, 2.6589687 , ..., 2.80048563, 2.83832588, 3.03888294]])
- created_at :
- 2020-10-06T00:56:59.863225
- arviz_version :
- 0.10.0
- inference_library :
- pymc3
- inference_library_version :
- 3.9.3
- sampling_time :
- 3.375001907348633
- tuning_steps :
- 500
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 Data variables: b0 (chain, draw) float64 -2.545 -2.677 -2.088 ... -3.023 -2.545 -2.89 b1 (chain, draw) float64 1.016 1.02 0.9981 1.022 ... 1.034 1.023 1.012 sigma_e (chain, draw) float64 2.772 2.993 2.867 2.833 ... 2.8 2.838 3.039 Attributes: created_at: 2020-10-06T00:56:59.863225 arviz_version: 0.10.0 inference_library: pymc3 inference_library_version: 3.9.3 sampling_time: 3.375001907348633 tuning_steps: 500
xarray.Dataset -
- chain: 4
- draw: 500
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
- step_size_bar(chain, draw)float640.4844 0.4844 ... 0.5218 0.5218
array([[0.48441936, 0.48441936, 0.48441936, ..., 0.48441936, 0.48441936, 0.48441936], [0.48473255, 0.48473255, 0.48473255, ..., 0.48473255, 0.48473255, 0.48473255], [0.51576623, 0.51576623, 0.51576623, ..., 0.51576623, 0.51576623, 0.51576623], [0.52181729, 0.52181729, 0.52181729, ..., 0.52181729, 0.52181729, 0.52181729]])
- step_size(chain, draw)float640.6579 0.6579 ... 0.5572 0.5572
array([[0.65787679, 0.65787679, 0.65787679, ..., 0.65787679, 0.65787679, 0.65787679], [0.44637199, 0.44637199, 0.44637199, ..., 0.44637199, 0.44637199, 0.44637199], [0.48842612, 0.48842612, 0.48842612, ..., 0.48842612, 0.48842612, 0.48842612], [0.55721811, 0.55721811, 0.55721811, ..., 0.55721811, 0.55721811, 0.55721811]])
- energy_error(chain, draw)float640.1089 -0.01555 ... -0.1793 1.522
array([[ 0.108865 , -0.0155471 , 0.07155788, ..., 0.04296501, 0.15742812, -0.33332631], [-0.02959823, -0.00236713, -0.09980043, ..., 1.42175158, -1.087928 , -0.45451202], [ 0.04905213, 0. , 0.19387839, ..., 0.00296038, 0.07487397, -0.00271838], [-0.21919158, 0.06004846, -0.05758929, ..., 0.17061896, -0.17928858, 1.52185661]])
- process_time_diff(chain, draw)float640.0005479 0.001 ... 0.0017 0.001683
array([[0.00054787, 0.00099997, 0.00185223, ..., 0.00099977, 0.00054254, 0.00100001], [0.00084784, 0.00108633, 0.0020409 , ..., 0.00052323, 0.00051445, 0.00103059], [0.00201625, 0.00059325, 0.00093283, ..., 0.00102077, 0.0011406 , 0.00054941], [0.0010606 , 0.00102543, 0.00055379, ..., 0.00172862, 0.00169956, 0.00168284]])
- max_energy_error(chain, draw)float640.1089 -0.1235 ... -0.1793 2.335
array([[ 0.108865 , -0.12354286, -0.14185415, ..., 0.09598741, 0.53482382, -0.33332631], [-0.16844692, 0.2400077 , 0.11562474, ..., 1.42175158, -1.59470083, -0.5532092 ], [-0.36648334, 1.47743899, -0.46971553, ..., 0.05945543, 0.27683305, 0.12141349], [ 0.64090009, 0.1046269 , 1.42605221, ..., 0.17061896, -0.17928858, 2.33507852]])
- perf_counter_start(chain, draw)float643.84e+04 3.84e+04 ... 3.84e+04
array([[38398.32515598, 38398.32584412, 38398.32698033, ..., 38398.92606505, 38398.92720005, 38398.92787772], [38397.30786484, 38397.30896574, 38397.31030007, ..., 38397.94424016, 38397.94489865, 38397.94554733], [38397.31111097, 38397.31335491, 38397.31413613, ..., 38397.95092883, 38397.95580027, 38397.95710175], [38397.10391414, 38397.10512732, 38397.10629472, ..., 38397.77145488, 38397.77339404, 38397.77530902]])
- depth(chain, draw)int642 3 4 4 4 4 3 3 ... 3 3 3 2 3 3 3 3
array([[2, 3, 4, ..., 3, 2, 3], [2, 2, 3, ..., 2, 2, 3], [4, 2, 2, ..., 3, 3, 2], [3, 3, 2, ..., 3, 3, 3]])
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
- perf_counter_diff(chain, draw)float640.0005474 0.0009994 ... 0.001682
array([[0.00054738, 0.00099937, 0.00185167, ..., 0.00099918, 0.00054206, 0.00099944], [0.00084672, 0.00108562, 0.00204043, ..., 0.00052271, 0.00051389, 0.0010299 ], [0.00201547, 0.00059255, 0.00093194, ..., 0.00101994, 0.00114038, 0.00054889], [0.00106005, 0.00102457, 0.00055308, ..., 0.00172803, 0.00169889, 0.00168203]])
- energy(chain, draw)float64256.5 256.4 256.8 ... 257.8 260.0
array([[256.53901717, 256.41152341, 256.81453713, ..., 259.09599776, 259.21878784, 258.71725041], [257.79794939, 259.30051126, 258.76059782, ..., 259.99556199, 259.3207446 , 257.320887 ], [259.358463 , 260.9352891 , 259.32214557, ..., 256.03515402, 259.49267325, 258.08906042], [257.4111622 , 258.72059479, 261.61061224, ..., 257.04211674, 257.7631332 , 260.04492158]])
- mean_tree_accept(chain, draw)float640.935 1.0 0.9774 ... 0.9993 0.4627
array([[0.93504528, 1. , 0.97735748, ..., 0.96845534, 0.81157022, 0.96331078], [1. , 0.89644474, 0.96742349, ..., 0.24129101, 1. , 0.99619178], [0.94804989, 0.40516219, 0.96041165, ..., 0.97067921, 0.84005685, 0.96174938], [0.85091472, 0.93431431, 0.75282966, ..., 0.91894568, 0.99930385, 0.46272935]])
- lp(chain, draw)float64-256.3 -256.3 ... -255.9 -258.6
array([[-256.30865316, -256.26391717, -256.53366773, ..., -257.10543575, -258.10563783, -256.1625731 ], [-257.6338782 , -258.21964153, -256.62948676, ..., -259.38536792, -257.10688257, -255.8261813 ], [-258.56978045, -258.56978045, -258.66976508, ..., -255.86326273, -257.55053431, -257.52311936], [-255.88563267, -257.0566285 , -259.34080164, ..., -256.73582417, -255.85048569, -258.61354926]])
- tree_size(chain, draw)float643.0 7.0 15.0 11.0 ... 7.0 7.0 7.0
array([[ 3., 7., 15., ..., 7., 3., 7.], [ 3., 3., 7., ..., 3., 3., 7.], [15., 3., 3., ..., 7., 7., 3.], [ 7., 7., 3., ..., 7., 7., 7.]])
- created_at :
- 2020-10-06T00:56:59.868643
- arviz_version :
- 0.10.0
- inference_library :
- pymc3
- inference_library_version :
- 3.9.3
- sampling_time :
- 3.375001907348633
- tuning_steps :
- 500
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: step_size_bar (chain, draw) float64 0.4844 0.4844 ... 0.5218 0.5218 step_size (chain, draw) float64 0.6579 0.6579 ... 0.5572 0.5572 energy_error (chain, draw) float64 0.1089 -0.01555 ... -0.1793 1.522 process_time_diff (chain, draw) float64 0.0005479 0.001 ... 0.001683 max_energy_error (chain, draw) float64 0.1089 -0.1235 ... -0.1793 2.335 perf_counter_start (chain, draw) float64 3.84e+04 3.84e+04 ... 3.84e+04 depth (chain, draw) int64 2 3 4 4 4 4 3 3 ... 3 3 3 2 3 3 3 3 diverging (chain, draw) bool False False False ... False False perf_counter_diff (chain, draw) float64 0.0005474 0.0009994 ... 0.001682 energy (chain, draw) float64 256.5 256.4 256.8 ... 257.8 260.0 mean_tree_accept (chain, draw) float64 0.935 1.0 0.9774 ... 0.9993 0.4627 lp (chain, draw) float64 -256.3 -256.3 ... -255.9 -258.6 tree_size (chain, draw) float64 3.0 7.0 15.0 11.0 ... 7.0 7.0 7.0 Attributes: created_at: 2020-10-06T00:56:59.868643 arviz_version: 0.10.0 inference_library: pymc3 inference_library_version: 3.9.3 sampling_time: 3.375001907348633 tuning_steps: 500
xarray.Dataset -
- time: 100
- time(time)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- y(time)float64-1.412 -7.319 1.151 ... 48.52 46.03
array([-1.41202037, -7.3186587 , 1.15145745, 1.50159596, -0.66638434, 1.340505 , 0.94309618, -3.74339279, -0.34243761, 4.41332204, 3.52852595, -0.38735502, 6.92937569, 2.17738437, 2.29506712, 2.479826 , 4.32780469, 14.8236344 , 8.58376674, 13.86029246, 8.30748541, 9.42697384, 6.20130931, 12.74674786, 14.49628457, 12.37415809, 13.04994867, 12.27711609, 13.04606435, 15.07724923, 16.25148031, 10.18710661, 12.24013837, 13.51964002, 9.40025182, 20.11401051, 19.57963549, 17.26609676, 16.39293544, 15.4848222 , 19.07510716, 19.94747454, 18.05554676, 18.95703705, 21.97194249, 18.55533794, 21.38972486, 17.64270549, 22.38207915, 20.23227438, 22.3752402 , 22.69176278, 25.10069955, 22.29368553, 25.38288326, 27.6663142 , 26.60546597, 20.45069871, 27.89511126, 27.62673933, 25.45690863, 25.41379887, 28.68450485, 36.54353412, 25.65553597, 29.01507728, 30.97776362, 35.17952383, 31.07761309, 38.35764652, 32.82119153, 32.72583667, 34.15217468, 34.91746821, 39.98665155, 32.85510289, 35.76383771, 37.94409775, 37.02228539, 37.94068802, 42.21713708, 36.9942534 , 36.55582315, 36.82877733, 42.81862081, 37.00939662, 41.67250008, 37.28144053, 44.59191824, 44.05883374, 42.87139157, 47.30850894, 48.57268519, 46.25413295, 51.7918344 , 48.79686829, 51.28945751, 48.48599342, 48.52212075, 46.03052542])
- created_at :
- 2020-10-06T00:56:59.872278
- arviz_version :
- 0.10.0
- inference_library :
- pymc3
- inference_library_version :
- 3.9.3
<xarray.Dataset> Dimensions: (time: 100) Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 Data variables: y (time) float64 -1.412 -7.319 1.151 1.502 ... 48.49 48.52 46.03 Attributes: created_at: 2020-10-06T00:56:59.872278 arviz_version: 0.10.0 inference_library: pymc3 inference_library_version: 3.9.3
xarray.Dataset -
- time: 100
- time(time)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- x(time)float640.0 0.5051 1.01 ... 49.49 50.0
array([ 0. , 0.50505051, 1.01010101, 1.51515152, 2.02020202, 2.52525253, 3.03030303, 3.53535354, 4.04040404, 4.54545455, 5.05050505, 5.55555556, 6.06060606, 6.56565657, 7.07070707, 7.57575758, 8.08080808, 8.58585859, 9.09090909, 9.5959596 , 10.1010101 , 10.60606061, 11.11111111, 11.61616162, 12.12121212, 12.62626263, 13.13131313, 13.63636364, 14.14141414, 14.64646465, 15.15151515, 15.65656566, 16.16161616, 16.66666667, 17.17171717, 17.67676768, 18.18181818, 18.68686869, 19.19191919, 19.6969697 , 20.2020202 , 20.70707071, 21.21212121, 21.71717172, 22.22222222, 22.72727273, 23.23232323, 23.73737374, 24.24242424, 24.74747475, 25.25252525, 25.75757576, 26.26262626, 26.76767677, 27.27272727, 27.77777778, 28.28282828, 28.78787879, 29.29292929, 29.7979798 , 30.3030303 , 30.80808081, 31.31313131, 31.81818182, 32.32323232, 32.82828283, 33.33333333, 33.83838384, 34.34343434, 34.84848485, 35.35353535, 35.85858586, 36.36363636, 36.86868687, 37.37373737, 37.87878788, 38.38383838, 38.88888889, 39.39393939, 39.8989899 , 40.4040404 , 40.90909091, 41.41414141, 41.91919192, 42.42424242, 42.92929293, 43.43434343, 43.93939394, 44.44444444, 44.94949495, 45.45454545, 45.95959596, 46.46464646, 46.96969697, 47.47474747, 47.97979798, 48.48484848, 48.98989899, 49.49494949, 50. ])
- created_at :
- 2020-10-06T00:56:59.872872
- arviz_version :
- 0.10.0
- inference_library :
- pymc3
- inference_library_version :
- 3.9.3
<xarray.Dataset> Dimensions: (time: 100) Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 Data variables: x (time) float64 0.0 0.5051 1.01 1.515 ... 48.48 48.99 49.49 50.0 Attributes: created_at: 2020-10-06T00:56:59.872872 arviz_version: 0.10.0 inference_library: pymc3 inference_library_version: 3.9.3
xarray.Dataset -
- chain: 4
- draw: 500
- time: 100
- time(time)int640 1 2 3 4 5 6 ... 94 95 96 97 98 99
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
- x(time, chain, draw)float64-2.022 -2.105 ... -2.368 -2.183
array([[[-2.0222061 , -2.10464399, -1.99986167, ..., -2.19873733, -2.17834116, -1.98438089], [-2.19843561, -2.24670621, -2.02755688, ..., -2.12155523, -2.1157866 , -2.05582307], [-2.16649033, -2.16649033, -2.21144701, ..., -1.99706482, -1.93199788, -1.93182487], [-2.02246359, -1.9558272 , -1.90160355, ..., -2.11427697, -2.04178635, -2.1487463 ]], [[-3.7566716 , -3.49926348, -3.97332594, ..., -3.22119922, -3.32909172, -4.15492801], [-3.2542671 , -3.14674504, -3.72206924, ..., -3.41202242, -3.47738857, -3.67214057], [-3.29985406, -3.29985406, -3.17662637, ..., -3.8994044 , -4.67998115, -4.66346875], [-3.77154879, -4.50886239, -5.03063003, ..., -3.42838865, -3.69925293, -3.35143331]], [[-2.40254233, -2.4522296 , -2.27490992, ..., -2.58193216, -2.484594 , -2.20772118], ... [-1.99712481, -2.11281981, -2.16806017, ..., -1.99381651, -2.01583904, -2.2055515 ]], [[-1.97715944, -2.04202649, -2.06124313, ..., -2.06002362, -2.10193947, -1.99103498], [-2.08993642, -2.12099265, -1.91879688, ..., -2.12091082, -2.09535458, -2.00764917], [-2.05826399, -2.05826399, -1.99405508, ..., -1.96157219, -2.01819148, -2.02240608], [-1.95831987, -2.03438017, -2.05658765, ..., -1.95675286, -1.97470316, -2.12526455]], [[-2.26367622, -2.31429803, -2.16599206, ..., -2.41358561, -2.68925519, -2.34288052], [-2.61860446, -2.54726798, -2.53484427, ..., -2.12718468, -2.23313826, -2.33159269], [-2.8547175 , -2.8547175 , -2.30138184, ..., -2.34058357, -2.1561869 , -2.14938582], [-2.38469362, -2.17599579, -2.05186096, ..., -2.39956684, -2.36839827, -2.183102 ]]])
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, time: 100) Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 Data variables: x (time, chain, draw) float64 -2.022 -2.105 -2.0 ... -2.368 -2.183
xarray.Dataset
We will create a subclass of az.SamplingWrapper
.
class PyMC3LinRegWrapper(az.SamplingWrapper):
def sample(self, modified_observed_data):
with self.model(*modified_observed_data) as linreg_model:
idata = pm.sample(
**self.sample_kwargs,
return_inferencedata=True,
idata_kwargs=self.idata_kwargs
)
return idata
def get_inference_data(self, idata):
return idata
def sel_observations(self, idx):
xdata = self.idata_orig.constant_data["x"]
ydata = self.idata_orig.observed_data["y"]
mask = np.isin(np.arange(len(xdata)), idx)
data__i = [ary[~mask] for ary in (xdata, ydata)]
data_ex = [ary[mask] for ary in (xdata, ydata)]
return data__i, data_ex
loo_orig = az.loo(idata, pointwise=True)
loo_orig
Computed from 2000 by 100 log-likelihood matrix
Estimate SE
elpd_loo -250.78 7.13
p_loo 2.96 -
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.5] (good) 100 100.0%
(0.5, 0.7] (ok) 0 0.0%
(0.7, 1] (bad) 0 0.0%
(1, Inf) (very bad) 0 0.0%
In this case, the Leave-One-Out Cross Validation (LOO-CV) approximation using Pareto Smoothed Importance Sampling (PSIS) works for all observations, so we will use modify loo_orig
in order to make az.reloo
believe that PSIS failed for some observations. This will also serve as a validation of our wrapper, as the PSIS LOO-CV already returned the correct value.
loo_orig.pareto_k[[13, 42, 56, 73]] = np.array([0.8, 1.2, 2.6, 0.9])
We initialize our sampling wrapper. Let’s stop and analyze each of the arguments.
We’d generally use model
to pass a model object of some kind, already compiled and reexecutable, however, as we saw before, we need to recompile the model every time we use it to pass the model generating function instead. Close enough.
We then use the log_lik_fun
and posterior_vars
argument to tell the wrapper how to call xr.apply_ufunc
. log_lik_fun
is the function to be called, which is then called with the following positional arguments:
log_lik_fun(*data_ex, *[idata__i.posterior[var_name] for var_name in posterior_vars]
where data_ex
is the second element returned by sel_observations
and idata__i
is the InferenceData object result of get_inference_data
which contains the fit on the subsetted data. We have generated data_ex
to be a tuple of DataArrays so it plays nicely with this call signature.
We use idata_orig
as a starting point, and mostly as a source of observed and constant data which is then subsetted in sel_observations
.
Finally, sample_kwargs
and idata_kwargs
are used to make sure all refits and corresponding InferenceData are generated with the same properties.
pymc3_wrapper = PyMC3LinRegWrapper(
model=compile_linreg_model,
log_lik_fun=calculate_log_lik,
posterior_vars=("b0", "b1", "sigma_e"),
idata_orig=idata,
sample_kwargs=sample_kwargs,
idata_kwargs=idata_kwargs,
)
And eventually, we can use this wrapper to call az.reloo
, and compare the results with the PSIS LOO-CV results.
loo_relooed = az.reloo(pymc3_wrapper, loo_orig=loo_orig)
/home/oriol/miniconda3/envs/arviz/lib/python3.8/site-packages/arviz/stats/stats_refitting.py:99: UserWarning: reloo is an experimental and untested feature
warnings.warn("reloo is an experimental and untested feature", UserWarning)
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 13
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma_e, b1, b0]
Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 2 seconds.
The acceptance probability does not match the target. It is 0.9084390959319811, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8833232031335186, but should be close to 0.8. Try to increase the number of tuning steps.
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 42
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma_e, b1, b0]
Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 2 seconds.
The acceptance probability does not match the target. It is 0.8788024509211416, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.900598064671754, but should be close to 0.8. Try to increase the number of tuning steps.
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 56
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma_e, b1, b0]
Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 2 seconds.
The acceptance probability does not match the target. It is 0.8949149672236311, but should be close to 0.8. Try to increase the number of tuning steps.
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 73
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma_e, b1, b0]
Sampling 4 chains for 500 tune and 500 draw iterations (2_000 + 2_000 draws total) took 2 seconds.
The acceptance probability does not match the target. It is 0.8797995668769552, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.882380854441132, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8936869082173754, but should be close to 0.8. Try to increase the number of tuning steps.
loo_relooed
Computed from 2000 by 100 log-likelihood matrix
Estimate SE
elpd_loo -250.77 7.13
p_loo 2.95 -
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.5] (good) 100 100.0%
(0.5, 0.7] (ok) 0 0.0%
(0.7, 1] (bad) 0 0.0%
(1, Inf) (very bad) 0 0.0%
loo_orig
Computed from 2000 by 100 log-likelihood matrix
Estimate SE
elpd_loo -250.78 7.13
p_loo 2.96 -
------
Pareto k diagnostic values:
Count Pct.
(-Inf, 0.5] (good) 96 96.0%
(0.5, 0.7] (ok) 0 0.0%
(0.7, 1] (bad) 2 2.0%
(1, Inf) (very bad) 2 2.0%