arviz.InferenceData.add_groups#
- InferenceData.add_groups(group_dict=None, coords=None, dims=None, **kwargs)[source]#
Add new groups to InferenceData object.
- Parameters
- group_dict
dict
of {str
dict
orxarray.Dataset
}, optional Groups to be added
- coords
dict
of {str
array_like}, optional Coordinates for the dataset
- dims
dict
of {str
list
ofstr
}, optional Dimensions of each variable. The keys are variable names, values are lists of coordinates.
- kwargs
dict
, optional The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
- group_dict
See also
Examples
Add a
log_likelihood
group to the “rugby” example InferenceData after loading. It originally doesn’t have thelog_likelihood
group:import arviz as az idata = az.load_arviz_data("rugby") idata2 = idata.copy() post = idata.posterior obs = idata.observed_data idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, team: 6) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home (chain, draw) float64 ... intercept (chain, draw) float64 ... atts_star (chain, draw, team) float64 ... defs_star (chain, draw, team) float64 ... sd_att (chain, draw) float64 ... sd_def (chain, draw) float64 ... atts (chain, draw, team) float64 ... defs (chain, draw, team) float64 ... Attributes: created_at: 2019-07-12T20:31:53.545143 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (chain, draw, match) int64 ... away_points (chain, draw, match) int64 ... Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: energy_error (chain, draw) float64 ... energy (chain, draw) float64 ... tree_size (chain, draw) float64 ... tune (chain, draw) bool ... mean_tree_accept (chain, draw) float64 ... lp (chain, draw) float64 ... depth (chain, draw) int64 ... max_energy_error (chain, draw) float64 ... step_size (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... diverging (chain, draw) bool ... Attributes: created_at: 2019-07-12T20:31:53.555203 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, team: 6, match: 60) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: sd_att_log__ (chain, draw) float64 ... intercept (chain, draw) float64 ... atts_star (chain, draw, team) float64 ... defs_star (chain, draw, team) float64 ... away_points (chain, draw, match) int64 ... sd_att (chain, draw) float64 ... sd_def_log__ (chain, draw) float64 ... home (chain, draw) float64 ... atts (chain, draw, team) float64 ... sd_def (chain, draw) float64 ... home_points (chain, draw, match) int64 ... defs (chain, draw, team) float64 ... Attributes: created_at: 2019-07-12T20:31:53.573731 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (match) float64 ... away_points (match) float64 ... Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7
Knowing the model, we can compute it manually. In this case however, we will generate random samples with the right shape.
import numpy as np rng = np.random.default_rng(73) ary = rng.normal(size=(post.dims["chain"], post.dims["draw"], obs.dims["match"])) idata.add_groups( log_likelihood={"home_points": ary}, dims={"home_points": ["match"]}, ) idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, team: 6) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home (chain, draw) float64 ... intercept (chain, draw) float64 ... atts_star (chain, draw, team) float64 ... defs_star (chain, draw, team) float64 ... sd_att (chain, draw) float64 ... sd_def (chain, draw) float64 ... atts (chain, draw, team) float64 ... defs (chain, draw, team) float64 ... Attributes: created_at: 2019-07-12T20:31:53.545143 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (chain, draw, match) int64 ... away_points (chain, draw, match) int64 ... Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) int64 0 1 2 3 4 5 6 7 8 ... 51 52 53 54 55 56 57 58 59 Data variables: home_points (chain, draw, match) float64 -1.093 0.7781 ... 0.2405 1.643 Attributes: created_at: 2023-07-18T19:52:42.552711 arviz_version: 0.16.1
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: energy_error (chain, draw) float64 ... energy (chain, draw) float64 ... tree_size (chain, draw) float64 ... tune (chain, draw) bool ... mean_tree_accept (chain, draw) float64 ... lp (chain, draw) float64 ... depth (chain, draw) int64 ... max_energy_error (chain, draw) float64 ... step_size (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... diverging (chain, draw) bool ... Attributes: created_at: 2019-07-12T20:31:53.555203 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, team: 6, match: 60) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: sd_att_log__ (chain, draw) float64 ... intercept (chain, draw) float64 ... atts_star (chain, draw, team) float64 ... defs_star (chain, draw, team) float64 ... away_points (chain, draw, match) int64 ... sd_att (chain, draw) float64 ... sd_def_log__ (chain, draw) float64 ... home (chain, draw) float64 ... atts (chain, draw, team) float64 ... sd_def (chain, draw) float64 ... home_points (chain, draw, match) int64 ... defs (chain, draw, team) float64 ... Attributes: created_at: 2019-07-12T20:31:53.573731 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (match) float64 ... away_points (match) float64 ... Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7
This is fine if we have raw data, but a bit inconvenient if we start with labeled data already. Why provide dims and coords manually again? Let’s generate a fake log likelihood (doesn’t match the model but it serves just the same for illustration purposes here) working from the posterior and observed_data groups manually:
import xarray as xr from xarray_einstats.stats import XrDiscreteRV from scipy.stats import poisson dist = XrDiscreteRV(poisson) log_lik = xr.Dataset() log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"])) idata2.add_groups({"log_likelihood": log_lik}) idata2
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, team: 6) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home (chain, draw) float64 ... intercept (chain, draw) float64 ... atts_star (chain, draw, team) float64 ... defs_star (chain, draw, team) float64 ... sd_att (chain, draw) float64 ... sd_def (chain, draw) float64 ... atts (chain, draw, team) float64 ... defs (chain, draw, team) float64 ... Attributes: created_at: 2019-07-12T20:31:53.545143 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (chain, draw, match) int64 ... away_points (chain, draw, match) int64 ... Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60, chain: 4, draw: 500, team: 6) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home_points (match, chain, draw, team) float64 -50.27 -53.03 ... -21.14
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: energy_error (chain, draw) float64 ... energy (chain, draw) float64 ... tree_size (chain, draw) float64 ... tune (chain, draw) bool ... mean_tree_accept (chain, draw) float64 ... lp (chain, draw) float64 ... depth (chain, draw) int64 ... max_energy_error (chain, draw) float64 ... step_size (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... diverging (chain, draw) bool ... Attributes: created_at: 2019-07-12T20:31:53.555203 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, team: 6, match: 60) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: sd_att_log__ (chain, draw) float64 ... intercept (chain, draw) float64 ... atts_star (chain, draw, team) float64 ... defs_star (chain, draw, team) float64 ... away_points (chain, draw, match) int64 ... sd_att (chain, draw) float64 ... sd_def_log__ (chain, draw) float64 ... home (chain, draw) float64 ... atts (chain, draw, team) float64 ... sd_def (chain, draw) float64 ... home_points (chain, draw, match) int64 ... defs (chain, draw, team) float64 ... Attributes: created_at: 2019-07-12T20:31:53.573731 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (match) float64 ... away_points (match) float64 ... Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7
Note that in the first example we have used the
kwargs
argument and in the second we have used thegroup_dict
one.