arviz.InferenceData.map#

InferenceData.map(fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs)[source]#

Apply a function to multiple groups.

Applies fun groupwise to the selected InferenceData groups and overwrites the group with the result of the function.

Parameters:
funcallable()

Function to be applied to each group. Assumes the function is called as fun(dataset, *args, **kwargs).

groupsstr or list of str, optional

Groups where the selection is to be applied. Can either be group names or metagroup names.

filter_groups{None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter.

inplacebool, optional

If True, modify the InferenceData object inplace, otherwise, return the modified copy.

argsarray_like, optional

Positional arguments passed to fun.

**kwargsmapping, optional

Keyword arguments passed to fun.

Returns:
InferenceData

A new InferenceData object by default. When inplace==True perform selection in place and return None

Examples

Shift observed_data, prior_predictive and posterior_predictive.

import arviz as az
import numpy as np
idata = az.load_arviz_data("non_centered_eight")
idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars")
idata_shifted_obs
arviz.InferenceData
    • <xarray.Dataset> Size: 293kB
      Dimensions:  (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 16kB ...
          theta_t  (chain, draw, school) float64 128kB ...
          tau      (chain, draw) float64 16kB ...
          theta    (chain, draw, school) float64 128kB ...
      Attributes:
          created_at:                 2022-10-13T14:37:26.351883
          arviz_version:              0.13.0.dev0
          inference_library:          pymc
          inference_library_version:  4.2.2
          sampling_time:              4.738754749298096
          tuning_steps:               1000

    • <xarray.Dataset> Size: 132kB
      Dimensions:    (chain: 4, draw: 500, obs_dim_0: 8)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * obs_dim_0  (obs_dim_0) int64 64B 0 1 2 3 4 5 6 7
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 128kB -8.912 0.4851 ... 27.95
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:34.333731
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset> Size: 132kB
      Dimensions:    (chain: 4, draw: 500, obs_dim_0: 8)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * obs_dim_0  (obs_dim_0) int64 64B 0 1 2 3 4 5 6 7
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 128kB ...
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:26.571887
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset> Size: 246kB
      Dimensions:              (chain: 4, draw: 500)
      Coordinates:
        * chain                (chain) int64 32B 0 1 2 3
        * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
      Data variables: (12/16)
          lp                   (chain, draw) float64 16kB ...
          largest_eigval       (chain, draw) float64 16kB ...
          perf_counter_start   (chain, draw) float64 16kB ...
          perf_counter_diff    (chain, draw) float64 16kB ...
          step_size            (chain, draw) float64 16kB ...
          diverging            (chain, draw) bool 2kB ...
          ...                   ...
          max_energy_error     (chain, draw) float64 16kB ...
          n_steps              (chain, draw) float64 16kB ...
          step_size_bar        (chain, draw) float64 16kB ...
          energy_error         (chain, draw) float64 16kB ...
          smallest_eigval      (chain, draw) float64 16kB ...
          index_in_trajectory  (chain, draw) int64 16kB ...
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:26.362154
          inference_library:          pymc
          inference_library_version:  4.2.2
          sampling_time:              4.738754749298096
          tuning_steps:               1000

    • <xarray.Dataset> Size: 77kB
      Dimensions:  (chain: 1, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 8B 0
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 4kB ...
          theta_t  (chain, draw, school) float64 32kB ...
          theta    (chain, draw, school) float64 32kB ...
          tau      (chain, draw) float64 4kB ...
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:18.108887
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset> Size: 36kB
      Dimensions:    (chain: 1, draw: 500, obs_dim_0: 8)
      Coordinates:
        * chain      (chain) int64 8B 0
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * obs_dim_0  (obs_dim_0) int64 64B 0 1 2 3 4 5 6 7
      Data variables:
          obs        (chain, draw, obs_dim_0) float64 32kB 25.87 -10.84 ... 17.32
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:18.111951
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset> Size: 128B
      Dimensions:    (obs_dim_0: 8)
      Coordinates:
        * obs_dim_0  (obs_dim_0) int64 64B 0 1 2 3 4 5 6 7
      Data variables:
          obs        (obs_dim_0) float64 64B 31.0 11.0 0.0 10.0 2.0 4.0 21.0 15.0
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:18.113060
          inference_library:          pymc
          inference_library_version:  4.2.2

    • <xarray.Dataset> Size: 576B
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
      Data variables:
          scores   (school) float64 64B ...
      Attributes:
          arviz_version:              0.13.0.dev0
          created_at:                 2022-10-13T14:37:18.114126
          inference_library:          pymc
          inference_library_version:  4.2.2

Rename and update the coordinate values in both posterior and prior groups.

idata = az.load_arviz_data("radon")
idata = idata.map(
    lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign(
        uranium_coefs=["intercept", "u_slope"]
    ),
    groups=["posterior", "prior"]
)
idata
arviz.InferenceData
    • <xarray.Dataset> Size: 4MB
      Dimensions:        (chain: 4, draw: 500, uranium_coefs: 2, County: 85)
      Coordinates:
        * chain          (chain) int64 32B 0 1 2 3
        * draw           (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
        * County         (County) <U17 6kB 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
        * uranium_coefs  (uranium_coefs) <U9 72B 'intercept' 'u_slope'
      Data variables:
          g              (chain, draw, uranium_coefs) float64 32kB ...
          za_county      (chain, draw, County) float64 1MB ...
          b              (chain, draw) float64 16kB ...
          sigma_a        (chain, draw) float64 16kB ...
          a              (chain, draw, County) float64 1MB ...
          a_county       (chain, draw, County) float64 1MB ...
          sigma          (chain, draw) float64 16kB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.191355
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2
          sampling_time:              18.096983432769775
          tuning_steps:               1000

    • <xarray.Dataset> Size: 15MB
      Dimensions:  (chain: 4, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 15MB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.449843
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset> Size: 15MB
      Dimensions:  (chain: 4, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 15MB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.448264
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset> Size: 150kB
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
      Data variables:
          step_size_bar     (chain, draw) float64 16kB ...
          diverging         (chain, draw) bool 2kB ...
          energy            (chain, draw) float64 16kB ...
          tree_size         (chain, draw) float64 16kB ...
          mean_tree_accept  (chain, draw) float64 16kB ...
          step_size         (chain, draw) float64 16kB ...
          depth             (chain, draw) int64 16kB ...
          energy_error      (chain, draw) float64 16kB ...
          lp                (chain, draw) float64 16kB ...
          max_energy_error  (chain, draw) float64 16kB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.197697
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2
          sampling_time:              18.096983432769775
          tuning_steps:               1000

    • <xarray.Dataset> Size: 1MB
      Dimensions:        (chain: 1, draw: 500, County: 85, uranium_coefs: 2)
      Coordinates:
        * chain          (chain) int64 8B 0
        * draw           (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499
        * County         (County) <U17 6kB 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
        * uranium_coefs  (uranium_coefs) <U9 72B 'intercept' 'u_slope'
      Data variables:
          a_county       (chain, draw, County) float64 340kB ...
          sigma_log__    (chain, draw) float64 4kB ...
          sigma_a        (chain, draw) float64 4kB ...
          a              (chain, draw, County) float64 340kB ...
          b              (chain, draw) float64 4kB ...
          za_county      (chain, draw, County) float64 340kB ...
          sigma          (chain, draw) float64 4kB ...
          g              (chain, draw, uranium_coefs) float64 8kB ...
          sigma_a_log__  (chain, draw) float64 4kB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.454586
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset> Size: 4MB
      Dimensions:  (chain: 1, draw: 500, obs_id: 919)
      Coordinates:
        * chain    (chain) int64 8B 0
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
      Data variables:
          y        (chain, draw, obs_id) float64 4MB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.457652
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset> Size: 15kB
      Dimensions:  (obs_id: 919)
      Coordinates:
        * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
      Data variables:
          y        (obs_id) float64 7kB ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.458415
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

    • <xarray.Dataset> Size: 21kB
      Dimensions:     (obs_id: 919, County: 85)
      Coordinates:
        * obs_id      (obs_id) int64 7kB 0 1 2 3 4 5 6 ... 912 913 914 915 916 917 918
        * County      (County) <U17 6kB 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
      Data variables:
          floor_idx   (obs_id) int32 4kB ...
          county_idx  (obs_id) int32 4kB ...
          uranium     (County) float64 680B ...
      Attributes:
          created_at:                 2020-07-24T18:15:12.459832
          arviz_version:              0.9.0
          inference_library:          pymc3
          inference_library_version:  3.9.2

Add extra coordinates to all groups containing observed variables

idata = az.load_arviz_data("rugby")
home_team, away_team = np.array([
    m.split() for m in idata.observed_data.match.values
]).T
idata = idata.map(
    lambda ds, **kwargs: ds.assign_coords(**kwargs),
    groups="observed_vars",
    home_team=("match", home_team),
    away_team=("match", away_team),
)
idata
arviz.InferenceData
    • <xarray.Dataset> Size: 452kB
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 16kB ...
          intercept  (chain, draw) float64 16kB ...
          atts_star  (chain, draw, team) float64 96kB ...
          defs_star  (chain, draw, team) float64 96kB ...
          sd_att     (chain, draw) float64 16kB ...
          sd_def     (chain, draw) float64 16kB ...
          atts       (chain, draw, team) float64 96kB ...
          defs       (chain, draw, team) float64 96kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:23.841916
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9
          sampling_time:              8.503105401992798
          tuning_steps:               1000

    • <xarray.Dataset> Size: 2MB
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB 'Wales' 'France' ... 'France' 'Ireland'
          away_team    (match) <U8 2kB 'Italy' 'England' ... 'Wales' 'England'
      Data variables:
          home_points  (chain, draw, match) int64 960kB ...
          away_points  (chain, draw, match) int64 960kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:25.689246
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 2MB
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB ...
          away_team    (match) <U8 2kB ...
      Data variables:
          home_points  (chain, draw, match) float64 960kB ...
          away_points  (chain, draw, match) float64 960kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:24.120642
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 260kB
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 32B 0 1 2 3
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 16kB ...
          sd_att     (chain, draw) float64 16kB ...
          sd_def     (chain, draw) float64 16kB ...
          intercept  (chain, draw) float64 16kB ...
          atts_star  (chain, draw, team) float64 96kB ...
          defs_star  (chain, draw, team) float64 96kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:24.377610
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 248kB
      Dimensions:                (chain: 4, draw: 500)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
      Data variables: (12/17)
          max_energy_error       (chain, draw) float64 16kB ...
          index_in_trajectory    (chain, draw) int64 16kB ...
          smallest_eigval        (chain, draw) float64 16kB ...
          perf_counter_start     (chain, draw) float64 16kB ...
          largest_eigval         (chain, draw) float64 16kB ...
          step_size              (chain, draw) float64 16kB ...
          ...                     ...
          reached_max_treedepth  (chain, draw) bool 2kB ...
          perf_counter_diff      (chain, draw) float64 16kB ...
          tree_depth             (chain, draw) int64 16kB ...
          process_time_diff      (chain, draw) float64 16kB ...
          step_size_bar          (chain, draw) float64 16kB ...
          energy                 (chain, draw) float64 16kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:23.854033
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9
          sampling_time:              8.503105401992798
          tuning_steps:               1000

    • <xarray.Dataset> Size: 116kB
      Dimensions:    (chain: 1, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 8B 0
        * draw       (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team       (team) <U8 192B 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          atts_star  (chain, draw, team) float64 24kB ...
          sd_att     (chain, draw) float64 4kB ...
          atts       (chain, draw, team) float64 24kB ...
          sd_def     (chain, draw) float64 4kB ...
          defs       (chain, draw, team) float64 24kB ...
          intercept  (chain, draw) float64 4kB ...
          home       (chain, draw) float64 4kB ...
          defs_star  (chain, draw, team) float64 24kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.475945
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 492kB
      Dimensions:      (chain: 1, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 8B 0
        * draw         (draw) int64 4kB 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB 'Wales' 'France' ... 'France' 'Ireland'
          away_team    (match) <U8 2kB 'Italy' 'England' ... 'Wales' 'England'
      Data variables:
          away_points  (chain, draw, match) int64 240kB ...
          home_points  (chain, draw, match) int64 240kB ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.479330
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 9kB
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) <U16 4kB 'Wales Italy' ... 'Ireland England'
          home_team    (match) <U8 2kB 'Wales' 'France' ... 'France' 'Ireland'
          away_team    (match) <U8 2kB 'Italy' 'England' ... 'Wales' 'England'
      Data variables:
          home_points  (match) int64 480B ...
          away_points  (match) int64 480B ...
      Attributes:
          created_at:                 2024-03-06T20:46:09.480812
          arviz_version:              0.17.0
          inference_library:          pymc
          inference_library_version:  5.10.4+7.g34d2a5d9

    • <xarray.Dataset> Size: 36kB
      Dimensions:  (chain: 4, draw: 500)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          sd_att   (chain, draw) float64 16kB ...
          sd_def   (chain, draw) float64 16kB ...
      Attributes:
          sd_att:   pymc.logprob.transforms.LogTransform
          sd_def:   pymc.logprob.transforms.LogTransform