arviz.InferenceData.add_groups#

InferenceData.add_groups(group_dict=None, coords=None, dims=None, **kwargs)[source]#

Add new groups to InferenceData object.

Parameters:
group_dictdict of {strdict or xarray.Dataset}, optional

Groups to be added

coordsdict of {strarray_like}, optional

Coordinates for the dataset

dimsdict of {strlist of str}, optional

Dimensions of each variable. The keys are variable names, values are lists of coordinates.

kwargsdict, optional

The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.

See also

extend

Extend InferenceData with groups from another InferenceData.

concat

Concatenate InferenceData objects.

Examples

Add a log_likelihood group to the “rugby” example InferenceData after loading.

import arviz as az
idata = az.load_arviz_data("rugby")
del idata.log_likelihood
idata2 = idata.copy()
post = idata.posterior
obs = idata.observed_data
idata
arviz.InferenceData
    • <xarray.Dataset> Size: 452kB
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 16kB ...
          intercept  (chain, draw) float64 16kB ...
          atts_star  (chain, draw, team) float64 96kB ...
          defs_star  (chain, draw, team) float64 96kB ...
          sd_att     (chain, draw) float64 16kB ...
          sd_def     (chain, draw) float64 16kB ...
          atts       (chain, draw, team) float64 96kB ...
          defs       (chain, draw, team) float64 96kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:23.841916
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9
          sampling_time:              8.503105401992798
          tuning_steps:               1000

    • <xarray.Dataset> Size: 2MB
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
      Data variables:
          home_points  (chain, draw, match) int64 960kB ...
          away_points  (chain, draw, match) int64 960kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:25.689246
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 260kB
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 16kB ...
          sd_att     (chain, draw) float64 16kB ...
          sd_def     (chain, draw) float64 16kB ...
          intercept  (chain, draw) float64 16kB ...
          atts_star  (chain, draw, team) float64 96kB ...
          defs_star  (chain, draw, team) float64 96kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:24.377610
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 248kB
      Dimensions:                (chain: 4, draw: 500)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
      Data variables: (12/17)
          max_energy_error       (chain, draw) float64 16kB ...
          index_in_trajectory    (chain, draw) int64 16kB ...
          smallest_eigval        (chain, draw) float64 16kB ...
          perf_counter_start     (chain, draw) float64 16kB ...
          largest_eigval         (chain, draw) float64 16kB ...
          step_size              (chain, draw) float64 16kB ...
          ...                     ...
          reached_max_treedepth  (chain, draw) bool 2kB ...
          perf_counter_diff      (chain, draw) float64 16kB ...
          tree_depth             (chain, draw) int64 16kB ...
          process_time_diff      (chain, draw) float64 16kB ...
          step_size_bar          (chain, draw) float64 16kB ...
          energy                 (chain, draw) float64 16kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:23.854033
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9
          sampling_time:              8.503105401992798
          tuning_steps:               1000

    • <xarray.Dataset> Size: 116kB
      Dimensions:    (chain: 1, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 8B 0
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          atts_star  (chain, draw, team) float64 24kB ...
          sd_att     (chain, draw) float64 4kB ...
          atts       (chain, draw, team) float64 24kB ...
          sd_def     (chain, draw) float64 4kB ...
          defs       (chain, draw, team) float64 24kB ...
          intercept  (chain, draw) float64 4kB ...
          home       (chain, draw) float64 4kB ...
          defs_star  (chain, draw, team) float64 24kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.475945
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 492kB
      Dimensions:      (chain: 1, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 8B 0
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
      Data variables:
          away_points  (chain, draw, match) int64 240kB ...
          home_points  (chain, draw, match) int64 240kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.479330
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 9kB
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
      Data variables:
          home_points  (match) int64 480B ...
          away_points  (match) int64 480B ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.480812
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 36kB
      Dimensions:  (chain: 4, draw: 500)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          sd_att   (chain, draw) float64 16kB ...
          sd_def   (chain, draw) float64 16kB ...
      Attributes:
          sd_att:   pymc.logprob.transforms.LogTransform
          sd_def:   pymc.logprob.transforms.LogTransform

Knowing the model, we can compute it manually. In this case however, we will generate random samples with the right shape.

import numpy as np
rng = np.random.default_rng(73)
ary = rng.normal(size=(post.sizes["chain"], post.sizes["draw"], obs.sizes["match"]))
idata.add_groups(
    log_likelihood={"home_points": ary},
    dims={"home_points": ["match"]},
)
idata
arviz.InferenceData
    • <xarray.Dataset> Size: 452kB
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 16kB ...
          intercept  (chain, draw) float64 16kB ...
          atts_star  (chain, draw, team) float64 96kB ...
          defs_star  (chain, draw, team) float64 96kB ...
          sd_att     (chain, draw) float64 16kB ...
          sd_def     (chain, draw) float64 16kB ...
          atts       (chain, draw, team) float64 96kB ...
          defs       (chain, draw, team) float64 96kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:23.841916
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9
          sampling_time:              8.503105401992798
          tuning_steps:               1000

    • <xarray.Dataset> Size: 2MB
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
      Data variables:
          home_points  (chain, draw, match) int64 960kB ...
          away_points  (chain, draw, match) int64 960kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:25.689246
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 965kB
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * match        (match) int64 480B 0 1 2 3 4 5 6 7 ... 52 53 54 55 56 57 58 59
      Data variables:
          home_points  (chain, draw, match) float64 960kB -1.093 0.7781 ... 1.643
      Attributes:
          created_at:     2024-03-14T16:19:33.875702
          arviz_version:  0.18.0.dev0

    • <xarray.Dataset> Size: 260kB
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 16kB ...
          sd_att     (chain, draw) float64 16kB ...
          sd_def     (chain, draw) float64 16kB ...
          intercept  (chain, draw) float64 16kB ...
          atts_star  (chain, draw, team) float64 96kB ...
          defs_star  (chain, draw, team) float64 96kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:24.377610
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 248kB
      Dimensions:                (chain: 4, draw: 500)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
      Data variables: (12/17)
          max_energy_error       (chain, draw) float64 16kB ...
          index_in_trajectory    (chain, draw) int64 16kB ...
          smallest_eigval        (chain, draw) float64 16kB ...
          perf_counter_start     (chain, draw) float64 16kB ...
          largest_eigval         (chain, draw) float64 16kB ...
          step_size              (chain, draw) float64 16kB ...
          ...                     ...
          reached_max_treedepth  (chain, draw) bool 2kB ...
          perf_counter_diff      (chain, draw) float64 16kB ...
          tree_depth             (chain, draw) int64 16kB ...
          process_time_diff      (chain, draw) float64 16kB ...
          step_size_bar          (chain, draw) float64 16kB ...
          energy                 (chain, draw) float64 16kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:23.854033
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9
          sampling_time:              8.503105401992798
          tuning_steps:               1000

    • <xarray.Dataset> Size: 116kB
      Dimensions:    (chain: 1, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 8B 0
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          atts_star  (chain, draw, team) float64 24kB ...
          sd_att     (chain, draw) float64 4kB ...
          atts       (chain, draw, team) float64 24kB ...
          sd_def     (chain, draw) float64 4kB ...
          defs       (chain, draw, team) float64 24kB ...
          intercept  (chain, draw) float64 4kB ...
          home       (chain, draw) float64 4kB ...
          defs_star  (chain, draw, team) float64 24kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.475945
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 492kB
      Dimensions:      (chain: 1, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 8B 0
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
      Data variables:
          away_points  (chain, draw, match) int64 240kB ...
          home_points  (chain, draw, match) int64 240kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.479330
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 9kB
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
      Data variables:
          home_points  (match) int64 480B ...
          away_points  (match) int64 480B ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.480812
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 36kB
      Dimensions:  (chain: 4, draw: 500)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          sd_att   (chain, draw) float64 16kB ...
          sd_def   (chain, draw) float64 16kB ...
      Attributes:
          sd_att:   pymc.logprob.transforms.LogTransform
          sd_def:   pymc.logprob.transforms.LogTransform

This is fine if we have raw data, but a bit inconvenient if we start with labeled data already. Why provide dims and coords manually again? Let’s generate a fake log likelihood (doesn’t match the model but it serves just the same for illustration purposes here) working from the posterior and observed_data groups manually:

import xarray as xr
from xarray_einstats.stats import XrDiscreteRV
from scipy.stats import poisson
dist = XrDiscreteRV(poisson)
log_lik = xr.Dataset()
log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
idata2.add_groups({"log_likelihood": log_lik})
idata2
arviz.InferenceData
    • <xarray.Dataset> Size: 452kB
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 16kB ...
          intercept  (chain, draw) float64 16kB ...
          atts_star  (chain, draw, team) float64 96kB ...
          defs_star  (chain, draw, team) float64 96kB ...
          sd_att     (chain, draw) float64 16kB ...
          sd_def     (chain, draw) float64 16kB ...
          atts       (chain, draw, team) float64 96kB ...
          defs       (chain, draw, team) float64 96kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:23.841916
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9
          sampling_time:              8.503105401992798
          tuning_steps:               1000

    • <xarray.Dataset> Size: 2MB
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
      Data variables:
          home_points  (chain, draw, match) int64 960kB ...
          away_points  (chain, draw, match) int64 960kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:25.689246
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 6MB
      Dimensions:      (match: 60, chain: 4, draw: 500, team: 6)
      Coordinates:
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * team         (team) <U8 192B 'Wales' 'France' ... 'Italy' 'England'
      Data variables:
          home_points  (match, chain, draw, team) float64 6MB -48.59 -53.93 ... -19.41

    • <xarray.Dataset> Size: 260kB
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 16kB ...
          sd_att     (chain, draw) float64 16kB ...
          sd_def     (chain, draw) float64 16kB ...
          intercept  (chain, draw) float64 16kB ...
          atts_star  (chain, draw, team) float64 96kB ...
          defs_star  (chain, draw, team) float64 96kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:24.377610
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 248kB
      Dimensions:                (chain: 4, draw: 500)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
      Data variables: (12/17)
          max_energy_error       (chain, draw) float64 16kB ...
          index_in_trajectory    (chain, draw) int64 16kB ...
          smallest_eigval        (chain, draw) float64 16kB ...
          perf_counter_start     (chain, draw) float64 16kB ...
          largest_eigval         (chain, draw) float64 16kB ...
          step_size              (chain, draw) float64 16kB ...
          ...                     ...
          reached_max_treedepth  (chain, draw) bool 2kB ...
          perf_counter_diff      (chain, draw) float64 16kB ...
          tree_depth             (chain, draw) int64 16kB ...
          process_time_diff      (chain, draw) float64 16kB ...
          step_size_bar          (chain, draw) float64 16kB ...
          energy                 (chain, draw) float64 16kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:23.854033
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9
          sampling_time:              8.503105401992798
          tuning_steps:               1000

    • <xarray.Dataset> Size: 116kB
      Dimensions:    (chain: 1, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 8B 0
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          atts_star  (chain, draw, team) float64 24kB ...
          sd_att     (chain, draw) float64 4kB ...
          atts       (chain, draw, team) float64 24kB ...
          sd_def     (chain, draw) float64 4kB ...
          defs       (chain, draw, team) float64 24kB ...
          intercept  (chain, draw) float64 4kB ...
          home       (chain, draw) float64 4kB ...
          defs_star  (chain, draw, team) float64 24kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.475945
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 492kB
      Dimensions:      (chain: 1, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 8B 0
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
      Data variables:
          away_points  (chain, draw, match) int64 240kB ...
          home_points  (chain, draw, match) int64 240kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.479330
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 9kB
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
      Data variables:
          home_points  (match) int64 480B ...
          away_points  (match) int64 480B ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.480812
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 36kB
      Dimensions:  (chain: 4, draw: 500)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          sd_att   (chain, draw) float64 16kB ...
          sd_def   (chain, draw) float64 16kB ...
      Attributes:
          sd_att:   pymc.logprob.transforms.LogTransform
          sd_def:   pymc.logprob.transforms.LogTransform

Note that in the first example we have used the kwargs argument and in the second we have used the group_dict one.