arviz.extract#

arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#

Extract an InferenceData group or subset of it.

Parameters:
idataInferenceData or InferenceData_like

InferenceData from which to extract the data.

groupstr, optional

Which InferenceData data group to extract data from.

combinedbool, optional

Combine chain and draw dimensions into sample. Won’t work if a dimension named sample already exists.

var_namesstr or list of str, optional

Variables to be extracted. Prefix the variables by ~ when you want to exclude them.

filter_vars: {None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include

num_samplesint, optional

Extract only a subset of the samples. Only valid if combined=True

keep_datasetbool, optional

If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.

rngbool, int, numpy.Generator, optional

Shuffle the samples, only valid if combined=True. By default, samples are shuffled if num_samples is not None, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.

Returns:
xarray.DataArray or xarray.Dataset

Examples

The default behaviour is to return the posterior group after stacking the chain and draw dimensions.

import arviz as az
idata = az.load_arviz_data("centered_eight")
az.extract(idata)
<xarray.Dataset>
Dimensions:  (sample: 2000, school: 8)
Coordinates:
  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3
  * draw     (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
Data variables:
    mu       (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404
    theta    (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295
    tau      (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461
Attributes:
    created_at:                 2022-10-13T14:37:37.315398
    arviz_version:              0.13.0.dev0
    inference_library:          pymc
    inference_library_version:  4.2.2
    sampling_time:              7.480114936828613
    tuning_steps:               1000

You can also indicate a subset to be returned, but in variables and in samples:

az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)>
array([[ 3.46919003e+00,  5.26285351e+00, -6.24775331e-01,
         7.03405245e+00,  9.12796258e+00, -1.86233178e+00,
         3.60916448e+00,  7.91077802e+00,  7.42543182e+00,
         1.05772005e+00,  6.83541506e+00,  1.49021096e+00,
         6.88210567e+00,  7.44738170e+00,  3.25399384e+00,
         2.89325055e+00,  9.15377229e+00,  3.25808233e+00,
         3.25808233e+00,  7.39820257e+00,  1.08583159e+01,
         8.02717401e+00,  5.58777112e+00,  8.76114257e+00,
        -2.99568677e+00,  1.33632598e+01,  1.11144133e+01,
         1.95554018e+00,  7.88942950e+00, -3.57120797e+00,
         7.24692402e+00, -1.56217939e+00, -7.76001689e+00,
         8.46576818e-01,  1.21442118e+01,  1.18937085e+01,
         6.19166018e+00,  8.51708491e+00,  6.35679201e+00,
         8.07411402e+00,  9.44361810e+00,  2.39036944e+01,
         5.51853692e+00,  6.72461144e+00, -2.56500961e+00,
         6.46054644e+00,  2.19751912e+00,  7.31039212e+00,
         2.97785044e+00,  8.50116237e+00,  4.30549864e+00,
         4.56875939e+00,  2.92802933e+00,  1.04382282e+01,
         4.24604215e+00,  1.45832141e-01,  7.19066476e-01,
         8.37610562e+00,  4.62526861e+00,  3.12961854e+00,
...
         7.49338364e+00,  5.48779218e+00, -1.96408397e+00,
         7.26641119e+00,  3.83504623e+00,  1.33404856e+01,
         6.23728051e+00,  9.21082747e+00,  3.36798595e+00,
        -8.07343345e+00,  7.98384939e+00,  8.86523760e+00,
         6.49636495e+00, -1.78243701e+00,  1.75189786e+00,
         1.41540312e+01, -4.29066328e-01, -4.06985166e+00,
         3.35132225e+00,  6.99697482e+00, -8.20216060e-01,
         1.03714332e+01,  1.74432212e-01, -1.42720848e+00,
        -1.35336102e+01,  4.30486506e+00,  1.89424693e+00,
         1.26077965e+01,  6.67596484e+00,  8.33094714e+00,
         3.35132225e+00,  1.13856246e+00,  3.35132225e+00,
         6.21490495e+00,  6.06495726e+00, -1.76985435e+00,
         1.01640899e+01, -8.14360121e+00,  9.24872843e+00,
         7.45065714e+00,  9.39119464e+00,  4.44811027e+00,
         9.01668501e+00,  6.40498626e+00,  8.79855496e+00,
         2.23116742e-01,  3.34820319e+00, -1.45149265e+00,
         5.26780466e+00,  7.25149063e+00,  9.27806411e-01,
         1.15331870e+01,  9.73067897e+00,  9.68254667e+00,
         3.14131866e+00, -3.58944527e-01,  1.71635294e+00,
         8.33041019e+00]])
Coordinates:
  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 1 2 2 0 3 1 0 0 1 2 3 3 ... 0 2 2 2 0 2 2 1 3 1 2 3
  * draw     (sample) int64 395 278 335 248 278 18 381 ... 467 7 305 220 26 25

To keep the chain and draw dimensions, use combined=False.

az.extract(idata, group="prior", combined=False)
<xarray.Dataset>
Dimensions:  (chain: 1, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    tau      (chain, draw) float64 ...
    theta    (chain, draw, school) float64 ...
    mu       (chain, draw) float64 ...
Attributes:
    arviz_version:              0.13.0.dev0
    created_at:                 2022-10-13T14:37:26.602116
    inference_library:          pymc
    inference_library_version:  4.2.2