arviz.extract#
- arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#
Extract an InferenceData group or subset of it.
- Parameters
- idataInferenceData or InferenceData_like
InferenceData from which to extract the data.
- groupstr, optional
Which InferenceData data group to extract data from.
- combinedbool, optional
Combine
chain
anddraw
dimensions intosample
. Won’t work if a dimension namedsample
already exists.- var_namesstr or list of str, optional
Variables to be extracted. Prefix the variables by
~
when you want to exclude them.- filter_vars: {None, “like”, “regex”}, optional
If
None
(default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A lapandas.filter
. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include- num_samplesint, optional
Extract only a subset of the samples. Only valid if
combined=True
- keep_datasetbool, optional
If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.
- rngbool, int, numpy.Generator, optional
Shuffle the samples, only valid if
combined=True
. By default, samples are shuffled ifnum_samples
is notNone
, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.
- Returns
- xarray.DataArray or xarray.Dataset
Examples
The default behaviour is to return the posterior group after stacking the chain and draw dimensions.
import arviz as az idata = az.load_arviz_data("centered_eight") az.extract(idata)
<xarray.Dataset> Dimensions: (sample: 2000, school: 8) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 Data variables: mu (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404 theta (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295 tau (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461 Attributes: created_at: 2022-10-13T14:37:37.315398 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 7.480114936828613 tuning_steps: 1000
You can also indicate a subset to be returned, but in variables and in samples:
az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)> array([[ 3.73269095e+00, 1.13955608e+01, 1.61296613e+01, -1.12478279e+00, 9.31726415e+00, 9.78656961e+00, 5.58777112e+00, 9.67704275e+00, 1.67363399e+00, 7.57793225e-01, -3.22971769e+00, 2.74842711e+00, 7.34920348e+00, 2.47851682e-01, 2.86064613e+00, 9.79959946e+00, -2.90704859e+00, 6.31107231e+00, 3.11147235e+00, 1.67019493e+00, 8.82854249e+00, 8.93305013e+00, 3.30285590e+00, 6.76652313e+00, 3.49517632e+00, 3.53056375e+00, 9.51332511e+00, 8.13029737e+00, 1.08238009e+01, -3.58852681e+00, 4.71860289e+00, 1.40261609e+01, 5.19324091e+00, 9.97022761e+00, 5.17403134e+00, -6.42927544e-02, 2.99815135e+00, 3.60286283e+00, 9.44361810e+00, 3.09387429e+01, 6.33178880e+00, 7.19731695e+00, 6.32443017e+00, 4.87102786e+00, 1.67363399e+00, 5.26285351e+00, 1.24594188e+01, 8.36807219e+00, 7.56338176e+00, 8.17614149e+00, -7.62142522e+00, 8.27358881e-01, 6.19187589e-01, -5.50959584e+00, 1.35882507e+01, 3.48360917e+00, 3.96348025e+00, 8.74916305e+00, 1.47320820e+00, 6.07094583e+00, ... 5.95381092e+00, 3.58898903e+00, 3.23104629e+00, 1.19331417e+01, 2.19066997e+00, 1.00048799e+01, -6.78154574e+00, -3.04307193e-01, 1.72850449e-01, 2.09180643e+00, 1.87963638e+00, -7.88124799e-01, 2.14561164e+00, 3.56817432e+00, 9.90728873e+00, 8.67071484e+00, 5.90635119e+00, 4.55080517e+00, 1.54158971e+01, -1.96408397e+00, 1.25679212e+00, -2.38039401e+00, 2.09180643e+00, 1.28825190e+00, 9.65910877e+00, 2.44083197e+00, 5.61311698e+00, 5.13629724e+00, 1.58508949e+01, 1.08249054e+00, 1.14247207e+01, 3.23104629e+00, 1.38433952e+01, 2.67585013e+00, 3.09944712e+00, -2.12121989e+00, 1.00207705e+01, 3.35132225e+00, -2.14982828e+00, -3.98684048e+00, 5.51627670e+00, 9.18527522e+00, 3.73976292e+00, 3.35132225e+00, 1.82138159e+01, 1.13402789e+01, -1.45535079e+01, 5.87431258e+00, 4.71064203e+00, -3.15352465e+00, 6.57878387e+00, -6.53343265e-01, -3.24711098e+00, 7.18026928e+00, 5.23739698e+00, -2.05161879e+00, -4.43059825e+00, 7.20509135e+00]]) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 0 1 0 2 2 1 0 0 1 0 0 3 ... 2 2 2 2 0 1 3 1 1 2 1 0 * draw (sample) int64 321 127 187 483 97 359 ... 356 300 250 136 355 426
To keep the chain and draw dimensions, use
combined=False
.az.extract(idata, group="prior", combined=False)
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... mu (chain, draw) float64 ... Attributes: created_at: 2022-10-13T14:37:26.602116 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2