arviz.extract#

arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#

Extract an InferenceData group or subset of it.

Parameters:
idataInferenceData or InferenceData_like

InferenceData from which to extract the data.

groupstr, optional

Which InferenceData data group to extract data from.

combinedbool, optional

Combine chain and draw dimensions into sample. Won’t work if a dimension named sample already exists.

var_namesstr or list of str, optional

Variables to be extracted. Prefix the variables by ~ when you want to exclude them.

filter_vars: {None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include

num_samplesint, optional

Extract only a subset of the samples. Only valid if combined=True

keep_datasetbool, optional

If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.

rngbool, int, numpy.Generator, optional

Shuffle the samples, only valid if combined=True. By default, samples are shuffled if num_samples is not None, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.

Returns:
xarray.DataArray or xarray.Dataset

Examples

The default behaviour is to return the posterior group after stacking the chain and draw dimensions.

import arviz as az
idata = az.load_arviz_data("centered_eight")
az.extract(idata)
<xarray.Dataset>
Dimensions:  (sample: 2000, school: 8)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3
  * draw     (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
Data variables:
    mu       (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404
    theta    (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295
    tau      (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461
Attributes:
    created_at:                 2022-10-13T14:37:37.315398
    arviz_version:              0.13.0.dev0
    inference_library:          pymc
    inference_library_version:  4.2.2
    sampling_time:              7.480114936828613
    tuning_steps:               1000

You can also indicate a subset to be returned, but in variables and in samples:

az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)>
array([[ 9.86049002e+00,  1.02329413e+00,  4.16945829e+00,
         7.48425135e+00,  8.51708491e+00,  9.76585447e+00,
         6.83541506e+00,  7.57793225e-01,  4.46941839e+00,
         1.35807636e+01,  5.82519719e+00,  2.92802933e+00,
         7.61099342e+00,  6.78132405e+00,  7.62948922e+00,
         7.80435210e+00,  5.22521699e+00,  1.61296613e+01,
         8.27358881e-01,  6.59766721e+00,  1.90232253e+01,
         4.98946139e+00,  1.67363399e+00,  6.54443324e+00,
        -1.43865360e-01,  8.47941062e+00, -4.08489960e+00,
         6.49799401e+00,  5.19324091e+00,  1.23043911e+01,
         6.46680742e+00,  9.58212542e+00,  1.05693241e+01,
         9.46772141e+00,  4.03663890e+00,  6.31107231e+00,
         1.07845202e+01,  5.36521047e+00,  6.18293721e+00,
         8.65007191e+00,  7.02777773e+00,  8.34450834e+00,
        -1.45632453e+00,  1.38641174e+01,  1.05594375e+01,
         3.28786132e+00,  3.59774091e+00,  1.03335495e+01,
         3.96148108e-01,  7.25006163e+00,  1.55558885e+01,
         6.96225585e+00,  9.75288135e+00,  2.28772366e+00,
        -3.52066245e-01,  8.82885706e+00,  1.34153595e+01,
        -4.84024225e-01,  4.34307692e+00,  1.25101358e+01,
...
         3.82693259e+00,  7.20084454e+00,  7.62386038e+00,
         1.39880952e+01,  1.55257653e+01,  5.42678601e+00,
         3.98680602e+00,  5.19428493e+00,  1.19954042e+01,
         6.22269309e+00,  6.05209366e+00, -2.39197951e-01,
         1.76314439e+01, -2.14242173e+00,  1.39608395e+01,
        -5.93250093e-01,  6.24901341e+00,  9.32159816e+00,
         9.82982661e+00,  8.14707248e+00,  3.35132225e+00,
         1.12827635e+01,  1.09661285e+01,  1.88956072e+00,
         5.12543482e+00,  1.06649000e+01,  1.05839477e+00,
        -1.23504922e+00,  8.26737119e+00,  6.99697482e+00,
         2.84025088e+00,  7.83577170e+00, -6.25581924e+00,
         2.12757214e+01,  1.52136583e+00,  7.92068098e+00,
         4.87831767e+00,  1.04302347e+01,  1.59645417e+00,
         5.75992457e+00, -4.12807622e+00,  1.71592678e+00,
         3.25087694e+00,  1.85733990e+00,  5.29491415e+00,
        -3.57597379e+00,  8.86523760e+00,  1.27955904e+01,
         6.61190575e-02,  3.70049478e+00,  9.89508146e+00,
         1.22944607e+01,  9.16972625e+00,  8.57206490e+00,
         6.93932071e-01,  5.55416881e+00,  2.22029000e+00,
        -9.14445769e-02]])
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 1 2 1 0 1 0 3 0 2 3 3 3 ... 0 1 3 0 0 1 3 1 3 1 2 0
  * draw     (sample) int64 353 67 42 445 122 183 446 ... 308 211 250 491 86 120

To keep the chain and draw dimensions, use combined=False.

az.extract(idata, group="prior", combined=False)
<xarray.Dataset>
Dimensions:  (chain: 1, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    tau      (chain, draw) float64 ...
    theta    (chain, draw, school) float64 ...
    mu       (chain, draw) float64 ...
Attributes:
    arviz_version:              0.13.0.dev0
    created_at:                 2022-10-13T14:37:26.602116
    inference_library:          pymc
    inference_library_version:  4.2.2