"""dist plot code."""
from collections.abc import Mapping, Sequence
from importlib import import_module
from typing import Any, Literal
import arviz_stats
import xarray as xr
from arviz_base import rcParams
from arviz_base.validate import (
validate_ci_prob,
validate_dict_argument,
validate_or_use_rcparam,
validate_sample_dims,
)
from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.utils import (
_compute_func,
compute_dist,
filter_aes,
filter_aes_full,
get_visual_kwargs,
process_group_variables_coords,
set_wrap_layout,
)
from arviz_plots.visuals import (
ecdf_line,
fill_between_y,
hist,
labelled_title,
line_x,
line_xy,
point_estimate_text,
remove_axis,
scatter_x,
scatter_xy,
step_hist,
)
[docs]
def plot_dist(
dt,
*,
var_names=None,
filter_vars=None,
group="posterior",
coords=None,
sample_dims=None,
kind=None,
point_estimate=None,
ci_kind=None,
ci_prob=None,
plot_collection=None,
backend=None,
labeller=None,
aes_by_visuals: Mapping[
Literal[
"dist",
"face",
"credible_interval",
"point_estimate",
"point_estimate_text",
"title",
"rug",
],
Sequence[str],
] = None,
visuals: Mapping[
Literal[
"dist",
"face",
"credible_interval",
"point_estimate",
"point_estimate_text",
"title",
"rug",
"remove_axis",
],
Mapping[str, Any] | bool,
] = None,
stats: Mapping[
Literal["dist", "credible_interval", "point_estimate"], Mapping[str, Any] | xr.Dataset
] = None,
**pc_kwargs,
):
"""Plot 1D marginal densities in the style of John K. Kruschke’s book [1]_.
Generate :term:`faceted` :term:`plots` with: a graphical representation of 1D marginal
densities (as KDE, histogram, ECDF or dotplot), a credible interval and a point estimate.
Parameters
----------
dt : DataTree or dict of {str : DataTree}
Input data. In case of dictionary input, the keys are taken to be model names.
In such cases, a dimension "model" is generated and can be used to map to aesthetics.
var_names : str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars : {None, "like", "regex"}, default=None
If None, interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
group : str, default "posterior"
Group to be plotted.
coords : dict, optional
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
kind : {"auto", "kde", "hist", "dot", "ecdf"}, optional
How to represent the marginal density.
Defaults to ``rcParams["plot.density_kind"]``
point_estimate : {"mean", "median", "mode"}, optional
Which point estimate to plot. Defaults to rcParam :data:`stats.point_estimate`
ci_kind : {"eti", "hdi"}, optional
Which credible interval to use. Defaults to ``rcParams["stats.ci_kind"]``
ci_prob : float, optional
Indicates the probability that should be contained within the plotted credible interval.
Defaults to ``rcParams["stats.ci_prob"]``
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh"}, optional
labeller : labeller, optional
aes_by_visuals : mapping of {str : sequence of str}, optional
Mapping of visuals to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `visuals`.
With a single model, no aesthetic mappings are generated by default,
each variable+coord combination gets a :term:`plot` but they all look the same,
unless there are user provided aesthetic mappings.
With multiple models, ``plot_dist`` maps "color" and "y" to the "model" dimension.
By default, all aesthetics but "y" are mapped to the density representation,
and if multiple models are present, "color" and "y" are mapped to the
credible interval and the point estimate.
When "point_estimate" key is provided but "point_estimate_text" isn't,
the values assigned to the first are also used for the second.
visuals : mapping of {str : mapping or bool}, optional
Valid keys are:
* dist -> depending on the value of `kind` passed to:
* "kde" -> passed to :func:`~arviz_plots.visuals.line_xy`
* "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
* "hist" -> passed to :func: `~arviz_plots.visuals.step_hist`
* "dot" -> passed to :func:`~arviz_plots.visuals.scatter_xy`
* face -> :term:`visual` that fills the area under the marginal distribution representation.
Defaults to False. Depending on the value of `kind` it is passed to:
* "kde", "ecdf" or "dot" -> passed to :func:`~arviz_plots.visuals.fill_between_y`
* "hist" -> passed to :func:`~arviz_plots.visuals.hist`
* credible_interval -> passed to :func:`~arviz_plots.visuals.line_x`
* point_estimate -> passed to :func:`~arviz_plots.visuals.scatter_x`
* point_estimate_text -> passed to :func:`~arviz_plots.visuals.point_estimate_text`
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`
* rug -> passed to :func:`~arviz_plots.visuals.scatter_x`. Defaults to False.
* remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function
stats : mapping of {str : mapping or Dataset}, optional
Valid keys are:
* dist -> passed to :func:`~arviz_stats.kde`, :func:`~arviz_stats.histogram`,
:func:`~arviz_stats.ecdf`, or :func:`~arviz_stats.qds` depending on `kind`
* credible_interval -> passed to :func:`~arviz_stats.eti` or :func:`arviz_stats.hdi`
* point_estimate -> passed to mean, median or mode. Defaults to
round the result according to ``rcParams["stats.round_to"]``.
In case a :class:`~xarray.Dataset` is provided, it will be interpreted
as pre-computed values for that statistic.
**pc_kwargs
Passed to :class:`arviz_plots.PlotCollection.wrap`
Returns
-------
PlotCollection
See Also
--------
:ref:`plots_intro` :
General introduction to batteries-included plotting functions, common use and logic overview
Examples
--------
Map the color to the variable, and have the mapping apply
to the title too instead of only the density representation:
.. plot::
:context: close-figs
>>> from arviz_plots import plot_dist, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> non_centered = load_arviz_data('non_centered_eight')
>>> pc = plot_dist(
>>> non_centered,
>>> coords={"school": ["Choate", "Deerfield", "Hotchkiss"]},
>>> aes={"color": ["__variable__"]},
>>> aes_by_visuals={"title": ["color"]},
>>> )
Faceting and aesthetics mappings happen on unique coordinate values. If there are repeated
coordinate values they will be grouped and reduced along with `sample_dims`.
.. plot::
:context: close-figs
>>> post = non_centered.posterior.to_dataset()
>>> repeated_coords = ["a", "a", "a", "b", "b", "b", "b", "c"]
>>> pc = plot_dist(post.assign_coords(school=repeated_coords))
.. minigallery:: plot_dist
References
----------
.. [1] Kruschke. Doing Bayesian Data Analysis, Second Edition: A Tutorial with R,
JAGS, and Stan. Academic Press, 2014. ISBN 978-0-12-405888-0.
https://www.sciencedirect.com/book/9780124058880
"""
kind = validate_or_use_rcparam(kind, "plot.density_kind")
point_estimate = validate_or_use_rcparam(point_estimate, "stats.point_estimate")
ci_kind = validate_or_use_rcparam(ci_kind, "stats.ci_kind")
ci_prob = validate_ci_prob(ci_prob)
aes_by_visuals = validate_dict_argument(aes_by_visuals, (plot_dist, "aes_by_visuals"))
visuals = validate_dict_argument(visuals, (plot_dist, "visuals"))
stats = validate_dict_argument(stats, (plot_dist, "stats"))
if kind == "ecdf":
visuals.setdefault("remove_axis", False)
distribution = process_group_variables_coords(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
)
sample_dims = validate_sample_dims(sample_dims, data=distribution)
if backend is None:
if plot_collection is None:
backend = rcParams["plot.backend"]
else:
backend = plot_collection.backend
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
if plot_collection is None:
pc_kwargs["figure_kwargs"] = pc_kwargs.get("figure_kwargs", {}).copy()
if backend is None:
backend = rcParams["plot.backend"]
pc_kwargs.setdefault(
"cols",
["__variable__"]
+ [dim for dim in distribution.dims if dim not in {"model"}.union(sample_dims)],
)
if "model" in distribution:
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs["aes"].setdefault("color", ["model"])
pc_kwargs["aes"].setdefault("y", ["model"])
pc_kwargs = set_wrap_layout(pc_kwargs, plot_bknd, distribution)
plot_collection = PlotCollection.wrap(
distribution,
backend=backend,
**pc_kwargs,
)
face_kwargs = get_visual_kwargs(visuals, "face", False)
density_kwargs = get_visual_kwargs(visuals, "dist")
aes_by_visuals.setdefault("dist", plot_collection.aes_set.difference("y"))
if face_kwargs is not False:
aes_by_visuals.setdefault("face", set(aes_by_visuals["dist"]).difference({"linestyle"}))
if "model" in distribution:
aes_by_visuals.setdefault("credible_interval", ["color", "y"])
aes_by_visuals.setdefault("point_estimate", ["color", "y"])
if "point_estimate" in aes_by_visuals and "point_estimate_text" not in aes_by_visuals:
aes_by_visuals["point_estimate_text"] = aes_by_visuals["point_estimate"]
density_reduce, density_active, density_aes, density_ignore = filter_aes_full(
plot_collection, aes_by_visuals, "dist", sample_dims
)
density = compute_dist(distribution, density_reduce, density_active, kind=kind, stats=stats)
kind_var_map = {
kind_i: [k for k, da in density.items() if kind_i == da.attrs["kind"]]
for kind_i in ("dot", "ecdf", "hist", "kde")
}
kind_var_map = {k: v for k, v in kind_var_map.items() if v}
# filled face (should go under the dist visual if both present)
if face_kwargs is not False:
_, face_aes, face_ignore = filter_aes(plot_collection, aes_by_visuals, "face", sample_dims)
if "color" not in face_aes:
face_kwargs.setdefault("color", "C0")
if "alpha" not in face_aes:
face_kwargs.setdefault("alpha", 0.4)
if "dot" in kind_var_map:
kwargs = stats.get("dist", {}).copy()
kwargs["top_only"] = True
top_only_qds = distribution[kind_var_map["dot"]].azstats.qds(
dim=density_reduce, **kwargs
)
density = xr.merge((top_only_qds, density), compat="override")
if any(kind_i in kind_var_map for kind_i in ("kde", "ecdf", "dot")):
fill_between_vars = [
var_name
for kind_i in ("kde", "ecdf", "dot")
for var_name in kind_var_map.get(kind_i, [])
]
face_density = (
density[fill_between_vars]
.rename(plot_axis="kwarg")
.sel(kwarg=["x", "y"])
.pad(kwarg=(0, 1), constant_values=0)
.assign_coords(kwarg=["x", "y_top", "y_bottom"])
)
plot_collection.map(
fill_between_y,
"face",
data=face_density,
ignore_aes=face_ignore,
**face_kwargs,
)
if "hist" in kind_var_map:
plot_collection.map(
hist,
"face",
data=density[kind_var_map["hist"]],
ignore_aes=face_ignore,
**face_kwargs,
)
# density
if density_kwargs is not False:
if "color" not in density_aes:
density_kwargs.setdefault("color", "C0")
if "kde" in kind_var_map:
plot_collection.map(
line_xy,
"dist",
data=density[kind_var_map["kde"]],
ignore_aes=density_ignore,
**density_kwargs,
)
if "ecdf" in kind_var_map:
plot_collection.map(
ecdf_line,
"dist",
data=density[kind_var_map["ecdf"]],
ignore_aes=density_ignore,
**density_kwargs,
)
if "hist" in kind_var_map:
plot_collection.map(
step_hist,
"dist",
data=density[kind_var_map["hist"]],
ignore_aes=density_ignore,
**density_kwargs,
)
if "dot" in kind_var_map:
plot_collection.map(
scatter_xy,
"dist",
data=density[kind_var_map["dot"]],
ignore_aes=density_ignore,
**density_kwargs,
)
# rug
rug_kwargs = get_visual_kwargs(visuals, "rug", False)
if rug_kwargs is not False:
if not isinstance(rug_kwargs, dict):
rug_kwargs = {}
_, rug_aes, rug_ignore = filter_aes(plot_collection, aes_by_visuals, "rug", sample_dims)
if "color" not in rug_aes:
rug_kwargs.setdefault("color", "B1")
if "marker" not in rug_aes:
rug_kwargs.setdefault("marker", "|")
if "size" not in rug_aes:
rug_kwargs.setdefault("size", 15)
plot_collection.map(
scatter_x,
"rug",
data=distribution,
ignore_aes=rug_ignore,
**rug_kwargs,
)
if (
(density_kwargs is not False or face_kwargs is not False)
and ("model" in distribution)
and (plot_collection.coords is None)
):
y_ds = plot_collection.get_aes_as_dataset("y")["mapping"]
density_ys = density.sel(
plot_axis=[
coord for coord in density["plot_axis"].values if coord in ("y", "histogram")
]
)
density_ys_max = density_ys.max(
[dim for dim in density_ys.dims if dim not in plot_collection.facet_dims]
)
y_ds = 0.15 * y_ds * density_ys_max
plot_collection.update_aes_from_dataset("y", y_ds)
# credible interval
ci_kwargs = get_visual_kwargs(visuals, "credible_interval")
if ci_kwargs is not False:
ci_reduce, ci_active, ci_aes, ci_ignore = filter_aes_full(
plot_collection, aes_by_visuals, "credible_interval", sample_dims
)
ci_stats_value = stats.get("credible_interval", {})
if isinstance(ci_stats_value, xr.Dataset):
ci = ci_stats_value
else:
ci_stats_kwargs = ci_stats_value.copy()
ci_stats_kwargs["prob"] = ci_prob
if ci_kind == "eti":
ci = _compute_func(
arviz_stats.eti,
distribution,
active_dims=ci_active,
reduce_dims=ci_reduce,
kwargs=ci_stats_kwargs,
)
elif ci_kind == "hdi":
ci = _compute_func(
arviz_stats.hdi,
distribution,
active_dims=ci_active,
reduce_dims=ci_reduce,
kwargs=ci_stats_kwargs,
)
if "color" not in ci_aes:
ci_kwargs.setdefault("color", "B2")
plot_collection.map(line_x, "credible_interval", data=ci, ignore_aes=ci_ignore, **ci_kwargs)
# point estimate
pe_kwargs = get_visual_kwargs(visuals, "point_estimate")
pet_kwargs = get_visual_kwargs(visuals, "point_estimate_text")
if (pe_kwargs is not False) or (pet_kwargs is not False):
pe_reduce, pe_active, pe_aes, pe_ignore = filter_aes_full(
plot_collection, aes_by_visuals, "point_estimate", sample_dims
)
pe_stats_value = stats.get("point_estimate", {})
if isinstance(pe_stats_value, xr.Dataset):
point = pe_stats_value
else:
pe_stats_kwargs = pe_stats_value.copy()
pe_stats_kwargs["round_to"] = "none"
pe_func = {
"mean": arviz_stats.mean,
"median": arviz_stats.median,
"mode": arviz_stats.mode,
}[point_estimate]
point = _compute_func(
pe_func,
distribution,
active_dims=pe_active,
reduce_dims=pe_reduce,
kwargs=pe_stats_kwargs,
)
if pe_kwargs is not False:
if "color" not in pe_aes:
pe_kwargs.setdefault("color", "B2")
plot_collection.map(
scatter_x,
"point_estimate",
data=point,
ignore_aes=pe_ignore,
**pe_kwargs,
)
# point estimate text
if pet_kwargs is not False:
if density_kwargs is False and face_kwargs is False:
point_y = xr.full_like(point, 0.05)
else:
density_ys = density.sel(
plot_axis=[
coord for coord in density["plot_axis"].values if coord in ("y", "histogram")
]
)
point_y = 0.1 * density_ys.max(
[dim for dim in density_ys.dims if dim not in point.dims]
)
point = xr.concat((point, point_y), dim="plot_axis").assign_coords(plot_axis=["x", "y"])
_, pet_aes, pet_ignore = filter_aes(
plot_collection, aes_by_visuals, "point_estimate_text", sample_dims
)
if "color" not in pet_aes:
pet_kwargs.setdefault("color", "B2")
pet_kwargs.setdefault("horizontal_align", "center")
pet_kwargs.setdefault("point_label", "x")
plot_collection.map(
point_estimate_text,
"point_estimate_text",
data=point,
point_estimate=point_estimate,
ignore_aes=pet_ignore,
**pet_kwargs,
)
# aesthetics
title_kwargs = get_visual_kwargs(visuals, "title")
if title_kwargs is not False:
_, title_aes, title_ignore = filter_aes(
plot_collection, aes_by_visuals, "title", sample_dims
)
if "color" not in title_aes:
title_kwargs.setdefault("color", "B1")
plot_collection.facet_map(
labelled_title,
"title",
ignore_aes=title_ignore,
subset_info=True,
labeller=labeller,
**title_kwargs,
)
if visuals.get("remove_axis", True) is not False:
plot_collection.facet_map(
remove_axis,
store_artist=backend == "none",
axis="y",
)
return plot_collection