Source code for arviz.data.utils

"""Data specific utilities."""

import warnings
import numpy as np

from ..utils import _var_names
from .converters import convert_to_dataset


def extract_dataset(
    data,
    group="posterior",
    combined=True,
    var_names=None,
    filter_vars=None,
    num_samples=None,
    rng=None,
):
    """Extract an InferenceData group or subset of it.

    .. deprecated:: 0.13
            `extract_dataset` will be removed in ArviZ 0.14, it is replaced by
            `extract` because the latter allows to obtain both DataSets and DataArrays.
    """
    warnings.warn(
        "extract_dataset has been deprecated, please use extract", FutureWarning, stacklevel=2
    )

    data = extract(
        data=data,
        group=group,
        combined=combined,
        var_names=var_names,
        filter_vars=filter_vars,
        num_samples=num_samples,
        rng=rng,
    )
    return data


[docs] def extract( data, group="posterior", combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None, ): """Extract an InferenceData group or subset of it. Parameters ---------- idata : InferenceData or InferenceData_like InferenceData from which to extract the data. group : str, optional Which InferenceData data group to extract data from. combined : bool, optional Combine ``chain`` and ``draw`` dimensions into ``sample``. Won't work if a dimension named ``sample`` already exists. var_names : str or list of str, optional Variables to be extracted. Prefix the variables by `~` when you want to exclude them. filter_vars: {None, "like", "regex"}, optional If `None` (default), interpret var_names as the real variables names. If "like", interpret var_names as substrings of the real variables names. If "regex", interpret var_names as regular expressions on the real variables names. A la `pandas.filter`. Like with plotting, sometimes it's easier to subset saying what to exclude instead of what to include num_samples : int, optional Extract only a subset of the samples. Only valid if ``combined=True`` keep_dataset : bool, optional If true, always return a DataSet. If false (default) return a DataArray when there is a single variable. rng : bool, int, numpy.Generator, optional Shuffle the samples, only valid if ``combined=True``. By default, samples are shuffled if ``num_samples`` is not ``None``, and are left in the same order otherwise. This ensures that subsetting the samples doesn't return only samples from a single chain and consecutive draws. Returns ------- xarray.DataArray or xarray.Dataset Examples -------- The default behaviour is to return the posterior group after stacking the chain and draw dimensions. .. jupyter-execute:: import arviz as az idata = az.load_arviz_data("centered_eight") az.extract(idata) You can also indicate a subset to be returned, but in variables and in samples: .. jupyter-execute:: az.extract(idata, var_names="theta", num_samples=100) To keep the chain and draw dimensions, use ``combined=False``. .. jupyter-execute:: az.extract(idata, group="prior", combined=False) """ if num_samples is not None and not combined: raise ValueError("num_samples is only compatible with combined=True") if rng is None: rng = num_samples is not None if rng is not False and not combined: raise ValueError("rng is only compatible with combined=True") data = convert_to_dataset(data, group=group) var_names = _var_names(var_names, data, filter_vars) if var_names is not None: if len(var_names) == 1 and not keep_dataset: var_names = var_names[0] data = data[var_names] if combined: data = data.stack(sample=("chain", "draw")) # 0 is a valid seed se we need to check for rng being exactly boolean if rng is not False: if rng is True: rng = np.random.default_rng() # default_rng takes ints or sequences of ints try: rng = np.random.default_rng(rng) random_subset = rng.permutation(np.arange(len(data["sample"]))) except TypeError as err: raise TypeError("Unable to initializate numpy random Generator from rng") from err except AttributeError as err: raise AttributeError("Unable to use rng to generate a permutation") from err data = data.isel(sample=random_subset) if num_samples is not None: data = data.isel(sample=slice(None, num_samples)) return data