arviz.extract#
- arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#
Extract an InferenceData group or subset of it.
- Parameters:
- idata
InferenceData
orInferenceData_like
InferenceData from which to extract the data.
- group
str
, optional Which InferenceData data group to extract data from.
- combinedbool, optional
Combine
chain
anddraw
dimensions intosample
. Won’t work if a dimension namedsample
already exists.- var_names
str
orlist
ofstr
, optional Variables to be extracted. Prefix the variables by
~
when you want to exclude them.- filter_vars: {None, “like”, “regex”}, optional
If
None
(default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A lapandas.filter
. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include- num_samples
int
, optional Extract only a subset of the samples. Only valid if
combined=True
- keep_datasetbool, optional
If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.
- rngbool,
int
,numpy.Generator
, optional Shuffle the samples, only valid if
combined=True
. By default, samples are shuffled ifnum_samples
is notNone
, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.
- idata
- Returns:
Examples
The default behaviour is to return the posterior group after stacking the chain and draw dimensions.
import arviz as az idata = az.load_arviz_data("centered_eight") az.extract(idata)
<xarray.Dataset> Dimensions: (sample: 2000, school: 8) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 Data variables: mu (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404 theta (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295 tau (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461 Attributes: created_at: 2022-10-13T14:37:37.315398 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 7.480114936828613 tuning_steps: 1000
You can also indicate a subset to be returned, but in variables and in samples:
az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)> array([[ 9.86049002e+00, 1.02329413e+00, 4.16945829e+00, 7.48425135e+00, 8.51708491e+00, 9.76585447e+00, 6.83541506e+00, 7.57793225e-01, 4.46941839e+00, 1.35807636e+01, 5.82519719e+00, 2.92802933e+00, 7.61099342e+00, 6.78132405e+00, 7.62948922e+00, 7.80435210e+00, 5.22521699e+00, 1.61296613e+01, 8.27358881e-01, 6.59766721e+00, 1.90232253e+01, 4.98946139e+00, 1.67363399e+00, 6.54443324e+00, -1.43865360e-01, 8.47941062e+00, -4.08489960e+00, 6.49799401e+00, 5.19324091e+00, 1.23043911e+01, 6.46680742e+00, 9.58212542e+00, 1.05693241e+01, 9.46772141e+00, 4.03663890e+00, 6.31107231e+00, 1.07845202e+01, 5.36521047e+00, 6.18293721e+00, 8.65007191e+00, 7.02777773e+00, 8.34450834e+00, -1.45632453e+00, 1.38641174e+01, 1.05594375e+01, 3.28786132e+00, 3.59774091e+00, 1.03335495e+01, 3.96148108e-01, 7.25006163e+00, 1.55558885e+01, 6.96225585e+00, 9.75288135e+00, 2.28772366e+00, -3.52066245e-01, 8.82885706e+00, 1.34153595e+01, -4.84024225e-01, 4.34307692e+00, 1.25101358e+01, ... 3.82693259e+00, 7.20084454e+00, 7.62386038e+00, 1.39880952e+01, 1.55257653e+01, 5.42678601e+00, 3.98680602e+00, 5.19428493e+00, 1.19954042e+01, 6.22269309e+00, 6.05209366e+00, -2.39197951e-01, 1.76314439e+01, -2.14242173e+00, 1.39608395e+01, -5.93250093e-01, 6.24901341e+00, 9.32159816e+00, 9.82982661e+00, 8.14707248e+00, 3.35132225e+00, 1.12827635e+01, 1.09661285e+01, 1.88956072e+00, 5.12543482e+00, 1.06649000e+01, 1.05839477e+00, -1.23504922e+00, 8.26737119e+00, 6.99697482e+00, 2.84025088e+00, 7.83577170e+00, -6.25581924e+00, 2.12757214e+01, 1.52136583e+00, 7.92068098e+00, 4.87831767e+00, 1.04302347e+01, 1.59645417e+00, 5.75992457e+00, -4.12807622e+00, 1.71592678e+00, 3.25087694e+00, 1.85733990e+00, 5.29491415e+00, -3.57597379e+00, 8.86523760e+00, 1.27955904e+01, 6.61190575e-02, 3.70049478e+00, 9.89508146e+00, 1.22944607e+01, 9.16972625e+00, 8.57206490e+00, 6.93932071e-01, 5.55416881e+00, 2.22029000e+00, -9.14445769e-02]]) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 1 2 1 0 1 0 3 0 2 3 3 3 ... 0 1 3 0 0 1 3 1 3 1 2 0 * draw (sample) int64 353 67 42 445 122 183 446 ... 308 211 250 491 86 120
To keep the chain and draw dimensions, use
combined=False
.az.extract(idata, group="prior", combined=False)
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... mu (chain, draw) float64 ... Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:26.602116 inference_library: pymc inference_library_version: 4.2.2