arviz.extract#
- arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#
Extract an InferenceData group or subset of it.
- Parameters:
- idata
InferenceData
orInferenceData_like
InferenceData from which to extract the data.
- group
str
, optional Which InferenceData data group to extract data from.
- combinedbool, optional
Combine
chain
anddraw
dimensions intosample
. Won’t work if a dimension namedsample
already exists.- var_names
str
orlist
ofstr
, optional Variables to be extracted. Prefix the variables by
~
when you want to exclude them.- filter_vars: {None, “like”, “regex”}, optional
If
None
(default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A lapandas.filter
. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include- num_samples
int
, optional Extract only a subset of the samples. Only valid if
combined=True
- keep_datasetbool, optional
If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.
- rngbool,
int
,numpy.Generator
, optional Shuffle the samples, only valid if
combined=True
. By default, samples are shuffled ifnum_samples
is notNone
, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.
- idata
- Returns:
Examples
The default behaviour is to return the posterior group after stacking the chain and draw dimensions.
import arviz as az idata = az.load_arviz_data("centered_eight") az.extract(idata)
<xarray.Dataset> Dimensions: (sample: 2000, school: 8) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 Data variables: mu (sample) float64 7.872 3.385 9.1 7.304 ... 1.859 1.767 3.486 3.404 theta (school, sample) float64 12.32 11.29 5.709 ... -2.623 8.452 1.295 tau (sample) float64 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461 Attributes: created_at: 2022-10-13T14:37:37.315398 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 7.480114936828613 tuning_steps: 1000
You can also indicate a subset to be returned, but in variables and in samples:
az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)> array([[ 7.25006163, 1.20870081, 5.36521047, 6.09019222, 4.96887564, 10.3335495 , 8.26200915, 6.56672708, 6.79867321, 3.25808233, 18.13102836, 3.17531159, -1.44923689, -14.18963813, 8.07073012, 13.64463487, 19.03480507, -3.43617208, 10.33118838, 0.8882626 , 7.38045201, 2.47020244, 0.02669979, 2.48775525, 7.24359936, 8.64601026, 2.23945711, 10.22682075, 4.78763947, 13.82900004, 11.45158356, 4.97176681, 12.32068558, -11.03564233, 4.00072564, 7.55391569, 0.47002133, 5.36521047, 0.23545152, 3.25808233, 12.83957038, 14.59400503, 3.72633567, 16.14922101, 5.78332001, 9.69873422, 9.4436181 , 17.97989162, 2.37221256, -0.78947236, -0.39105714, 11.83895627, -0.87626159, 6.06428092, 10.2958492 , -0.144637 , 0.19295578, 4.80087665, -3.66628453, 6.96225585, 3.8715119 , 14.00020113, 8.19749048, 2.92802933, 11.10766799, 6.19166018, 6.0042965 , 14.37538352, -0.84905156, 4.40498925, 9.01554086, 7.55249033, 13.92582666, 4.43196277, 8.33570972, 6.88210567, 0.39614811, -0.43516545, 3.59774091, 2.43743399, ... 6.09071623, 5.08789837, 2.76069798, 2.84025088, 5.45279475, 2.22912138, 3.23201494, 8.26737119, 1.81028014, 9.2949271 , 3.28762851, 0.97295524, 15.06136584, 1.56608555, 5.70045771, -2.4119179 , 1.91912907, 6.99697482, 6.6080452 , 3.35132225, 4.90514523, 9.68439852, 1.8573399 , 13.50380373, 5.69639395, 0.92936424, 9.82982661, 2.56285034, 4.58559411, 2.98326158, -1.86901085, 8.56424468, -1.83024329, 12.12807101, 4.67353194, 1.52136583, 1.2950506 , 3.83319884, 3.09944712, 6.22269309, 2.00424 , 8.23112776, 0.27514975, 7.98384939, 9.87222713, 5.29491415, 5.08930303, 3.3242612 , 0.58012835, 0.2817112 , 7.90169447, 13.70728902, 20.84920247, 4.62591046, 8.17396635, 5.91405091, 3.98680602, 11.65332513, 15.52576531, 5.62894307, -14.99429294, 3.93695355, -3.09055561, 2.75953018, 8.05718631, 5.54969423, 8.33094714, -1.02702293, 4.28127982, 15.3553629 , 8.79855496, 4.61650021, -0.23919795, 13.28979809, -0.10090309, -10.70077322, 1.84434317, 5.37551996, -1.72640641, 6.99697482]]) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) object MultiIndex * chain (sample) int64 3 1 3 2 2 1 0 1 2 1 2 0 ... 0 1 3 2 3 1 3 3 1 1 0 3 * draw (sample) int64 199 48 116 466 410 217 36 ... 150 11 158 237 434 103
To keep the chain and draw dimensions, use
combined=False
.az.extract(idata, group="prior", combined=False)
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: tau (chain, draw) float64 ... theta (chain, draw, school) float64 ... mu (chain, draw) float64 ... Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:26.602116 inference_library: pymc inference_library_version: 4.2.2