Working with InferenceData

Here we present a collection of common manipulations you can use while working with InferenceData.

import arviz as az
idata = az.load_arviz_data("centered_eight")
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 -3.477 -2.456 -2.826 ... 4.597 5.899 0.1614
          theta    (chain, draw, school) float64 1.669 -8.537 -2.623 ... 10.59 4.523
          tau      (chain, draw) float64 3.73 2.075 3.703 4.146 ... 8.346 7.711 5.407
      Attributes:
          created_at:                 2019-06-21T17:36:34.398087
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (chain, draw, school) float64 7.85 -19.03 -22.5 ... 4.698 -15.07
      Attributes:
          created_at:                 2019-06-21T17:36:34.489022
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * school            (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
      Data variables:
          tune              (chain, draw) bool True False False ... False False False
          depth             (chain, draw) int64 5 3 3 4 5 5 4 4 5 ... 4 4 4 5 5 5 5 5
          tree_size         (chain, draw) float64 31.0 7.0 7.0 15.0 ... 31.0 31.0 31.0
          lp                (chain, draw) float64 -59.05 -56.19 ... -63.62 -58.35
          energy_error      (chain, draw) float64 0.07387 -0.1841 ... -0.087 -0.003652
          step_size_bar     (chain, draw) float64 0.2417 0.2417 ... 0.1502 0.1502
          max_energy_error  (chain, draw) float64 0.131 -0.2067 ... -0.101 -0.1757
          energy            (chain, draw) float64 60.76 62.76 64.4 ... 67.77 67.21
          mean_tree_accept  (chain, draw) float64 0.9506 0.9906 ... 0.9875 0.9967
          step_size         (chain, draw) float64 0.1275 0.1275 ... 0.1064 0.1064
          diverging         (chain, draw) bool False False False ... False False False
          log_likelihood    (chain, draw, school) float64 -5.168 -4.589 ... -3.896
      Attributes:
          created_at:                 2019-06-21T17:36:34.485802
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 500, school: 8)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * school     (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
      Data variables:
          tau        (chain, draw) float64 6.561 1.016 68.91 ... 1.56 5.949 0.7631
          tau_log__  (chain, draw) float64 1.881 0.01593 4.233 ... 1.783 -0.2704
          mu         (chain, draw) float64 5.293 0.8137 0.7122 ... -1.658 -3.273
          theta      (chain, draw, school) float64 2.357 7.371 7.251 ... -3.775 -3.555
          obs        (chain, draw, school) float64 -3.54 6.769 19.68 ... -21.16 -6.071
      Attributes:
          created_at:                 2019-06-21T17:36:34.490387
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
      Attributes:
          created_at:                 2019-06-21T17:36:34.491909
          inference_library:          pymc3
          inference_library_version:  3.7

idata.posterior
<xarray.Dataset>
Dimensions:  (chain: 4, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    mu       (chain, draw) float64 -3.477 -2.456 -2.826 ... 4.597 5.899 0.1614
    theta    (chain, draw, school) float64 1.669 -8.537 -2.623 ... 10.59 4.523
    tau      (chain, draw) float64 3.73 2.075 3.703 4.146 ... 8.346 7.711 5.407
Attributes:
    created_at:                 2019-06-21T17:36:34.398087
    inference_library:          pymc3
    inference_library_version:  3.7

Combine chains and draws

stacked = idata.posterior.stack(draws=("chain", "draw"))
stacked
<xarray.Dataset>
Dimensions:  (draws: 2000, school: 8)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * draws    (draws) MultiIndex
  - chain    (draws) int64 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3
  - draw     (draws) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
Data variables:
    mu       (draws) float64 -3.477 -2.456 -2.826 -1.996 ... 4.597 5.899 0.1614
    theta    (school, draws) float64 1.669 -6.239 2.195 ... -1.095 4.013 4.523
    tau      (draws) float64 3.73 2.075 3.703 4.146 ... 8.589 8.346 7.711 5.407
Attributes:
    created_at:                 2019-06-21T17:36:34.398087
    inference_library:          pymc3
    inference_library_version:  3.7

Obtain a NumPy array for a given parameter

Let’s say we want to get the values for mu as a NumPy array.

stacked.mu.values
array([-3.47698606, -2.45587061, -2.82625433, ...,  4.59705819,
        5.89850592,  0.16138927])

Get the number of variables

Let’s check how many groups are in our hierarchical model.

len(idata.observed_data.school)
8

Get the variables’ names

What are the names of the groups in our hierarchical model?

idata.observed_data.school.values
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter',
       'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'],
      dtype=object)

Get a subset of chains

Let’s evaluate only chain 0 and 2 here.

idata.sel(chain=[0, 2]).posterior
<xarray.Dataset>
Dimensions:  (chain: 2, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 0 2
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    mu       (chain, draw) float64 -3.477 -2.456 -2.826 ... -1.571 -4.435 9.763
    theta    (chain, draw, school) float64 1.669 -8.537 -2.623 ... 12.01 16.67
    tau      (chain, draw) float64 3.73 2.075 3.703 4.146 ... 2.812 12.18 4.453
Attributes:
    created_at:                 2019-06-21T17:36:34.398087
    inference_library:          pymc3
    inference_library_version:  3.7

Remove the first n draws (burn-in)

Let’s say we want to remove the first 100 samples, from all the chains and all InferenceData groups with draws.

burnin = idata.sel(draw=slice(100, None))

If you check the burnin object you will see that the groups posterior, posterior_predictive, prior and sample_stats have 400 draws compared to idata that has 500. The group observed_data has not been affected because it does not have the draw dimension. Alternatively, you can specify which group or groups you want to change.

burnin_posterior = idata.sel(draw=slice(100, None), groups="posterior")

Compute posterior mean values along draw and chains dimensions

If you want to compute the mean value of the posterior samples, you can simply do the following:

idata.posterior.mean()
<xarray.Dataset>
Dimensions:  ()
Data variables:
    mu       float64 4.093
    theta    float64 4.56
    tau      float64 4.089

This will effectively compute the mean along all dimensions. This is probably what you want for mu and tau, which have two dimensions (chain and draw), but maybe not what you expected for theta, which has one more dimension school. You can specify along which dimension you want to compute the mean (or other functions).

idata.posterior.mean(dim=['chain', 'draw'])
<xarray.Dataset>
Dimensions:  (school: 8)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    mu       float64 4.093
    theta    (school) float64 6.026 4.724 3.576 4.478 3.064 3.821 6.25 4.544
    tau      float64 4.089