Source code for arviz.plots.lmplot

"""Plot regression figure."""

import warnings
from numbers import Integral
from itertools import repeat

import xarray as xr
import numpy as np
from xarray.core.dataarray import DataArray

from ..sel_utils import xarray_var_iter
from ..rcparams import rcParams
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function


def _repeat_flatten_list(lst, n):
    return [item for sublist in repeat(lst, n) for item in sublist]


[docs] def plot_lm( y, idata=None, x=None, y_model=None, y_hat=None, num_samples=50, kind_pp="samples", kind_model="lines", xjitter=False, plot_dim=None, backend=None, y_kwargs=None, y_hat_plot_kwargs=None, y_hat_fill_kwargs=None, y_model_plot_kwargs=None, y_model_fill_kwargs=None, y_model_mean_kwargs=None, backend_kwargs=None, show=None, figsize=None, textsize=None, axes=None, legend=True, grid=True, ): """Posterior predictive and mean plots for regression-like data. Parameters ---------- y : str or DataArray or ndarray If str, variable name from ``observed_data``. idata : InferenceData, Optional Optional only if ``y`` is not str. x : str, tuple of strings, DataArray or array-like, optional If str or tuple, variable name from ``constant_data``. If ndarray, could be 1D, or 2D for multiple plots. If None, coords name of ``y`` (``y`` should be DataArray). y_model : str or Sequence, Optional If str, variable name from ``posterior``. Its dimensions should be same as ``y`` plus added chains and draws. y_hat : str, Optional If str, variable name from ``posterior_predictive``. Its dimensions should be same as ``y`` plus added chains and draws. num_samples : int, Optional, Default 50 Significant if ``kind_pp`` is "samples" or ``kind_model`` is "lines". Number of samples to be drawn from posterior predictive or kind_pp : {"samples", "hdi"}, Default "samples" Options to visualize uncertainty in data. kind_model : {"lines", "hdi"}, Default "lines" Options to visualize uncertainty in mean of the data. plot_dim : str, Optional Necessary if ``y`` is multidimensional. backend : str, Optional Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". y_kwargs : dict, optional Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib and :meth:`bokeh:bokeh.plotting.Figure.circle` in bokeh y_hat_plot_kwargs : dict, optional Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib and :meth:`bokeh:bokeh.plotting.Figure.circle` in bokeh y_hat_fill_kwargs : dict, optional Passed to :func:`arviz.plot_hdi` y_model_plot_kwargs : dict, optional Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib and :meth:`bokeh:bokeh.plotting.Figure.line` in bokeh y_model_fill_kwargs : dict, optional Significant if ``kind_model`` is "hdi". Passed to :func:`arviz.plot_hdi` y_model_mean_kwargs : dict, optional Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib and :meth:`bokeh:bokeh.plotting.Figure.line` in bokeh backend_kwargs : dict, optional These are kwargs specific to the backend being used. Passed to :func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`. figsize : (float, float), optional Figure size. If None it will be defined automatically. textsize : float, optional Text size scaling factor for labels, titles and lines. If None it will be autoscaled based on ``figsize``. axes : 2D numpy array-like of matplotlib_axes or bokeh_figures, optional A 2D array of locations into which to plot the densities. If not supplied, Arviz will create its own array of plot areas (and return it). show : bool, optional Call backend show function. legend : bool, optional Add legend to figure. By default True. grid : bool, optional Add grid to figure. By default True. Returns ------- axes: matplotlib axes or bokeh figures See Also -------- plot_ts : Plot timeseries data plot_ppc : Plot for posterior/prior predictive checks Examples -------- Plot regression default plot .. plot:: :context: close-figs >>> import arviz as az >>> import numpy as np >>> import xarray as xr >>> idata = az.load_arviz_data('regression1d') >>> x = xr.DataArray(np.linspace(0, 1, 100)) >>> idata.posterior["y_model"] = idata.posterior["intercept"] + idata.posterior["slope"]*x >>> az.plot_lm(idata=idata, y="y", x=x) Plot regression data and mean uncertainty .. plot:: :context: close-figs >>> az.plot_lm(idata=idata, y="y", x=x, y_model="y_model") Plot regression data and mean uncertainty in hdi form .. plot:: :context: close-figs >>> az.plot_lm( ... idata=idata, y="y", x=x, y_model="y_model", kind_pp="hdi", kind_model="hdi" ... ) Plot regression data for multi-dimensional y using plot_dim .. plot:: :context: close-figs >>> data = az.from_dict( ... observed_data = { "y": np.random.normal(size=(5, 7)) }, ... posterior_predictive = {"y": np.random.randn(4, 1000, 5, 7) / 2}, ... dims={"y": ["dim1", "dim2"]}, ... coords={"dim1": range(5), "dim2": range(7)} ... ) >>> az.plot_lm(idata=data, y="y", plot_dim="dim1") """ if kind_pp not in ("samples", "hdi"): raise ValueError("kind_ppc should be either samples or hdi") if kind_model not in ("lines", "hdi"): raise ValueError("kind_model should be either lines or hdi") if y_hat is None and isinstance(y, str): y_hat = y if isinstance(y, str): y = idata.observed_data[y] elif not isinstance(y, DataArray): y = xr.DataArray(y) if len(y.dims) > 1 and plot_dim is None: raise ValueError("Argument plot_dim is needed in case of multidimensional data") x_var_names = None if isinstance(x, str): x = idata.constant_data[x] x_skip_dims = x.dims elif isinstance(x, tuple): x_var_names = x x = idata.constant_data x_skip_dims = x.dims elif isinstance(x, DataArray): x_skip_dims = x.dims elif x is None: x = y.coords[y.dims[0]] if plot_dim is None else y.coords[plot_dim] x_skip_dims = x.dims else: x = xr.DataArray(x) x_skip_dims = [x.dims[-1]] # If posterior is present in idata and y_hat is there, get its values if isinstance(y_model, str): if "posterior" not in idata.groups(): warnings.warn("Posterior not found in idata", UserWarning) y_model = None elif hasattr(idata.posterior, y_model): y_model = idata.posterior[y_model] else: warnings.warn("y_model not found in posterior", UserWarning) y_model = None # If posterior_predictive is present in idata and y_hat is there, get its values if isinstance(y_hat, str): if "posterior_predictive" not in idata.groups(): warnings.warn("posterior_predictive not found in idata", UserWarning) y_hat = None elif hasattr(idata.posterior_predictive, y_hat): y_hat = idata.posterior_predictive[y_hat] else: warnings.warn("y_hat not found in posterior_predictive", UserWarning) y_hat = None # Check if num_pp_smaples is valid and generate num_pp_smaples number of random indexes. # Only needed if kind_pp="samples" or kind_model="lines". Not req for plotting hdi pp_sample_ix = None if (y_hat is not None and kind_pp == "samples") or ( y_model is not None and kind_model == "lines" ): if y_hat is not None: total_pp_samples = y_hat.sizes["chain"] * y_hat.sizes["draw"] else: total_pp_samples = y_model.sizes["chain"] * y_model.sizes["draw"] if ( not isinstance(num_samples, Integral) or num_samples < 1 or num_samples > total_pp_samples ): raise TypeError(f"`num_samples` must be an integer between 1 and {total_pp_samples}.") pp_sample_ix = np.random.choice(total_pp_samples, size=num_samples, replace=False) # crucial step in case of multidim y if plot_dim is None: skip_dims = list(y.dims) elif isinstance(plot_dim, str): skip_dims = [plot_dim] elif isinstance(plot_dim, tuple): skip_dims = list(plot_dim) # Generate x axis plotters. x = filter_plotters_list( plotters=list( xarray_var_iter( x, var_names=x_var_names, skip_dims=set(x_skip_dims), combined=True, ) ), plot_kind="plot_lm", ) # Generate y axis plotters y = filter_plotters_list( plotters=list( xarray_var_iter( y, skip_dims=set(skip_dims), combined=True, ) ), plot_kind="plot_lm", ) # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs len_y = len(y) len_x = len(x) length_plotters = len_x * len_y y = _repeat_flatten_list(y, len_x) x = _repeat_flatten_list(x, len_y) # Filter out the required values to generate plotters if y_hat is not None: if kind_pp == "samples": y_hat = y_hat.stack(__sample__=("chain", "draw"))[..., pp_sample_ix] skip_dims += ["__sample__"] y_hat = [ tup for _, tup in zip( range(len_y), xarray_var_iter( y_hat, skip_dims=set(skip_dims), combined=True, ), ) ] y_hat = _repeat_flatten_list(y_hat, len_x) # Filter out the required values to generate plotters if y_model is not None: if kind_model == "lines": y_model = y_model.stack(__sample__=("chain", "draw"))[..., pp_sample_ix] y_model = [ tup for _, tup in zip( range(len_y), xarray_var_iter( y_model, skip_dims=set(y_model.dims), combined=True, ), ) ] y_model = _repeat_flatten_list(y_model, len_x) rows, cols = default_grid(length_plotters) lmplot_kwargs = dict( x=x, y=y, y_model=y_model, y_hat=y_hat, num_samples=num_samples, kind_pp=kind_pp, kind_model=kind_model, length_plotters=length_plotters, xjitter=xjitter, rows=rows, cols=cols, y_kwargs=y_kwargs, y_hat_plot_kwargs=y_hat_plot_kwargs, y_hat_fill_kwargs=y_hat_fill_kwargs, y_model_plot_kwargs=y_model_plot_kwargs, y_model_fill_kwargs=y_model_fill_kwargs, y_model_mean_kwargs=y_model_mean_kwargs, backend_kwargs=backend_kwargs, show=show, figsize=figsize, textsize=textsize, axes=axes, legend=legend, grid=grid, ) if backend is None: backend = rcParams["plot.backend"] backend = backend.lower() plot = get_plotting_function("plot_lm", "lmplot", backend) ax = plot(**lmplot_kwargs) return ax