arviz.InferenceData.map#

InferenceData.map(fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs)[source]#

Apply a function to multiple groups.

Applies fun groupwise to the selected InferenceData groups and overwrites the group with the result of the function.

Parameters
funcallable

Function to be applied to each group. Assumes the function is called as fun(dataset, *args, **kwargs).

groupsstr or list of str, optional

Groups where the selection is to be applied. Can either be group names or metagroup names.

filter_groups{None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter.

inplacebool, optional

If True, modify the InferenceData object inplace, otherwise, return the modified copy.

argsarray_like, optional

Positional arguments passed to fun.

**kwargsmapping, optional

Keyword arguments passed to fun.

Returns
InferenceData

A new InferenceData object by default. When inplace==True perform selection in place and return None

Examples

Shift observed_data, prior_predictive and posterior_predictive.

import arviz as az
import numpy as np
idata = az.load_arviz_data("non_centered_eight")
idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars")
idata_shifted_obs
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 6.424 6.275 2.003 2.664 ... 5.339 5.551 6.124
          theta_t  (chain, draw, school) float64 -0.000439 0.06986 ... -0.8657 2.295
          tau      (chain, draw) float64 2.094 3.41 7.218 0.7277 ... 1.71 2.084 2.642
          theta    (chain, draw, school) float64 6.424 6.571 4.541 ... 3.837 12.19
      Attributes:
          created_at:                 2019-06-21T17:36:37.382566
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (chain, draw, school) float64 12.42 1.107 11.0 ... 5.087 31.48
      Attributes:
          created_at:                 2019-06-21T17:36:37.487547
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * school            (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
      Data variables:
          tune              (chain, draw) bool True False False ... False False False
          depth             (chain, draw) int64 3 3 3 3 3 3 3 3 3 ... 3 3 3 3 3 3 3 3
          tree_size         (chain, draw) float64 7.0 7.0 7.0 7.0 ... 7.0 7.0 7.0 7.0
          lp                (chain, draw) float64 -43.18 -45.53 ... -45.33 -46.47
          energy_error      (chain, draw) float64 0.009026 0.3015 ... 0.02968
          step_size_bar     (chain, draw) float64 0.543 0.543 0.543 ... 0.5398 0.5398
          max_energy_error  (chain, draw) float64 0.1725 0.3265 3.684 ... 1.004 0.4816
          energy            (chain, draw) float64 46.75 48.95 56.69 ... 52.13 51.57
          mean_tree_accept  (chain, draw) float64 0.9032 0.8249 ... 0.7628 0.8899
          step_size         (chain, draw) float64 0.5235 0.5235 ... 0.4142 0.4142
          diverging         (chain, draw) bool False False False ... False False False
          log_likelihood    (chain, draw, school) float64 -4.662 -3.232 ... -3.809
      Attributes:
          created_at:                 2019-06-21T17:36:37.484480
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 500, school: 8)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * school     (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
      Data variables:
          obs        (chain, draw, school) float64 -7.299 -4.477 ... -7.388 -0.3322
          theta_t    (chain, draw, school) float64 0.5478 0.4233 ... -0.7948 -0.5729
          tau        (chain, draw) float64 0.009301 2.146 33.01 ... 18.43 2.301 1.718
          tau_log__  (chain, draw) float64 -4.678 0.7637 3.497 ... 2.914 0.8335 0.5413
          theta      (chain, draw, school) float64 -3.626 -3.627 ... -4.299 -3.918
          mu         (chain, draw) float64 -3.631 -1.434 -2.172 ... 0.1286 -2.934
      Attributes:
          created_at:                 2019-06-21T17:36:37.489185
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (school) float64 31.0 11.0 0.0 10.0 2.0 4.0 21.0 15.0
      Attributes:
          created_at:                 2019-06-21T17:36:37.491073
          inference_library:          pymc3
          inference_library_version:  3.7

Rename and update the coordinate values in both posterior and prior groups.

idata = az.load_arviz_data("radon")
idata = idata.map(
    lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign(
        uranium_coefs=["intercept", "u_slope"]
    ),
    groups=["posterior", "prior"]
)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:        (chain: 4, draw: 500, uranium_coefs: 2, County: 85)
      Coordinates:
        * chain          (chain) int64 0 1 2 3
        * draw           (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * uranium_coefs  (uranium_coefs) <U9 'intercept' 'u_slope'
        * County         (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
      Data variables:
          g              (chain, draw, uranium_coefs) float64 1.499 0.5436 ... 0.6087
          za_county      (chain, draw, County) float64 ...
          b              (chain, draw) float64 -0.5965 -0.5848 ... -0.7949 -0.5402
          sigma_a        (chain, draw) float64 0.2061 0.1526 0.2201 ... 0.1764 0.1121
          a              (chain, draw, County) float64 ...
          a_county       (chain, draw, County) float64 ...
          sigma          (chain, draw) float64 0.7271 0.7161 0.7189 ... 0.7267 0.7324
      Attributes:
          created_at:                 2020-07-24T18:15:12.191355
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2
          sampling_time:              18.096983432769775
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.449843
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.448264
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          step_size_bar     (chain, draw) float64 0.1102 0.1102 ... 0.1193 0.1193
          diverging         (chain, draw) bool False False False ... False False False
          energy            (chain, draw) float64 1.186e+03 1.184e+03 ... 1.189e+03
          tree_size         (chain, draw) float64 31.0 31.0 31.0 ... 31.0 31.0 31.0
          mean_tree_accept  (chain, draw) float64 0.9877 0.9969 1.0 ... 0.9996 0.9872
          step_size         (chain, draw) float64 0.1006 0.1006 0.1006 ... 0.117 0.117
          depth             (chain, draw) int64 5 5 5 5 5 5 5 5 5 ... 5 5 5 5 5 5 5 5
          energy_error      (chain, draw) float64 0.0289 -0.02917 ... 0.04322
          lp                (chain, draw) float64 -1.145e+03 -1.133e+03 ... -1.152e+03
          max_energy_error  (chain, draw) float64 -0.05888 -0.07148 ... 0.04451
      Attributes:
          created_at:                 2020-07-24T18:15:12.197697
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2
          sampling_time:              18.096983432769775
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:        (chain: 1, draw: 500, County: 85, uranium_coefs: 2)
      Coordinates:
        * chain          (chain) int64 0
        * draw           (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * County         (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
        * uranium_coefs  (uranium_coefs) <U9 'intercept' 'u_slope'
      Data variables:
          a_county       (chain, draw, County) float64 8.536 5.285 ... 1.061 6.237
          sigma_log__    (chain, draw) float64 -0.9709 -1.344 ... -0.8501 -2.235
          sigma_a        (chain, draw) float64 2.991 2.237 0.3554 ... 0.4417 0.1448
          a              (chain, draw, County) float64 6.139 4.951 ... 0.9329 6.152
          b              (chain, draw) float64 0.5647 1.237 -0.6316 ... -0.01409 0.197
          za_county      (chain, draw, County) float64 0.8012 0.1114 ... 0.8862 0.5826
          sigma          (chain, draw) float64 0.3787 0.2608 0.6214 ... 0.4274 0.107
          g              (chain, draw, uranium_coefs) float64 11.31 7.504 ... 11.72
          sigma_a_log__  (chain, draw) float64 1.096 0.805 -1.035 ... -0.8171 -1.932
      Attributes:
          created_at:                 2020-07-24T18:15:12.454586
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.457652
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:  (obs_id: 919)
      Coordinates:
        * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918
      Data variables:
          y        (obs_id) float64 0.8329 0.8329 1.099 0.09531 ... 1.629 1.335 1.099
      Attributes:
          created_at:                 2020-07-24T18:15:12.458415
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset>
      Dimensions:     (obs_id: 919, County: 85)
      Coordinates:
        * obs_id      (obs_id) int64 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
        * County      (County) object 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE'
      Data variables:
          floor_idx   (obs_id) int32 1 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 1 0 0 0 0 0 0 0
          county_idx  (obs_id) int32 0 0 0 0 1 1 1 1 1 ... 83 83 83 83 83 83 83 84 84
          uranium     (County) float64 -0.689 -0.8473 -0.1135 ... -0.09002 0.3553
      Attributes:
          created_at:                 2020-07-24T18:15:12.459832
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

Add extra coordinates to all groups containing observed variables

idata = az.load_arviz_data("rugby")
home_team, away_team = np.array([
    m.split() for m in idata.observed_data.match.values
]).T
idata = idata.map(
    lambda ds, **kwargs: ds.assign_coords(**kwargs),
    groups="observed_vars",
    home_team=("match", home_team),
    away_team=("match", away_team),
)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265
          intercept  (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892
          atts_star  (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878
          defs_star  (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649
          sd_att     (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591
          sd_def     (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849
          atts       (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029
          defs       (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) object 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland'
          away_team    (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England'
      Data variables:
          home_points  (chain, draw, match) int64 ...
          away_points  (chain, draw, match) int64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          energy_error      (chain, draw) float64 -0.07666 -0.4523 ... 0.115 -0.07691
          energy            (chain, draw) float64 540.2 545.3 542.3 ... 544.0 545.6
          tree_size         (chain, draw) float64 15.0 63.0 31.0 ... 63.0 31.0 31.0
          tune              (chain, draw) bool True False False ... False False False
          mean_tree_accept  (chain, draw) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539
          lp                (chain, draw) float64 -536.4 -536.0 ... -536.1 -536.4
          depth             (chain, draw) int64 4 6 5 4 4 4 5 5 5 ... 6 4 6 5 3 6 5 5
          max_energy_error  (chain, draw) float64 -0.5361 -0.5871 ... 0.7109 1.014
          step_size         (chain, draw) float64 0.2469 0.2469 ... 0.2459 0.2459
          step_size_bar     (chain, draw) float64 0.2313 0.2313 ... 0.2488 0.2488
          diverging         (chain, draw) bool False False False ... False False False
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (chain: 1, draw: 500, team: 6, match: 60)
      Coordinates:
        * chain         (chain) int64 0
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          sd_att_log__  (chain, draw) float64 1.322 -2.014 1.588 ... -0.8585 -0.1922
          intercept     (chain, draw) float64 4.464 3.352 1.567 ... 4.363 4.128 1.049
          atts_star     (chain, draw, team) float64 -2.64 4.172 ... -0.2874 -0.8538
          defs_star     (chain, draw, team) float64 -0.7817 -0.1478 ... 0.1655 0.01067
          away_points   (chain, draw, match) int64 11308 0 11 1 0 21442 ... 11 1 2 2 0
          sd_att        (chain, draw) float64 3.752 0.1334 4.896 ... 0.4238 0.8251
          sd_def_log__  (chain, draw) float64 -0.2662 0.2411 0.6071 ... 1.402 -1.981
          home          (chain, draw) float64 -1.511 -0.001582 ... -0.02416 0.2651
          atts          (chain, draw, team) float64 -4.667 2.145 ... -0.2702 -0.8365
          sd_def        (chain, draw) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138
          home_points   (chain, draw, match) int64 0 47 11899 3262 1 ... 3 2 1 12 13
          defs          (chain, draw, team) float64 -0.2517 0.3823 ... 0.089 -0.06586
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland'
          away_team    (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England'
      Data variables:
          home_points  (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0
          away_points  (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7