arviz.InferenceData.stack#
- InferenceData.stack(dimensions=None, groups=None, filter_groups=None, inplace=False, **kwargs)[source]#
Perform an xarray stacking on all groups.
Stack any number of existing dimensions into a single new dimension. Loops groups to perform Dataset.stack(key=value) for every kwarg if value is a dimension of the dataset. The selection is performed on all relevant groups (like posterior, prior, sample stats) while non relevant groups like observed data are omitted. See
xarray.Dataset.stack()
- Parameters:
- dimensions
dict
, optional Names of new dimensions, and the existing dimensions that they replace.
- groups: str or list of str, optional
Groups where the selection is to be applied. Can either be group names or metagroup names.
- filter_groups{
None
, “like”, “regex”}, optional If
None
(default), interpret groups as the real group or metagroup names. If “like”, interpret groups as substrings of the real group or metagroup names. If “regex”, interpret groups as regular expressions on the real group or metagroup names. A lapandas.filter
.- inplacebool, optional
If
True
, modify the InferenceData object inplace, otherwise, return the modified copy.- kwargs
dict
, optional It must be accepted by
xarray.Dataset.stack()
.
- dimensions
- Returns:
InferenceData
A new InferenceData object by default. When
inplace==True
perform selection in-place and returnNone
See also
xarray.Dataset.stack
Stack any number of existing dimensions into a single new dimension.
unstack
Perform an xarray unstacking on all groups of InferenceData object.
Examples
Use
stack
to stack any number of existing dimensions into a single new dimension. We first check the original object:import arviz as az idata = az.load_arviz_data("rugby") idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, team: 6) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home (chain, draw) float64 ... intercept (chain, draw) float64 ... atts_star (chain, draw, team) float64 ... defs_star (chain, draw, team) float64 ... sd_att (chain, draw) float64 ... sd_def (chain, draw) float64 ... atts (chain, draw, team) float64 ... defs (chain, draw, team) float64 ... Attributes: created_at: 2019-07-12T20:31:53.545143 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (chain, draw, match) int64 ... away_points (chain, draw, match) int64 ... Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: energy_error (chain, draw) float64 ... energy (chain, draw) float64 ... tree_size (chain, draw) float64 ... tune (chain, draw) bool ... mean_tree_accept (chain, draw) float64 ... lp (chain, draw) float64 ... depth (chain, draw) int64 ... max_energy_error (chain, draw) float64 ... step_size (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... diverging (chain, draw) bool ... Attributes: created_at: 2019-07-12T20:31:53.555203 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, team: 6, match: 60) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: sd_att_log__ (chain, draw) float64 ... intercept (chain, draw) float64 ... atts_star (chain, draw, team) float64 ... defs_star (chain, draw, team) float64 ... away_points (chain, draw, match) int64 ... sd_att (chain, draw) float64 ... sd_def_log__ (chain, draw) float64 ... home (chain, draw) float64 ... atts (chain, draw, team) float64 ... sd_def (chain, draw) float64 ... home_points (chain, draw, match) int64 ... defs (chain, draw, team) float64 ... Attributes: created_at: 2019-07-12T20:31:53.573731 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (match) float64 ... away_points (match) float64 ... Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7
In order to stack two dimensions
chain
anddraw
tosample
, we can use:idata.stack(sample=["chain", "draw"], inplace=True) idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (team: 6, sample: 2000) Coordinates: * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: home (sample) float64 0.1642 0.1162 0.09299 ... 0.1452 0.148 0.2265 intercept (sample) float64 2.893 2.941 2.939 2.977 ... 2.951 2.903 2.892 atts_star (team, sample) float64 0.1673 0.226 0.1959 ... -0.01013 0.02878 defs_star (team, sample) float64 -0.03638 0.01689 ... -0.06325 -0.0649 sd_att (sample) float64 0.4854 0.1438 0.2139 ... 0.4472 0.2883 0.4591 sd_def (sample) float64 0.2747 1.033 0.6363 ... 0.3294 0.5574 0.2849 atts (team, sample) float64 0.1063 0.1538 0.1781 ... 0.2923 0.2029 defs (team, sample) float64 -0.06765 -0.1792 ... -0.2033 -0.1986 Attributes: created_at: 2019-07-12T20:31:53.545143 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60, sample: 2000) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: home_points (match, sample) int64 43 43 42 45 43 49 ... 20 14 21 29 26 27 away_points (match, sample) int64 7 14 9 15 10 12 8 ... 12 14 16 18 20 12 Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (sample: 2000) Coordinates: * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: energy_error (sample) float64 -0.07666 -0.4523 ... 0.115 -0.07691 energy (sample) float64 540.2 545.3 542.3 ... 544.0 544.0 545.6 tree_size (sample) float64 15.0 63.0 31.0 15.0 ... 63.0 31.0 31.0 tune (sample) bool True False False False ... False False False mean_tree_accept (sample) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539 lp (sample) float64 -536.4 -536.0 -533.8 ... -536.1 -536.4 depth (sample) int64 4 6 5 4 4 4 5 5 5 3 ... 4 6 6 4 6 5 3 6 5 5 max_energy_error (sample) float64 -0.5361 -0.5871 0.3981 ... 0.7109 1.014 step_size (sample) float64 0.2469 0.2469 0.2469 ... 0.2459 0.2459 step_size_bar (sample) float64 0.2313 0.2313 0.2313 ... 0.2488 0.2488 diverging (sample) bool False False False ... False False False Attributes: created_at: 2019-07-12T20:31:53.555203 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (team: 6, match: 60, sample: 500) Coordinates: * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * match (match) object 'Wales Italy' ... 'Ireland England' * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 * draw (sample) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: sd_att_log__ (sample) float64 1.322 -2.014 1.588 ... 0.7974 -0.8585 -0.1922 intercept (sample) float64 4.464 3.352 1.567 3.897 ... 4.363 4.128 1.049 atts_star (team, sample) float64 -2.64 -0.04968 ... -0.01713 -0.8538 defs_star (team, sample) float64 -0.7817 0.9151 2.07 ... -1.922 0.01067 away_points (match, sample) int64 11308 61 405 346 41 ... 4 15 102 259 0 sd_att (sample) float64 3.752 0.1334 4.896 ... 2.22 0.4238 0.8251 sd_def_log__ (sample) float64 -0.2662 0.2411 0.6071 ... 1.367 1.402 -1.981 home (sample) float64 -1.511 -0.001582 1.75 ... -0.02416 0.2651 atts (team, sample) float64 -4.667 0.03653 ... -0.1798 -0.8365 sd_def (sample) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138 home_points (match, sample) int64 0 16 0 3 66 16 27 ... 71 0 0 5 50 71 13 defs (team, sample) float64 -0.2517 0.8887 ... -0.1544 -0.06586 Attributes: created_at: 2019-07-12T20:31:53.573731 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: home_points (match) float64 ... away_points (match) float64 ... Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7
We can also take the example of custom InferenceData object and perform stacking. We first check the original object:
import numpy as np datadict = { "a": np.random.randn(100), "b": np.random.randn(1, 100, 10), "c": np.random.randn(1, 100, 3, 4), } coords = { "c1": np.arange(3), "c99": np.arange(4), "b1": np.arange(10), } dims = {"c": ["c1", "c99"], "b": ["b1"]} idata = az.from_dict( posterior=datadict, posterior_predictive=datadict, coords=coords, dims=dims ) idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 1, draw: 100, b1: 10, c1: 3, c99: 4) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 * b1 (b1) int64 0 1 2 3 4 5 6 7 8 9 * c1 (c1) int64 0 1 2 * c99 (c99) int64 0 1 2 3 Data variables: a (chain, draw) float64 -2.177 -0.1614 1.233 ... -0.495 -0.1453 b (chain, draw, b1) float64 1.731 -1.426 -0.2903 ... -0.04606 -0.5633 c (chain, draw, c1, c99) float64 0.8478 1.754 ... -1.571 -1.495 Attributes: created_at: 2023-11-17T19:39:33.325350 arviz_version: 0.17.0.dev0
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 100, b1: 10, c1: 3, c99: 4) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 * b1 (b1) int64 0 1 2 3 4 5 6 7 8 9 * c1 (c1) int64 0 1 2 * c99 (c99) int64 0 1 2 3 Data variables: a (chain, draw) float64 -2.177 -0.1614 1.233 ... -0.495 -0.1453 b (chain, draw, b1) float64 1.731 -1.426 -0.2903 ... -0.04606 -0.5633 c (chain, draw, c1, c99) float64 0.8478 1.754 ... -1.571 -1.495 Attributes: created_at: 2023-11-17T19:39:33.328993 arviz_version: 0.17.0.dev0
In order to stack two dimensions
c1
andc99
toz
, we can use:idata.stack(z=["c1", "c99"], inplace=True) idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 1, draw: 100, b1: 10, z: 12) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 * b1 (b1) int64 0 1 2 3 4 5 6 7 8 9 * z (z) object MultiIndex * c1 (z) int64 0 0 0 0 1 1 1 1 2 2 2 2 * c99 (z) int64 0 1 2 3 0 1 2 3 0 1 2 3 Data variables: a (chain, draw) float64 -2.177 -0.1614 1.233 ... -0.495 -0.1453 b (chain, draw, b1) float64 1.731 -1.426 -0.2903 ... -0.04606 -0.5633 c (chain, draw, z) float64 0.8478 1.754 -1.463 ... -1.571 -1.495 Attributes: created_at: 2023-11-17T19:39:33.325350 arviz_version: 0.17.0.dev0
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 100, b1: 10, z: 12) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99 * b1 (b1) int64 0 1 2 3 4 5 6 7 8 9 * z (z) object MultiIndex * c1 (z) int64 0 0 0 0 1 1 1 1 2 2 2 2 * c99 (z) int64 0 1 2 3 0 1 2 3 0 1 2 3 Data variables: a (chain, draw) float64 -2.177 -0.1614 1.233 ... -0.495 -0.1453 b (chain, draw, b1) float64 1.731 -1.426 -0.2903 ... -0.04606 -0.5633 c (chain, draw, z) float64 0.8478 1.754 -1.463 ... -1.571 -1.495 Attributes: created_at: 2023-11-17T19:39:33.328993 arviz_version: 0.17.0.dev0