arviz.extract#

arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#

Extract an InferenceData group or subset of it.

Parameters:
idataInferenceData or InferenceData_like

InferenceData from which to extract the data.

groupstr, optional

Which InferenceData data group to extract data from.

combinedbool, optional

Combine chain and draw dimensions into sample. Won’t work if a dimension named sample already exists.

var_namesstr or list of str, optional

Variables to be extracted. Prefix the variables by ~ when you want to exclude them.

filter_vars: {None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include

num_samplesint, optional

Extract only a subset of the samples. Only valid if combined=True

keep_datasetbool, optional

If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.

rngbool, int, numpy.Generator, optional

Shuffle the samples, only valid if combined=True. By default, samples are shuffled if num_samples is not None, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.

Returns:
xarray.DataArray or xarray.Dataset

Examples

The default behaviour is to return the posterior group after stacking the chain and draw dimensions.

import arviz as az
idata = az.load_arviz_data("centered_eight")
az.extract(idata)
<xarray.Dataset>
Dimensions:  (sample: 2000, school: 8)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3
  * draw     (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
Data variables:
    mu       (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404
    theta    (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295
    tau      (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461
Attributes:
    created_at:                 2022-10-13T14:37:37.315398
    arviz_version:              0.13.0.dev0
    inference_library:          pymc
    inference_library_version:  4.2.2
    sampling_time:              7.480114936828613
    tuning_steps:               1000

You can also indicate a subset to be returned, but in variables and in samples:

az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)>
array([[  7.25006163,   1.20870081,   5.36521047,   6.09019222,
          4.96887564,  10.3335495 ,   8.26200915,   6.56672708,
          6.79867321,   3.25808233,  18.13102836,   3.17531159,
         -1.44923689, -14.18963813,   8.07073012,  13.64463487,
         19.03480507,  -3.43617208,  10.33118838,   0.8882626 ,
          7.38045201,   2.47020244,   0.02669979,   2.48775525,
          7.24359936,   8.64601026,   2.23945711,  10.22682075,
          4.78763947,  13.82900004,  11.45158356,   4.97176681,
         12.32068558, -11.03564233,   4.00072564,   7.55391569,
          0.47002133,   5.36521047,   0.23545152,   3.25808233,
         12.83957038,  14.59400503,   3.72633567,  16.14922101,
          5.78332001,   9.69873422,   9.4436181 ,  17.97989162,
          2.37221256,  -0.78947236,  -0.39105714,  11.83895627,
         -0.87626159,   6.06428092,  10.2958492 ,  -0.144637  ,
          0.19295578,   4.80087665,  -3.66628453,   6.96225585,
          3.8715119 ,  14.00020113,   8.19749048,   2.92802933,
         11.10766799,   6.19166018,   6.0042965 ,  14.37538352,
         -0.84905156,   4.40498925,   9.01554086,   7.55249033,
         13.92582666,   4.43196277,   8.33570972,   6.88210567,
          0.39614811,  -0.43516545,   3.59774091,   2.43743399,
...
          6.09071623,   5.08789837,   2.76069798,   2.84025088,
          5.45279475,   2.22912138,   3.23201494,   8.26737119,
          1.81028014,   9.2949271 ,   3.28762851,   0.97295524,
         15.06136584,   1.56608555,   5.70045771,  -2.4119179 ,
          1.91912907,   6.99697482,   6.6080452 ,   3.35132225,
          4.90514523,   9.68439852,   1.8573399 ,  13.50380373,
          5.69639395,   0.92936424,   9.82982661,   2.56285034,
          4.58559411,   2.98326158,  -1.86901085,   8.56424468,
         -1.83024329,  12.12807101,   4.67353194,   1.52136583,
          1.2950506 ,   3.83319884,   3.09944712,   6.22269309,
          2.00424   ,   8.23112776,   0.27514975,   7.98384939,
          9.87222713,   5.29491415,   5.08930303,   3.3242612 ,
          0.58012835,   0.2817112 ,   7.90169447,  13.70728902,
         20.84920247,   4.62591046,   8.17396635,   5.91405091,
          3.98680602,  11.65332513,  15.52576531,   5.62894307,
        -14.99429294,   3.93695355,  -3.09055561,   2.75953018,
          8.05718631,   5.54969423,   8.33094714,  -1.02702293,
          4.28127982,  15.3553629 ,   8.79855496,   4.61650021,
         -0.23919795,  13.28979809,  -0.10090309, -10.70077322,
          1.84434317,   5.37551996,  -1.72640641,   6.99697482]])
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 3 1 3 2 2 1 0 1 2 1 2 0 ... 0 1 3 2 3 1 3 3 1 1 0 3
  * draw     (sample) int64 199 48 116 466 410 217 36 ... 150 11 158 237 434 103

To keep the chain and draw dimensions, use combined=False.

az.extract(idata, group="prior", combined=False)
<xarray.Dataset>
Dimensions:  (chain: 1, draw: 500, school: 8)
Coordinates:
  * chain    (chain) int64 0
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    tau      (chain, draw) float64 ...
    theta    (chain, draw, school) float64 ...
    mu       (chain, draw) float64 ...
Attributes:
    arviz_version:              0.13.0.dev0
    created_at:                 2022-10-13T14:37:26.602116
    inference_library:          pymc
    inference_library_version:  4.2.2