arviz.InferenceData.add_groups#

InferenceData.add_groups(group_dict=None, coords=None, dims=None, **kwargs)[source]#

Add new groups to InferenceData object.

Parameters
group_dictdict of {strdict or xarray.Dataset}, optional

Groups to be added

coordsdict of {strarray_like}, optional

Coordinates for the dataset

dimsdict of {strlist of str}, optional

Dimensions of each variable. The keys are variable names, values are lists of coordinates.

kwargsdict, optional

The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.

See also

extend

Extend InferenceData with groups from another InferenceData.

concat

Concatenate InferenceData objects.

Examples

Add a log_likelihood group to the “rugby” example InferenceData after loading. It originally doesn’t have the log_likelihood group:

import arviz as az
idata = az.load_arviz_data("rugby")
idata2 = idata.copy()
post = idata.posterior
obs = idata.observed_data
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 ...
          intercept  (chain, draw) float64 ...
          atts_star  (chain, draw, team) float64 ...
          defs_star  (chain, draw, team) float64 ...
          sd_att     (chain, draw) float64 ...
          sd_def     (chain, draw) float64 ...
          atts       (chain, draw, team) float64 ...
          defs       (chain, draw, team) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (chain, draw, match) int64 ...
          away_points  (chain, draw, match) int64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          energy_error      (chain, draw) float64 ...
          energy            (chain, draw) float64 ...
          tree_size         (chain, draw) float64 ...
          tune              (chain, draw) bool ...
          mean_tree_accept  (chain, draw) float64 ...
          lp                (chain, draw) float64 ...
          depth             (chain, draw) int64 ...
          max_energy_error  (chain, draw) float64 ...
          step_size         (chain, draw) float64 ...
          step_size_bar     (chain, draw) float64 ...
          diverging         (chain, draw) bool ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (chain: 1, draw: 500, team: 6, match: 60)
      Coordinates:
        * chain         (chain) int64 0
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          sd_att_log__  (chain, draw) float64 ...
          intercept     (chain, draw) float64 ...
          atts_star     (chain, draw, team) float64 ...
          defs_star     (chain, draw, team) float64 ...
          away_points   (chain, draw, match) int64 ...
          sd_att        (chain, draw) float64 ...
          sd_def_log__  (chain, draw) float64 ...
          home          (chain, draw) float64 ...
          atts          (chain, draw, team) float64 ...
          sd_def        (chain, draw) float64 ...
          home_points   (chain, draw, match) int64 ...
          defs          (chain, draw, team) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (match) float64 ...
          away_points  (match) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7

Knowing the model, we can compute it manually. In this case however, we will generate random samples with the right shape.

import numpy as np
rng = np.random.default_rng(73)
ary = rng.normal(size=(post.dims["chain"], post.dims["draw"], obs.dims["match"]))
idata.add_groups(
    log_likelihood={"home_points": ary},
    dims={"home_points": ["match"]},
)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 ...
          intercept  (chain, draw) float64 ...
          atts_star  (chain, draw, team) float64 ...
          defs_star  (chain, draw, team) float64 ...
          sd_att     (chain, draw) float64 ...
          sd_def     (chain, draw) float64 ...
          atts       (chain, draw, team) float64 ...
          defs       (chain, draw, team) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (chain, draw, match) int64 ...
          away_points  (chain, draw, match) int64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) int64 0 1 2 3 4 5 6 7 8 ... 51 52 53 54 55 56 57 58 59
      Data variables:
          home_points  (chain, draw, match) float64 -1.093 0.7781 ... 0.2405 1.643
      Attributes:
          created_at:     2022-11-16T10:14:21.758528
          arviz_version:  0.14.0

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          energy_error      (chain, draw) float64 ...
          energy            (chain, draw) float64 ...
          tree_size         (chain, draw) float64 ...
          tune              (chain, draw) bool ...
          mean_tree_accept  (chain, draw) float64 ...
          lp                (chain, draw) float64 ...
          depth             (chain, draw) int64 ...
          max_energy_error  (chain, draw) float64 ...
          step_size         (chain, draw) float64 ...
          step_size_bar     (chain, draw) float64 ...
          diverging         (chain, draw) bool ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (chain: 1, draw: 500, team: 6, match: 60)
      Coordinates:
        * chain         (chain) int64 0
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          sd_att_log__  (chain, draw) float64 ...
          intercept     (chain, draw) float64 ...
          atts_star     (chain, draw, team) float64 ...
          defs_star     (chain, draw, team) float64 ...
          away_points   (chain, draw, match) int64 ...
          sd_att        (chain, draw) float64 ...
          sd_def_log__  (chain, draw) float64 ...
          home          (chain, draw) float64 ...
          atts          (chain, draw, team) float64 ...
          sd_def        (chain, draw) float64 ...
          home_points   (chain, draw, match) int64 ...
          defs          (chain, draw, team) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (match) float64 ...
          away_points  (match) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7

This is fine if we have raw data, but a bit inconvenient if we start with labeled data already. Why provide dims and coords manually again? Let’s generate a fake log likelihood (doesn’t match the model but it serves just the same for illustration purposes here) working from the posterior and observed_data groups manually:

import xarray as xr
from xarray_einstats.stats import XrDiscreteRV
from scipy.stats import poisson
dist = XrDiscreteRV(poisson)
log_lik = xr.Dataset()
log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
idata2.add_groups({"log_likelihood": log_lik})
idata2
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 ...
          intercept  (chain, draw) float64 ...
          atts_star  (chain, draw, team) float64 ...
          defs_star  (chain, draw, team) float64 ...
          sd_att     (chain, draw) float64 ...
          sd_def     (chain, draw) float64 ...
          atts       (chain, draw, team) float64 ...
          defs       (chain, draw, team) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (chain, draw, match) int64 ...
          away_points  (chain, draw, match) int64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60, chain: 4, draw: 500, team: 6)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team         (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home_points  (match, chain, draw, team) float64 -50.27 -53.03 ... -21.14

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          energy_error      (chain, draw) float64 ...
          energy            (chain, draw) float64 ...
          tree_size         (chain, draw) float64 ...
          tune              (chain, draw) bool ...
          mean_tree_accept  (chain, draw) float64 ...
          lp                (chain, draw) float64 ...
          depth             (chain, draw) int64 ...
          max_energy_error  (chain, draw) float64 ...
          step_size         (chain, draw) float64 ...
          step_size_bar     (chain, draw) float64 ...
          diverging         (chain, draw) bool ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (chain: 1, draw: 500, team: 6, match: 60)
      Coordinates:
        * chain         (chain) int64 0
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          sd_att_log__  (chain, draw) float64 ...
          intercept     (chain, draw) float64 ...
          atts_star     (chain, draw, team) float64 ...
          defs_star     (chain, draw, team) float64 ...
          away_points   (chain, draw, match) int64 ...
          sd_att        (chain, draw) float64 ...
          sd_def_log__  (chain, draw) float64 ...
          home          (chain, draw) float64 ...
          atts          (chain, draw, team) float64 ...
          sd_def        (chain, draw) float64 ...
          home_points   (chain, draw, match) int64 ...
          defs          (chain, draw, team) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (match) float64 ...
          away_points  (match) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7

Note that in the first example we have used the kwargs argument and in the second we have used the group_dict one.