arviz.InferenceData.map#

InferenceData.map(fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs)[source]#

Apply a function to multiple groups.

Applies fun groupwise to the selected InferenceData groups and overwrites the group with the result of the function.

Parameters:
funcallable()

Function to be applied to each group. Assumes the function is called as fun(dataset, *args, **kwargs).

groupsstr or list of str, optional

Groups where the selection is to be applied. Can either be group names or metagroup names.

filter_groups{None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter.

inplacebool, optional

If True, modify the InferenceData object inplace, otherwise, return the modified copy.

argsarray_like, optional

Positional arguments passed to fun.

**kwargsmapping, optional

Keyword arguments passed to fun.

Returns:
InferenceData

A new InferenceData object by default. When inplace==True perform selection in place and return None

Examples

Shift observed_data, prior_predictive and posterior_predictive.

import arviz as az
import numpy as np
idata = az.load_arviz_data("non_centered_eight")
idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars")
idata_shifted_obs
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 ...
          theta_t  (chain, draw, school) float64 ...
          tau      (chain, draw) float64 ...
          theta    (chain, draw, school) float64 ...
      Attributes:
          created_at:                 2022-10-13T14:37:26.351883
          arviz_version:              0.13.0.dev0
          inference_library:          pymc
          inference_library_version:  4.2.2
          sampling_time:              4.738754749298096
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, obs_dim_0: 8)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 -8.912 0.4851 ... 13.55 27.95
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:34.333731
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, obs_dim_0: 8)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 ...
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:26.571887
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset>
      Dimensions:              (chain: 4, draw: 500)
      Coordinates:
        * chain                (chain) int64 0 1 2 3
        * draw                 (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
      Data variables: (12/16)
          lp                   (chain, draw) float64 ...
          largest_eigval       (chain, draw) float64 ...
          perf_counter_start   (chain, draw) float64 ...
          perf_counter_diff    (chain, draw) float64 ...
          step_size            (chain, draw) float64 ...
          diverging            (chain, draw) bool ...
          ...                   ...
          max_energy_error     (chain, draw) float64 ...
          n_steps              (chain, draw) float64 ...
          step_size_bar        (chain, draw) float64 ...
          energy_error         (chain, draw) float64 ...
          smallest_eigval      (chain, draw) float64 ...
          index_in_trajectory  (chain, draw) int64 ...
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:26.362154
          inference_library:          pymc
          inference_library_version:  4.2.2
          sampling_time:              4.738754749298096
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 ...
          theta_t  (chain, draw, school) float64 ...
          theta    (chain, draw, school) float64 ...
          tau      (chain, draw) float64 ...
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:18.108887
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 500, obs_dim_0: 8)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 25.87 -10.84 ... -4.658 17.32
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:18.111951
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset>
      Dimensions:    (obs_dim_0: 8)
      Coordinates:
        * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          obs        (obs_dim_0) float64 31.0 11.0 0.0 10.0 2.0 4.0 21.0 15.0
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:18.113060
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset>
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          scores   (school) float64 ...
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:18.114126
          inference_library:          pymc
          inference_library_version:  4.2.2

Rename and update the coordinate values in both posterior and prior groups.

idata = az.load_arviz_data("radon")
idata = idata.map(
    lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign(
        uranium_coefs=["intercept", "u_slope"]
    ),
    groups=["posterior", "prior"]
)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:        (chain: 4, draw: 500, uranium_coefs: 2, County: 85)
      Coordinates:
        * chain          (chain) int64 0 1 2 3
        * draw           (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * uranium_coefs  (uranium_coefs) <U9 'intercept' 'u_slope'
        * County         (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
      Data variables:
          g              (chain, draw, uranium_coefs) float64 ...
          za_county      (chain, draw, County) float64 ...
          b              (chain, draw) float64 ...
          sigma_a        (chain, draw) float64 ...
          a              (chain, draw, County) float64 ...
          a_county       (chain, draw, County) float64 ...
          sigma          (chain, draw) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.191355
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2
          sampling_time:              18.096983432769775
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.449843
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.448264
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          step_size_bar     (chain, draw) float64 ...
          diverging         (chain, draw) bool ...
          energy            (chain, draw) float64 ...
          tree_size         (chain, draw) float64 ...
          mean_tree_accept  (chain, draw) float64 ...
          step_size         (chain, draw) float64 ...
          depth             (chain, draw) int64 ...
          energy_error      (chain, draw) float64 ...
          lp                (chain, draw) float64 ...
          max_energy_error  (chain, draw) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.197697
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2
          sampling_time:              18.096983432769775
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:        (chain: 1, draw: 500, County: 85, uranium_coefs: 2)
      Coordinates:
        * chain          (chain) int64 0
        * draw           (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * County         (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
        * uranium_coefs  (uranium_coefs) <U9 'intercept' 'u_slope'
      Data variables:
          a_county       (chain, draw, County) float64 ...
          sigma_log__    (chain, draw) float64 ...
          sigma_a        (chain, draw) float64 ...
          a              (chain, draw, County) float64 ...
          b              (chain, draw) float64 ...
          za_county      (chain, draw, County) float64 ...
          sigma          (chain, draw) float64 ...
          g              (chain, draw, uranium_coefs) float64 ...
          sigma_a_log__  (chain, draw) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.454586
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.457652
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:  (obs_id: 919)
      Coordinates:
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.458415
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:     (obs_id: 919, County: 85)
      Coordinates:
        * obs_id      (obs_id) int64 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
        * County      (County) object 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE'
      Data variables:
          floor_idx   (obs_id) int32 ...
          county_idx  (obs_id) int32 ...
          uranium     (County) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.459832
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

Add extra coordinates to all groups containing observed variables

idata = az.load_arviz_data("rugby")
home_team, away_team = np.array([
    m.split() for m in idata.observed_data.match.values
]).T
idata = idata.map(
    lambda ds, **kwargs: ds.assign_coords(**kwargs),
    groups="observed_vars",
    home_team=("match", home_team),
    away_team=("match", away_team),
)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 ...
          intercept  (chain, draw) float64 ...
          atts_star  (chain, draw, team) float64 ...
          defs_star  (chain, draw, team) float64 ...
          sd_att     (chain, draw) float64 ...
          sd_def     (chain, draw) float64 ...
          atts       (chain, draw, team) float64 ...
          defs       (chain, draw, team) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) object 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland'
          away_team    (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England'
      Data variables:
          home_points  (chain, draw, match) int64 ...
          away_points  (chain, draw, match) int64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          energy_error      (chain, draw) float64 ...
          energy            (chain, draw) float64 ...
          tree_size         (chain, draw) float64 ...
          tune              (chain, draw) bool ...
          mean_tree_accept  (chain, draw) float64 ...
          lp                (chain, draw) float64 ...
          depth             (chain, draw) int64 ...
          max_energy_error  (chain, draw) float64 ...
          step_size         (chain, draw) float64 ...
          step_size_bar     (chain, draw) float64 ...
          diverging         (chain, draw) bool ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (chain: 1, draw: 500, team: 6, match: 60)
      Coordinates:
        * chain         (chain) int64 0
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          sd_att_log__  (chain, draw) float64 ...
          intercept     (chain, draw) float64 ...
          atts_star     (chain, draw, team) float64 ...
          defs_star     (chain, draw, team) float64 ...
          away_points   (chain, draw, match) int64 ...
          sd_att        (chain, draw) float64 ...
          sd_def_log__  (chain, draw) float64 ...
          home          (chain, draw) float64 ...
          atts          (chain, draw, team) float64 ...
          sd_def        (chain, draw) float64 ...
          home_points   (chain, draw, match) int64 ...
          defs          (chain, draw, team) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland'
          away_team    (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England'
      Data variables:
          home_points  (match) float64 ...
          away_points  (match) float64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7