arviz.InferenceData.add_groups#
- InferenceData.add_groups(group_dict=None, coords=None, dims=None, **kwargs)[source]#
Add new groups to InferenceData object.
- Parameters
- group_dictdict of {strdict or xarray.Dataset}, optional
Groups to be added
- coordsdict of {strarray_like}, optional
Coordinates for the dataset
- dimsdict of {strlist of str}, optional
Dimensions of each variable. The keys are variable names, values are lists of coordinates.
- kwargsdict, optional
The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
See also
Examples
Add a
log_likelihood
group to the “rugby” example InferenceData after loading. It originally doesn’t have thelog_likelihood
group:import arviz as az idata = az.load_arviz_data("rugby") idata2 = idata.copy() post = idata.posterior obs = idata.observed_data idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, team: 6) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265 intercept (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892 atts_star (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878 defs_star (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649 sd_att (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591 sd_def (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849 atts (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029 defs (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986 Attributes: created_at: 2019-07-12T20:31:53.545143 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (chain, draw, match) int64 ... away_points (chain, draw, match) int64 ... Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: energy_error (chain, draw) float64 -0.07666 -0.4523 ... 0.115 -0.07691 energy (chain, draw) float64 540.2 545.3 542.3 ... 544.0 545.6 tree_size (chain, draw) float64 15.0 63.0 31.0 ... 63.0 31.0 31.0 tune (chain, draw) bool True False False ... False False False mean_tree_accept (chain, draw) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539 lp (chain, draw) float64 -536.4 -536.0 ... -536.1 -536.4 depth (chain, draw) int64 4 6 5 4 4 4 5 5 5 ... 6 4 6 5 3 6 5 5 max_energy_error (chain, draw) float64 -0.5361 -0.5871 ... 0.7109 1.014 step_size (chain, draw) float64 0.2469 0.2469 ... 0.2459 0.2459 step_size_bar (chain, draw) float64 0.2313 0.2313 ... 0.2488 0.2488 diverging (chain, draw) bool False False False ... False False False Attributes: created_at: 2019-07-12T20:31:53.555203 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, team: 6, match: 60) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: sd_att_log__ (chain, draw) float64 1.322 -2.014 1.588 ... -0.8585 -0.1922 intercept (chain, draw) float64 4.464 3.352 1.567 ... 4.363 4.128 1.049 atts_star (chain, draw, team) float64 -2.64 4.172 ... -0.2874 -0.8538 defs_star (chain, draw, team) float64 -0.7817 -0.1478 ... 0.1655 0.01067 away_points (chain, draw, match) int64 11308 0 11 1 0 21442 ... 11 1 2 2 0 sd_att (chain, draw) float64 3.752 0.1334 4.896 ... 0.4238 0.8251 sd_def_log__ (chain, draw) float64 -0.2662 0.2411 0.6071 ... 1.402 -1.981 home (chain, draw) float64 -1.511 -0.001582 ... -0.02416 0.2651 atts (chain, draw, team) float64 -4.667 2.145 ... -0.2702 -0.8365 sd_def (chain, draw) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138 home_points (chain, draw, match) int64 0 47 11899 3262 1 ... 3 2 1 12 13 defs (chain, draw, team) float64 -0.2517 0.3823 ... 0.089 -0.06586 Attributes: created_at: 2019-07-12T20:31:53.573731 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0 away_points (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0 Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7
Knowing the model, we can compute it manually. In this case however, we will generate random samples with the right shape.
import numpy as np rng = np.random.default_rng(73) ary = rng.normal(size=(post.dims["chain"], post.dims["draw"], obs.dims["match"])) idata.add_groups( log_likelihood={"home_points": ary}, dims={"home_points": ["match"]}, ) idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, team: 6) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265 intercept (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892 atts_star (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878 defs_star (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649 sd_att (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591 sd_def (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849 atts (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029 defs (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986 Attributes: created_at: 2019-07-12T20:31:53.545143 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (chain, draw, match) int64 ... away_points (chain, draw, match) int64 ... Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) int64 0 1 2 3 4 5 6 7 8 ... 51 52 53 54 55 56 57 58 59 Data variables: home_points (chain, draw, match) float64 -1.093 0.7781 ... 0.2405 1.643 Attributes: created_at: 2022-10-24T16:32:42.132111 arviz_version: 0.13.0
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: energy_error (chain, draw) float64 -0.07666 -0.4523 ... 0.115 -0.07691 energy (chain, draw) float64 540.2 545.3 542.3 ... 544.0 545.6 tree_size (chain, draw) float64 15.0 63.0 31.0 ... 63.0 31.0 31.0 tune (chain, draw) bool True False False ... False False False mean_tree_accept (chain, draw) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539 lp (chain, draw) float64 -536.4 -536.0 ... -536.1 -536.4 depth (chain, draw) int64 4 6 5 4 4 4 5 5 5 ... 6 4 6 5 3 6 5 5 max_energy_error (chain, draw) float64 -0.5361 -0.5871 ... 0.7109 1.014 step_size (chain, draw) float64 0.2469 0.2469 ... 0.2459 0.2459 step_size_bar (chain, draw) float64 0.2313 0.2313 ... 0.2488 0.2488 diverging (chain, draw) bool False False False ... False False False Attributes: created_at: 2019-07-12T20:31:53.555203 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, team: 6, match: 60) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: sd_att_log__ (chain, draw) float64 1.322 -2.014 1.588 ... -0.8585 -0.1922 intercept (chain, draw) float64 4.464 3.352 1.567 ... 4.363 4.128 1.049 atts_star (chain, draw, team) float64 -2.64 4.172 ... -0.2874 -0.8538 defs_star (chain, draw, team) float64 -0.7817 -0.1478 ... 0.1655 0.01067 away_points (chain, draw, match) int64 11308 0 11 1 0 21442 ... 11 1 2 2 0 sd_att (chain, draw) float64 3.752 0.1334 4.896 ... 0.4238 0.8251 sd_def_log__ (chain, draw) float64 -0.2662 0.2411 0.6071 ... 1.402 -1.981 home (chain, draw) float64 -1.511 -0.001582 ... -0.02416 0.2651 atts (chain, draw, team) float64 -4.667 2.145 ... -0.2702 -0.8365 sd_def (chain, draw) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138 home_points (chain, draw, match) int64 0 47 11899 3262 1 ... 3 2 1 12 13 defs (chain, draw, team) float64 -0.2517 0.3823 ... 0.089 -0.06586 Attributes: created_at: 2019-07-12T20:31:53.573731 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0 away_points (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0 Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7
This is fine if we have raw data, but a bit inconvenient if we start with labeled data already. Why provide dims and coords manually again? Let’s generate a fake log likelihood (doesn’t match the model but it serves just the same for illustration purposes here) working from the posterior and observed_data groups manually:
import xarray as xr from xarray_einstats.stats import XrDiscreteRV from scipy.stats import poisson dist = XrDiscreteRV(poisson) log_lik = xr.Dataset() log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"])) idata2.add_groups({"log_likelihood": log_lik}) idata2
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, team: 6) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265 intercept (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892 atts_star (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878 defs_star (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649 sd_att (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591 sd_def (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849 atts (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029 defs (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986 Attributes: created_at: 2019-07-12T20:31:53.545143 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (chain, draw, match) int64 ... away_points (chain, draw, match) int64 ... Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60, chain: 4, draw: 500, team: 6) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home_points (match, chain, draw, team) float64 -50.27 -53.03 ... -21.14
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: energy_error (chain, draw) float64 -0.07666 -0.4523 ... 0.115 -0.07691 energy (chain, draw) float64 540.2 545.3 542.3 ... 544.0 545.6 tree_size (chain, draw) float64 15.0 63.0 31.0 ... 63.0 31.0 31.0 tune (chain, draw) bool True False False ... False False False mean_tree_accept (chain, draw) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539 lp (chain, draw) float64 -536.4 -536.0 ... -536.1 -536.4 depth (chain, draw) int64 4 6 5 4 4 4 5 5 5 ... 6 4 6 5 3 6 5 5 max_energy_error (chain, draw) float64 -0.5361 -0.5871 ... 0.7109 1.014 step_size (chain, draw) float64 0.2469 0.2469 ... 0.2459 0.2459 step_size_bar (chain, draw) float64 0.2313 0.2313 ... 0.2488 0.2488 diverging (chain, draw) bool False False False ... False False False Attributes: created_at: 2019-07-12T20:31:53.555203 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, team: 6, match: 60) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: sd_att_log__ (chain, draw) float64 1.322 -2.014 1.588 ... -0.8585 -0.1922 intercept (chain, draw) float64 4.464 3.352 1.567 ... 4.363 4.128 1.049 atts_star (chain, draw, team) float64 -2.64 4.172 ... -0.2874 -0.8538 defs_star (chain, draw, team) float64 -0.7817 -0.1478 ... 0.1655 0.01067 away_points (chain, draw, match) int64 11308 0 11 1 0 21442 ... 11 1 2 2 0 sd_att (chain, draw) float64 3.752 0.1334 4.896 ... 0.4238 0.8251 sd_def_log__ (chain, draw) float64 -0.2662 0.2411 0.6071 ... 1.402 -1.981 home (chain, draw) float64 -1.511 -0.001582 ... -0.02416 0.2651 atts (chain, draw, team) float64 -4.667 2.145 ... -0.2702 -0.8365 sd_def (chain, draw) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138 home_points (chain, draw, match) int64 0 47 11899 3262 1 ... 3 2 1 12 13 defs (chain, draw, team) float64 -0.2517 0.3823 ... 0.089 -0.06586 Attributes: created_at: 2019-07-12T20:31:53.573731 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0 away_points (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0 Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7
Note that in the first example we have used the
kwargs
argument and in the second we have used thegroup_dict
one.