arviz.InferenceData.map#

InferenceData.map(fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs)[source]#

Apply a function to multiple groups.

Applies fun groupwise to the selected InferenceData groups and overwrites the group with the result of the function.

Parameters
funcallable

Function to be applied to each group. Assumes the function is called as fun(dataset, *args, **kwargs).

groupsstr or list of str, optional

Groups where the selection is to be applied. Can either be group names or metagroup names.

filter_groups{None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter.

inplacebool, optional

If True, modify the InferenceData object inplace, otherwise, return the modified copy.

argsarray_like, optional

Positional arguments passed to fun.

**kwargsmapping, optional

Keyword arguments passed to fun.

Returns
InferenceData

A new InferenceData object by default. When inplace==True perform selection in place and return None

Examples

Shift observed_data, prior_predictive and posterior_predictive.

import arviz as az
import numpy as np
idata = az.load_arviz_data("non_centered_eight")
idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars")
idata_shifted_obs
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 ...
          theta_t  (chain, draw, school) float64 ...
          tau      (chain, draw) float64 ...
          theta    (chain, draw, school) float64 ...
      Attributes:
          created_at:                 2022-10-13T14:37:26.351883
          arviz_version:              0.13.0.dev0
          inference_library:          pymc
          inference_library_version:  4.2.2
          sampling_time:              4.738754749298096
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, obs_dim_0: 8)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 -8.912 0.4851 ... 13.55 27.95
      Attributes:
          created_at:                 2022-10-13T14:37:34.333731
          arviz_version:              0.13.0.dev0
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, obs_dim_0: 8)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 ...
      Attributes:
          created_at:                 2022-10-13T14:37:26.571887
          arviz_version:              0.13.0.dev0
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset>
      Dimensions:              (chain: 4, draw: 500)
      Coordinates:
        * chain                (chain) int64 0 1 2 3
        * draw                 (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
      Data variables: (12/16)
          lp                   (chain, draw) float64 ...
          largest_eigval       (chain, draw) float64 ...
          perf_counter_start   (chain, draw) float64 ...
          perf_counter_diff    (chain, draw) float64 ...
          step_size            (chain, draw) float64 ...
          diverging            (chain, draw) bool ...
          ...                   ...
          max_energy_error     (chain, draw) float64 ...
          n_steps              (chain, draw) float64 ...
          step_size_bar        (chain, draw) float64 ...
          energy_error         (chain, draw) float64 ...
          smallest_eigval      (chain, draw) float64 ...
          index_in_trajectory  (chain, draw) int64 ...
      Attributes:
          created_at:                 2022-10-13T14:37:26.362154
          arviz_version:              0.13.0.dev0
          inference_library:          pymc
          inference_library_version:  4.2.2
          sampling_time:              4.738754749298096
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 ...
          theta_t  (chain, draw, school) float64 ...
          theta    (chain, draw, school) float64 ...
          tau      (chain, draw) float64 ...
      Attributes:
          created_at:                 2022-10-13T14:37:18.108887
          arviz_version:              0.13.0.dev0
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 500, obs_dim_0: 8)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 25.87 -10.84 ... -4.658 17.32
      Attributes:
          created_at:                 2022-10-13T14:37:18.111951
          arviz_version:              0.13.0.dev0
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset>
      Dimensions:    (obs_dim_0: 8)
      Coordinates:
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          obs        (obs_dim_0) float64 31.0 11.0 0.0 10.0 2.0 4.0 21.0 15.0
      Attributes:
          created_at:                 2022-10-13T14:37:18.113060
          arviz_version:              0.13.0.dev0
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset>
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          scores   (school) float64 ...
      Attributes:
          created_at:                 2022-10-13T14:37:18.114126
          arviz_version:              0.13.0.dev0
          inference_library:          pymc
          inference_library_version:  4.2.2

Rename and update the coordinate values in both posterior and prior groups.

idata = az.load_arviz_data("radon")
idata = idata.map(
    lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign(
        uranium_coefs=["intercept", "u_slope"]
    ),
    groups=["posterior", "prior"]
)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:        (chain: 4, draw: 500, uranium_coefs: 2, County: 85)
      Coordinates:
        * chain          (chain) int64 0 1 2 3
        * draw           (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * uranium_coefs  (uranium_coefs) <U9 'intercept' 'u_slope'
        * County         (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
      Data variables:
          g              (chain, draw, uranium_coefs) float64 ...
          za_county      (chain, draw, County) float64 ...
          b              (chain, draw) float64 ...
          sigma_a        (chain, draw) float64 ...
          a              (chain, draw, County) float64 ...
          a_county       (chain, draw, County) float64 ...
          sigma          (chain, draw) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.191355
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2
          sampling_time:              18.096983432769775
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.449843
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.448264
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          step_size_bar     (chain, draw) float64 ...
          diverging         (chain, draw) bool ...
          energy            (chain, draw) float64 ...
          tree_size         (chain, draw) float64 ...
          mean_tree_accept  (chain, draw) float64 ...
          step_size         (chain, draw) float64 ...
          depth             (chain, draw) int64 ...
          energy_error      (chain, draw) float64 ...
          lp                (chain, draw) float64 ...
          max_energy_error  (chain, draw) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.197697
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2
          sampling_time:              18.096983432769775
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:        (chain: 1, draw: 500, County: 85, uranium_coefs: 2)
      Coordinates:
        * chain          (chain) int64 0
        * draw           (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * County         (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
        * uranium_coefs  (uranium_coefs) <U9 'intercept' 'u_slope'
      Data variables:
          a_county       (chain, draw, County) float64 ...
          sigma_log__    (chain, draw) float64 ...
          sigma_a        (chain, draw) float64 ...
          a              (chain, draw, County) float64 ...
          b              (chain, draw) float64 ...
          za_county      (chain, draw, County) float64 ...
          sigma          (chain, draw) float64 ...
          g              (chain, draw, uranium_coefs) float64 ...
          sigma_a_log__  (chain, draw) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.454586
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.457652
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:  (obs_id: 919)
      Coordinates:
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.458415
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:     (obs_id: 919, County: 85)
      Coordinates:
        * obs_id      (obs_id) int64 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
        * County      (County) object 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE'
      Data variables:
          floor_idx   (obs_id) int32 ...
          county_idx  (obs_id) int32 ...
          uranium     (County) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.459832
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

Add extra coordinates to all groups containing observed variables

idata = az.load_arviz_data("rugby")
home_team, away_team = np.array([
    m.split() for m in idata.observed_data.match.values
]).T
idata = idata.map(
    lambda ds, **kwargs: ds.assign_coords(**kwargs),
    groups="observed_vars",
    home_team=("match", home_team),
    away_team=("match", away_team),
)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 ...
          intercept  (chain, draw) float64 ...
          atts_star  (chain, draw, team) float64 ...
          defs_star  (chain, draw, team) float64 ...
          sd_att     (chain, draw) float64 ...
          sd_def     (chain, draw) float64 ...
          atts       (chain, draw, team) float64 ...
          defs       (chain, draw, team) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) object 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland'
          away_team    (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England'
      Data variables:
          home_points  (chain, draw, match) int64 ...
          away_points  (chain, draw, match) int64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          energy_error      (chain, draw) float64 ...
          energy            (chain, draw) float64 ...
          tree_size         (chain, draw) float64 ...
          tune              (chain, draw) bool ...
          mean_tree_accept  (chain, draw) float64 ...
          lp                (chain, draw) float64 ...
          depth             (chain, draw) int64 ...
          max_energy_error  (chain, draw) float64 ...
          step_size         (chain, draw) float64 ...
          step_size_bar     (chain, draw) float64 ...
          diverging         (chain, draw) bool ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (chain: 1, draw: 500, team: 6, match: 60)
      Coordinates:
        * chain         (chain) int64 0
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          sd_att_log__  (chain, draw) float64 ...
          intercept     (chain, draw) float64 ...
          atts_star     (chain, draw, team) float64 ...
          defs_star     (chain, draw, team) float64 ...
          away_points   (chain, draw, match) int64 ...
          sd_att        (chain, draw) float64 ...
          sd_def_log__  (chain, draw) float64 ...
          home          (chain, draw) float64 ...
          atts          (chain, draw, team) float64 ...
          sd_def        (chain, draw) float64 ...
          home_points   (chain, draw, match) int64 ...
          defs          (chain, draw, team) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland'
          away_team    (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England'
      Data variables:
          home_points  (match) float64 ...
          away_points  (match) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7