Source code for arviz.sel_utils

"""Utilities for selecting and iterating on xarray objects."""

from itertools import product, tee

import numpy as np
import xarray as xr

from .labels import BaseLabeller

__all__ = ["xarray_sel_iter", "xarray_var_iter", "xarray_to_ndarray"]


def selection_to_string(selection):
    """Convert dictionary of coordinates to a string for labels.

    Parameters
    ----------
    selection : dict[Any] -> Any

    Returns
    -------
    str
        key1: value1, key2: value2, ...
    """
    return ", ".join([f"{v}" for _, v in selection.items()])


def make_label(var_name, selection, position="below"):
    """Consistent labelling for plots.

    Parameters
    ----------
    var_name : str
       Name of the variable

    selection : dict[Any] -> Any
        Coordinates of the variable
    position : str
        Whether to position the coordinates' label "below" (default) or "beside"
        the name of the variable

    Returns
    -------
    label
        A text representation of the label
    """
    if selection:
        sel = selection_to_string(selection)
        if position == "below":
            base = "{}\n{}"
        elif position == "beside":
            base = "{}[{}]"
    else:
        sel = ""
        base = "{}{}"
    return base.format(var_name, sel)


def _dims(data, var_name, skip_dims):
    return [dim for dim in data[var_name].dims if dim not in skip_dims]


def _zip_dims(new_dims, vals):
    return [dict(zip(new_dims, prod)) for prod in product(*vals)]


[docs] def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False): """Convert xarray data to an iterator over variable names and selections. Iterates over each var_name and all of its coordinates, returning the variable names and selections that allow properly obtain the data from ``data`` as desired. Parameters ---------- data : xarray.Dataset Posterior data in an xarray var_names : iterator of strings (optional) Should be a subset of data.data_vars. Defaults to all of them. combined : bool Whether to combine chains or leave them separate skip_dims : set dimensions to not iterate over reverse_selections : bool Whether to reverse selections before iterating. Returns ------- Iterator of (var_name: str, selection: dict(str, any)) The string is the variable name, the dictionary are coordinate names to values,. To get the values of the variable at these coordinates, do ``data[var_name].sel(**selection)``. """ if skip_dims is None: skip_dims = set() if combined: skip_dims = skip_dims.union({"chain", "draw"}) else: skip_dims.add("draw") if var_names is None: if isinstance(data, xr.Dataset): var_names = list(data.data_vars) elif isinstance(data, xr.DataArray): var_names = [data.name] data = {data.name: data} for var_name in var_names: if var_name in data: new_dims = _dims(data, var_name, skip_dims) vals = [list(dict.fromkeys(data[var_name][dim].values)) for dim in new_dims] dims = _zip_dims(new_dims, vals) idims = _zip_dims(new_dims, [range(len(v)) for v in vals]) if reverse_selections: dims = reversed(dims) idims = reversed(idims) for selection, iselection in zip(dims, idims): yield var_name, selection, iselection
[docs] def xarray_var_iter( data, var_names=None, combined=False, skip_dims=None, reverse_selections=False, dim_order=None ): """Convert xarray data to an iterator over vectors. Iterates over each var_name and all of its coordinates, returning the 1d data. Parameters ---------- data : xarray.Dataset Posterior data in an xarray var_names : iterator of strings (optional) Should be a subset of data.data_vars. Defaults to all of them. combined : bool Whether to combine chains or leave them separate skip_dims : set dimensions to not iterate over reverse_selections : bool Whether to reverse selections before iterating. dim_order: list Order for the first dimensions. Skips dimensions not found in the variable. Returns ------- Iterator of (str, dict(str, any), np.array) The string is the variable name, the dictionary are coordinate names to values, and the array are the values of the variable at those coordinates. """ data_to_sel = data if var_names is None and isinstance(data, xr.DataArray): data_to_sel = {data.name: data} if isinstance(dim_order, str): dim_order = [dim_order] for var_name, selection, iselection in xarray_sel_iter( data, var_names=var_names, combined=combined, skip_dims=skip_dims, reverse_selections=reverse_selections, ): selected_data = data_to_sel[var_name].sel(**selection) if dim_order is not None: dim_order_selected = [dim for dim in dim_order if dim in selected_data.dims] if dim_order_selected: selected_data = selected_data.transpose(*dim_order_selected, ...) yield var_name, selection, iselection, selected_data.values
[docs] def xarray_to_ndarray(data, *, var_names=None, combined=True, label_fun=None): """Take xarray data and unpacks into variables and data into list and numpy array respectively. Assumes that chain and draw are in coordinates Parameters ---------- data: xarray.DataSet Data in an xarray from an InferenceData object. Examples include posterior or sample_stats var_names: iter Should be a subset of data.data_vars not including chain and draws. Defaults to all of them combined: bool Whether to combine chain into one array Returns ------- var_names: list List of variable names data: np.array Data values """ if label_fun is None: label_fun = BaseLabeller().make_label_vert data_to_sel = data if var_names is None and isinstance(data, xr.DataArray): data_to_sel = {data.name: data} iterator1, iterator2 = tee(xarray_sel_iter(data, var_names=var_names, combined=combined)) vars_and_sel = list(iterator1) unpacked_var_names = [ label_fun(var_name, selection, isel) for var_name, selection, isel in vars_and_sel ] # Merge chains and variables, check dtype to be compatible with divergences data data0 = data_to_sel[vars_and_sel[0][0]].sel(**vars_and_sel[0][1]) unpacked_data = np.empty((len(unpacked_var_names), data0.size), dtype=data0.dtype) for idx, (var_name, selection, _) in enumerate(iterator2): unpacked_data[idx] = data_to_sel[var_name].sel(**selection).values.flatten() return unpacked_var_names, unpacked_data