arviz.InferenceData.map#
- InferenceData.map(fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs)[source]#
Apply a function to multiple groups.
Applies
fun
groupwise to the selectedInferenceData
groups and overwrites the group with the result of the function.- Parameters
- funcallable
Function to be applied to each group. Assumes the function is called as
fun(dataset, *args, **kwargs)
.- groupsstr or list of str, optional
Groups where the selection is to be applied. Can either be group names or metagroup names.
- filter_groups{None, “like”, “regex”}, optional
If
None
(default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A lapandas.filter
.- inplacebool, optional
If
True
, modify the InferenceData object inplace, otherwise, return the modified copy.- argsarray_like, optional
Positional arguments passed to
fun
.- **kwargsmapping, optional
Keyword arguments passed to
fun
.
- Returns
- InferenceData
A new InferenceData object by default. When
inplace==True
perform selection in place and returnNone
Examples
Shift observed_data, prior_predictive and posterior_predictive.
import arviz as az import numpy as np idata = az.load_arviz_data("non_centered_eight") idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars") idata_shifted_obs
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 6.339 6.783 5.168 1.51 ... 3.593 1.325 4.081 theta_t (chain, draw, school) float64 -0.9553 -1.162 ... 0.5131 1.119 tau (chain, draw) float64 2.574 0.9522 1.832 1.93 ... 3.436 2.333 1.316 theta (chain, draw, school) float64 3.88 3.348 6.233 ... 4.756 5.553 Attributes: created_at: 2022-10-13T14:37:26.351883 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 4.738754749298096 tuning_steps: 1000
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, obs_dim_0: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * obs_dim_0 (obs_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: obs (chain, draw, obs_dim_0) float64 -8.912 0.4851 ... 13.55 27.95 Attributes: created_at: 2022-10-13T14:37:34.333731 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, obs_dim_0: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * obs_dim_0 (obs_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: obs (chain, draw, obs_dim_0) float64 -4.92 -3.33 ... -4.099 -3.873 Attributes: created_at: 2022-10-13T14:37:26.571887 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: (12/16) lp (chain, draw) float64 -47.23 -47.66 ... -45.31 -44.3 largest_eigval (chain, draw) float64 nan nan nan nan ... nan nan nan perf_counter_start (chain, draw) float64 6.638e+03 6.638e+03 ... 6.64e+03 perf_counter_diff (chain, draw) float64 0.002218 0.002257 ... 0.001373 step_size (chain, draw) float64 0.4865 0.4865 ... 0.5299 0.5299 diverging (chain, draw) bool False False False ... False False ... ... max_energy_error (chain, draw) float64 0.1991 0.1653 ... 0.1055 -0.04423 n_steps (chain, draw) float64 7.0 7.0 7.0 7.0 ... 7.0 7.0 7.0 step_size_bar (chain, draw) float64 0.3976 0.3976 ... 0.4382 0.4382 energy_error (chain, draw) float64 0.1991 0.02881 ... -0.02114 smallest_eigval (chain, draw) float64 nan nan nan nan ... nan nan nan index_in_trajectory (chain, draw) int64 -7 6 4 -2 -2 4 ... 8 -3 -6 -6 3 4 Attributes: created_at: 2022-10-13T14:37:26.362154 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 4.738754749298096 tuning_steps: 1000
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 -1.297 4.331 -4.348 ... 4.195 6.55 7.456 theta_t (chain, draw, school) float64 0.9176 -0.281 1.071 ... 0.1039 0.8562 theta (chain, draw, school) float64 40.47 -14.09 47.45 ... 7.63 8.887 tau (chain, draw) float64 45.52 1.976 11.63 8.993 ... 17.76 104.8 1.671 Attributes: created_at: 2022-10-13T14:37:18.108887 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, obs_dim_0: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * obs_dim_0 (obs_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: obs (chain, draw, obs_dim_0) float64 25.87 -10.84 ... -4.658 17.32 Attributes: created_at: 2022-10-13T14:37:18.111951 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2
-
<xarray.Dataset> Dimensions: (obs_dim_0: 8) Coordinates: * obs_dim_0 (obs_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: obs (obs_dim_0) float64 31.0 11.0 0.0 10.0 2.0 4.0 21.0 15.0 Attributes: created_at: 2022-10-13T14:37:18.113060 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: scores (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 Attributes: created_at: 2022-10-13T14:37:18.114126 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2
Rename and update the coordinate values in both posterior and prior groups.
idata = az.load_arviz_data("radon") idata = idata.map( lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign( uranium_coefs=["intercept", "u_slope"] ), groups=["posterior", "prior"] ) idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, uranium_coefs: 2, County: 85) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * uranium_coefs (uranium_coefs) <U9 'intercept' 'u_slope' * County (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE' Data variables: g (chain, draw, uranium_coefs) float64 1.499 0.5436 ... 0.6087 za_county (chain, draw, County) float64 ... b (chain, draw) float64 -0.5965 -0.5848 ... -0.7949 -0.5402 sigma_a (chain, draw) float64 0.2061 0.1526 0.2201 ... 0.1764 0.1121 a (chain, draw, County) float64 ... a_county (chain, draw, County) float64 ... sigma (chain, draw) float64 0.7271 0.7161 0.7189 ... 0.7267 0.7324 Attributes: created_at: 2020-07-24T18:15:12.191355 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2 sampling_time: 18.096983432769775 tuning_steps: 1000
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, obs_id: 919) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * obs_id (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918 Data variables: y (chain, draw, obs_id) float64 ... Attributes: created_at: 2020-07-24T18:15:12.449843 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, obs_id: 919) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * obs_id (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918 Data variables: y (chain, draw, obs_id) float64 ... Attributes: created_at: 2020-07-24T18:15:12.448264 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: step_size_bar (chain, draw) float64 0.1102 0.1102 ... 0.1193 0.1193 diverging (chain, draw) bool False False False ... False False False energy (chain, draw) float64 1.186e+03 1.184e+03 ... 1.189e+03 tree_size (chain, draw) float64 31.0 31.0 31.0 ... 31.0 31.0 31.0 mean_tree_accept (chain, draw) float64 0.9877 0.9969 1.0 ... 0.9996 0.9872 step_size (chain, draw) float64 0.1006 0.1006 0.1006 ... 0.117 0.117 depth (chain, draw) int64 5 5 5 5 5 5 5 5 5 ... 5 5 5 5 5 5 5 5 energy_error (chain, draw) float64 0.0289 -0.02917 ... 0.04322 lp (chain, draw) float64 -1.145e+03 -1.133e+03 ... -1.152e+03 max_energy_error (chain, draw) float64 -0.05888 -0.07148 ... 0.04451 Attributes: created_at: 2020-07-24T18:15:12.197697 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2 sampling_time: 18.096983432769775 tuning_steps: 1000
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, County: 85, uranium_coefs: 2) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * County (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE' * uranium_coefs (uranium_coefs) <U9 'intercept' 'u_slope' Data variables: a_county (chain, draw, County) float64 8.536 5.285 ... 1.061 6.237 sigma_log__ (chain, draw) float64 -0.9709 -1.344 ... -0.8501 -2.235 sigma_a (chain, draw) float64 2.991 2.237 0.3554 ... 0.4417 0.1448 a (chain, draw, County) float64 6.139 4.951 ... 0.9329 6.152 b (chain, draw) float64 0.5647 1.237 -0.6316 ... -0.01409 0.197 za_county (chain, draw, County) float64 0.8012 0.1114 ... 0.8862 0.5826 sigma (chain, draw) float64 0.3787 0.2608 0.6214 ... 0.4274 0.107 g (chain, draw, uranium_coefs) float64 11.31 7.504 ... 11.72 sigma_a_log__ (chain, draw) float64 1.096 0.805 -1.035 ... -0.8171 -1.932 Attributes: created_at: 2020-07-24T18:15:12.454586 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, obs_id: 919) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * obs_id (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918 Data variables: y (chain, draw, obs_id) float64 ... Attributes: created_at: 2020-07-24T18:15:12.457652 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
-
<xarray.Dataset> Dimensions: (obs_id: 919) Coordinates: * obs_id (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918 Data variables: y (obs_id) float64 0.8329 0.8329 1.099 0.09531 ... 1.629 1.335 1.099 Attributes: created_at: 2020-07-24T18:15:12.458415 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
-
<xarray.Dataset> Dimensions: (obs_id: 919, County: 85) Coordinates: * obs_id (obs_id) int64 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918 * County (County) object 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE' Data variables: floor_idx (obs_id) int32 1 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 1 0 0 0 0 0 0 0 county_idx (obs_id) int32 0 0 0 0 1 1 1 1 1 ... 83 83 83 83 83 83 83 84 84 uranium (County) float64 -0.689 -0.8473 -0.1135 ... -0.09002 0.3553 Attributes: created_at: 2020-07-24T18:15:12.459832 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
Add extra coordinates to all groups containing observed variables
idata = az.load_arviz_data("rugby") home_team, away_team = np.array([ m.split() for m in idata.observed_data.match.values ]).T idata = idata.map( lambda ds, **kwargs: ds.assign_coords(**kwargs), groups="observed_vars", home_team=("match", home_team), away_team=("match", away_team), ) idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, team: 6) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265 intercept (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892 atts_star (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878 defs_star (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649 sd_att (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591 sd_def (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849 atts (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029 defs (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986 Attributes: created_at: 2019-07-12T20:31:53.545143 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) object 'Wales Italy' ... 'Ireland England' home_team (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland' away_team (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England' Data variables: home_points (chain, draw, match) int64 ... away_points (chain, draw, match) int64 ... Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: energy_error (chain, draw) float64 -0.07666 -0.4523 ... 0.115 -0.07691 energy (chain, draw) float64 540.2 545.3 542.3 ... 544.0 545.6 tree_size (chain, draw) float64 15.0 63.0 31.0 ... 63.0 31.0 31.0 tune (chain, draw) bool True False False ... False False False mean_tree_accept (chain, draw) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539 lp (chain, draw) float64 -536.4 -536.0 ... -536.1 -536.4 depth (chain, draw) int64 4 6 5 4 4 4 5 5 5 ... 6 4 6 5 3 6 5 5 max_energy_error (chain, draw) float64 -0.5361 -0.5871 ... 0.7109 1.014 step_size (chain, draw) float64 0.2469 0.2469 ... 0.2459 0.2459 step_size_bar (chain, draw) float64 0.2313 0.2313 ... 0.2488 0.2488 diverging (chain, draw) bool False False False ... False False False Attributes: created_at: 2019-07-12T20:31:53.555203 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, team: 6, match: 60) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: sd_att_log__ (chain, draw) float64 1.322 -2.014 1.588 ... -0.8585 -0.1922 intercept (chain, draw) float64 4.464 3.352 1.567 ... 4.363 4.128 1.049 atts_star (chain, draw, team) float64 -2.64 4.172 ... -0.2874 -0.8538 defs_star (chain, draw, team) float64 -0.7817 -0.1478 ... 0.1655 0.01067 away_points (chain, draw, match) int64 11308 0 11 1 0 21442 ... 11 1 2 2 0 sd_att (chain, draw) float64 3.752 0.1334 4.896 ... 0.4238 0.8251 sd_def_log__ (chain, draw) float64 -0.2662 0.2411 0.6071 ... 1.402 -1.981 home (chain, draw) float64 -1.511 -0.001582 ... -0.02416 0.2651 atts (chain, draw, team) float64 -4.667 2.145 ... -0.2702 -0.8365 sd_def (chain, draw) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138 home_points (chain, draw, match) int64 0 47 11899 3262 1 ... 3 2 1 12 13 defs (chain, draw, team) float64 -0.2517 0.3823 ... 0.089 -0.06586 Attributes: created_at: 2019-07-12T20:31:53.573731 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' home_team (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland' away_team (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England' Data variables: home_points (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0 away_points (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0 Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7