arviz.extract_dataset#
- arviz.extract_dataset(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, rng=None)[source]#
Extract an InferenceData group or subset of it as a
xarray.Dataset
.- Parameters
- idataInferenceData or InferenceData_like
InferenceData from which to extract the data.
- groupstr, optional
Which InferenceData data group to extract data from.
- combinedbool, optional
Combine
chain
anddraw
dimensions intosample
. Won’t work if a dimension namedsample
already exists.- var_namesstr or list of str, optional
Variables to be plotted, two variables are required. Prefix the variables by ~ when you want to exclude them from the plot.
- filter_vars: {None, “like”, “regex”}, optional
If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include
- num_samplesint, optional
Extract only a subset of the samples. Only valid if
combined=True
- rngbool, int, numpy.Generator, optional
Shuffle the samples, only valid if
combined=True
. By default, samples are shuffled ifnum_samples
is notNone
, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.
- Returns
- xarray.Dataset
Examples
The default behaviour is to return the posterior group after stacking the chain and draw dimensions.
import arviz as az idata = az.load_arviz_data("centered_eight") az.extract_dataset(idata)
<xarray.Dataset> Dimensions: (school: 8, sample: 2000) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) MultiIndex - chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 0 3 3 3 3 3 3 3 3 3 3 3 3 3 - draw (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 Data variables: mu (sample) float64 -3.477 -2.456 -2.826 -1.996 ... 4.597 5.899 0.1614 theta (school, sample) float64 1.669 -6.239 2.195 ... -1.095 4.013 4.523 tau (sample) float64 3.73 2.075 3.703 4.146 ... 8.589 8.346 7.711 5.407 Attributes: created_at: 2019-06-21T17:36:34.398087 inference_library: pymc3 inference_library_version: 3.7
You can also indicate a subset to be returned, but in variables and in samples:
az.extract_dataset(idata, var_names="theta", num_samples=100)
<xarray.Dataset> Dimensions: (school: 8, sample: 100) Coordinates: * school (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' * sample (sample) MultiIndex - chain (sample) int64 2 1 1 1 3 0 0 0 3 1 0 1 ... 1 1 1 3 1 2 1 0 1 2 0 2 - draw (sample) int64 21 292 439 86 441 11 209 ... 156 92 168 2 180 278 Data variables: theta (school, sample) float64 1.609 3.862 1.92 ... 6.226 4.652 8.521 Attributes: created_at: 2019-06-21T17:36:34.398087 inference_library: pymc3 inference_library_version: 3.7
To keep the chain and draw dimensions, use
combined=False
.az.extract_dataset(idata, group="prior", combined=False)
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499 * school (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: tau (chain, draw) float64 6.561 1.016 68.91 ... 1.56 5.949 0.7631 tau_log__ (chain, draw) float64 1.881 0.01593 4.233 ... 1.783 -0.2704 mu (chain, draw) float64 5.293 0.8137 0.7122 ... -1.658 -3.273 theta (chain, draw, school) float64 2.357 7.371 7.251 ... -3.775 -3.555 obs (chain, draw, school) float64 -3.54 6.769 19.68 ... -21.16 -6.071 Attributes: created_at: 2019-06-21T17:36:34.490387 inference_library: pymc3 inference_library_version: 3.7