arviz.InferenceData.map#
- InferenceData.map(fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs)[source]#
Apply a function to multiple groups.
Applies
fun
groupwise to the selectedInferenceData
groups and overwrites the group with the result of the function.- Parameters
- fun
callable()
Function to be applied to each group. Assumes the function is called as
fun(dataset, *args, **kwargs)
.- groups
str
orlist
ofstr
, optional Groups where the selection is to be applied. Can either be group names or metagroup names.
- filter_groups{
None
, “like”, “regex”}, optional If
None
(default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A lapandas.filter
.- inplacebool, optional
If
True
, modify the InferenceData object inplace, otherwise, return the modified copy.- argsarray_like, optional
Positional arguments passed to
fun
.- **kwargs
mapping
, optional Keyword arguments passed to
fun
.
- fun
- Returns
InferenceData
A new InferenceData object by default. When
inplace==True
perform selection in place and returnNone
Examples
Shift observed_data, prior_predictive and posterior_predictive.
import arviz as az import numpy as np idata = az.load_arviz_data("non_centered_eight") idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars") idata_shifted_obs
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 ... theta_t (chain, draw, school) float64 ... tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... Attributes: created_at: 2022-10-13T14:37:26.351883 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 4.738754749298096 tuning_steps: 1000
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, obs_dim_0: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * obs_dim_0 (obs_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: obs (chain, draw, obs_dim_0) float64 -8.912 0.4851 ... 13.55 27.95 Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:34.333731 inference_library: pymc inference_library_version: 4.2.2
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, obs_dim_0: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * obs_dim_0 (obs_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: obs (chain, draw, obs_dim_0) float64 ... Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:26.571887 inference_library: pymc inference_library_version: 4.2.2
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: (12/16) lp (chain, draw) float64 ... largest_eigval (chain, draw) float64 ... perf_counter_start (chain, draw) float64 ... perf_counter_diff (chain, draw) float64 ... step_size (chain, draw) float64 ... diverging (chain, draw) bool ... ... ... max_energy_error (chain, draw) float64 ... n_steps (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... energy_error (chain, draw) float64 ... smallest_eigval (chain, draw) float64 ... index_in_trajectory (chain, draw) int64 ... Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:26.362154 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 4.738754749298096 tuning_steps: 1000
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 ... theta_t (chain, draw, school) float64 ... theta (chain, draw, school) float64 ... tau (chain, draw) float64 ... Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:18.108887 inference_library: pymc inference_library_version: 4.2.2
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, obs_dim_0: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * obs_dim_0 (obs_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: obs (chain, draw, obs_dim_0) float64 25.87 -10.84 ... -4.658 17.32 Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:18.111951 inference_library: pymc inference_library_version: 4.2.2
-
<xarray.Dataset> Dimensions: (obs_dim_0: 8) Coordinates: * obs_dim_0 (obs_dim_0) int64 0 1 2 3 4 5 6 7 Data variables: obs (obs_dim_0) float64 31.0 11.0 0.0 10.0 2.0 4.0 21.0 15.0 Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:18.113060 inference_library: pymc inference_library_version: 4.2.2
-
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: scores (school) float64 ... Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:18.114126 inference_library: pymc inference_library_version: 4.2.2
Rename and update the coordinate values in both posterior and prior groups.
idata = az.load_arviz_data("radon") idata = idata.map( lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign( uranium_coefs=["intercept", "u_slope"] ), groups=["posterior", "prior"] ) idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, uranium_coefs: 2, County: 85) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * uranium_coefs (uranium_coefs) <U9 'intercept' 'u_slope' * County (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE' Data variables: g (chain, draw, uranium_coefs) float64 ... za_county (chain, draw, County) float64 ... b (chain, draw) float64 ... sigma_a (chain, draw) float64 ... a (chain, draw, County) float64 ... a_county (chain, draw, County) float64 ... sigma (chain, draw) float64 ... Attributes: created_at: 2020-07-24T18:15:12.191355 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2 sampling_time: 18.096983432769775 tuning_steps: 1000
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, obs_id: 919) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * obs_id (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918 Data variables: y (chain, draw, obs_id) float64 ... Attributes: created_at: 2020-07-24T18:15:12.449843 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, obs_id: 919) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * obs_id (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918 Data variables: y (chain, draw, obs_id) float64 ... Attributes: created_at: 2020-07-24T18:15:12.448264 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: step_size_bar (chain, draw) float64 ... diverging (chain, draw) bool ... energy (chain, draw) float64 ... tree_size (chain, draw) float64 ... mean_tree_accept (chain, draw) float64 ... step_size (chain, draw) float64 ... depth (chain, draw) int64 ... energy_error (chain, draw) float64 ... lp (chain, draw) float64 ... max_energy_error (chain, draw) float64 ... Attributes: created_at: 2020-07-24T18:15:12.197697 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2 sampling_time: 18.096983432769775 tuning_steps: 1000
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, County: 85, uranium_coefs: 2) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * County (County) object 'AITKIN' 'ANOKA' ... 'YELLOW MEDICINE' * uranium_coefs (uranium_coefs) <U9 'intercept' 'u_slope' Data variables: a_county (chain, draw, County) float64 ... sigma_log__ (chain, draw) float64 ... sigma_a (chain, draw) float64 ... a (chain, draw, County) float64 ... b (chain, draw) float64 ... za_county (chain, draw, County) float64 ... sigma (chain, draw) float64 ... g (chain, draw, uranium_coefs) float64 ... sigma_a_log__ (chain, draw) float64 ... Attributes: created_at: 2020-07-24T18:15:12.454586 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, obs_id: 919) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * obs_id (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918 Data variables: y (chain, draw, obs_id) float64 ... Attributes: created_at: 2020-07-24T18:15:12.457652 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
-
<xarray.Dataset> Dimensions: (obs_id: 919) Coordinates: * obs_id (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918 Data variables: y (obs_id) float64 ... Attributes: created_at: 2020-07-24T18:15:12.458415 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
-
<xarray.Dataset> Dimensions: (obs_id: 919, County: 85) Coordinates: * obs_id (obs_id) int64 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918 * County (County) object 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE' Data variables: floor_idx (obs_id) int32 ... county_idx (obs_id) int32 ... uranium (County) float64 ... Attributes: created_at: 2020-07-24T18:15:12.459832 arviz_version: 0.9.0 inference_library: pymc3 inference_library_version: 3.9.2
Add extra coordinates to all groups containing observed variables
idata = az.load_arviz_data("rugby") home_team, away_team = np.array([ m.split() for m in idata.observed_data.match.values ]).T idata = idata.map( lambda ds, **kwargs: ds.assign_coords(**kwargs), groups="observed_vars", home_team=("match", home_team), away_team=("match", away_team), ) idata
arviz.InferenceData-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, team: 6) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' Data variables: home (chain, draw) float64 ... intercept (chain, draw) float64 ... atts_star (chain, draw, team) float64 ... defs_star (chain, draw, team) float64 ... sd_att (chain, draw) float64 ... sd_def (chain, draw) float64 ... atts (chain, draw, team) float64 ... defs (chain, draw, team) float64 ... Attributes: created_at: 2019-07-12T20:31:53.545143 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500, match: 60) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * match (match) object 'Wales Italy' ... 'Ireland England' home_team (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland' away_team (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England' Data variables: home_points (chain, draw, match) int64 ... away_points (chain, draw, match) int64 ... Attributes: created_at: 2019-07-12T20:31:53.563854 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499 Data variables: energy_error (chain, draw) float64 ... energy (chain, draw) float64 ... tree_size (chain, draw) float64 ... tune (chain, draw) bool ... mean_tree_accept (chain, draw) float64 ... lp (chain, draw) float64 ... depth (chain, draw) int64 ... max_energy_error (chain, draw) float64 ... step_size (chain, draw) float64 ... step_size_bar (chain, draw) float64 ... diverging (chain, draw) bool ... Attributes: created_at: 2019-07-12T20:31:53.555203 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, team: 6, match: 60) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England' * match (match) object 'Wales Italy' ... 'Ireland England' Data variables: sd_att_log__ (chain, draw) float64 ... intercept (chain, draw) float64 ... atts_star (chain, draw, team) float64 ... defs_star (chain, draw, team) float64 ... away_points (chain, draw, match) int64 ... sd_att (chain, draw) float64 ... sd_def_log__ (chain, draw) float64 ... home (chain, draw) float64 ... atts (chain, draw, team) float64 ... sd_def (chain, draw) float64 ... home_points (chain, draw, match) int64 ... defs (chain, draw, team) float64 ... Attributes: created_at: 2019-07-12T20:31:53.573731 inference_library: pymc3 inference_library_version: 3.7
-
<xarray.Dataset> Dimensions: (match: 60) Coordinates: * match (match) object 'Wales Italy' ... 'Ireland England' home_team (match) <U8 'Wales' 'France' 'Ireland' ... 'France' 'Ireland' away_team (match) <U8 'Italy' 'England' 'Scotland' ... 'Wales' 'England' Data variables: home_points (match) float64 ... away_points (match) float64 ... Attributes: created_at: 2019-07-12T20:31:53.581293 inference_library: pymc3 inference_library_version: 3.7