"""Contain functions for Bayes Factor plotting."""
from collections.abc import Mapping, Sequence
from importlib import import_module
from typing import Any, Literal
import numpy as np
import xarray as xr
from arviz_base import extract, rcParams
from arviz_base.validate import validate_dict_argument, validate_sample_dims
from xarray import concat
from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.dist_plot import plot_dist
from arviz_plots.plots.utils import (
get_visual_kwargs,
process_group_variables_coords,
set_wrap_layout,
)
[docs]
def plot_prior_posterior(
dt,
*,
var_names=None,
filter_vars=None,
group=None, # pylint: disable=unused-argument
coords=None,
sample_dims=None,
kind=None,
plot_collection=None,
backend=None,
labeller=None,
aes_by_visuals: Mapping[
Literal[
"dist",
"credible_interval",
"point_estimate",
"point_estimate_text",
"title",
"rug",
],
Sequence[str],
] = None,
visuals: Mapping[
Literal[
"dist",
"credible_interval",
"point_estimate",
"point_estimate_text",
"title",
"rug",
"remove_axis",
],
Mapping[str, Any] | bool,
] = None,
stats: Mapping[
Literal["dist", "credible_interval", "point_estimate"], Mapping[str, Any] | xr.Dataset
] = None,
**pc_kwargs,
):
r"""Plot 1D marginal densities for prior and posterior.
The Bayes factor is estimated by comparing a model (H1) against a model
in which the parameter of interest has been restricted to be a point-null (H0)
This computation assumes the models are nested and thus H0 is a special case of H1.
Parameters
----------
dt : DataTree or dict of {str : DataTree}
Input data. In case of dictionary input, the keys are taken to be model names.
In such cases, a dimension "model" is generated and can be used to map to aesthetics.
var_names : str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars : {None, "like", "regex"}, default=None
If None, interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
group : None
This argument is ignored. Have it here for compatibility with other plotting functions.
coords : dict, optional
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
kind : {"kde", "hist", "dot", "ecdf"}, optional
How to represent the marginal density.
Defaults to ``rcParams["plot.density_kind"]``
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh"}, optional
labeller : labeller, optional
aes_by_visuals : mapping of {str : sequence of str}, optional
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. The prior and posterior groups are combined creating a new
dimension "group". By default, there is an aesthetic mapping from group to color.
Valid keys are the same as for `visuals`.
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:
* dist -> depending on the value of `kind` passed to:
* "kde" -> passed to :func:`~arviz_plots.visuals.line_xy`
* "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
* "hist" -> passed to :func: `~arviz_plots.visuals.step_hist`
* "dot" -> passed to :func:`~arviz_plots.visuals.scatter_xy`
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`
* legend -> passed to :class:`arviz_plots.PlotCollection.add_legend`
stats : mapping, optional
Valid keys are:
* dist -> passed to kde, ecdf, ...
**pc_kwargs
Passed to :class:`arviz_plots.PlotCollection.wrap`
Returns
-------
PlotCollection
Examples
--------
Select two variables and plot them with an ecdf.
.. plot::
:context: close-figs
>>> from arviz_plots import plot_prior_posterior, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> dt = load_arviz_data('centered_eight')
>>> plot_prior_posterior(dt, var_names=["mu", "tau"], kind="ecdf")
.. minigallery:: plot_prior_posterior
"""
aes_by_visuals = validate_dict_argument(aes_by_visuals, (plot_dist, "aes_by_visuals"))
visuals = validate_dict_argument(visuals, (plot_dist, "visuals"))
stats = validate_dict_argument(stats, (plot_dist, "stats"))
if backend is None:
if plot_collection is None:
backend = rcParams["plot.backend"]
else:
backend = plot_collection.backend
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
sample_dims_prior = validate_sample_dims(sample_dims, data=dt.prior)
sample_dims = validate_sample_dims(sample_dims, data=dt.posterior)
prior_size = np.prod([dt.prior.sizes[dim] for dim in sample_dims_prior])
posterior_size = np.prod([dt.posterior.sizes[dim] for dim in sample_dims])
num_samples = min(prior_size, posterior_size)
ds_prior = extract(
dt,
group="prior",
sample_dims=sample_dims_prior,
combined=True,
num_samples=num_samples,
random_seed=0,
keep_dataset=True,
)
prior_dims_drop = list(set(sample_dims_prior).union(ds_prior.attrs["sample_dims"]))
sample_dims_prior = ds_prior.attrs["sample_dims"]
ds_prior = ds_prior.drop_vars(prior_dims_drop).assign_coords(
{sample_dims_prior[0]: np.arange(num_samples)}
)
ds_posterior = extract(
dt,
group="posterior",
sample_dims=sample_dims,
combined=True,
num_samples=num_samples,
random_seed=0,
keep_dataset=True,
)
posterior_dims_drop = list(set(sample_dims).union(ds_posterior.attrs["sample_dims"]))
sample_dims = ds_posterior.attrs["sample_dims"]
ds_posterior = ds_posterior.drop_vars(posterior_dims_drop).assign_coords(
{sample_dims[0]: np.arange(num_samples)}
)
distribution = concat([ds_prior, ds_posterior], dim="group").assign_coords(
{"group": ["prior", "posterior"]}
)
distribution = process_group_variables_coords(
distribution,
group=None,
var_names=var_names,
filter_vars=filter_vars,
coords=coords,
)
if plot_collection is None:
pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs["aes"].setdefault("color", ["group"])
pc_kwargs.setdefault("col_wrap", 4)
pc_kwargs.setdefault(
"cols",
["__variable__"]
+ [dim for dim in distribution.dims if dim not in sample_dims + ["group"]],
)
pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution)
plot_collection = PlotCollection.wrap(
distribution,
backend=backend,
**pc_kwargs,
)
visuals.setdefault("credible_interval", False)
visuals.setdefault("point_estimate", False)
visuals.setdefault("point_estimate_text", False)
if kind == "hist":
visuals.setdefault("dist", {})
visuals.setdefault("remove_axis", True)
plot_collection = plot_dist(
distribution,
var_names=None,
group=None,
coords=None,
sample_dims=sample_dims,
kind=kind,
point_estimate=None,
ci_kind=None,
ci_prob=None,
plot_collection=plot_collection,
backend=backend,
labeller=labeller,
aes_by_visuals=aes_by_visuals,
visuals=visuals,
stats=stats,
**pc_kwargs,
)
legend_kwargs = get_visual_kwargs(visuals, "legend")
if legend_kwargs is not False:
legend_kwargs.setdefault("dim", ["group"])
plot_collection.add_legend(**legend_kwargs)
return plot_collection