Source code for arviz.data.io_emcee

"""emcee-specific conversion code."""

import warnings
from collections import OrderedDict

import numpy as np
import xarray as xr

from .. import utils
from .base import dict_to_dataset, generate_dims_coords, make_attrs
from .inference_data import InferenceData


def _verify_names(sampler, var_names, arg_names, slices):
    """Make sure var_names and arg_names are assigned reasonably.

    This is meant to run before loading emcee objects into InferenceData.
    In case var_names or arg_names is None, will provide defaults. If they are
    not None, it verifies there are the right number of them.

    Throws a ValueError in case validation fails.

    Parameters
    ----------
    sampler : emcee.EnsembleSampler
        Fitted emcee sampler
    var_names : list[str] or None
        Names for the emcee parameters
    arg_names : list[str] or None
        Names for the args/observations provided to emcee
    slices : list[seq] or None
        slices to select the variables (used for multidimensional variables)

    Returns
    -------
    list[str], list[str], list[seq]
        Defaults for var_names, arg_names and slices
    """
    # There are 3 possible cases: emcee2, emcee3 and sampler read from h5 file (emcee3 only)
    if hasattr(sampler, "args"):
        ndim = sampler.chain.shape[-1]
        num_args = len(sampler.args)
    elif hasattr(sampler, "log_prob_fn"):
        ndim = sampler.get_chain().shape[-1]
        num_args = len(sampler.log_prob_fn.args)
    else:
        ndim = sampler.get_chain().shape[-1]
        num_args = 0  # emcee only stores the posterior samples

    if slices is None:
        slices = utils.arange(ndim)
        num_vars = ndim
    else:
        num_vars = len(slices)
    indices = utils.arange(ndim)
    slicing_try = np.concatenate([utils.one_de(indices[idx]) for idx in slices])
    if len(set(slicing_try)) != ndim:
        warnings.warn(
            "Check slices: Not all parameters in chain captured. "
            f"{ndim} are present, and {len(slicing_try)} have been captured.",
            UserWarning,
        )
    if len(slicing_try) != len(set(slicing_try)):
        warnings.warn(f"Overlapping slices. Check the index present: {slicing_try}", UserWarning)

    if var_names is None:
        var_names = [f"var_{idx}" for idx in range(num_vars)]
    if arg_names is None:
        arg_names = [f"arg_{idx}" for idx in range(num_args)]

    if len(var_names) != num_vars:
        raise ValueError(
            f"The sampler has {num_vars} variables, "
            f"but only {len(var_names)} var_names were provided!"
        )

    if len(arg_names) != num_args:
        raise ValueError(
            f"The sampler has {num_args} args, "
            f"but only {len(arg_names)} arg_names were provided!"
        )
    return var_names, arg_names, slices


# pylint: disable=too-many-instance-attributes
class EmceeConverter:
    """Encapsulate emcee specific logic."""

    def __init__(
        self,
        sampler,
        var_names=None,
        slices=None,
        arg_names=None,
        arg_groups=None,
        blob_names=None,
        blob_groups=None,
        index_origin=None,
        coords=None,
        dims=None,
    ):
        var_names, arg_names, slices = _verify_names(sampler, var_names, arg_names, slices)
        self.sampler = sampler
        self.var_names = var_names
        self.slices = slices
        self.arg_names = arg_names
        self.arg_groups = arg_groups
        self.blob_names = blob_names
        self.blob_groups = blob_groups
        self.index_origin = index_origin
        self.coords = coords
        self.dims = dims
        import emcee

        self.emcee = emcee

    def posterior_to_xarray(self):
        """Convert the posterior to an xarray dataset."""
        # Use emcee3 syntax, else use emcee2
        if hasattr(self.sampler, "get_chain"):
            samples_ary = self.sampler.get_chain().swapaxes(0, 1)
        else:
            samples_ary = self.sampler.chain

        data = {
            var_name: (samples_ary[(..., idx)])
            for idx, var_name in zip(self.slices, self.var_names)
        }
        return dict_to_dataset(
            data,
            library=self.emcee,
            coords=self.coords,
            dims=self.dims,
            index_origin=self.index_origin,
        )

    def args_to_xarray(self):
        """Convert emcee args to observed and constant_data xarray Datasets."""
        dims = {} if self.dims is None else self.dims
        if self.arg_groups is None:
            self.arg_groups = ["observed_data" for _ in self.arg_names]
        if len(self.arg_names) != len(self.arg_groups):
            raise ValueError(
                "arg_names and arg_groups must have the same length, or arg_groups be None"
            )
        arg_groups_set = set(self.arg_groups)
        bad_groups = [
            group for group in arg_groups_set if group not in ("observed_data", "constant_data")
        ]
        if bad_groups:
            raise SyntaxError(
                "all arg_groups values should be either 'observed_data' or 'constant_data' , "
                f"not {bad_groups}"
            )
        obs_const_dict = {group: OrderedDict() for group in arg_groups_set}
        for idx, (arg_name, group) in enumerate(zip(self.arg_names, self.arg_groups)):
            # Use emcee3 syntax, else use emcee2
            arg_array = np.atleast_1d(
                self.sampler.log_prob_fn.args[idx]
                if hasattr(self.sampler, "log_prob_fn")
                else self.sampler.args[idx]
            )
            arg_dims = dims.get(arg_name)
            arg_dims, coords = generate_dims_coords(
                arg_array.shape,
                arg_name,
                dims=arg_dims,
                coords=self.coords,
                index_origin=self.index_origin,
            )
            # filter coords based on the dims
            coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in arg_dims}
            obs_const_dict[group][arg_name] = xr.DataArray(arg_array, dims=arg_dims, coords=coords)
        for key, values in obs_const_dict.items():
            obs_const_dict[key] = xr.Dataset(data_vars=values, attrs=make_attrs(library=self.emcee))
        return obs_const_dict

    def blobs_to_dict(self):
        """Convert blobs to dictionary {groupname: xr.Dataset}.

        It also stores lp values in sample_stats group.
        """
        store_blobs = self.blob_names is not None
        self.blob_names = [] if self.blob_names is None else self.blob_names
        if self.blob_groups is None:
            self.blob_groups = ["log_likelihood" for _ in self.blob_names]
        if len(self.blob_names) != len(self.blob_groups):
            raise ValueError(
                "blob_names and blob_groups must have the same length, or blob_groups be None"
            )
        if store_blobs:
            if int(self.emcee.__version__[0]) >= 3:
                blobs = self.sampler.get_blobs()
            else:
                blobs = np.array(self.sampler.blobs, dtype=object)
            if (blobs is None or blobs.size == 0) and self.blob_names:
                raise ValueError("No blobs in sampler, blob_names must be None")
            if len(blobs.shape) == 2:
                blobs = np.expand_dims(blobs, axis=-1)
            blobs = blobs.swapaxes(0, 2)
            nblobs, nwalkers, ndraws, *_ = blobs.shape
            if len(self.blob_names) != nblobs and len(self.blob_names) > 1:
                raise ValueError(
                    "Incorrect number of blob names. "
                    f"Expected {nblobs}, found {len(self.blob_names)}"
                )
        blob_groups_set = set(self.blob_groups)
        blob_groups_set.add("sample_stats")
        idata_groups = ("posterior", "observed_data", "constant_data")
        if np.any(np.isin(list(blob_groups_set), idata_groups)):
            raise SyntaxError(
                f"{idata_groups} groups should not come from blobs. "
                "Using them here would overwrite their actual values"
            )
        blob_dict = {group: OrderedDict() for group in blob_groups_set}
        if len(self.blob_names) == 1:
            blob_dict[self.blob_groups[0]][self.blob_names[0]] = blobs.swapaxes(0, 2).swapaxes(0, 1)
        else:
            for i_blob, (name, group) in enumerate(zip(self.blob_names, self.blob_groups)):
                # for coherent blobs (all having the same dimensions) one line is enough
                blob = blobs[i_blob]
                # for blobs of different size, we get an array of arrays, which we convert
                # to an ndarray per blob_name
                if blob.dtype == object:
                    blob = blob.reshape(-1)
                    blob = np.stack(blob)
                    blob = blob.reshape((nwalkers, ndraws, -1))
                blob_dict[group][name] = np.squeeze(blob)

        # store lp in sample_stats group
        blob_dict["sample_stats"]["lp"] = (
            self.sampler.get_log_prob().swapaxes(0, 1)
            if hasattr(self.sampler, "get_log_prob")
            else self.sampler.lnprobability
        )
        for key, values in blob_dict.items():
            blob_dict[key] = dict_to_dataset(
                values,
                library=self.emcee,
                coords=self.coords,
                dims=self.dims,
                index_origin=self.index_origin,
            )
        return blob_dict

    def to_inference_data(self):
        """Convert all available data to an InferenceData object."""
        blobs_dict = self.blobs_to_dict()
        obs_const_dict = self.args_to_xarray()
        return InferenceData(
            **{"posterior": self.posterior_to_xarray(), **obs_const_dict, **blobs_dict}
        )


[docs] def from_emcee( sampler=None, var_names=None, slices=None, arg_names=None, arg_groups=None, blob_names=None, blob_groups=None, index_origin=None, coords=None, dims=None, ): """Convert emcee data into an InferenceData object. For a usage example read :ref:`emcee_conversion` Parameters ---------- sampler : emcee.EnsembleSampler Fitted sampler from emcee. var_names : list of str, optional A list of names for variables in the sampler slices : list of array-like or slice, optional A list containing the indexes of each variable. Should only be used for multidimensional variables. arg_names : list of str, optional A list of names for args in the sampler arg_groups : list of str, optional A list of the group names (either ``observed_data`` or ``constant_data``) where args in the sampler are stored. If None, all args will be stored in observed data group. blob_names : list of str, optional A list of names for blobs in the sampler. When None, blobs are omitted, independently of them being present in the sampler or not. blob_groups : list of str, optional A list of the groups where blob_names variables should be assigned respectively. If blob_names!=None and blob_groups is None, all variables are assigned to log_likelihood group coords : dict of {str : array_like}, optional Map of dimensions to coordinates dims : dict of {str : list of str}, optional Map variable names to their coordinates Returns ------- arviz.InferenceData """ return EmceeConverter( sampler=sampler, var_names=var_names, slices=slices, arg_names=arg_names, arg_groups=arg_groups, blob_names=blob_names, blob_groups=blob_groups, index_origin=index_origin, coords=coords, dims=dims, ).to_inference_data()