arviz.extract#
- arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#
Extract an InferenceData group or subset of it.
- Parameters:
- idata
InferenceData
orInferenceData_like
InferenceData from which to extract the data.
- group
str
, optional Which InferenceData data group to extract data from.
- combinedbool, optional
Combine
chain
anddraw
dimensions intosample
. Won’t work if a dimension namedsample
already exists.- var_names
str
orlist
ofstr
, optional Variables to be extracted. Prefix the variables by
~
when you want to exclude them.- filter_vars: {None, “like”, “regex”}, optional
If
None
(default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A lapandas.filter
. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include- num_samples
int
, optional Extract only a subset of the samples. Only valid if
combined=True
- keep_datasetbool, optional
If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.
- rngbool,
int
,numpy.Generator
, optional Shuffle the samples, only valid if
combined=True
. By default, samples are shuffled ifnum_samples
is notNone
, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.
- idata
- Returns:
Examples
The default behaviour is to return the posterior group after stacking the chain and draw dimensions.
import arviz as az idata = az.load_arviz_data("centered_eight") az.extract(idata)
<xarray.Dataset> Size: 209kB Dimensions: (sample: 2000, school: 8) Coordinates: * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' * sample (sample) object 16kB MultiIndex * chain (sample) int64 16kB 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 16kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: mu (sample) float64 16kB 7.872 3.385 9.1 7.304 ... 1.767 3.486 3.404 theta (school, sample) float64 128kB 12.32 11.29 5.709 ... 8.452 1.295 tau (sample) float64 16kB 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461 Attributes: created_at: 2022-10-13T14:37:37.315398 arviz_version: 0.13.0.dev0 inference_library: pymc inference_library_version: 4.2.2 sampling_time: 7.480114936828613 tuning_steps: 1000
You can also indicate a subset to be returned, but in variables and in samples:
az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)> Size: 6kB array([[ 3.61820136e+00, 1.01596619e+01, 6.60623407e+00, 7.22883426e+00, 6.33178880e+00, 1.05281528e+00, 8.64506211e-01, 3.91569975e+00, 8.02796674e+00, 1.35882507e+01, 3.02179357e+00, -1.14147644e+00, 2.58651619e+00, 6.74667610e+00, -4.08489960e+00, 3.72633567e+00, 7.61099342e+00, -5.16033022e+00, 2.69821045e+00, 2.28772366e+00, 7.25006163e+00, -5.10846749e+00, 2.63102226e+00, 1.09034661e+01, 9.01452429e-01, 9.76734589e+00, 1.00708740e+01, 1.20889907e+01, 6.72461144e+00, -2.56500961e+00, 1.23317934e+01, 1.55558885e+01, 5.87573480e-01, 4.24604215e+00, 1.47320820e+00, 1.31998205e+01, 9.47489537e+00, 9.58212542e+00, -3.39398393e+00, 4.32460450e+00, 6.19187589e-01, 4.93254022e+00, 8.75895674e+00, 7.25006163e+00, 1.90348051e+01, 5.20472312e+00, 2.85987166e+00, 6.89414945e+00, 1.21647283e+01, 6.85122480e+00, 2.19751912e+00, 9.97022761e+00, 3.62729697e+00, 7.67252955e+00, 2.66967253e+00, 3.19241432e+00, -3.94180636e+00, 1.55896104e+01, 5.29515274e+00, 9.29485423e+00, ... 1.21913193e+01, 5.19428493e+00, 5.23148216e-01, 3.26190553e+00, 6.79472198e+00, 5.42326935e+00, 1.16330776e+01, 1.00620073e+01, 3.83504623e+00, 4.44278431e-04, 6.25408810e-01, 3.34606063e+00, 2.73537789e+00, 3.93680253e+00, -3.98684048e+00, 5.19515568e+00, 1.12827635e+01, 1.07388301e+01, 3.35132225e+00, 3.21563028e+00, 6.99697482e+00, -1.50008511e-01, 1.00424576e+01, 4.16591560e-01, 5.24838051e+00, 1.12300530e+01, 7.35932193e+00, 9.62883356e+00, 8.82829191e+00, -1.00903091e-01, 1.12702553e+01, 5.29807819e+00, 2.50337198e+00, -3.66902984e+00, 6.99697482e+00, -4.43059825e+00, 5.29491415e+00, 1.15104032e+01, 3.94983056e+00, 2.44083197e+00, 8.57206490e+00, 1.00048799e+01, 1.06649000e+01, 1.94926476e+00, 3.24934784e+00, 3.35132225e+00, 5.42326935e+00, 6.25408810e-01, 8.58671448e+00, 8.38931576e+00, -1.14874960e+00, -9.14445769e-02, 5.72597307e+00, 5.05685363e+00, 7.82905101e+00, 1.25813350e+01, 2.09029102e+00, 3.82917093e-01]]) Coordinates: * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' * sample (sample) object 800B MultiIndex * chain (sample) int64 800B 3 2 1 2 2 3 1 0 1 2 1 ... 0 2 3 2 0 2 0 2 0 3 1 * draw (sample) int64 800B 483 435 153 294 290 168 ... 44 123 5 171 481
To keep the chain and draw dimensions, use
combined=False
.az.extract(idata, group="prior", combined=False)
<xarray.Dataset> Size: 45kB Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: tau (chain, draw) float64 4kB ... theta (chain, draw, school) float64 32kB ... mu (chain, draw) float64 4kB ... Attributes: arviz_version: 0.13.0.dev0 created_at: 2022-10-13T14:37:26.602116 inference_library: pymc inference_library_version: 4.2.2