arviz.extract#
- arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#
Extract an InferenceData group or subset of it.
- Parameters
- idataInferenceData or InferenceData_like
InferenceData from which to extract the data.
- groupstr, optional
Which InferenceData data group to extract data from.
- combinedbool, optional
Combine
chain
anddraw
dimensions intosample
. Won’t work if a dimension namedsample
already exists.- var_namesstr or list of str, optional
Variables to be extracted. Prefix the variables by
~
when you want to exclude them.- filter_vars: {None, “like”, “regex”}, optional
If
None
(default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A lapandas.filter
. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include- num_samplesint, optional
Extract only a subset of the samples. Only valid if
combined=True
- keep_datasetbool, optional
If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.
- rngbool, int, numpy.Generator, optional
Shuffle the samples, only valid if
combined=True
. By default, samples are shuffled ifnum_samples
is notNone
, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.
- Returns
- xarray.DataArray or xarray.Dataset
Examples
The default behaviour is to return the posterior group after stacking the chain and draw dimensions.
import arviz as az idata = az.load_arviz_data("centered_eight") az.extract(idata)
<xarray.Dataset> Dimensions: (sample: 2000, school: 8) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 Data variables: mu (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404 theta (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295 tau (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461 Attributes: created_at: 2022-10-13T14:37:37.315398 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 7.480114936828613 tuning_steps: 1000
You can also indicate a subset to be returned, but in variables and in samples:
az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)> array([[ 6.79565454e+00, 1.11413987e+01, -2.06011629e+00, 6.28910373e+00, 1.04489792e+01, 3.72488467e+00, 4.64574154e+00, -4.16905800e+00, 7.26191618e+00, -2.40749939e+00, 1.09034661e+01, 5.15263045e+00, -1.32636459e+00, 1.04382282e+01, 8.06311068e+00, 6.27996076e+00, -3.58852681e+00, 4.45090398e+00, 8.80737716e+00, 6.96225585e+00, 1.70405127e+01, 9.05461834e+00, 8.85972309e+00, 7.92304552e+00, 7.51543835e+00, 7.74520291e+00, 9.58212542e+00, 6.89792949e+00, 5.98588351e+00, 1.42499661e+01, 5.41534715e+00, 4.93254022e+00, 1.38641174e+01, 1.04368426e+01, 2.86064613e+00, 6.80175067e+00, 1.24760716e+01, 3.50860324e+00, 4.50692938e+00, 8.80547069e+00, 7.34507204e+00, 1.51422919e+01, 6.36195093e+00, 3.44785980e-02, 6.78132405e+00, 4.08789055e+00, 1.61384468e+01, -4.51536648e+00, 7.23246870e+00, 5.35987796e+00, 9.23455863e+00, 7.57099034e+00, 7.59179819e-01, 1.16627749e+01, 1.59326273e+00, 1.14832706e+01, -6.41664260e+00, 2.47771747e+01, 4.64202496e+00, -8.82173225e-01, ... 2.85143502e+00, -5.53955074e+00, 8.28114677e+00, 2.87701498e+00, 1.62643562e+00, 1.45713136e-02, 2.17014917e+00, 4.12215798e+00, 8.14707248e+00, 3.11912417e+00, 1.21802431e+00, 1.18821725e+01, 2.22113452e+00, 4.71104108e+00, 3.22485573e+00, 1.82598822e+01, 5.08451764e+00, -8.14360121e+00, 7.39941330e+00, 4.18686491e+00, 9.32159816e+00, 3.78957862e-01, 2.22113452e+00, 9.21082747e+00, 2.99111323e+00, 1.39564414e+01, 6.40498626e+00, 1.22372245e+01, 5.04912122e+00, 1.40122789e+01, 2.49875095e+00, 1.06071332e+00, 2.65083405e+00, 1.94926476e+00, 8.34019084e+00, 1.88956072e+00, -3.66902984e+00, 4.70781258e+00, 1.13402789e+01, 8.86580029e+00, 3.85491975e+00, 2.28266671e+00, 9.76957618e+00, -4.20468849e-01, 1.49927153e+00, 1.05973281e+01, 1.08956780e+00, 9.92018135e+00, 9.24872843e+00, 6.76390558e+00, 7.16213323e+00, 4.11506327e+00, 1.11477976e+01, 1.67634160e+00, 3.35132225e+00, 3.32392687e+00, 6.77777189e+00, 4.53714710e+00]]) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 2 3 2 0 1 0 0 1 2 3 1 2 ... 3 2 3 1 1 2 3 0 1 2 1 0 * draw (sample) int64 280 336 487 337 290 141 ... 416 278 457 485 499 341
To keep the chain and draw dimensions, use
combined=False
.az.extract(idata, group="prior", combined=False)
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 1.941 3.388 4.208 ... 0.8353 0.06893 2.145 theta (chain, draw, school) float64 4.866 4.59 -0.7404 ... -2.031 6.045 mu (chain, draw) float64 3.903 3.915 -1.751 ... -2.294 0.7908 2.869 Attributes: created_at: 2022-10-13T14:37:26.602116 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2