arviz.InferenceData.map#

InferenceData.map(fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs)[source]#

Apply a function to multiple groups.

Applies fun groupwise to the selected InferenceData groups and overwrites the group with the result of the function.

Parameters
funcallable

Function to be applied to each group. Assumes the function is called as fun(dataset, *args, **kwargs).

groupsstr or list of str, optional

Groups where the selection is to be applied. Can either be group names or metagroup names.

filter_groups{None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter.

inplacebool, optional

If True, modify the InferenceData object inplace, otherwise, return the modified copy.

argsarray_like, optional

Positional arguments passed to fun.

**kwargsmapping, optional

Keyword arguments passed to fun.

Returns
InferenceData

A new InferenceData object by default. When inplace==True perform selection in place and return None

Examples

Shift observed_data, prior_predictive and posterior_predictive.

In [1]: import arviz as az
   ...: idata = az.load_arviz_data("non_centered_eight")
   ...: idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars")
   ...: print(idata_shifted_obs.observed_data)
   ...: print(idata_shifted_obs.posterior_predictive)
   ...: 
<xarray.Dataset>
Dimensions:  (school: 8)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    obs      (school) float64 31.0 11.0 0.0 10.0 2.0 4.0 21.0 15.0
Attributes:
    created_at:                 2019-06-21T17:36:37.491073
    inference_library:          pymc3
    inference_library_version:  3.7
<xarray.Dataset>
Dimensions:  (chain: 4, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    obs      (chain, draw, school) float64 12.42 1.107 11.0 ... 5.087 31.48
Attributes:
    created_at:                 2019-06-21T17:36:37.487547
    inference_library:          pymc3
    inference_library_version:  3.7

Rename and update the coordinate values in both posterior and prior groups.

In [2]: idata = az.load_arviz_data("radon")
   ...: idata = idata.map(
   ...:     lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign(
   ...:         uranium_coefs=["intercept", "u_slope"]
   ...:     ),
   ...:     groups=["posterior", "prior"]
   ...: )
   ...: idata.posterior
   ...: 
Out[2]: 
<xarray.Dataset>
Dimensions:        (chain: 4, draw: 500, uranium_coefs: 2, County: 85)
Coordinates:
  * chain          (chain) int64 0 1 2 3
  * draw           (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
  * uranium_coefs  (uranium_coefs) <U9 'intercept' 'u_slope'
  * County         (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE'
Data variables:
    g              (chain, draw, uranium_coefs) float64 1.499 0.5436 ... 0.6087
    za_county      (chain, draw, County) float64 -0.1139 0.1347 ... 1.025 0.3687
    b              (chain, draw) float64 -0.5965 -0.5848 ... -0.7949 -0.5402
    sigma_a        (chain, draw) float64 0.2061 0.1526 0.2201 ... 0.1764 0.1121
    a              (chain, draw, County) float64 1.125 1.039 ... 1.452 1.723
    a_county       (chain, draw, County) float64 1.101 1.066 ... 1.567 1.764
    sigma          (chain, draw) float64 0.7271 0.7161 0.7189 ... 0.7267 0.7324
Attributes:
    created_at:                 2020-07-24T18:15:12.191355
    arviz_version:              0.9.0
    inference_library:          pymc3
    inference_library_version:  3.9.2
    sampling_time:              18.096983432769775
    tuning_steps:               1000

Add extra coordinates to all groups containing observed variables

In [3]: idata = az.load_arviz_data("rugby")
   ...: home_team, away_team = np.array([
   ...:     m.split() for m in idata.observed_data.match.values
   ...: ]).T
   ...: idata = idata.map(
   ...:     lambda ds, **kwargs: ds.assign_coords(**kwargs),
   ...:     groups="observed_vars",
   ...:     home_team=("match", home_team),
   ...:     away_team=("match", away_team),
   ...: )
   ...: print(idata.posterior_predictive)
   ...: print(idata.observed_data)
   ...: 
<xarray.Dataset>
Dimensions:      (chain: 4, draw: 500, match: 60)
Coordinates:
  * chain        (chain) int64 0 1 2 3
  * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
  * match        (match) object 'Wales Italy' ... 'Ireland England'
    home_team    (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland'
    away_team    (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England'
Data variables:
    home_points  (chain, draw, match) int64 43 9 27 24 12 30 ... 17 23 49 17 27
    away_points  (chain, draw, match) int64 7 16 9 11 24 11 ... 23 24 9 18 28 12
Attributes:
    created_at:                 2019-07-12T20:31:53.563854
    inference_library:          pymc3
    inference_library_version:  3.7
<xarray.Dataset>
Dimensions:      (match: 60)
Coordinates:
  * match        (match) object 'Wales Italy' ... 'Ireland England'
    home_team    (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland'
    away_team    (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England'
Data variables:
    home_points  (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0
    away_points  (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0
Attributes:
    created_at:                 2019-07-12T20:31:53.581293
    inference_library:          pymc3
    inference_library_version:  3.7