Working with InferenceData#
Here we present a collection of common manipulations you can use while working with InferenceData
.
import arviz as az
import numpy as np
import xarray as xr
xr.set_options(display_expand_data=False, display_expand_attrs=False);
display_expand_data=False
makes the default view for xarray.DataArray
fold the data values to a single line. To explore the values, click on the icon on the left of the view, right under the xarray.DataArray
text. It has no effect on Dataset
objects that already default to folded views.
display_expand_attrs=False
folds the attributes in both DataArray
and Dataset
objects to keep the views shorter. In this page we print DataArrays and Datasets several times and they always have the same attributes.
idata = az.load_arviz_data("centered_eight")
idata
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 ... theta (chain, draw, school) float64 ... tau (chain, draw) float64 ... Attributes: (6)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: (12/16) max_energy_error (chain, draw) float64 ... energy_error (chain, draw) float64 ... lp (chain, draw) float64 ... index_in_trajectory (chain, draw) int64 ... acceptance_rate (chain, draw) float64 ... diverging (chain, draw) bool ... ... ... smallest_eigval (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... step_size (chain, draw) float64 ... energy (chain, draw) float64 ... tree_depth (chain, draw) int64 ... perf_counter_diff (chain, draw) float64 ... Attributes: (6)
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... mu (chain, draw) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: scores (school) float64 ... Attributes: (4)
Get the dataset corresponding to a single group#
post = idata.posterior
post
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 ... theta (chain, draw, school) float64 ... tau (chain, draw) float64 ... Attributes: (6)
Tip
You’ll have noticed we stored the posterior group in a new variable: post
. As .copy()
was not called, now using idata.posterior
or post
is equivalent.
Use this to keep your code short yet easy to read. Store the groups you’ll need very often as separate variables to use explicitly, but don’t delete the InferenceData parent. You’ll need it for many ArviZ functions to work properly. For example: plot_pair()
needs data from sample_stats
group to show divergences, compare()
needs data from both log_likelihood
and posterior
groups, plot_loo_pit()
needs not 2 but 3 groups: log_likelihood
, posterior_predictive
and posterior
.
Add a new variable#
post["log_tau"] = np.log(post["tau"])
idata.posterior
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 ... theta (chain, draw, school) float64 ... tau (chain, draw) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461 log_tau (chain, draw) float64 1.553 1.363 1.578 ... 1.008 1.076 1.495 Attributes: (6)
Combine chains and draws#
stacked = az.extract(idata)
stacked
<xarray.Dataset> Dimensions: (sample: 2000, school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 Data variables: mu (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404 theta (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295 tau (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461 log_tau (sample) float64 1.553 1.363 1.578 0.6188 ... 1.008 1.076 1.495 Attributes: (6)
You can also use xarray.Dataset.stack()
if you only want to combine the chain and draw dimensions. arviz.extract()
is a convenience function aimed at taking care of the most common subsetting operations with MCMC samples. It can:
Combine chains and draws
Return a subset of variables (with optional filtering with regular expressions or string matching)
Return a subset of samples. Moreover by default it returns a random subset to prevent getting non-representative samples due to bad mixing.
Access any group
Get a random subset of the samples#
az.extract(idata, num_samples=100)
<xarray.Dataset> Dimensions: (sample: 100, school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 3 1 1 3 2 3 3 3 3 1 0 0 ... 2 0 2 2 3 2 2 3 2 0 3 2 * draw (sample) int64 214 202 176 487 371 27 ... 207 411 170 102 70 385 Data variables: mu (sample) float64 10.14 0.2007 12.31 0.4654 ... -5.88 12.93 4.57 theta (school, sample) float64 10.01 24.78 14.16 ... -4.128 14.04 3.417 tau (sample) float64 3.008 19.69 1.974 3.611 ... 2.457 1.675 2.081 log_tau (sample) float64 1.101 2.98 0.6802 1.284 ... 0.8991 0.5161 0.733 Attributes: (6)
Tip
Use a random seed to get the same subset from multiple groups: az.extract(idata, num_samples=100, rng=3)
and az.extract(idata, group="log_likelihood", num_samples=100, rng=3)
will continue to have matching samples
Obtain a NumPy array for a given parameter#
Let’s say we want to get the values for mu
as a NumPy array.
stacked.mu.values
array([7.87179637, 3.38455431, 9.10047569, ..., 1.76673325, 3.48611194,
3.40446391])
Get the dimension lengths#
Let’s check how many groups are in our hierarchical model.
len(idata.observed_data.school)
8
Get coordinate values#
What are the names of the groups in our hierarchical model? You can access them from the coordinate name school
in this case
idata.observed_data.school
<xarray.DataArray 'school' (school: 8)> 'Choate' 'Deerfield' 'Phillips Andover' ... "St. Paul's" 'Mt. Hermon' Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Get a subset of chains#
Let’s keep only chain 0 and 2 here. For the subset to take effect on all relevant InferenceData groups: posterior, sample_stats, log_likelihood, posterior_predictive we will use the arviz.InferenceData.sel()
, the method of InferenceData instead of xarray.Dataset.sel()
.
idata.sel(chain=[0, 2])
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 2 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 7.872 3.385 9.1 7.304 ... 2.871 4.096 1.776 theta (chain, draw, school) float64 12.32 9.905 14.95 ... 2.363 -2.968 tau (chain, draw) float64 4.726 3.909 4.844 1.857 ... 4.09 2.72 1.917 log_tau (chain, draw) float64 1.553 1.363 1.578 ... 1.408 1.001 0.6508 Attributes: (6)
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 2 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 2 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 2, draw: 500) Coordinates: * chain (chain) int64 0 2 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: (12/16) max_energy_error (chain, draw) float64 ... energy_error (chain, draw) float64 ... lp (chain, draw) float64 ... index_in_trajectory (chain, draw) int64 ... acceptance_rate (chain, draw) float64 ... diverging (chain, draw) bool ... ... ... smallest_eigval (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... step_size (chain, draw) float64 ... energy (chain, draw) float64 ... tree_depth (chain, draw) int64 ... perf_counter_diff (chain, draw) float64 ... Attributes: (6)
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... mu (chain, draw) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: scores (school) float64 ... Attributes: (4)
Remove the first n draws (burn-in)#
Let’s say we want to remove the first 100 samples, from all the chains and all InferenceData
groups with draws.
idata.sel(draw=slice(100, None))
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 400, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 100 101 102 103 104 105 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 11.7 8.118 -5.88 -7.149 ... 1.767 3.486 3.404 theta (chain, draw, school) float64 14.23 9.72 9.195 ... 6.762 1.295 tau (chain, draw) float64 4.289 2.765 2.457 1.719 ... 2.741 2.932 4.461 log_tau (chain, draw) float64 1.456 1.017 0.8991 ... 1.008 1.076 1.495 Attributes: (6)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 400, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 100 101 102 103 104 105 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 400, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 100 101 102 103 104 105 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 400) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 100 101 102 103 104 ... 496 497 498 499 Data variables: (12/16) max_energy_error (chain, draw) float64 ... energy_error (chain, draw) float64 ... lp (chain, draw) float64 ... index_in_trajectory (chain, draw) int64 ... acceptance_rate (chain, draw) float64 ... diverging (chain, draw) bool ... ... ... smallest_eigval (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... step_size (chain, draw) float64 ... energy (chain, draw) float64 ... tree_depth (chain, draw) int64 ... perf_counter_diff (chain, draw) float64 ... Attributes: (6)
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 400, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 100 101 102 103 104 105 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... mu (chain, draw) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 400, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 100 101 102 103 104 105 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: scores (school) float64 ... Attributes: (4)
If you check the burnin
object you will see that the groups posterior
, posterior_predictive
, prior
and sample_stats
have 400 draws compared to idata
that has 500. The group observed_data
has not been affected because it does not have the draw
dimension. Alternatively, you can specify which group or groups you want to change.
idata.sel(draw=slice(100, None), groups="posterior")
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 400, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 100 101 102 103 104 105 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 11.7 8.118 -5.88 -7.149 ... 1.767 3.486 3.404 theta (chain, draw, school) float64 14.23 9.72 9.195 ... 6.762 1.295 tau (chain, draw) float64 4.289 2.765 2.457 1.719 ... 2.741 2.932 4.461 log_tau (chain, draw) float64 1.456 1.017 0.8991 ... 1.008 1.076 1.495 Attributes: (6)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: (12/16) max_energy_error (chain, draw) float64 ... energy_error (chain, draw) float64 ... lp (chain, draw) float64 ... index_in_trajectory (chain, draw) int64 ... acceptance_rate (chain, draw) float64 ... diverging (chain, draw) bool ... ... ... smallest_eigval (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... step_size (chain, draw) float64 ... energy (chain, draw) float64 ... tree_depth (chain, draw) int64 ... perf_counter_diff (chain, draw) float64 ... Attributes: (6)
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... mu (chain, draw) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: scores (school) float64 ... Attributes: (4)
Compute posterior mean values along draw
and chain
dimensions#
To compute the mean value of the posterior samples, do the following:
post.mean()
<xarray.Dataset> Dimensions: () Data variables: mu float64 4.486 theta float64 4.912 tau float64 4.124 log_tau float64 1.173
This computes the mean along all dimensions. This is probably what you want for mu
and tau
, which have two dimensions (chain
and draw
), but maybe not what you expected for theta
, which has one more dimension school
.
You can specify along which dimension you want to compute the mean (or other functions).
post.mean(dim=["chain", "draw"])
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu float64 4.486 theta (school) float64 6.46 5.028 3.938 4.872 3.667 3.975 6.581 4.772 tau float64 4.124 log_tau float64 1.173
Compute and store posterior pushforward quantities#
We use “posterior pushfoward quantities” to refer to quantities that are not variables in the posterior but deterministic computations using posterior variables.
You can use xarray for these pushforward operations and store them as a new variable in the posterior group. You’ll then be able to plot them with ArviZ functions, calculate stats and diagnostics on them (like the mcse()
) or save and share the inferencedata object with the pushforward quantities included.
Compute the rolling mean of \(\log(\tau)\) with xarray.DataArray.rolling()
, storing the result in the posterior
post["mlogtau"] = post["log_tau"].rolling({"draw": 50}).mean()
Using xarray for pusforward calculations has all the advantages of working with xarray. It also inherits the disadvantages of working with xarray, but we believe those to be outweighed by the advantages, and we have already shown how to extract the data as NumPy arrays. Working with InferenceData is working mainly with xarray objects and this is what is shown in this guide.
Some examples of these advantages are specifying operations with named dimensions instead of positional ones (as seen in some previous sections), automatic alignment and broadcasting of arrays (as we’ll see now), or integration with Dask (as shown in the Dask for ArviZ guide).
In this cell you will compute pairwise differences between schools on their mean effects (variable theta
).
To do so, substract the variable theta after renaming the school dimension to the original variable.
Xarray then aligns and broadcasts the two variables because they have different dimensions, and
the result is a 4d variable with all the pointwise differences.
Eventually, store the result in the theta_school_diff
variable:
post["theta_school_diff"] = post.theta - post.theta.rename(school="school_bis")
Note
This same operation using NumPy would require manual alignment of the two arrays to make sure they broadcast correctly. The could would be something like:
theta_school_diff = theta[:, :, :, None] - theta[:, :, None, :]
The theta_shool_diff
variable in the posterior has kept the named dimensions and coordinates:
post
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8, school_bis: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon' * school_bis (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (chain, draw) float64 7.872 3.385 9.1 ... 3.486 3.404 theta (chain, draw, school) float64 12.32 9.905 ... 6.762 1.295 tau (chain, draw) float64 4.726 3.909 4.844 ... 2.932 4.461 log_tau (chain, draw) float64 1.553 1.363 1.578 ... 1.076 1.495 mlogtau (chain, draw) float64 nan nan nan ... 1.494 1.496 1.511 theta_school_diff (chain, draw, school, school_bis) float64 0.0 ... 0.0 Attributes: (6)
Advanced subsetting#
To select the value corresponding to the difference between the Choate and Deerfield schools do:
post["theta_school_diff"].sel(school="Choate", school_bis="Deerfield")
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500)> 2.415 2.156 -0.04943 1.228 3.384 9.662 ... -1.656 -0.4021 1.524 -3.372 -6.305 Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 school <U16 'Choate' school_bis <U16 'Deerfield'
For more advanced subsetting (the equivalent to what is sometimes called “fancy indexing” in NumPy) you need to provide the indices as DataArray
objects:
school_idx = xr.DataArray(["Choate", "Hotchkiss", "Mt. Hermon"], dims=["pairwise_school_diff"])
school_bis_idx = xr.DataArray(
["Deerfield", "Choate", "Lawrenceville"], dims=["pairwise_school_diff"]
)
post["theta_school_diff"].sel(school=school_idx, school_bis=school_bis_idx)
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500, pairwise_school_diff: 3)> 2.415 -6.741 -1.84 2.156 -3.474 3.784 ... -2.619 6.923 -6.305 1.667 -6.641 Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 school (pairwise_school_diff) <U16 'Choate' 'Hotchkiss' 'Mt. Hermon' school_bis (pairwise_school_diff) <U16 'Deerfield' 'Choate' 'Lawrenceville' Dimensions without coordinates: pairwise_school_diff
Using lists or NumPy arrays instead of DataArrays does colum/row based indexing. As you can see, the result has 9 values of theta_shool_diff
instead of the 3 pairs of difference we selected in the previous cell:
post["theta_school_diff"].sel(
school=["Choate", "Hotchkiss", "Mt. Hermon"],
school_bis=["Deerfield", "Choate", "Lawrenceville"],
)
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500, school: 3, school_bis: 3)> 2.415 0.0 -4.581 -4.326 -6.741 -11.32 ... 1.667 -6.077 -5.203 1.102 -6.641 Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Hotchkiss' 'Mt. Hermon' * school_bis (school_bis) <U16 'Deerfield' 'Choate' 'Lawrenceville'
Add new chains using concat#
After checking the mcse()
and realizing you need more samples, you rerun the model with two chains
and obtain an idata_rerun
object.
idata_rerun = (
idata.sel(chain=[0, 1])
.copy()
.assign_coords(coords={"chain": [4, 5]}, groups="posterior_groups")
)
You can combine the two into a single InferenceData object using arviz.concat()
:
idata_complete = az.concat(idata, idata_rerun, dim="chain")
idata_complete.posterior.sizes["chain"]
6
Add groups to InferenceData objects#
You can also add new groups to InferenceData objects with the extend()
(if the new groups are already in an InferenceData object) or with add_groups()
(if the new groups are dictionaries or xarray.Dataset
objects).
rng = np.random.default_rng(3)
idata.add_groups(
{"predictions": {"obs": rng.normal(size=(4, 500, 2))}},
dims={"obs": ["new_school"]},
coords={"new_school": ["Essex College", "Moordale"]},
)
idata
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8, school_bis: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon' * school_bis (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (chain, draw) float64 7.872 3.385 9.1 ... 3.486 3.404 theta (chain, draw, school) float64 12.32 9.905 ... 6.762 1.295 tau (chain, draw) float64 4.726 3.909 4.844 ... 2.932 4.461 log_tau (chain, draw) float64 1.553 1.363 1.578 ... 1.076 1.495 mlogtau (chain, draw) float64 nan nan nan ... 1.494 1.496 1.511 theta_school_diff (chain, draw, school, school_bis) float64 0.0 ... 0.0 Attributes: (6)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, new_school: 2) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * new_school (new_school) <U13 'Essex College' 'Moordale' Data variables: obs (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822 Attributes: (2)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: (12/16) max_energy_error (chain, draw) float64 ... energy_error (chain, draw) float64 ... lp (chain, draw) float64 ... index_in_trajectory (chain, draw) int64 ... acceptance_rate (chain, draw) float64 ... diverging (chain, draw) bool ... ... ... smallest_eigval (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... step_size (chain, draw) float64 ... energy (chain, draw) float64 ... tree_depth (chain, draw) int64 ... perf_counter_diff (chain, draw) float64 ... Attributes: (6)
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... mu (chain, draw) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: scores (school) float64 ... Attributes: (4)
Add Transformations to Multiple Groups#
You can also add transformations to Multiple InferenceData Groups using arviz.InferenceData.map()
. It takes a function as an input and applies the function groupwise to the selected InferenceData groups and overwrites the group with the result of the function.
selected_groups = ["posterior", "prior"]
def calc_mean(dataset, *args, **kwargs):
result = dataset.mean(dim="chain", *args, **kwargs)
return result
means = idata.map(calc_mean, groups=selected_groups, inplace=False)
means
-
<xarray.Dataset> Dimensions: (draw: 500, school: 8, school_bis: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon' * school_bis (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (draw) float64 5.974 5.096 7.177 ... 3.284 4.739 3.146 theta (draw, school) float64 9.519 5.554 6.118 ... 5.595 3.773 tau (draw) float64 4.068 3.156 3.603 ... 2.725 3.225 2.979 log_tau (draw) float64 1.322 1.118 1.234 ... 0.958 1.035 0.9508 mlogtau (draw) float64 nan nan nan nan ... 0.993 1.002 1.01 1.021 theta_school_diff (draw, school, school_bis) float64 0.0 3.965 ... 0.0
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, new_school: 2) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * new_school (new_school) <U13 'Essex College' 'Moordale' Data variables: obs (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822 Attributes: (2)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: (12/16) max_energy_error (chain, draw) float64 ... energy_error (chain, draw) float64 ... lp (chain, draw) float64 ... index_in_trajectory (chain, draw) int64 ... acceptance_rate (chain, draw) float64 ... diverging (chain, draw) bool ... ... ... smallest_eigval (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... step_size (chain, draw) float64 ... energy (chain, draw) float64 ... tree_depth (chain, draw) int64 ... perf_counter_diff (chain, draw) float64 ... Attributes: (6)
-
<xarray.Dataset> Dimensions: (draw: 500, school: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (draw) float64 1.941 3.388 4.208 5.687 ... 0.8353 0.06893 2.145 theta (draw, school) float64 4.866 4.59 -0.7404 ... 3.33 -2.031 6.045 mu (draw) float64 3.903 3.915 -1.751 2.595 ... -2.294 0.7908 2.869
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: scores (school) float64 ... Attributes: (4)
You can also pass a lambda function in map
idata_shifted_obs = idata.map(lambda x: x + 3, groups="posterior")
idata_shifted_obs
-
<xarray.Dataset> Dimensions: (draw: 500, school: 8, school_bis: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon' * school_bis (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (draw) float64 8.974 8.096 10.18 ... 6.284 7.739 6.146 theta (draw, school) float64 12.52 8.554 9.118 ... 8.595 6.773 tau (draw) float64 7.068 6.156 6.603 ... 5.725 6.225 5.979 log_tau (draw) float64 4.322 4.118 4.234 ... 3.958 4.035 3.951 mlogtau (draw) float64 nan nan nan nan ... 3.993 4.002 4.01 4.021 theta_school_diff (draw, school, school_bis) float64 3.0 6.965 ... 3.0
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, new_school: 2) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * new_school (new_school) <U13 'Essex College' 'Moordale' Data variables: obs (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822 Attributes: (2)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: (12/16) max_energy_error (chain, draw) float64 ... energy_error (chain, draw) float64 ... lp (chain, draw) float64 ... index_in_trajectory (chain, draw) int64 ... acceptance_rate (chain, draw) float64 ... diverging (chain, draw) bool ... ... ... smallest_eigval (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... step_size (chain, draw) float64 ... energy (chain, draw) float64 ... tree_depth (chain, draw) int64 ... perf_counter_diff (chain, draw) float64 ... Attributes: (6)
-
<xarray.Dataset> Dimensions: (draw: 500, school: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (draw) float64 1.941 3.388 4.208 5.687 ... 0.8353 0.06893 2.145 theta (draw, school) float64 4.866 4.59 -0.7404 ... 3.33 -2.031 6.045 mu (draw) float64 3.903 3.915 -1.751 2.595 ... -2.294 0.7908 2.869
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: scores (school) float64 ... Attributes: (4)
You can also add extra coordinates using map
_upper = np.array([
x.upper() for x in idata.observed_data.school.values
]).T
idata_with_upper = idata.map(
lambda ds, **kwargs: ds.assign_coords(**kwargs),
groups="observed_vars",
upper=("Upper", _upper),
)
idata_with_upper
-
<xarray.Dataset> Dimensions: (draw: 500, school: 8, school_bis: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon' * school_bis (school_bis) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (draw) float64 5.974 5.096 7.177 ... 3.284 4.739 3.146 theta (draw, school) float64 9.519 5.554 6.118 ... 5.595 3.773 tau (draw) float64 4.068 3.156 3.603 ... 2.725 3.225 2.979 log_tau (draw) float64 1.322 1.118 1.234 ... 0.958 1.035 0.9508 mlogtau (draw) float64 nan nan nan nan ... 0.993 1.002 1.01 1.021 theta_school_diff (draw, school, school_bis) float64 0.0 3.965 ... 0.0
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8, Upper: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' upper (Upper) <U16 'CHOATE' 'DEERFIELD' ... "ST. PAUL'S" 'MT. HERMON' Dimensions without coordinates: Upper Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, new_school: 2) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * new_school (new_school) <U13 'Essex College' 'Moordale' Data variables: obs (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822 Attributes: (2)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: (12/16) max_energy_error (chain, draw) float64 ... energy_error (chain, draw) float64 ... lp (chain, draw) float64 ... index_in_trajectory (chain, draw) int64 ... acceptance_rate (chain, draw) float64 ... diverging (chain, draw) bool ... ... ... smallest_eigval (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... step_size (chain, draw) float64 ... energy (chain, draw) float64 ... tree_depth (chain, draw) int64 ... perf_counter_diff (chain, draw) float64 ... Attributes: (6)
-
<xarray.Dataset> Dimensions: (draw: 500, school: 8) Coordinates: * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (draw) float64 1.941 3.388 4.208 5.687 ... 0.8353 0.06893 2.145 theta (draw, school) float64 4.866 4.59 -0.7404 ... 3.33 -2.031 6.045 mu (draw) float64 3.903 3.915 -1.751 2.595 ... -2.294 0.7908 2.869
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8, Upper: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' upper (Upper) <U16 'CHOATE' 'DEERFIELD' ... "ST. PAUL'S" 'MT. HERMON' Dimensions without coordinates: Upper Data variables: obs (chain, draw, school) float64 ... Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8, Upper: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' upper (Upper) <U16 'CHOATE' 'DEERFIELD' ... "ST. PAUL'S" 'MT. HERMON' Dimensions without coordinates: Upper Data variables: obs (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 Attributes: (4)
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: scores (school) float64 ... Attributes: (4)