"""Utilities for selecting and iterating on xarray objects."""
from itertools import product, tee
import numpy as np
import xarray as xr
from .labels import BaseLabeller
__all__ = ["xarray_sel_iter", "xarray_var_iter", "xarray_to_ndarray"]
def selection_to_string(selection):
"""Convert dictionary of coordinates to a string for labels.
Parameters
----------
selection : dict[Any] -> Any
Returns
-------
str
key1: value1, key2: value2, ...
"""
return ", ".join([f"{v}" for _, v in selection.items()])
def make_label(var_name, selection, position="below"):
"""Consistent labelling for plots.
Parameters
----------
var_name : str
Name of the variable
selection : dict[Any] -> Any
Coordinates of the variable
position : str
Whether to position the coordinates' label "below" (default) or "beside"
the name of the variable
Returns
-------
label
A text representation of the label
"""
if selection:
sel = selection_to_string(selection)
if position == "below":
base = "{}\n{}"
elif position == "beside":
base = "{}[{}]"
else:
sel = ""
base = "{}{}"
return base.format(var_name, sel)
def _dims(data, var_name, skip_dims):
return [dim for dim in data[var_name].dims if dim not in skip_dims]
def _zip_dims(new_dims, vals):
return [dict(zip(new_dims, prod)) for prod in product(*vals)]
[docs]
def xarray_sel_iter(data, var_names=None, combined=False, skip_dims=None, reverse_selections=False):
"""Convert xarray data to an iterator over variable names and selections.
Iterates over each var_name and all of its coordinates, returning the variable
names and selections that allow properly obtain the data from ``data`` as desired.
Parameters
----------
data : xarray.Dataset
Posterior data in an xarray
var_names : iterator of strings (optional)
Should be a subset of data.data_vars. Defaults to all of them.
combined : bool
Whether to combine chains or leave them separate
skip_dims : set
dimensions to not iterate over
reverse_selections : bool
Whether to reverse selections before iterating.
Returns
-------
Iterator of (var_name: str, selection: dict(str, any))
The string is the variable name, the dictionary are coordinate names to values,.
To get the values of the variable at these coordinates, do
``data[var_name].sel(**selection)``.
"""
if skip_dims is None:
skip_dims = set()
if combined:
skip_dims = skip_dims.union({"chain", "draw"})
else:
skip_dims.add("draw")
if var_names is None:
if isinstance(data, xr.Dataset):
var_names = list(data.data_vars)
elif isinstance(data, xr.DataArray):
var_names = [data.name]
data = {data.name: data}
for var_name in var_names:
if var_name in data:
new_dims = _dims(data, var_name, skip_dims)
vals = [list(dict.fromkeys(data[var_name][dim].values)) for dim in new_dims]
dims = _zip_dims(new_dims, vals)
idims = _zip_dims(new_dims, [range(len(v)) for v in vals])
if reverse_selections:
dims = reversed(dims)
idims = reversed(idims)
for selection, iselection in zip(dims, idims):
yield var_name, selection, iselection
[docs]
def xarray_var_iter(
data, var_names=None, combined=False, skip_dims=None, reverse_selections=False, dim_order=None
):
"""Convert xarray data to an iterator over vectors.
Iterates over each var_name and all of its coordinates, returning the 1d
data.
Parameters
----------
data : xarray.Dataset
Posterior data in an xarray
var_names : iterator of strings (optional)
Should be a subset of data.data_vars. Defaults to all of them.
combined : bool
Whether to combine chains or leave them separate
skip_dims : set
dimensions to not iterate over
reverse_selections : bool
Whether to reverse selections before iterating.
dim_order: list
Order for the first dimensions. Skips dimensions not found in the variable.
Returns
-------
Iterator of (str, dict(str, any), np.array)
The string is the variable name, the dictionary are coordinate names to values,
and the array are the values of the variable at those coordinates.
"""
data_to_sel = data
if var_names is None and isinstance(data, xr.DataArray):
data_to_sel = {data.name: data}
if isinstance(dim_order, str):
dim_order = [dim_order]
for var_name, selection, iselection in xarray_sel_iter(
data,
var_names=var_names,
combined=combined,
skip_dims=skip_dims,
reverse_selections=reverse_selections,
):
selected_data = data_to_sel[var_name].sel(**selection)
if dim_order is not None:
dim_order_selected = [dim for dim in dim_order if dim in selected_data.dims]
if dim_order_selected:
selected_data = selected_data.transpose(*dim_order_selected, ...)
yield var_name, selection, iselection, selected_data.values
[docs]
def xarray_to_ndarray(data, *, var_names=None, combined=True, label_fun=None):
"""Take xarray data and unpacks into variables and data into list and numpy array respectively.
Assumes that chain and draw are in coordinates
Parameters
----------
data: xarray.DataSet
Data in an xarray from an InferenceData object. Examples include posterior or sample_stats
var_names: iter
Should be a subset of data.data_vars not including chain and draws. Defaults to all of them
combined: bool
Whether to combine chain into one array
Returns
-------
var_names: list
List of variable names
data: np.array
Data values
"""
if label_fun is None:
label_fun = BaseLabeller().make_label_vert
data_to_sel = data
if var_names is None and isinstance(data, xr.DataArray):
data_to_sel = {data.name: data}
iterator1, iterator2 = tee(xarray_sel_iter(data, var_names=var_names, combined=combined))
vars_and_sel = list(iterator1)
unpacked_var_names = [
label_fun(var_name, selection, isel) for var_name, selection, isel in vars_and_sel
]
# Merge chains and variables, check dtype to be compatible with divergences data
data0 = data_to_sel[vars_and_sel[0][0]].sel(**vars_and_sel[0][1])
unpacked_data = np.empty((len(unpacked_var_names), data0.size), dtype=data0.dtype)
for idx, (var_name, selection, _) in enumerate(iterator2):
unpacked_data[idx] = data_to_sel[var_name].sel(**selection).values.flatten()
return unpacked_var_names, unpacked_data