"""Low level converters usually used by other functions."""
import datetime
import functools
import importlib
import re
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import numpy as np
import xarray as xr
try:
import tree
except ImportError:
tree = None
try:
import ujson as json
except ImportError:
# mypy struggles with conditional imports expressed as catching ImportError:
# https://github.com/python/mypy/issues/1153
import json # type: ignore
from .. import __version__, utils
from ..rcparams import rcParams
CoordSpec = Dict[str, List[Any]]
DimSpec = Dict[str, List[str]]
RequiresArgTypeT = TypeVar("RequiresArgTypeT")
RequiresReturnTypeT = TypeVar("RequiresReturnTypeT")
class requires: # pylint: disable=invalid-name
"""Decorator to return None if an object does not have the required attribute.
If the decorator is called various times on the same function with different
attributes, it will return None if one of them is missing. If instead a list
of attributes is passed, it will return None if all attributes in the list are
missing. Both functionalities can be combined as desired.
"""
def __init__(self, *props: Union[str, List[str]]) -> None:
self.props: Tuple[Union[str, List[str]], ...] = props
# Until typing.ParamSpec (https://www.python.org/dev/peps/pep-0612/) is available
# in all our supported Python versions, there is no way to simultaneously express
# the following two properties:
# - the input function may take arbitrary args/kwargs, and
# - the output function takes those same arbitrary args/kwargs, but has a different return type.
# We either have to limit the input function to e.g. only allowing a "self" argument,
# or we have to adopt the current approach of annotating the returned function as if
# it was defined as "def f(*args: Any, **kwargs: Any) -> Optional[RequiresReturnTypeT]".
#
# Since all functions decorated with @requires currently only accept a single argument,
# we choose to limit application of @requires to only functions of one argument.
# When typing.ParamSpec is available, this definition can be updated to use it.
# See https://github.com/arviz-devs/arviz/pull/1504 for more discussion.
def __call__(
self, func: Callable[[RequiresArgTypeT], RequiresReturnTypeT]
) -> Callable[[RequiresArgTypeT], Optional[RequiresReturnTypeT]]: # noqa: D202
"""Wrap the decorated function."""
def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]:
"""Return None if not all props are available."""
for prop in self.props:
prop = [prop] if isinstance(prop, str) else prop
if all((getattr(cls, prop_i) is None for prop_i in prop)):
return None
return func(cls)
return wrapped
def _yield_flat_up_to(shallow_tree, input_tree, path=()):
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow
lists as leaves.
Args:
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
input_tree: Nested structure. Return the paths and values from this tree.
Must have the same upper structure as shallow_tree.
path: Tuple. Optional argument, only used when recursing. The path from the
root of the original shallow_tree, down to the root of the shallow_tree
arg of this recursive call.
Yields:
Pairs of (path, value), where path the tuple path of a leaf node in
shallow_tree, and value is the value of the corresponding node in
input_tree.
"""
# pylint: disable=protected-access
if tree is None:
raise ImportError("Missing optional dependency 'dm-tree'. Use pip or conda to install it")
if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
isinstance(shallow_tree, tree.collections_abc.Mapping)
or tree._is_namedtuple(shallow_tree)
or tree._is_attrs(shallow_tree)
):
yield (path, input_tree)
else:
input_tree = dict(tree._yield_sorted_items(input_tree))
for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree):
subpath = path + (shallow_key,)
input_subtree = input_tree[shallow_key]
for leaf_path, leaf_value in _yield_flat_up_to(
shallow_subtree, input_subtree, path=subpath
):
yield (leaf_path, leaf_value)
# pylint: enable=protected-access
def _flatten_with_path(structure):
return list(_yield_flat_up_to(structure, structure))
def generate_dims_coords(
shape,
var_name,
dims=None,
coords=None,
default_dims=None,
index_origin=None,
skip_event_dims=None,
):
"""Generate default dimensions and coordinates for a variable.
Parameters
----------
shape : tuple[int]
Shape of the variable
var_name : str
Name of the variable. If no dimension name(s) is provided, ArviZ
will generate a default dimension name using ``var_name``, e.g.,
``"foo_dim_0"`` for the first dimension if ``var_name`` is ``"foo"``.
dims : list
List of dimensions for the variable
coords : dict[str] -> list[str]
Map of dimensions to coordinates
default_dims : list[str]
Dimension names that are not part of the variable's shape. For example,
when manipulating Monte Carlo traces, the ``default_dims`` would be
``["chain" , "draw"]`` which ArviZ uses as its own names for dimensions
of MCMC traces.
index_origin : int, optional
Starting value of integer coordinate values. Defaults to the value in rcParam
``data.index_origin``.
skip_event_dims : bool, default False
Returns
-------
list[str]
Default dims
dict[str] -> list[str]
Default coords
"""
if index_origin is None:
index_origin = rcParams["data.index_origin"]
if default_dims is None:
default_dims = []
if dims is None:
dims = []
if skip_event_dims is None:
skip_event_dims = False
if coords is None:
coords = {}
coords = deepcopy(coords)
dims = deepcopy(dims)
ndims = len([dim for dim in dims if dim not in default_dims])
if ndims > len(shape):
if skip_event_dims:
dims = dims[: len(shape)]
else:
warnings.warn(
(
"In variable {var_name}, there are "
+ "more dims ({dims_len}) given than exist ({shape_len}). "
+ "Passed array should have shape ({defaults}*shape)"
).format(
var_name=var_name,
dims_len=len(dims),
shape_len=len(shape),
defaults=",".join(default_dims) + ", " if default_dims is not None else "",
),
UserWarning,
)
if skip_event_dims:
# this is needed in case the reduction keeps the dimension with size 1
for i, (dim, dim_size) in enumerate(zip(dims, shape)):
if (dim in coords) and (dim_size != len(coords[dim])):
dims = dims[:i]
break
for i, dim_len in enumerate(shape):
idx = i + len([dim for dim in default_dims if dim in dims])
if len(dims) < idx + 1:
dim_name = f"{var_name}_dim_{idx}"
dims.append(dim_name)
elif dims[idx] is None:
dim_name = f"{var_name}_dim_{idx}"
dims[idx] = dim_name
dim_name = dims[idx]
if dim_name not in coords:
coords[dim_name] = np.arange(index_origin, dim_len + index_origin)
coords = {
key: coord
for key, coord in coords.items()
if any(key == dim for dim in dims + default_dims)
}
return dims, coords
def numpy_to_data_array(
ary,
*,
var_name="data",
coords=None,
dims=None,
default_dims=None,
index_origin=None,
skip_event_dims=None,
):
"""Convert a numpy array to an xarray.DataArray.
By default, the first two dimensions will be (chain, draw), and any remaining
dimensions will be "shape".
* If the numpy array is 1d, this dimension is interpreted as draw
* If the numpy array is 2d, it is interpreted as (chain, draw)
* If the numpy array is 3 or more dimensions, the last dimensions are kept as shapes.
To modify this behaviour, use ``default_dims``.
Parameters
----------
ary : np.ndarray
A numpy array. If it has 2 or more dimensions, the first dimension should be
independent chains from a simulation. Use `np.expand_dims(ary, 0)` to add a
single dimension to the front if there is only 1 chain.
var_name : str
If there are no dims passed, this string is used to name dimensions
coords : dict[str, iterable]
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
dims : List(str)
A list of coordinate names for the variable
default_dims : list of str, optional
Passed to :py:func:`generate_dims_coords`. Defaults to ``["chain", "draw"]``, and
an empty list is accepted
index_origin : int, optional
Passed to :py:func:`generate_dims_coords`
skip_event_dims : bool
Returns
-------
xr.DataArray
Will have the same data as passed, but with coordinates and dimensions
"""
# manage and transform copies
if default_dims is None:
default_dims = ["chain", "draw"]
if "chain" in default_dims and "draw" in default_dims:
ary = utils.two_de(ary)
n_chains, n_samples, *_ = ary.shape
if n_chains > n_samples:
warnings.warn(
"More chains ({n_chains}) than draws ({n_samples}). "
"Passed array should have shape (chains, draws, *shape)".format(
n_chains=n_chains, n_samples=n_samples
),
UserWarning,
)
else:
ary = utils.one_de(ary)
dims, coords = generate_dims_coords(
ary.shape[len(default_dims) :],
var_name,
dims=dims,
coords=coords,
default_dims=default_dims,
index_origin=index_origin,
skip_event_dims=skip_event_dims,
)
# reversed order for default dims: 'chain', 'draw'
if "draw" not in dims and "draw" in default_dims:
dims = ["draw"] + dims
if "chain" not in dims and "chain" in default_dims:
dims = ["chain"] + dims
index_origin = rcParams["data.index_origin"]
if "chain" not in coords and "chain" in default_dims:
coords["chain"] = np.arange(index_origin, n_chains + index_origin)
if "draw" not in coords and "draw" in default_dims:
coords["draw"] = np.arange(index_origin, n_samples + index_origin)
# filter coords based on the dims
coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
return xr.DataArray(ary, coords=coords, dims=dims)
[docs]
def dict_to_dataset(
data,
*,
attrs=None,
library=None,
coords=None,
dims=None,
default_dims=None,
index_origin=None,
skip_event_dims=None,
):
"""Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
ArviZ itself supports conversion of flat dictionaries.
Suport for pytrees requires ``dm-tree`` which is an optional dependency.
See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
this inclues at least dictionaries and tuple types.
Parameters
----------
data : dict of {str : array_like or dict} or pytree
Data to convert. Keys are variable names.
attrs : dict, optional
Json serializable metadata to attach to the dataset, in addition to defaults.
library : module, optional
Library used for performing inference. Will be attached to the attrs metadata.
coords : dict of {str : ndarray}, optional
Coordinates for the dataset
dims : dict of {str : list of str}, optional
Dimensions of each variable. The keys are variable names, values are lists of
coordinates.
default_dims : list of str, optional
Passed to :py:func:`numpy_to_data_array`
index_origin : int, optional
Passed to :py:func:`numpy_to_data_array`
skip_event_dims : bool, optional
If True, cut extra dims whenever present to match the shape of the data.
Necessary for PPLs which have the same name in both observed data and log
likelihood groups, to account for their different shapes when observations are
multivariate.
Returns
-------
xarray.Dataset
In case of nested pytrees, the variable name will be a tuple of individual names.
Notes
-----
This function is available through two aliases: ``dict_to_dataset`` or ``pytree_to_dataset``.
Examples
--------
Convert a dictionary with two 2D variables to a Dataset.
.. ipython::
In [1]: import arviz as az
...: import numpy as np
...: az.dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
Note that unlike the :class:`xarray.Dataset` constructor, ArviZ has added extra
information to the generated Dataset such as default dimension names for sampled
dimensions and some attributes.
The function is also general enough to work on pytrees such as nested dictionaries:
.. ipython::
In [1]: az.pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
which has two variables (as many as leafs) named ``('top', 'second')`` and ``top2``.
Dimensions and co-ordinates can be defined as usual:
.. ipython::
In [1]: datadict = {
...: "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
...: "d": np.random.randn(100),
...: }
...: az.dict_to_dataset(
...: datadict,
...: coords={"c": np.arange(10)},
...: dims={("top", "b"): ["c"]}
...: )
"""
if dims is None:
dims = {}
if tree is not None:
try:
data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
except TypeError: # probably unsortable keys -- the function will still work if
pass # it is an honest dictionary.
data_vars = {
key: numpy_to_data_array(
values,
var_name=key,
coords=coords,
dims=dims.get(key),
default_dims=default_dims,
index_origin=index_origin,
skip_event_dims=skip_event_dims,
)
for key, values in data.items()
}
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
pytree_to_dataset = dict_to_dataset
def make_attrs(attrs=None, library=None):
"""Make standard attributes to attach to xarray datasets.
Parameters
----------
attrs : dict (optional)
Additional attributes to add or overwrite
Returns
-------
dict
attrs
"""
default_attrs = {
"created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"arviz_version": __version__,
}
if library is not None:
library_name = library.__name__
default_attrs["inference_library"] = library_name
try:
version = importlib.metadata.version(library_name)
default_attrs["inference_library_version"] = version
except importlib.metadata.PackageNotFoundError:
if hasattr(library, "__version__"):
version = library.__version__
default_attrs["inference_library_version"] = version
if attrs is not None:
default_attrs.update(attrs)
return default_attrs
def _extend_xr_method(func, doc=None, description="", examples="", see_also=""):
"""Make wrapper to extend methods from xr.Dataset to InferenceData Class.
Parameters
----------
func : callable
An xr.Dataset function
doc : str
docstring for the func
description : str
the description of the func to be added in docstring
examples : str
the examples of the func to be added in docstring
see_also : str, list
the similar methods of func to be included in See Also section of docstring
"""
# pydocstyle requires a non empty line
@functools.wraps(func)
def wrapped(self, *args, **kwargs):
_filter = kwargs.pop("filter_groups", None)
_groups = kwargs.pop("groups", None)
_inplace = kwargs.pop("inplace", False)
out = self if _inplace else deepcopy(self)
groups = self._group_names(_groups, _filter) # pylint: disable=protected-access
for group in groups:
xr_data = getattr(out, group)
xr_data = func(xr_data, *args, **kwargs) # pylint: disable=not-callable
setattr(out, group, xr_data)
return None if _inplace else out
description_default = """{method_name} method is extended from xarray.Dataset methods.
{description}
For more info see :meth:`xarray:xarray.Dataset.{method_name}`.
In addition to the arguments available in the original method, the following
ones are added by ArviZ to adapt the method to being called on an ``InferenceData`` object.
""".format(
description=description, method_name=func.__name__ # pylint: disable=no-member
)
params = """
Other Parameters
----------------
groups: str or list of str, optional
Groups where the selection is to be applied. Can either be group names
or metagroup names.
filter_groups: {None, "like", "regex"}, optional, default=None
If `None` (default), interpret groups as the real group or metagroup names.
If "like", interpret groups as substrings of the real group or metagroup names.
If "regex", interpret groups as regular expressions on the real group or
metagroup names. A la `pandas.filter`.
inplace: bool, optional
If ``True``, modify the InferenceData object inplace,
otherwise, return the modified copy.
"""
if not isinstance(see_also, str):
see_also = "\n".join(see_also)
see_also_basic = """
See Also
--------
xarray.Dataset.{method_name}
{custom_see_also}
""".format(
method_name=func.__name__, custom_see_also=see_also # pylint: disable=no-member
)
wrapped.__doc__ = (
description_default + params + examples + see_also_basic if doc is None else doc
)
return wrapped
def _make_json_serializable(data: dict) -> dict:
"""Convert `data` with numpy.ndarray-like values to JSON-serializable form."""
ret = {}
for key, value in data.items():
try:
json.dumps(value)
except (TypeError, OverflowError):
pass
else:
ret[key] = value
continue
if isinstance(value, dict):
ret[key] = _make_json_serializable(value)
elif isinstance(value, np.ndarray):
ret[key] = np.asarray(value).tolist()
else:
raise TypeError(
f"Value associated with variable `{type(value)}` is not JSON serializable."
)
return ret
def infer_stan_dtypes(stan_code):
"""Infer Stan integer variables from generated quantities block."""
# Remove old deprecated comments
stan_code = "\n".join(
line if "#" not in line else line[: line.find("#")] for line in stan_code.splitlines()
)
pattern_remove_comments = re.compile(
r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE
)
stan_code = re.sub(pattern_remove_comments, "", stan_code)
# Check generated quantities
if "generated quantities" not in stan_code:
return {}
# Extract generated quantities block
gen_quantities_location = stan_code.index("generated quantities")
block_start = gen_quantities_location + stan_code[gen_quantities_location:].index("{")
curly_bracket_count = 0
block_end = None
for block_end, char in enumerate(stan_code[block_start:], block_start + 1):
if char == "{":
curly_bracket_count += 1
elif char == "}":
curly_bracket_count -= 1
if curly_bracket_count == 0:
break
stan_code = stan_code[block_start:block_end]
stan_integer = r"int"
stan_limits = r"(?:\<[^\>]+\>)*" # ignore group: 0 or more <....>
stan_param = r"([^;=\s\[]+)" # capture group: ends= ";", "=", "[" or whitespace
stan_ws = r"\s*" # 0 or more whitespace
stan_ws_one = r"\s+" # 1 or more whitespace
pattern_int = re.compile(
"".join((stan_integer, stan_ws_one, stan_limits, stan_ws, stan_param)), re.IGNORECASE
)
dtypes = {key.strip(): "int" for key in re.findall(pattern_int, stan_code)}
return dtypes