"""Parallel coordinates plot showing posterior points with and without divergences marked."""
import numpy as np
from scipy.stats import rankdata
from ..data import convert_to_dataset
from ..labels import BaseLabeller
from ..sel_utils import xarray_to_ndarray
from ..rcparams import rcParams
from ..stats.stats_utils import stats_variance_2d as svar
from ..utils import _numba_var, _var_names, get_coords
from .plot_utils import get_plotting_function
[docs]
def plot_parallel(
data,
var_names=None,
filter_vars=None,
coords=None,
figsize=None,
textsize=None,
legend=True,
colornd="k",
colord="C1",
shadend=0.025,
labeller=None,
ax=None,
norm_method=None,
backend=None,
backend_config=None,
backend_kwargs=None,
show=None,
):
"""
Plot parallel coordinates plot showing posterior points with and without divergences.
Described by https://arxiv.org/abs/1709.01449
Parameters
----------
data: obj
Any object that can be converted to an :class:`arviz.InferenceData` object
refer to documentation of :func:`arviz.convert_to_dataset` for details
var_names: list of variable names
Variables to be plotted, if `None` all variables are plotted. Can be used to change the
order of the plotted variables. Prefix the variables by ``~`` when you want to exclude
them from the plot.
filter_vars: {None, "like", "regex"}, optional, default=None
If `None` (default), interpret var_names as the real variables names. If "like",
interpret var_names as substrings of the real variables names. If "regex",
interpret var_names as regular expressions on the real variables names. A la
``pandas.filter``.
coords: mapping, optional
Coordinates of ``var_names`` to be plotted.
Passed to :meth:`xarray.Dataset.sel`.
figsize: tuple
Figure size. If None it will be defined automatically.
textsize: float
Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
on ``figsize``.
legend: bool
Flag for plotting legend (defaults to True)
colornd: valid matplotlib color
color for non-divergent points. Defaults to 'k'
colord: valid matplotlib color
color for divergent points. Defaults to 'C1'
shadend: float
Alpha blending value for non-divergent points, between 0 (invisible) and 1 (opaque).
Defaults to .025
labeller : labeller instance, optional
Class providing the method ``make_label_vert`` to generate the labels in the plot.
Read the :ref:`label_guide` for more details and usage examples.
ax: axes, optional
Matplotlib axes or bokeh figures.
norm_method: str
Method for normalizing the data. Methods include normal, minmax and rank.
Defaults to none.
backend: str, optional
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
backend_config: dict, optional
Currently specifies the bounds to use for bokeh axes.
Defaults to value set in ``rcParams``.
backend_kwargs: bool, optional
These are kwargs specific to the backend being used, passed to
:func:`matplotlib.pyplot.subplots` or
:func:`bokeh.plotting.figure`.
show: bool, optional
Call backend show function.
Returns
-------
axes: matplotlib axes or bokeh figures
See Also
--------
plot_pair : Plot a scatter, kde and/or hexbin matrix with (optional) marginals on the diagonal.
plot_trace : Plot distribution (histogram or kernel density estimates) and sampled values
or rank plot
Examples
--------
Plot default parallel plot
.. plot::
:context: close-figs
>>> import arviz as az
>>> data = az.load_arviz_data('centered_eight')
>>> az.plot_parallel(data, var_names=["mu", "tau"])
Plot parallel plot with normalization
.. plot::
:context: close-figs
>>> az.plot_parallel(data, var_names=["theta", "tau", "mu"], norm_method="normal")
Plot parallel plot with minmax
.. plot::
:context: close-figs
>>> ax = az.plot_parallel(data, var_names=["theta", "tau", "mu"], norm_method="minmax")
>>> ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
Plot parallel plot with rank
.. plot::
:context: close-figs
>>> ax = az.plot_parallel(data, var_names=["theta", "tau", "mu"], norm_method="rank")
>>> ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
"""
if coords is None:
coords = {}
if labeller is None:
labeller = BaseLabeller()
# Get diverging draws and combine chains
divergent_data = convert_to_dataset(data, group="sample_stats")
_, diverging_mask = xarray_to_ndarray(
divergent_data,
var_names=("diverging",),
combined=True,
)
diverging_mask = np.squeeze(diverging_mask)
# Get posterior draws and combine chains
posterior_data = convert_to_dataset(data, group="posterior")
var_names = _var_names(var_names, posterior_data, filter_vars)
var_names, _posterior = xarray_to_ndarray(
get_coords(posterior_data, coords),
var_names=var_names,
combined=True,
label_fun=labeller.make_label_vert,
)
if len(var_names) < 2:
raise ValueError("Number of variables to be plotted must be 2 or greater.")
if norm_method is not None:
if norm_method == "normal":
mean = np.mean(_posterior, axis=1)
if _posterior.ndim <= 2:
standard_deviation = np.sqrt(_numba_var(svar, np.var, _posterior, axis=1))
else:
standard_deviation = np.std(_posterior, axis=1)
for i in range(0, np.shape(mean)[0]):
_posterior[i, :] = (_posterior[i, :] - mean[i]) / standard_deviation[i]
elif norm_method == "minmax":
min_elem = np.min(_posterior, axis=1)
max_elem = np.max(_posterior, axis=1)
for i in range(0, np.shape(min_elem)[0]):
_posterior[i, :] = ((_posterior[i, :]) - min_elem[i]) / (max_elem[i] - min_elem[i])
elif norm_method == "rank":
_posterior = rankdata(_posterior, axis=1, method="average")
else:
raise ValueError(f"{norm_method} is not supported. Use normal, minmax or rank.")
parallel_kwargs = dict(
ax=ax,
colornd=colornd,
colord=colord,
shadend=shadend,
diverging_mask=diverging_mask,
posterior=_posterior,
textsize=textsize,
var_names=var_names,
legend=legend,
figsize=figsize,
backend_kwargs=backend_kwargs,
backend_config=backend_config,
show=show,
)
if backend is None:
backend = rcParams["plot.backend"]
backend = backend.lower()
# TODO: Add backend kwargs
plot = get_plotting_function("plot_parallel", "parallelplot", backend)
ax = plot(**parallel_kwargs)
return ax