Source code for arviz.data.io_pymc3_3x

# pylint: disable=unused-import
"""PyMC3-specific conversion code (PyMC3<4.0)."""
import logging
import warnings
from types import ModuleType
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import xarray as xr

from .. import utils
from ..rcparams import rcParams
from .base import CoordSpec, DimSpec, dict_to_dataset, generate_dims_coords, make_attrs, requires
from .inference_data import InferenceData, concat

if TYPE_CHECKING:
    from typing import Set  # pylint: disable=ungrouped-imports

    import pymc3 as pm

    try:
        import aesara  # pylint: disable=unused-import
    except ImportError:
        import theano as aesara  # pylint: disable=unused-import
    from pymc3 import Model, MultiTrace  # pylint: disable=invalid-name
else:
    MultiTrace = Any  # pylint: disable=invalid-name
    Model = Any  # pylint: disable=invalid-name

___all__ = [""]

_log = logging.getLogger(__name__)

Coords = Dict[str, List[Any]]
Dims = Dict[str, List[str]]
# random variable object ...
Var = Any  # pylint: disable=invalid-name


def _monkey_patch_pymc3(pm: ModuleType) -> None:  # pylint: disable=invalid-name
    assert pm.__name__ == "pymc3"

    def fixed_eq(self, other):
        """Use object identity for MultiObservedRV equality."""
        return self is other

    if tuple((int(x) for x in pm.__version__.split("."))) < (3, 9):  # type: ignore
        pm.model.MultiObservedRV.__eq__ = fixed_eq  # type: ignore


class PyMC3Converter:  # pylint: disable=too-many-instance-attributes
    """Encapsulate PyMC3 specific logic."""

    model = None  # type: Optional[pm.Model]
    nchains = None  # type: int
    ndraws = None  # type: int
    posterior_predictive = None  # Type: Optional[Dict[str, np.ndarray]]
    predictions = None  # Type: Optional[Dict[str, np.ndarray]]
    prior = None  # Type: Optional[Dict[str, np.ndarray]]

    def __init__(
        self,
        *,
        trace=None,
        prior=None,
        posterior_predictive=None,
        log_likelihood=None,
        predictions=None,
        coords: Optional[Coords] = None,
        dims: Optional[Dims] = None,
        model=None,
        save_warmup: Optional[bool] = None,
        density_dist_obs: bool = True,
    ):
        import pymc3

        try:
            import aesara  # pylint: disable=redefined-outer-name
        except ImportError:
            import theano as aesara

        _monkey_patch_pymc3(pymc3)

        self.pymc3 = pymc3
        self.aesara = aesara

        self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
        self.trace = trace

        # this permits us to get the model from command-line argument or from with model:
        try:
            self.model = self.pymc3.modelcontext(model or self.model)
        except TypeError as e:
            _log.error("Got error %s trying to find log_likelihood in translation.", e)
            self.model = None

        if self.model is None:
            warnings.warn(
                "Using `from_pymc3` without the model will be deprecated in a future release. "
                "Not using the model will return less accurate and less useful results. "
                "Make sure you use the model argument or call from_pymc3 within a model context.",
                FutureWarning,
            )

        # This next line is brittle and may not work forever, but is a secret
        # way to access the model from the trace.
        self.attrs = None
        if trace is not None:
            if isinstance(self.trace, InferenceData):
                raise ValueError(
                    "Using the `InferenceData` as a `trace` argument won't work. "
                    "Please use the `arviz.InferenceData.extend` method to extend the "
                    "`InferenceData` with groups from another `InferenceData`."
                )
            if self.model is None:
                self.model = list(self.trace._straces.values())[  # pylint: disable=protected-access
                    0
                ].model
            self.nchains = trace.nchains if hasattr(trace, "nchains") else 1
            if hasattr(trace.report, "n_draws") and trace.report.n_draws is not None:
                self.ndraws = trace.report.n_draws
                self.attrs = {
                    "sampling_time": trace.report.t_sampling,
                    "tuning_steps": trace.report.n_tune,
                }
            else:
                self.ndraws = len(trace)
                if self.save_warmup:
                    warnings.warn(
                        "Warmup samples will be stored in posterior group and will not be"
                        " excluded from stats and diagnostics."
                        " Please consider using PyMC3>=3.9 and do not slice the trace manually.",
                        UserWarning,
                    )
            self.ntune = len(self.trace) - self.ndraws
            self.posterior_trace, self.warmup_trace = self.split_trace()
        else:
            self.nchains = self.ndraws = 0

        self.prior = prior
        self.posterior_predictive = posterior_predictive
        self.log_likelihood = (
            rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood
        )
        self.predictions = predictions

        def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
            return next(iter(dct.values()))

        if trace is None:
            # if you have a posterior_predictive built with keep_dims,
            # you'll lose here, but there's nothing I can do about that.
            self.nchains = 1
            get_from = None
            if predictions is not None:
                get_from = predictions
            elif posterior_predictive is not None:
                get_from = posterior_predictive
            elif prior is not None:
                get_from = prior
            if get_from is None:
                # pylint: disable=line-too-long
                raise ValueError(
                    "When constructing InferenceData must have at least"
                    " one of trace, prior, posterior_predictive or predictions."
                )

            aelem = arbitrary_element(get_from)
            self.ndraws = aelem.shape[0]

        self.coords = {} if coords is None else coords
        if hasattr(self.model, "coords"):
            self.coords = {**self.model.coords, **self.coords}

        self.dims = {} if dims is None else dims
        if hasattr(self.model, "RV_dims"):
            model_dims = {k: list(v) for k, v in self.model.RV_dims.items()}
            self.dims = {**model_dims, **self.dims}

        self.density_dist_obs = density_dist_obs
        self.observations, self.multi_observations = self.find_observations()

    def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str, Var]]]:
        """If there are observations available, return them as a dictionary."""
        if self.model is None:
            return (None, None)
        observations = {}
        multi_observations = {}
        for obs in self.model.observed_RVs:
            if hasattr(obs, "observations"):
                observations[obs.name] = obs.observations
            elif hasattr(obs, "data") and self.density_dist_obs:
                for key, val in obs.data.items():
                    multi_observations[key] = val.eval() if hasattr(val, "eval") else val
        return observations, multi_observations

    def split_trace(self) -> Tuple[Union[None, MultiTrace], Union[None, MultiTrace]]:
        """Split MultiTrace object into posterior and warmup.

        Returns
        -------
        trace_posterior: pymc3.MultiTrace or None
            The slice of the trace corresponding to the posterior. If the posterior
            trace is empty, None is returned
        trace_warmup: pymc3.MultiTrace or None
            The slice of the trace corresponding to the warmup. If the warmup trace is
            empty or ``save_warmup=False``, None is returned
        """
        trace_posterior = None
        trace_warmup = None
        if self.save_warmup and self.ntune > 0:
            trace_warmup = self.trace[: self.ntune]
        if self.ndraws > 0:
            trace_posterior = self.trace[self.ntune :]
        return trace_posterior, trace_warmup

    def log_likelihood_vals_point(self, point, var, log_like_fun):
        """Compute log likelihood for each observed point."""
        log_like_val = utils.one_de(log_like_fun(point))
        if var.missing_values:
            mask = var.observations.mask
            if np.ndim(mask) > np.ndim(log_like_val):
                mask = np.any(mask, axis=-1)
            log_like_val = np.where(mask, np.nan, log_like_val)
        return log_like_val

    def _extract_log_likelihood(self, trace):
        """Compute log likelihood of each observation."""
        if self.trace is None:
            return None
        if self.model is None:
            return None

        # If we have predictions, then we have a thinned trace which does not
        # support extracting a log likelihood.
        if self.log_likelihood is True:
            cached = [(var, var.logp_elemwise) for var in self.model.observed_RVs]
        else:
            cached = [
                (var, var.logp_elemwise)
                for var in self.model.observed_RVs
                if var.name in self.log_likelihood
            ]
        try:
            log_likelihood_dict = (
                self.pymc3.sampling._DefaultTrace(  # pylint: disable=protected-access
                    len(trace.chains)
                )
            )
        except AttributeError as err:
            raise AttributeError(
                "Installed version of ArviZ requires PyMC3>=3.8. Please upgrade with "
                "`pip install pymc3>=3.8` or `conda install -c conda-forge pymc3>=3.8`."
            ) from err
        for var, log_like_fun in cached:
            try:
                for k, chain in enumerate(trace.chains):
                    log_like_chain = [
                        self.log_likelihood_vals_point(point, var, log_like_fun)
                        for point in trace.points([chain])
                    ]
                    log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
            except TypeError as e:
                raise TypeError(
                    *tuple(["While computing log-likelihood for {var}: "] + list(e.args))
                ) from e
        return log_likelihood_dict.trace_dict

    @requires("trace")
    def posterior_to_xarray(self):
        """Convert the posterior to an xarray dataset."""
        var_names = self.pymc3.util.get_default_varnames(
            self.trace.varnames, include_transformed=False
        )
        data = {}
        data_warmup = {}
        for var_name in var_names:
            if self.warmup_trace:
                data_warmup[var_name] = np.array(
                    self.warmup_trace.get_values(var_name, combine=False, squeeze=False)
                )
            if self.posterior_trace:
                data[var_name] = np.array(
                    self.posterior_trace.get_values(var_name, combine=False, squeeze=False)
                )
        return (
            dict_to_dataset(
                data, library=self.pymc3, coords=self.coords, dims=self.dims, attrs=self.attrs
            ),
            dict_to_dataset(
                data_warmup,
                library=self.pymc3,
                coords=self.coords,
                dims=self.dims,
                attrs=self.attrs,
            ),
        )

    @requires("trace")
    def sample_stats_to_xarray(self):
        """Extract sample_stats from PyMC3 trace."""
        data = {}
        rename_key = {
            "model_logp": "lp",
            "mean_tree_accept": "acceptance_rate",
            "depth": "tree_depth",
            "tree_size": "n_steps",
        }
        data = {}
        data_warmup = {}
        for stat in self.trace.stat_names:
            name = rename_key.get(stat, stat)
            if name == "tune":
                continue
            if self.warmup_trace:
                data_warmup[name] = np.array(
                    self.warmup_trace.get_sampler_stats(stat, combine=False)
                )
            if self.posterior_trace:
                data[name] = np.array(self.posterior_trace.get_sampler_stats(stat, combine=False))

        return (
            dict_to_dataset(
                data, library=self.pymc3, dims=None, coords=self.coords, attrs=self.attrs
            ),
            dict_to_dataset(
                data_warmup, library=self.pymc3, dims=None, coords=self.coords, attrs=self.attrs
            ),
        )

    @requires("trace")
    @requires("model")
    def log_likelihood_to_xarray(self):
        """Extract log likelihood and log_p data from PyMC3 trace."""
        if self.predictions or not self.log_likelihood:
            return None
        data_warmup = {}
        data = {}
        warn_msg = (
            "Could not compute log_likelihood, it will be omitted. "
            "Check your model object or set log_likelihood=False"
        )
        if self.posterior_trace:
            try:
                data = self._extract_log_likelihood(self.posterior_trace)
            except TypeError:
                warnings.warn(warn_msg)
        if self.warmup_trace:
            try:
                data_warmup = self._extract_log_likelihood(self.warmup_trace)
            except TypeError:
                warnings.warn(warn_msg)
        return (
            dict_to_dataset(
                data, library=self.pymc3, dims=self.dims, coords=self.coords, skip_event_dims=True
            ),
            dict_to_dataset(
                data_warmup,
                library=self.pymc3,
                dims=self.dims,
                coords=self.coords,
                skip_event_dims=True,
            ),
        )

    def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
        """Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
        data = {}
        for k, ary in dct.items():
            shape = ary.shape
            if shape[0] == self.nchains and shape[1] == self.ndraws:
                data[k] = ary
            elif shape[0] == self.nchains * self.ndraws:
                data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
            else:
                data[k] = utils.expand_dims(ary)
                # pylint: disable=line-too-long
                _log.warning(
                    "posterior predictive variable %s's shape not compatible with number of chains and draws. "
                    "This can mean that some draws or even whole chains are not represented.",
                    k,
                )
        return dict_to_dataset(data, library=self.pymc3, coords=self.coords, dims=self.dims)

    @requires(["posterior_predictive"])
    def posterior_predictive_to_xarray(self):
        """Convert posterior_predictive samples to xarray."""
        return self.translate_posterior_predictive_dict_to_xarray(self.posterior_predictive)

    @requires(["predictions"])
    def predictions_to_xarray(self):
        """Convert predictions (out of sample predictions) to xarray."""
        return self.translate_posterior_predictive_dict_to_xarray(self.predictions)

    def priors_to_xarray(self):
        """Convert prior samples (and if possible prior predictive too) to xarray."""
        if self.prior is None:
            return {"prior": None, "prior_predictive": None}
        if self.observations is not None:
            prior_predictive_vars = list(self.observations.keys())
            prior_vars = [key for key in self.prior.keys() if key not in prior_predictive_vars]
        else:
            prior_vars = list(self.prior.keys())
            prior_predictive_vars = None

        priors_dict = {}
        for group, var_names in zip(
            ("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
        ):
            priors_dict[group] = (
                None
                if var_names is None
                else dict_to_dataset(
                    {k: utils.expand_dims(self.prior[k]) for k in var_names},
                    library=self.pymc3,
                    coords=self.coords,
                    dims=self.dims,
                )
            )
        return priors_dict

    @requires(["observations", "multi_observations"])
    @requires("model")
    def observed_data_to_xarray(self):
        """Convert observed data to xarray."""
        if self.predictions:
            return None
        if self.dims is None:
            dims = {}
        else:
            dims = self.dims
        observed_data = {}
        for name, vals in {**self.observations, **self.multi_observations}.items():
            if hasattr(vals, "get_value"):
                vals = vals.get_value()
            vals = utils.one_de(vals)
            val_dims = dims.get(name)
            val_dims, coords = generate_dims_coords(
                vals.shape, name, dims=val_dims, coords=self.coords
            )
            # filter coords based on the dims
            coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
            observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
        return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.pymc3))

    @requires(["trace", "predictions"])
    @requires("model")
    def constant_data_to_xarray(self):
        """Convert constant data to xarray."""
        # For constant data, we are concerned only with deterministics and data.
        # The constant data vars must be either pm.Data (TensorSharedVariable) or pm.Deterministic
        constant_data_vars = {}  # type: Dict[str, Var]
        for var in self.model.deterministics:
            if hasattr(self.aesara, "gof"):
                ancestors_func = self.aesara.gof.graph.ancestors  # pylint: disable=no-member
            else:
                ancestors_func = self.aesara.graph.basic.ancestors  # pylint: disable=no-member
            ancestors = ancestors_func(var.owner.inputs)
            # no dependency on a random variable
            if not any((isinstance(a, self.pymc3.model.PyMC3Variable) for a in ancestors)):
                constant_data_vars[var.name] = var

        def is_data(name, var) -> bool:
            assert self.model is not None
            return (
                var not in self.model.deterministics
                and var not in self.model.observed_RVs
                and var not in self.model.free_RVs
                and var not in self.model.potentials
                and (self.observations is None or name not in self.observations)
            )

        # I don't know how to find pm.Data, except that they are named variables that aren't
        # observed or free RVs, nor are they deterministics, and then we eliminate observations.
        for name, var in self.model.named_vars.items():
            if is_data(name, var):
                constant_data_vars[name] = var

        if not constant_data_vars:
            return None
        if self.dims is None:
            dims = {}
        else:
            dims = self.dims
        constant_data = {}
        for name, vals in constant_data_vars.items():
            if hasattr(vals, "get_value"):
                vals = vals.get_value()
            # this might be a Deterministic, and must be evaluated
            elif hasattr(self.model[name], "eval"):
                vals = self.model[name].eval()
            vals = np.atleast_1d(vals)
            val_dims = dims.get(name)
            val_dims, coords = generate_dims_coords(
                vals.shape, name, dims=val_dims, coords=self.coords
            )
            # filter coords based on the dims
            coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
            try:
                constant_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
            except ValueError as err:
                raise ValueError(f"Error translating constant_data variable {name}: {err}") from err
        return xr.Dataset(data_vars=constant_data, attrs=make_attrs(library=self.pymc3))

    def to_inference_data(self):
        """Convert all available data to an InferenceData object.

        Note that if groups can not be created (e.g., there is no `trace`, so
        the `posterior` and `sample_stats` can not be extracted), then the InferenceData
        will not have those groups.
        """
        id_dict = {
            "posterior": self.posterior_to_xarray(),
            "sample_stats": self.sample_stats_to_xarray(),
            "log_likelihood": self.log_likelihood_to_xarray(),
            "posterior_predictive": self.posterior_predictive_to_xarray(),
            "predictions": self.predictions_to_xarray(),
            **self.priors_to_xarray(),
            "observed_data": self.observed_data_to_xarray(),
        }
        if self.predictions:
            id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
        else:
            id_dict["constant_data"] = self.constant_data_to_xarray()
        return InferenceData(save_warmup=self.save_warmup, **id_dict)


[docs]def from_pymc3( trace=None, *, prior=None, posterior_predictive=None, log_likelihood=None, coords=None, dims=None, model=None, save_warmup=None, density_dist_obs=True, ): """Convert pymc3 data into an InferenceData object. All three of them are optional arguments, but at least one of ``trace``, ``prior`` and ``posterior_predictive`` must be present. For a usage example read the :ref:`Creating InferenceData section on from_pymc3 <creating_InferenceData>` Parameters ---------- trace : pymc3.MultiTrace, optional Trace generated from MCMC sampling. Output of :py:func:`pymc3:pymc3.sampling.sample`. prior : dict, optional Dictionary with the variable names as keys, and values numpy arrays containing prior and prior predictive samples. posterior_predictive : dict, optional Dictionary with the variable names as keys, and values numpy arrays containing posterior predictive samples. log_likelihood : bool or array_like of str, optional List of variables to calculate `log_likelihood`. Defaults to True which calculates `log_likelihood` for all observed variables. If set to False, log_likelihood is skipped. Defaults to the value of rcParam ``data.log_likelihood``. coords : dict of {str: array-like}, optional Map of coordinate names to coordinate values dims : dict of {str: list of str}, optional Map of variable names to the coordinate names to use to index its dimensions. model : pymc3.Model, optional Model used to generate ``trace``. It is not necessary to pass ``model`` if in ``with`` context. save_warmup : bool, optional Save warmup iterations InferenceData object. If not defined, use default defined by the rcParams. density_dist_obs : bool, default True Store variables passed with ``observed`` arg to :class:`pymc3:pymc.distributions.DensityDist` in the generated InferenceData. Returns ------- InferenceData """ return PyMC3Converter( trace=trace, prior=prior, posterior_predictive=posterior_predictive, log_likelihood=log_likelihood, coords=coords, dims=dims, model=model, save_warmup=save_warmup, density_dist_obs=density_dist_obs, ).to_inference_data()
### Later I could have this return ``None`` if the ``idata_orig`` argument is supplied. But ### perhaps we should have an inplace argument?
[docs]def from_pymc3_predictions( predictions, posterior_trace=None, model=None, coords=None, dims=None, idata_orig=None, inplace=False, ): """Translate out-of-sample predictions into ``InferenceData``. Parameters ---------- predictions: Dict[str, np.ndarray] The predictions are the return value of ``pymc3.sample_posterior_predictive``, a dictionary of strings (variable names) to numpy ndarrays (draws). posterior_trace: pm.MultiTrace This should be a trace that has been thinned appropriately for ``pymc3.sample_posterior_predictive``. Specifically, any variable whose shape is a deterministic function of the shape of any predictor (explanatory, independent, etc.) variables must be *removed* from this trace. model: pymc3.Model This argument is *not* optional, unlike in conventional uses of ``from_pymc3``. The reason is that the posterior_trace argument is likely to supply an incorrect value of model. coords: Dict[str, array-like[Any]] Coordinates for the variables. Map from coordinate names to coordinate values. dims: Dict[str, array-like[str]] Map from variable name to ordered set of coordinate names. idata_orig: InferenceData, optional If supplied, then modify this inference data in place, adding ``predictions`` and (if available) ``predictions_constant_data`` groups. If this is not supplied, make a fresh InferenceData inplace: boolean, optional If idata_orig is supplied and inplace is True, merge the predictions into idata_orig, rather than returning a fresh InferenceData object. Returns ------- InferenceData: May be modified ``idata_orig``. """ if inplace and not idata_orig: raise ValueError( ( "Do not pass True for inplace unless passing" "an existing InferenceData as idata_orig" ) ) new_idata = PyMC3Converter( trace=posterior_trace, predictions=predictions, model=model, coords=coords, dims=dims ).to_inference_data() if idata_orig is None: return new_idata elif inplace: concat([idata_orig, new_idata], dim=None, inplace=True) return idata_orig else: # if we are not returning in place, then merge the old groups into the new inference # data and return that. concat([new_idata, idata_orig], dim=None, copy=True, inplace=True) return new_idata