"""Plot kde or histograms and values from MCMC samples."""
import warnings
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union, Sequence
from ..data import CoordSpec, InferenceData, convert_to_dataset
from ..labels import BaseLabeller
from ..rcparams import rcParams
from ..sel_utils import xarray_var_iter
from ..utils import _var_names, get_coords
from .plot_utils import KwargSpec, get_plotting_function
[docs]
def plot_trace(
data: InferenceData,
var_names: Optional[Sequence[str]] = None,
filter_vars: Optional[str] = None,
transform: Optional[Callable] = None,
coords: Optional[CoordSpec] = None,
divergences: Optional[str] = "auto",
kind: Optional[str] = "trace",
figsize: Optional[Tuple[float, float]] = None,
rug: bool = False,
lines: Optional[List[Tuple[str, CoordSpec, Any]]] = None,
circ_var_names: Optional[List[str]] = None,
circ_var_units: str = "radians",
compact: bool = True,
compact_prop: Optional[Union[str, Mapping[str, Any]]] = None,
combined: bool = False,
chain_prop: Optional[Union[str, Mapping[str, Any]]] = None,
legend: bool = False,
plot_kwargs: Optional[KwargSpec] = None,
fill_kwargs: Optional[KwargSpec] = None,
rug_kwargs: Optional[KwargSpec] = None,
hist_kwargs: Optional[KwargSpec] = None,
trace_kwargs: Optional[KwargSpec] = None,
rank_kwargs: Optional[KwargSpec] = None,
labeller=None,
axes=None,
backend: Optional[str] = None,
backend_config: Optional[KwargSpec] = None,
backend_kwargs: Optional[KwargSpec] = None,
show: Optional[bool] = None,
):
"""Plot distribution (histogram or kernel density estimates) and sampled values or rank plot.
If `divergences` data is available in `sample_stats`, will plot the location of divergences as
dashed vertical lines.
Parameters
----------
data: obj
Any object that can be converted to an :class:`arviz.InferenceData` object
Refer to documentation of :func:`arviz.convert_to_dataset` for details
var_names: str or list of str, optional
One or more variables to be plotted. Prefix the variables by ``~`` when you want
to exclude them from the plot.
filter_vars: {None, "like", "regex"}, optional, default=None
If `None` (default), interpret var_names as the real variables names. If "like",
interpret var_names as substrings of the real variables names. If "regex",
interpret var_names as regular expressions on the real variables names. A la
``pandas.filter``.
coords: dict of {str: slice or array_like}, optional
Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
divergences: {"bottom", "top", None}, optional
Plot location of divergences on the traceplots.
kind: {"trace", "rank_bars", "rank_vlines"}, optional
Choose between plotting sampled values per iteration and rank plots.
transform: callable, optional
Function to transform data (defaults to None i.e.the identity function)
figsize: tuple of (float, float), optional
If None, size is (12, variables * 2)
rug: bool, optional
If True adds a rugplot of samples. Defaults to False. Ignored for 2D KDE.
Only affects continuous variables.
lines: list of tuple of (str, dict, array_like), optional
List of (var_name, {'coord': selection}, [line, positions]) to be overplotted as
vertical lines on the density and horizontal lines on the trace.
circ_var_names : str or list of str, optional
List of circular variables to account for when plotting KDE.
circ_var_units : str
Whether the variables in ``circ_var_names`` are in "degrees" or "radians".
compact: bool, optional
Plot multidimensional variables in a single plot.
compact_prop: str or dict {str: array_like}, optional
Defines the property name and the property values to distinguish different
dimensions with compact=True.
When compact=True it defaults to color, it is
ignored otherwise.
combined: bool, optional
Flag for combining multiple chains into a single line. If False (default), chains will be
plotted separately.
chain_prop: str or dict {str: array_like}, optional
Defines the property name and the property values to distinguish different chains.
If compact=True it defaults to linestyle,
otherwise it uses the color to distinguish
different chains.
legend: bool, optional
Add a legend to the figure with the chain color code.
plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs: dict, optional
Extra keyword arguments passed to :func:`arviz.plot_dist`. Only affects continuous
variables.
trace_kwargs: dict, optional
Extra keyword arguments passed to :meth:`matplotlib.axes.Axes.plot`
labeller : labeller instance, optional
Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
Read the :ref:`label_guide` for more details and usage examples.
rank_kwargs : dict, optional
Extra keyword arguments passed to :func:`arviz.plot_rank`
axes: axes, optional
Matplotlib axes or bokeh figures.
backend: {"matplotlib", "bokeh"}, optional
Select plotting backend.
backend_config: dict, optional
Currently specifies the bounds to use for bokeh axes. Defaults to value set in rcParams.
backend_kwargs: dict, optional
These are kwargs specific to the backend being used, passed to
:func:`matplotlib.pyplot.subplots` or
:func:`bokeh.plotting.figure`.
show: bool, optional
Call backend show function.
Returns
-------
axes: matplotlib axes or bokeh figures
See Also
--------
plot_rank : Plot rank order statistics of chains.
Examples
--------
Plot a subset variables and select them with partial naming
.. plot::
:context: close-figs
>>> import arviz as az
>>> data = az.load_arviz_data('non_centered_eight')
>>> coords = {'school': ['Choate', 'Lawrenceville']}
>>> az.plot_trace(data, var_names=('theta'), filter_vars="like", coords=coords)
Show all dimensions of multidimensional variables in the same plot
.. plot::
:context: close-figs
>>> az.plot_trace(data, compact=True)
Display a rank plot instead of trace
.. plot::
:context: close-figs
>>> az.plot_trace(data, var_names=["mu", "tau"], kind="rank_bars")
Combine all chains into one distribution and select variables with regular expressions
.. plot::
:context: close-figs
>>> az.plot_trace(
>>> data, var_names=('^theta'), filter_vars="regex", coords=coords, combined=True
>>> )
Plot reference lines against distribution and trace
.. plot::
:context: close-figs
>>> lines = (('theta_t',{'school': "Choate"}, [-1]),)
>>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines)
"""
if kind not in {"trace", "rank_vlines", "rank_bars"}:
raise ValueError("The value of kind must be either trace, rank_vlines or rank_bars.")
if divergences == "auto":
divergences = "top" if rug else "bottom"
if divergences:
try:
divergence_data = convert_to_dataset(data, group="sample_stats").diverging.transpose(
"chain", "draw"
)
except (ValueError, AttributeError): # No sample_stats, or no `.diverging`
divergences = None
if coords is None:
coords = {}
if labeller is None:
labeller = BaseLabeller()
if divergences:
divergence_data = get_coords(
divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")}
)
else:
divergence_data = False
coords_data = get_coords(convert_to_dataset(data, group="posterior"), coords)
if transform is not None:
coords_data = transform(coords_data)
var_names = _var_names(var_names, coords_data, filter_vars)
skip_dims = set(coords_data.dims) - {"chain", "draw"} if compact else set()
plotters = list(
xarray_var_iter(
coords_data,
var_names=var_names,
combined=True,
skip_dims=skip_dims,
dim_order=["chain", "draw"],
)
)
max_plots = rcParams["plot.max_subplots"]
max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1)
if len(plotters) > max_plots:
warnings.warn(
"rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
"of variables to plot ({len_plotters}), generating only {max_plots} "
"plots".format(max_plots=max_plots, len_plotters=len(plotters)),
UserWarning,
)
plotters = plotters[:max_plots]
# TODO: Check if this can be further simplified
trace_plot_args = dict(
# User Kwargs
data=coords_data,
var_names=var_names,
# coords = coords,
divergences=divergences,
kind=kind,
figsize=figsize,
rug=rug,
lines=lines,
circ_var_names=circ_var_names,
circ_var_units=circ_var_units,
plot_kwargs=plot_kwargs,
fill_kwargs=fill_kwargs,
rug_kwargs=rug_kwargs,
hist_kwargs=hist_kwargs,
trace_kwargs=trace_kwargs,
rank_kwargs=rank_kwargs,
compact=compact,
compact_prop=compact_prop,
combined=combined,
chain_prop=chain_prop,
legend=legend,
labeller=labeller,
# Generated kwargs
divergence_data=divergence_data,
# skip_dims=skip_dims,
plotters=plotters,
axes=axes,
backend_config=backend_config,
backend_kwargs=backend_kwargs,
show=show,
)
if backend is None:
backend = rcParams["plot.backend"]
backend = backend.lower()
plot = get_plotting_function("plot_trace", "traceplot", backend)
axes = plot(**trace_plot_args)
return axes