Source code for arviz.plots.essplot

"""Plot quantile or local effective sample sizes."""

import numpy as np
import xarray as xr

from ..data import convert_to_dataset
from ..labels import BaseLabeller
from ..rcparams import rcParams
from ..sel_utils import xarray_var_iter
from ..stats import ess
from ..utils import _var_names, get_coords
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function


[docs] def plot_ess( idata, var_names=None, filter_vars=None, kind="local", relative=False, coords=None, figsize=None, grid=None, textsize=None, rug=False, rug_kind="diverging", n_points=20, extra_methods=False, min_ess=400, labeller=None, ax=None, extra_kwargs=None, text_kwargs=None, hline_kwargs=None, rug_kwargs=None, backend=None, backend_kwargs=None, show=None, **kwargs, ): r"""Generate quantile, local, or evolution ESS plots. The local and the quantile ESS plots are recommended for checking that there are enough samples for all the explored regions of the parameter space. Checking local and quantile ESS is particularly relevant when working with HDI intervals as opposed to ESS bulk, which is suitable for point estimates. Parameters ---------- idata : InferenceData Any object that can be converted to an :class:`arviz.InferenceData` object Refer to documentation of :func:`arviz.convert_to_dataset` for details. var_names : list of str, optional Variables to be plotted. Prefix the variables by ``~`` when you want to exclude them from the plot. See :ref:`this section <common_var_names>` for usage examples. filter_vars : {None, "like", "regex"}, default None If `None` (default), interpret `var_names` as the real variables names. If "like", interpret `var_names` as substrings of the real variables names. If "regex", interpret `var_names` as regular expressions on the real variables names. See :ref:`this section <common_filter_vars>` for usage examples. kind : {"local", "quantile", "evolution"}, default "local" Specify the kind of plot: * The ``kind="local"`` argument generates the ESS' local efficiency for estimating quantiles of a desired posterior. * The ``kind="quantile"`` argument generates the ESS' local efficiency for estimating small-interval probability of a desired posterior. * The ``kind="evolution"`` argument generates the estimated ESS' with incrised number of iterations of a desired posterior. relative : bool, default False Show relative ess in plot ``ress = ess / N``. coords : dict, optional Coordinates of `var_names` to be plotted. Passed to :meth:`xarray.Dataset.sel`. See :ref:`this section <common_coords>` for usage examples. grid : tuple, optional Number of rows and columns. By default, the rows and columns are automatically inferred. See :ref:`this section <common_grid>` for usage examples. figsize : (float, float), optional Figure size. If ``None`` it will be defined automatically. textsize : float, optional Text size scaling factor for labels, titles and lines. If ``None`` it will be autoscaled based on `figsize`. rug : bool, default False Add a `rug plot <https://en.wikipedia.org/wiki/Rug_plot>`_ for a specific subset of values. rug_kind : str, default "diverging" Variable in sample stats to use as rug mask. Must be a boolean variable. n_points : int, default 20 Number of points for which to plot their quantile/local ess or number of subsets in the evolution plot. extra_methods : bool, default False Plot mean and sd ESS as horizontal lines. Not taken into account if ``kind = 'evolution'``. min_ess : int, default 400 Minimum number of ESS desired. If ``relative=True`` the line is plotted at ``min_ess / n_samples`` for local and quantile kinds and as a curve following the ``min_ess / n`` dependency in evolution kind. labeller : Labeller, optional Class providing the method ``make_label_vert`` to generate the labels in the plot titles. Read the :ref:`label_guide` for more details and usage examples. ax : 2D array-like of matplotlib_axes or bokeh_figure, optional A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create its own array of plot areas (and return it). extra_kwargs : dict, optional If evolution plot, `extra_kwargs` is used to plot ess tail and differentiate it from ess bulk. Otherwise, passed to extra methods lines. text_kwargs : dict, optional Only taken into account when ``extra_methods=True``. kwargs passed to ax.annotate for extra methods lines labels. It accepts the additional key ``x`` to set ``xy=(text_kwargs["x"], mcse)`` hline_kwargs : dict, optional kwargs passed to :func:`~matplotlib.axes.Axes.axhline` or to :class:`~bokeh.models.Span` depending on the backend for the horizontal minimum ESS line. For relative ess evolution plots the kwargs are passed to :func:`~matplotlib.axes.Axes.plot` or to :class:`~bokeh.plotting.figure.line` rug_kwargs : dict kwargs passed to rug plot. backend : {"matplotlib", "bokeh"}, default "matplotlib" Select plotting backend. backend_kwargs : dict, optional These are kwargs specific to the backend being used, passed to :func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`. For additional documentation check the plotting method of the backend. show : bool, optional Call backend show function. **kwargs Passed as-is to :meth:`mpl:matplotlib.axes.Axes.hist` or :meth:`mpl:matplotlib.axes.Axes.plot` function depending on the value of `kind`. Returns ------- axes : matplotlib_axes or bokeh_figure See Also -------- ess : Calculate estimate of the effective sample size. References ---------- .. [1] Vehtari et al. (2019). Rank-normalization, folding, and localization: An improved Rhat for assessing convergence of MCMC https://arxiv.org/abs/1903.08008 Examples -------- Plot local ESS. .. plot:: :context: close-figs >>> import arviz as az >>> idata = az.load_arviz_data("centered_eight") >>> coords = {"school": ["Choate", "Lawrenceville"]} >>> az.plot_ess( ... idata, kind="local", var_names=["mu", "theta"], coords=coords ... ) Plot ESS evolution as the number of samples increase. When the model is converging properly, both lines in this plot should be roughly linear. .. plot:: :context: close-figs >>> az.plot_ess( ... idata, kind="evolution", var_names=["mu", "theta"], coords=coords ... ) Customize local ESS plot to look like reference paper. .. plot:: :context: close-figs >>> az.plot_ess( ... idata, kind="local", var_names=["mu"], drawstyle="steps-mid", color="k", ... linestyle="-", marker=None, rug=True, rug_kwargs={"color": "r"} ... ) Customize ESS evolution plot to look like reference paper. .. plot:: :context: close-figs >>> extra_kwargs = {"color": "lightsteelblue"} >>> az.plot_ess( ... idata, kind="evolution", var_names=["mu"], ... color="royalblue", extra_kwargs=extra_kwargs ... ) """ valid_kinds = ("local", "quantile", "evolution") kind = kind.lower() if kind not in valid_kinds: raise ValueError(f"Invalid kind, kind must be one of {valid_kinds} not {kind}") if coords is None: coords = {} if "chain" in coords or "draw" in coords: raise ValueError("chain and draw are invalid coordinates for this kind of plot") if labeller is None: labeller = BaseLabeller() extra_methods = False if kind == "evolution" else extra_methods data = get_coords(convert_to_dataset(idata, group="posterior"), coords) var_names = _var_names(var_names, data, filter_vars) n_draws = data.sizes["draw"] n_samples = n_draws * data.sizes["chain"] ess_tail_dataset = None mean_ess = None sd_ess = None if kind == "quantile": probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points) xdata = probs ylabel = "{} for quantiles" ess_dataset = xr.concat( [ ess(data, var_names=var_names, relative=relative, method="quantile", prob=p) for p in probs ], dim="ess_dim", ) elif kind == "local": probs = np.linspace(0, 1, n_points, endpoint=False) xdata = probs ylabel = "{} for small intervals" ess_dataset = xr.concat( [ ess( data, var_names=var_names, relative=relative, method="local", prob=[p, p + 1 / n_points], ) for p in probs ], dim="ess_dim", ) else: first_draw = data.draw.values[0] ylabel = "{}" xdata = np.linspace(n_samples / n_points, n_samples, n_points) draw_divisions = np.linspace(n_draws // n_points, n_draws, n_points, dtype=int) ess_dataset = xr.concat( [ ess( data.sel(draw=slice(first_draw + draw_div)), var_names=var_names, relative=relative, method="bulk", ) for draw_div in draw_divisions ], dim="ess_dim", ) ess_tail_dataset = xr.concat( [ ess( data.sel(draw=slice(first_draw + draw_div)), var_names=var_names, relative=relative, method="tail", ) for draw_div in draw_divisions ], dim="ess_dim", ) plotters = filter_plotters_list( list(xarray_var_iter(ess_dataset, var_names=var_names, skip_dims={"ess_dim"})), "plot_ess" ) length_plotters = len(plotters) rows, cols = default_grid(length_plotters, grid=grid) if extra_methods: mean_ess = ess(data, var_names=var_names, method="mean", relative=relative) sd_ess = ess(data, var_names=var_names, method="sd", relative=relative) essplot_kwargs = dict( ax=ax, plotters=plotters, xdata=xdata, ess_tail_dataset=ess_tail_dataset, mean_ess=mean_ess, sd_ess=sd_ess, idata=idata, data=data, kind=kind, extra_methods=extra_methods, textsize=textsize, rows=rows, cols=cols, figsize=figsize, kwargs=kwargs, extra_kwargs=extra_kwargs, text_kwargs=text_kwargs, n_samples=n_samples, relative=relative, min_ess=min_ess, labeller=labeller, ylabel=ylabel, rug=rug, rug_kind=rug_kind, rug_kwargs=rug_kwargs, hline_kwargs=hline_kwargs, backend_kwargs=backend_kwargs, show=show, ) if backend is None: backend = rcParams["plot.backend"] backend = backend.lower() # TODO: Add backend kwargs plot = get_plotting_function("plot_ess", "essplot", backend) ax = plot(**essplot_kwargs) return ax