arviz.InferenceData.map#
- InferenceData.map(fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs)[source]#
Apply a function to multiple groups.
Applies
fun
groupwise to the selectedInferenceData
groups and overwrites the group with the result of the function.- Parameters
- funcallable
Function to be applied to each group. Assumes the function is called as
fun(dataset, *args, **kwargs)
.- groupsstr or list of str, optional
Groups where the selection is to be applied. Can either be group names or metagroup names.
- filter_groups{None, “like”, “regex”}, optional
If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter.
- inplacebool, optional
If
True
, modify the InferenceData object inplace, otherwise, return the modified copy.- argsarray_like, optional
Positional arguments passed to
fun
.- **kwargsmapping, optional
Keyword arguments passed to
fun
.
- Returns
- InferenceData
A new InferenceData object by default. When inplace==True perform selection in place and return None
Examples
Shift observed_data, prior_predictive and posterior_predictive.
In [1]: import arviz as az ...: idata = az.load_arviz_data("non_centered_eight") ...: idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars") ...: print(idata_shifted_obs.observed_data) ...: print(idata_shifted_obs.posterior_predictive) ...: <xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (school) float64 31.0 11.0 0.0 10.0 2.0 4.0 21.0 15.0 Attributes: created_at: 2019-06-21T17:36:37.491073 inference_library: pymc3 inference_library_version: 3.7 <xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 12.42 1.107 11.0 ... 5.087 31.48 Attributes: created_at: 2019-06-21T17:36:37.487547 inference_library: pymc3 inference_library_version: 3.7
Rename and update the coordinate values in both posterior and prior groups.
In [2]: idata = az.load_arviz_data("radon") ...: idata = idata.map( ...: lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign( ...: uranium_coefs=["intercept", "u_slope"] ...: ), ...: groups=["posterior", "prior"] ...: ) ...: idata.posterior ...: Out[2]: <xarray.Dataset> Dimensions: (chain: 4, draw: 500, uranium_coefs: 2, County: 85) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * uranium_coefs (uranium_coefs) <U9 'intercept' 'u_slope' * County (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE' Data variables: g (chain, draw, uranium_coefs) float64 1.499 0.5436 ... 0.6087 za_county (chain, draw, County) float64 -0.1139 0.1347 ... 1.025 0.3687 b (chain, draw) float64 -0.5965 -0.5848 ... -0.7949 -0.5402 sigma_a (chain, draw) float64 0.2061 0.1526 0.2201 ... 0.1764 0.1121 a (chain, draw, County) float64 1.125 1.039 ... 1.452 1.723 a_county (chain, draw, County) float64 1.101 1.066 ... 1.567 1.764 sigma (chain, draw) float64 0.7271 0.7161 0.7189 ... 0.7267 0.7324 Attributes: created_at: 2020-07-24T18:15:12.191355 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2 sampling_time: 18.096983432769775 tuning_steps: 1000
Add extra coordinates to all groups containing observed variables
In [3]: idata = az.load_arviz_data("rugby") ...: home_team, away_team = np.array([ ...: m.split() for m in idata.observed_data.match.values ...: ]).T ...: idata = idata.map( ...: lambda ds, **kwargs: ds.assign_coords(**kwargs), ...: groups="observed_vars", ...: home_team=("match", home_team), ...: away_team=("match", away_team), ...: ) ...: print(idata.posterior_predictive) ...: print(idata.observed_data) ...: <xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) object 'Wales Italy' ... 'Ireland England' home_team (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland' away_team (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England' Data variables: home_points (chain, draw, match) int64 43 9 27 24 12 30 ... 17 23 49 17 27 away_points (chain, draw, match) int64 7 16 9 11 24 11 ... 23 24 9 18 28 12 Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7 <xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' home_team (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland' away_team (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England' Data variables: home_points (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0 away_points (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0 Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7