arviz.InferenceData.stack#

InferenceData.stack(dimensions=None, groups=None, filter_groups=None, inplace=False, **kwargs)[source]#

Perform an xarray stacking on all groups.

Stack any number of existing dimensions into a single new dimension. Loops groups to perform Dataset.stack(key=value) for every kwarg if value is a dimension of the dataset. The selection is performed on all relevant groups (like posterior, prior, sample stats) while non relevant groups like observed data are omitted. See xarray.Dataset.stack()

Parameters
dimensionsdict, optional

Names of new dimensions, and the existing dimensions that they replace.

groups: str or list of str, optional

Groups where the selection is to be applied. Can either be group names or metagroup names.

filter_groups{None, “like”, “regex”}, optional

If None (default), interpret groups as the real group or metagroup names. If “like”, interpret groups as substrings of the real group or metagroup names. If “regex”, interpret groups as regular expressions on the real group or metagroup names. A la pandas.filter.

inplacebool, optional

If True, modify the InferenceData object inplace, otherwise, return the modified copy.

kwargsdict, optional

It must be accepted by xarray.Dataset.stack().

Returns
InferenceData

A new InferenceData object by default. When inplace==True perform selection in-place and return None

See also

xarray.Dataset.stack

Stack any number of existing dimensions into a single new dimension.

unstack

Perform an xarray unstacking on all groups of InferenceData object.

Examples

Use stack to stack any number of existing dimensions into a single new dimension. We first check the original object:

import arviz as az
idata = az.load_arviz_data("rugby")
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265
          intercept  (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892
          atts_star  (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878
          defs_star  (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649
          sd_att     (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591
          sd_def     (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849
          atts       (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029
          defs       (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (chain, draw, match) int64 ...
          away_points  (chain, draw, match) int64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          energy_error      (chain, draw) float64 -0.07666 -0.4523 ... 0.115 -0.07691
          energy            (chain, draw) float64 540.2 545.3 542.3 ... 544.0 545.6
          tree_size         (chain, draw) float64 15.0 63.0 31.0 ... 63.0 31.0 31.0
          tune              (chain, draw) bool True False False ... False False False
          mean_tree_accept  (chain, draw) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539
          lp                (chain, draw) float64 -536.4 -536.0 ... -536.1 -536.4
          depth             (chain, draw) int64 4 6 5 4 4 4 5 5 5 ... 6 4 6 5 3 6 5 5
          max_energy_error  (chain, draw) float64 -0.5361 -0.5871 ... 0.7109 1.014
          step_size         (chain, draw) float64 0.2469 0.2469 ... 0.2459 0.2459
          step_size_bar     (chain, draw) float64 0.2313 0.2313 ... 0.2488 0.2488
          diverging         (chain, draw) bool False False False ... False False False
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (chain: 1, draw: 500, team: 6, match: 60)
      Coordinates:
        * chain         (chain) int64 0
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          sd_att_log__  (chain, draw) float64 1.322 -2.014 1.588 ... -0.8585 -0.1922
          intercept     (chain, draw) float64 4.464 3.352 1.567 ... 4.363 4.128 1.049
          atts_star     (chain, draw, team) float64 -2.64 4.172 ... -0.2874 -0.8538
          defs_star     (chain, draw, team) float64 -0.7817 -0.1478 ... 0.1655 0.01067
          away_points   (chain, draw, match) int64 11308 0 11 1 0 21442 ... 11 1 2 2 0
          sd_att        (chain, draw) float64 3.752 0.1334 4.896 ... 0.4238 0.8251
          sd_def_log__  (chain, draw) float64 -0.2662 0.2411 0.6071 ... 1.402 -1.981
          home          (chain, draw) float64 -1.511 -0.001582 ... -0.02416 0.2651
          atts          (chain, draw, team) float64 -4.667 2.145 ... -0.2702 -0.8365
          sd_def        (chain, draw) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138
          home_points   (chain, draw, match) int64 0 47 11899 3262 1 ... 3 2 1 12 13
          defs          (chain, draw, team) float64 -0.2517 0.3823 ... 0.089 -0.06586
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0
          away_points  (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7

In order to stack two dimensions chain and draw to sample, we can use:

idata.stack(sample=["chain", "draw"], inplace=True)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (team: 6, sample: 2000)
      Coordinates:
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * sample     (sample) object MultiIndex
        * chain      (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3
        * draw       (sample) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          home       (sample) float64 0.1642 0.1162 0.09299 ... 0.1452 0.148 0.2265
          intercept  (sample) float64 2.893 2.941 2.939 2.977 ... 2.951 2.903 2.892
          atts_star  (team, sample) float64 0.1673 0.226 0.1959 ... -0.01013 0.02878
          defs_star  (team, sample) float64 -0.03638 0.01689 ... -0.06325 -0.0649
          sd_att     (sample) float64 0.4854 0.1438 0.2139 ... 0.4472 0.2883 0.4591
          sd_def     (sample) float64 0.2747 1.033 0.6363 ... 0.3294 0.5574 0.2849
          atts       (team, sample) float64 0.1063 0.1538 0.1781 ... 0.2923 0.2029
          defs       (team, sample) float64 -0.06765 -0.1792 ... -0.2033 -0.1986
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60, sample: 2000)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
        * sample       (sample) object MultiIndex
        * chain        (sample) int64 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3
        * draw         (sample) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          home_points  (match, sample) int64 43 43 42 45 43 49 ... 20 14 21 29 26 27
          away_points  (match, sample) int64 7 14 9 15 10 12 8 ... 12 14 16 18 20 12
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (sample: 2000)
      Coordinates:
        * sample            (sample) object MultiIndex
        * chain             (sample) int64 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3
        * draw              (sample) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
      Data variables:
          energy_error      (sample) float64 -0.07666 -0.4523 ... 0.115 -0.07691
          energy            (sample) float64 540.2 545.3 542.3 ... 544.0 544.0 545.6
          tree_size         (sample) float64 15.0 63.0 31.0 15.0 ... 63.0 31.0 31.0
          tune              (sample) bool True False False False ... False False False
          mean_tree_accept  (sample) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539
          lp                (sample) float64 -536.4 -536.0 -533.8 ... -536.1 -536.4
          depth             (sample) int64 4 6 5 4 4 4 5 5 5 3 ... 4 6 6 4 6 5 3 6 5 5
          max_energy_error  (sample) float64 -0.5361 -0.5871 0.3981 ... 0.7109 1.014
          step_size         (sample) float64 0.2469 0.2469 0.2469 ... 0.2459 0.2459
          step_size_bar     (sample) float64 0.2313 0.2313 0.2313 ... 0.2488 0.2488
          diverging         (sample) bool False False False ... False False False
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (team: 6, match: 60, sample: 500)
      Coordinates:
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
        * sample        (sample) object MultiIndex
        * chain         (sample) int64 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0
        * draw          (sample) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          sd_att_log__  (sample) float64 1.322 -2.014 1.588 ... 0.7974 -0.8585 -0.1922
          intercept     (sample) float64 4.464 3.352 1.567 3.897 ... 4.363 4.128 1.049
          atts_star     (team, sample) float64 -2.64 -0.04968 ... -0.01713 -0.8538
          defs_star     (team, sample) float64 -0.7817 0.9151 2.07 ... -1.922 0.01067
          away_points   (match, sample) int64 11308 61 405 346 41 ... 4 15 102 259 0
          sd_att        (sample) float64 3.752 0.1334 4.896 ... 2.22 0.4238 0.8251
          sd_def_log__  (sample) float64 -0.2662 0.2411 0.6071 ... 1.367 1.402 -1.981
          home          (sample) float64 -1.511 -0.001582 1.75 ... -0.02416 0.2651
          atts          (team, sample) float64 -4.667 0.03653 ... -0.1798 -0.8365
          sd_def        (sample) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138
          home_points   (match, sample) int64 0 16 0 3 66 16 27 ... 71 0 0 5 50 71 13
          defs          (team, sample) float64 -0.2517 0.8887 ... -0.1544 -0.06586
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0
          away_points  (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7

We can also take the example of custom InferenceData object and perform stacking. We first check the original object:

import numpy as np
datadict = {
    "a": np.random.randn(100),
    "b": np.random.randn(1, 100, 10),
    "c": np.random.randn(1, 100, 3, 4),
}
coords = {
    "c1": np.arange(3),
    "c99": np.arange(4),
    "b1": np.arange(10),
}
dims = {"c": ["c1", "c99"], "b": ["b1"]}
idata = az.from_dict(
    posterior=datadict, posterior_predictive=datadict, coords=coords, dims=dims
)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 100, b1: 10, c1: 3, c99: 4)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
        * b1       (b1) int64 0 1 2 3 4 5 6 7 8 9
        * c1       (c1) int64 0 1 2
        * c99      (c99) int64 0 1 2 3
      Data variables:
          a        (chain, draw) float64 0.4054 -0.836 -0.8187 ... 1.002 0.757 -0.3649
          b        (chain, draw, b1) float64 1.021 -1.392 0.1199 ... 0.06372 -0.5689
          c        (chain, draw, c1, c99) float64 1.321 0.3537 -0.708 ... 1.248 1.658
      Attributes:
          created_at:     2022-10-02T06:41:28.593808
          arviz_version:  0.13.0.dev0

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 100, b1: 10, c1: 3, c99: 4)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
        * b1       (b1) int64 0 1 2 3 4 5 6 7 8 9
        * c1       (c1) int64 0 1 2
        * c99      (c99) int64 0 1 2 3
      Data variables:
          a        (chain, draw) float64 0.4054 -0.836 -0.8187 ... 1.002 0.757 -0.3649
          b        (chain, draw, b1) float64 1.021 -1.392 0.1199 ... 0.06372 -0.5689
          c        (chain, draw, c1, c99) float64 1.321 0.3537 -0.708 ... 1.248 1.658
      Attributes:
          created_at:     2022-10-02T06:41:28.597411
          arviz_version:  0.13.0.dev0

In order to stack two dimensions c1 and c99 to z, we can use:

idata.stack(z=["c1", "c99"], inplace=True)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 100, b1: 10, z: 12)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
        * b1       (b1) int64 0 1 2 3 4 5 6 7 8 9
        * z        (z) object MultiIndex
        * c1       (z) int64 0 0 0 0 1 1 1 1 2 2 2 2
        * c99      (z) int64 0 1 2 3 0 1 2 3 0 1 2 3
      Data variables:
          a        (chain, draw) float64 0.4054 -0.836 -0.8187 ... 1.002 0.757 -0.3649
          b        (chain, draw, b1) float64 1.021 -1.392 0.1199 ... 0.06372 -0.5689
          c        (chain, draw, z) float64 1.321 0.3537 -0.708 ... 1.248 1.658
      Attributes:
          created_at:     2022-10-02T06:41:28.593808
          arviz_version:  0.13.0.dev0

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 100, b1: 10, z: 12)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
        * b1       (b1) int64 0 1 2 3 4 5 6 7 8 9
        * z        (z) object MultiIndex
        * c1       (z) int64 0 0 0 0 1 1 1 1 2 2 2 2
        * c99      (z) int64 0 1 2 3 0 1 2 3 0 1 2 3
      Data variables:
          a        (chain, draw) float64 0.4054 -0.836 -0.8187 ... 1.002 0.757 -0.3649
          b        (chain, draw, b1) float64 1.021 -1.392 0.1199 ... 0.06372 -0.5689
          c        (chain, draw, z) float64 1.321 0.3537 -0.708 ... 1.248 1.658
      Attributes:
          created_at:     2022-10-02T06:41:28.597411
          arviz_version:  0.13.0.dev0