"""Dictionary specific conversion code."""
import warnings
from typing import Optional
from ..rcparams import rcParams
from .base import dict_to_dataset, requires
from .inference_data import WARMUP_TAG, InferenceData
# pylint: disable=too-many-instance-attributes
class DictConverter:
"""Encapsulate Dictionary specific logic."""
def __init__(
self,
*,
posterior=None,
posterior_predictive=None,
predictions=None,
sample_stats=None,
log_likelihood=None,
prior=None,
prior_predictive=None,
sample_stats_prior=None,
observed_data=None,
constant_data=None,
predictions_constant_data=None,
warmup_posterior=None,
warmup_posterior_predictive=None,
warmup_predictions=None,
warmup_log_likelihood=None,
warmup_sample_stats=None,
save_warmup=None,
index_origin=None,
coords=None,
dims=None,
pred_dims=None,
pred_coords=None,
attrs=None,
**kwargs,
):
self.posterior = posterior
self.posterior_predictive = posterior_predictive
self.predictions = predictions
self.sample_stats = sample_stats
self.log_likelihood = log_likelihood
self.prior = prior
self.prior_predictive = prior_predictive
self.sample_stats_prior = sample_stats_prior
self.observed_data = observed_data
self.constant_data = constant_data
self.predictions_constant_data = predictions_constant_data
self.warmup_posterior = warmup_posterior
self.warmup_posterior_predictive = warmup_posterior_predictive
self.warmup_predictions = warmup_predictions
self.warmup_log_likelihood = warmup_log_likelihood
self.warmup_sample_stats = warmup_sample_stats
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
self.coords = (
coords
if pred_coords is None
else pred_coords if coords is None else {**coords, **pred_coords}
)
self.index_origin = index_origin
self.coords = coords
self.dims = dims
self.pred_dims = dims if pred_dims is None else pred_dims
self.attrs = {} if attrs is None else attrs
self.attrs.pop("created_at", None)
self.attrs.pop("arviz_version", None)
self._kwargs = kwargs
def _init_dict(self, attr_name):
dict_or_none = getattr(self, attr_name, {})
return {} if dict_or_none is None else dict_or_none
@requires(["posterior", f"{WARMUP_TAG}posterior"])
def posterior_to_xarray(self):
"""Convert posterior samples to xarray."""
data = self._init_dict("posterior")
data_warmup = self._init_dict(f"{WARMUP_TAG}posterior")
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior is not a dictionary")
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_posterior is not a dictionary")
if "log_likelihood" in data:
warnings.warn(
"log_likelihood variable found in posterior group."
" For stats functions log likelihood data needs to be in log_likelihood group.",
UserWarning,
)
posterior_attrs = self._kwargs.get("posterior_attrs")
posterior_warmup_attrs = self._kwargs.get("posterior_warmup_attrs")
return (
dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=posterior_attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=None,
coords=self.coords,
dims=self.dims,
attrs=posterior_warmup_attrs,
index_origin=self.index_origin,
),
)
@requires(["sample_stats", f"{WARMUP_TAG}sample_stats"])
def sample_stats_to_xarray(self):
"""Convert sample_stats samples to xarray."""
data = self._init_dict("sample_stats")
data_warmup = self._init_dict(f"{WARMUP_TAG}sample_stats")
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats is not a dictionary")
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_sample_stats is not a dictionary")
if "log_likelihood" in data:
warnings.warn(
"log_likelihood variable found in sample_stats."
" Storing log_likelihood data in sample_stats group will be deprecated in "
"favour of storing them in the log_likelihood group.",
PendingDeprecationWarning,
)
sample_stats_attrs = self._kwargs.get("sample_stats_attrs")
sample_stats_warmup_attrs = self._kwargs.get("sample_stats_warmup_attrs")
return (
dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=sample_stats_attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=None,
coords=self.coords,
dims=self.dims,
attrs=sample_stats_warmup_attrs,
index_origin=self.index_origin,
),
)
@requires(["log_likelihood", f"{WARMUP_TAG}log_likelihood"])
def log_likelihood_to_xarray(self):
"""Convert log_likelihood samples to xarray."""
data = self._init_dict("log_likelihood")
data_warmup = self._init_dict(f"{WARMUP_TAG}log_likelihood")
if not isinstance(data, dict):
raise TypeError("DictConverter.log_likelihood is not a dictionary")
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_log_likelihood is not a dictionary")
log_likelihood_attrs = self._kwargs.get("log_likelihood_attrs")
log_likelihood_warmup_attrs = self._kwargs.get("log_likelihood_warmup_attrs")
return (
dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=log_likelihood_attrs,
index_origin=self.index_origin,
skip_event_dims=True,
),
dict_to_dataset(
data_warmup,
library=None,
coords=self.coords,
dims=self.dims,
attrs=log_likelihood_warmup_attrs,
index_origin=self.index_origin,
skip_event_dims=True,
),
)
@requires(["posterior_predictive", f"{WARMUP_TAG}posterior_predictive"])
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
data = self._init_dict("posterior_predictive")
data_warmup = self._init_dict(f"{WARMUP_TAG}posterior_predictive")
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior_predictive is not a dictionary")
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_posterior_predictive is not a dictionary")
posterior_predictive_attrs = self._kwargs.get("posterior_predictive_attrs")
posterior_predictive_warmup_attrs = self._kwargs.get("posterior_predictive_warmup_attrs")
return (
dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=posterior_predictive_attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=None,
coords=self.coords,
dims=self.dims,
attrs=posterior_predictive_warmup_attrs,
index_origin=self.index_origin,
),
)
@requires(["predictions", f"{WARMUP_TAG}predictions"])
def predictions_to_xarray(self):
"""Convert predictions to xarray."""
data = self._init_dict("predictions")
data_warmup = self._init_dict(f"{WARMUP_TAG}predictions")
if not isinstance(data, dict):
raise TypeError("DictConverter.predictions is not a dictionary")
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_predictions is not a dictionary")
predictions_attrs = self._kwargs.get("predictions_attrs")
predictions_warmup_attrs = self._kwargs.get("predictions_warmup_attrs")
return (
dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.pred_dims,
attrs=predictions_attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=None,
coords=self.coords,
dims=self.pred_dims,
attrs=predictions_warmup_attrs,
index_origin=self.index_origin,
),
)
@requires("prior")
def prior_to_xarray(self):
"""Convert prior samples to xarray."""
data = self.prior
if not isinstance(data, dict):
raise TypeError("DictConverter.prior is not a dictionary")
prior_attrs = self._kwargs.get("prior_attrs")
return dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=prior_attrs,
index_origin=self.index_origin,
)
@requires("sample_stats_prior")
def sample_stats_prior_to_xarray(self):
"""Convert sample_stats_prior samples to xarray."""
data = self.sample_stats_prior
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats_prior is not a dictionary")
sample_stats_prior_attrs = self._kwargs.get("sample_stats_prior_attrs")
return dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=sample_stats_prior_attrs,
index_origin=self.index_origin,
)
@requires("prior_predictive")
def prior_predictive_to_xarray(self):
"""Convert prior_predictive samples to xarray."""
data = self.prior_predictive
if not isinstance(data, dict):
raise TypeError("DictConverter.prior_predictive is not a dictionary")
prior_predictive_attrs = self._kwargs.get("prior_predictive_attrs")
return dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=prior_predictive_attrs,
index_origin=self.index_origin,
)
def data_to_xarray(self, data, group, dims=None):
"""Convert data to xarray."""
if not isinstance(data, dict):
raise TypeError(f"DictConverter.{group} is not a dictionary")
if dims is None:
dims = {} if self.dims is None else self.dims
return dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
default_dims=[],
attrs=self.attrs,
index_origin=self.index_origin,
)
@requires("observed_data")
def observed_data_to_xarray(self):
"""Convert observed_data to xarray."""
return self.data_to_xarray(self.observed_data, group="observed_data", dims=self.dims)
@requires("constant_data")
def constant_data_to_xarray(self):
"""Convert constant_data to xarray."""
return self.data_to_xarray(self.constant_data, group="constant_data")
@requires("predictions_constant_data")
def predictions_constant_data_to_xarray(self):
"""Convert predictions_constant_data to xarray."""
return self.data_to_xarray(
self.predictions_constant_data, group="predictions_constant_data", dims=self.pred_dims
)
def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Note that if groups can not be created, then the InferenceData
will not have those groups.
"""
return InferenceData(
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"predictions": self.predictions_to_xarray(),
"prior": self.prior_to_xarray(),
"sample_stats_prior": self.sample_stats_prior_to_xarray(),
"prior_predictive": self.prior_predictive_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
"constant_data": self.constant_data_to_xarray(),
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
"save_warmup": self.save_warmup,
"attrs": self.attrs,
}
)
# pylint: disable=too-many-instance-attributes
[docs]
def from_dict(
posterior=None,
*,
posterior_predictive=None,
predictions=None,
sample_stats=None,
log_likelihood=None,
prior=None,
prior_predictive=None,
sample_stats_prior=None,
observed_data=None,
constant_data=None,
predictions_constant_data=None,
warmup_posterior=None,
warmup_posterior_predictive=None,
warmup_predictions=None,
warmup_log_likelihood=None,
warmup_sample_stats=None,
save_warmup=None,
index_origin: Optional[int] = None,
coords=None,
dims=None,
pred_dims=None,
pred_coords=None,
attrs=None,
**kwargs,
):
"""Convert Dictionary data into an InferenceData object.
For a usage example read the
:ref:`Creating InferenceData section on from_dict <creating_InferenceData>`
Parameters
----------
posterior : dict, optional
posterior_predictive : dict, optional
predictions: dict, optional
sample_stats : dict, optional
log_likelihood : dict, optional
For stats functions, log likelihood data should be stored here.
prior : dict, optional
prior_predictive : dict, optional
observed_data : dict, optional
constant_data : dict, optional
predictions_constant_data: dict, optional
warmup_posterior : dict, optional
warmup_posterior_predictive : dict, optional
warmup_predictions : dict, optional
warmup_log_likelihood : dict, optional
warmup_sample_stats : dict, optional
save_warmup : bool, optional
Save warmup iterations InferenceData object. If not defined, use default
defined by the rcParams.
index_origin : int, optional
coords : dict of {str : list}, optional
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
dims : dict of {str : list of str}, optional
A mapping from variables to a list of coordinate names for the variable.
pred_dims : dict of {str : list of str}, optional
A mapping from variables to a list of coordinate names for predictions.
pred_coords : dict of {str : list}, optional
A mapping from variables to a list of coordinate values for predictions.
attrs : dict, optional
A dictionary containing attributes for different groups.
kwargs : dict, optional
A dictionary containing group attrs. Accepted kwargs are:
- posterior_attrs, posterior_warmup_attrs : attrs for posterior group
- sample_stats_attrs, sample_stats_warmup_attrs : attrs for sample_stats group
- log_likelihood_attrs, log_likelihood_warmup_attrs : attrs for log_likelihood group
- posterior_predictive_attrs, posterior_predictive_warmup_attrs : attrs for
posterior_predictive group
- predictions_attrs, predictions_warmup_attrs : attrs for predictions group
- prior_attrs : attrs for prior group
- sample_stats_prior_attrs : attrs for sample_stats_prior group
- prior_predictive_attrs : attrs for prior_predictive group
Returns
-------
InferenceData
"""
return DictConverter(
posterior=posterior,
posterior_predictive=posterior_predictive,
predictions=predictions,
sample_stats=sample_stats,
log_likelihood=log_likelihood,
prior=prior,
prior_predictive=prior_predictive,
sample_stats_prior=sample_stats_prior,
observed_data=observed_data,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
warmup_posterior=warmup_posterior,
warmup_posterior_predictive=warmup_posterior_predictive,
warmup_predictions=warmup_predictions,
warmup_log_likelihood=warmup_log_likelihood,
warmup_sample_stats=warmup_sample_stats,
save_warmup=save_warmup,
index_origin=index_origin,
coords=coords,
dims=dims,
pred_dims=pred_dims,
pred_coords=pred_coords,
attrs=attrs,
**kwargs,
).to_inference_data()
from_pytree = from_dict