arviz.extract#
- arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#
Extract an InferenceData group or subset of it.
- Parameters
- idata
InferenceData
orInferenceData_like
InferenceData from which to extract the data.
- group
str
, optional Which InferenceData data group to extract data from.
- combinedbool, optional
Combine
chain
anddraw
dimensions intosample
. Won’t work if a dimension namedsample
already exists.- var_names
str
orlist
ofstr
, optional Variables to be extracted. Prefix the variables by
~
when you want to exclude them.- filter_vars: {None, “like”, “regex”}, optional
If
None
(default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A lapandas.filter
. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include- num_samples
int
, optional Extract only a subset of the samples. Only valid if
combined=True
- keep_datasetbool, optional
If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.
- rngbool,
int
,numpy.Generator
, optional Shuffle the samples, only valid if
combined=True
. By default, samples are shuffled ifnum_samples
is notNone
, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.
- idata
- Returns
Examples
The default behaviour is to return the posterior group after stacking the chain and draw dimensions.
import arviz as az idata = az.load_arviz_data("centered_eight") az.extract(idata)
<xarray.Dataset> Dimensions: (sample: 2000, school: 8) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 Data variables: mu (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404 theta (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295 tau (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461 Attributes: created_at: 2022-10-13T14:37:37.315398 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 7.480114936828613 tuning_steps: 1000
You can also indicate a subset to be returned, but in variables and in samples:
az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)> array([[ 1.16356393e+01, 9.80188050e+00, 6.93102223e+00, 4.40479805e+00, 3.37453765e+00, 1.66027661e+00, 6.88803121e+00, 7.68194332e+00, 3.42527056e+00, 3.24161727e+00, 9.76390798e+00, 2.66587615e+00, 1.25933432e+01, 1.20017602e+01, -2.56500961e+00, -1.19450497e+00, 3.16324840e-01, -4.35165446e-01, -1.45575080e+00, 7.57620852e+00, 7.81921367e-01, 9.80188050e+00, 2.61377333e-01, -5.10846749e+00, 6.25495846e+00, 9.58014932e+00, -1.32636459e+00, 8.27358881e-01, 1.92743101e+01, 9.85349670e-01, 2.67717800e+00, 9.80188050e+00, -8.82173225e-01, 4.27281294e+00, 9.71751869e+00, 5.21907152e+00, 3.23315640e+00, 7.25006163e+00, 1.35503511e+01, 1.12856232e+01, 3.88495610e+00, -6.98493081e-01, 1.32968861e+01, 9.14079559e+00, 6.90074411e+00, -5.37572195e-01, 3.25808233e+00, -6.24775331e-01, 9.76585447e+00, 9.88063755e+00, 1.04382282e+01, 1.47359226e+01, 7.34034281e+00, 1.59326273e+00, 2.37221256e+00, 5.29726546e-01, 1.31022009e+01, -1.33919123e-02, 5.36521047e+00, -3.75264208e+00, ... 2.76188750e+00, 1.13577324e+01, -7.89320888e-01, 5.26780466e+00, 3.35132225e+00, -3.55332084e+00, 5.44159038e+00, 5.04095986e+00, 8.86523760e+00, 1.25813350e+01, 4.83350971e+00, 2.22113452e+00, 4.58559411e+00, 1.94926476e+00, 1.58236330e+01, 2.86838389e-01, 6.99697482e+00, 2.27260207e+00, -2.26780401e+00, 5.83475486e-01, 6.55331619e+00, -3.15352465e+00, 5.49126949e+00, 4.12192432e+00, 3.35132225e+00, 3.90203385e+00, 4.11797530e+00, 9.31886541e+00, 9.03110639e+00, 7.63726943e+00, 1.59962412e+00, 3.58661575e+00, 9.12635811e+00, 3.47889525e+00, 3.94770392e+00, 4.47140798e-03, 5.24838051e+00, 7.03611228e+00, -2.23165528e+00, 1.84003191e+00, 5.69639395e+00, 5.33233277e+00, 9.98761823e+00, 6.99697482e+00, 3.35132225e+00, 4.05865793e+00, 1.05857124e+01, -9.67889623e+00, 1.30644225e+00, 1.97055711e+00, -4.34713310e+00, 2.09180643e+00, 4.40295366e+00, 9.07594450e+00, 6.25408810e-01, 2.73537789e+00, 7.98394952e+00, 7.45065714e+00]]) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 0 3 0 1 2 2 0 0 1 2 2 3 ... 1 1 0 2 2 0 0 1 0 3 2 2 * draw (sample) int64 188 28 456 337 350 328 ... 323 277 196 275 474 434
To keep the chain and draw dimensions, use
combined=False
.az.extract(idata, group="prior", combined=False)
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... mu (chain, draw) float64 ... Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:26.602116 inference_library: pymc inference_library_version: 4.2.2