Source code for arviz.wrappers.wrap_stan
# pylint: disable=arguments-differ
"""Base class for Stan interface wrappers."""
from typing import Union
from ..data import from_cmdstanpy, from_pystan
from .base import SamplingWrapper
# pylint: disable=abstract-method
class StanSamplingWrapper(SamplingWrapper):
"""Stan sampling wrapper base class.
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
description. An example of ``PyStanSamplingWrapper`` usage can be found
in the :ref:`pystan_refitting` notebook. For usage examples of other wrappers
see the user guide pages on :ref:`wrapper_guide`.
Warnings
--------
Sampling wrappers are an experimental feature in a very early stage. Please use them
with caution.
See Also
--------
SamplingWrapper
"""
def sel_observations(self, idx):
"""Select a subset of the observations in idata_orig.
**Not implemented**: This method must be implemented on a model basis.
It is documented here to show its format and call signature.
Parameters
----------
idx
Indexes to separate from the rest of the observed data.
Returns
-------
modified_observed_data : dict
Dictionary containing both excluded and included data but properly divided
in the different keys. Passed to ``data`` argument of ``model.sampling``.
excluded_observed_data : str
Variable name containing the pointwise log likelihood data of the excluded
data. As PyStan cannot call C++ functions and log_likelihood__i is already
calculated *during* the simultion, instead of the value on which to evaluate
the likelihood, ``log_likelihood__i`` expects a string so it can extract the
corresponding data from the InferenceData object.
"""
raise NotImplementedError("sel_observations must be implemented on a model basis")
def get_inference_data(self, fitted_model): # pylint: disable=arguments-renamed
"""Convert the fit object returned by ``self.sample`` to InferenceData."""
if fitted_model.__class__.__name__ == "CmdStanMCMC":
idata = from_cmdstanpy(posterior=fitted_model, **self.idata_kwargs)
else:
idata = from_pystan(posterior=fitted_model, **self.idata_kwargs)
return idata
def log_likelihood__i(self, excluded_obs, idata__i):
"""Retrieve the log likelihood of the excluded observations from ``idata__i``."""
return idata__i.log_likelihood[excluded_obs]
[docs]
class PyStan2SamplingWrapper(StanSamplingWrapper):
"""PyStan (2.x) sampling wrapper base class.
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
description. An example of ``PyStanSamplingWrapper`` usage can be found
in the :ref:`pystan_refitting` notebook. For usage examples of other wrappers
see the user guide pages on :ref:`wrapper_guide`.
Warnings
--------
Sampling wrappers are an experimental feature in a very early stage. Please use them
with caution.
See Also
--------
SamplingWrapper
"""
[docs]
def sample(self, modified_observed_data):
"""Resample the PyStan model stored in self.model on modified_observed_data."""
fit = self.model.sampling(data=modified_observed_data, **self.sample_kwargs)
return fit
[docs]
class PyStanSamplingWrapper(StanSamplingWrapper):
"""PyStan (3.0+) sampling wrapper base class.
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
description. An example of ``PyStanSamplingWrapper`` usage can be found
in the :ref:`pystan_refitting` notebook.
Warnings
--------
Sampling wrappers are an experimental feature in a very early stage. Please use them
with caution.
"""
[docs]
def sample(self, modified_observed_data):
"""Rebuild and resample the PyStan model on modified_observed_data."""
import stan # pylint: disable=import-error,import-outside-toplevel
self.model: Union[str, stan.Model]
if isinstance(self.model, str):
program_code = self.model
else:
program_code = self.model.program_code
self.model = stan.build(program_code, data=modified_observed_data)
fit = self.model.sample(**self.sample_kwargs)
return fit
[docs]
class CmdStanPySamplingWrapper(StanSamplingWrapper):
"""CmdStanPy sampling wrapper base class.
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
description. An example of ``CmdStanPySamplingWrapper`` usage can be found
in the :ref:`cmdstanpy_refitting` notebook.
Warnings
--------
Sampling wrappers are an experimental feature in a very early stage. Please use them
with caution.
"""
[docs]
def __init__(self, data_file, **kwargs):
"""Initialize the CmdStanPySamplingWrapper.
Parameters
----------
data_file : str
Filename on which to store the data for every refit.
It's contents will be overwritten.
"""
super().__init__(**kwargs)
self.data_file = data_file
[docs]
def sample(self, modified_observed_data):
"""Resample cmdstanpy model on modified_observed_data."""
from cmdstanpy import write_stan_json
write_stan_json(self.data_file, modified_observed_data)
fit = self.model.sample(**{**self.sample_kwargs, "data": self.data_file})
return fit