Source code for arviz.plots.tsplot

"""Plot timeseries data."""
import warnings
import numpy as np

from ..sel_utils import xarray_var_iter
from ..rcparams import rcParams
from .plot_utils import default_grid, get_plotting_function


[docs] def plot_ts( idata, y, x=None, y_hat=None, y_holdout=None, y_forecasts=None, x_holdout=None, plot_dim=None, holdout_dim=None, num_samples=100, backend=None, backend_kwargs=None, y_kwargs=None, y_hat_plot_kwargs=None, y_mean_plot_kwargs=None, vline_kwargs=None, textsize=None, figsize=None, legend=True, axes=None, show=None, ): """Plot timeseries data. Parameters ---------- idata : InferenceData :class:`arviz.InferenceData` object. y : str Variable name from ``observed_data``. Values to be plotted on y-axis before holdout. x : str, Optional Values to be plotted on x-axis before holdout. If None, coords of ``y`` dims is chosen. y_hat : str, optional Variable name from ``posterior_predictive``. Assumed to be of shape ``(chain, draw, *y_dims)``. y_holdout : str, optional Variable name from ``observed_data``. It represents the observed data after the holdout period. Useful while testing the model, when you want to compare observed test data with predictions/forecasts. y_forecasts : str, optional Variable name from ``posterior_predictive``. It represents forecasts (posterior predictive) values after holdout period. Useful to compare observed vs predictions/forecasts. Assumed shape ``(chain, draw, *shape)``. x_holdout : str, Defaults to coords of y. Variable name from ``constant_data``. If None, coords of ``y_holdout`` or coords of ``y_forecast`` (either of the two available) is chosen. plot_dim: str, Optional Should be present in ``y.dims``. Necessary for selection of ``x`` if ``x`` is None and ``y`` is multidimensional. holdout_dim: str, Optional Should be present in ``y_holdout.dims`` or ``y_forecats.dims``. Necessary to choose ``x_holdout`` if ``x`` is None and if ``y_holdout`` or ``y_forecasts`` is multidimensional. num_samples : int, default 100 Number of posterior predictive samples drawn from ``y_hat`` and ``y_forecasts``. backend : {"matplotlib", "bokeh"}, default "matplotlib" Select plotting backend. y_kwargs : dict, optional Passed to :meth:`matplotlib.axes.Axes.plot` in matplotlib. y_hat_plot_kwargs : dict, optional Passed to :meth:`matplotlib.axes.Axes.plot` in matplotlib. y_mean_plot_kwargs : dict, optional Passed to :meth:`matplotlib.axes.Axes.plot` in matplotlib. vline_kwargs : dict, optional Passed to :meth:`matplotlib.axes.Axes.axvline` in matplotlib. backend_kwargs : dict, optional These are kwargs specific to the backend being used. Passed to :func:`matplotlib.pyplot.subplots`. figsize : tuple, optional Figure size. If None, it will be defined automatically. textsize : float, optional Text size scaling factor for labels, titles and lines. If None, it will be autoscaled based on ``figsize``. Returns ------- axes: matplotlib axes or bokeh figures. See Also -------- plot_lm : Posterior predictive and mean plots for regression-like data. plot_ppc : Plot for posterior/prior predictive checks. Examples -------- Plot timeseries default plot .. plot:: :context: close-figs >>> import arviz as az >>> nchains, ndraws = (4, 500) >>> obs_data = { ... "y": 2 * np.arange(1, 9) + 3, ... "z": 2 * np.arange(8, 12) + 3, ... } >>> posterior_predictive = { ... "y": np.random.normal( ... (obs_data["y"] * 1.2) - 3, size=(nchains, ndraws, len(obs_data["y"])) ... ), ... "z": np.random.normal( ... (obs_data["z"] * 1.2) - 3, size=(nchains, ndraws, len(obs_data["z"])) ... ), ... } >>> idata = az.from_dict( ... observed_data=obs_data, ... posterior_predictive=posterior_predictive, ... coords={"obs_dim": np.arange(1, 9), "pred_dim": np.arange(8, 12)}, ... dims={"y": ["obs_dim"], "z": ["pred_dim"]}, ... ) >>> ax = az.plot_ts(idata=idata, y="y", y_holdout="z") Plot timeseries multidim plot .. plot:: :context: close-figs >>> ndim1, ndim2 = (5, 7) >>> data = { ... "y": np.random.normal(size=(ndim1, ndim2)), ... "z": np.random.normal(size=(ndim1, ndim2)), ... } >>> posterior_predictive = { ... "y": np.random.randn(nchains, ndraws, ndim1, ndim2), ... "z": np.random.randn(nchains, ndraws, ndim1, ndim2), ... } >>> const_data = {"x": np.arange(1, 6), "x_pred": np.arange(5, 10)} >>> idata = az.from_dict( ... observed_data=data, ... posterior_predictive=posterior_predictive, ... constant_data=const_data, ... dims={ ... "y": ["dim1", "dim2"], ... "z": ["holdout_dim1", "holdout_dim2"], ... }, ... coords={ ... "dim1": range(ndim1), ... "dim2": range(ndim2), ... "holdout_dim1": range(ndim1 - 1, ndim1 + 4), ... "holdout_dim2": range(ndim2 - 1, ndim2 + 6), ... }, ... ) >>> az.plot_ts( ... idata=idata, ... y="y", ... plot_dim="dim1", ... y_holdout="z", ... holdout_dim="holdout_dim1", ... ) """ # Assign default values if none is provided y_hat = y if y_hat is None and isinstance(y, str) else y_hat y_forecasts = y_holdout if y_forecasts is None and isinstance(y_holdout, str) else y_forecasts # holdout_dim = plot_dim if holdout_dim is None and plot_dim is not None else holdout_dim if isinstance(y, str): y = idata.observed_data[y] if isinstance(y_holdout, str): y_holdout = idata.observed_data[y_holdout] if len(y.dims) > 1 and plot_dim is None: raise ValueError("Argument plot_dim is needed in case of multidimensional data") if y_holdout is not None and len(y_holdout.dims) > 1 and holdout_dim is None: raise ValueError("Argument holdout_dim is needed in case of multidimensional data") # Assigning values to x x_var_names = None if isinstance(x, str): x = idata.constant_data[x] elif isinstance(x, tuple): x_var_names = x x = idata.constant_data elif x is None: if plot_dim is None: x = y.coords[y.dims[0]] else: x = y.coords[plot_dim] # If posterior_predictive is present in idata and y_hat is there, get its values if isinstance(y_hat, str): if "posterior_predictive" not in idata.groups(): warnings.warn("posterior_predictive not found in idata", UserWarning) y_hat = None elif hasattr(idata.posterior_predictive, y_hat): y_hat = idata.posterior_predictive[y_hat] else: warnings.warn("y_hat not found in posterior_predictive", UserWarning) y_hat = None # If posterior_predictive is present in idata and y_forecasts is there, get its values x_holdout_var_names = None if isinstance(y_forecasts, str): if "posterior_predictive" not in idata.groups(): warnings.warn("posterior_predictive not found in idata", UserWarning) y_forecasts = None elif hasattr(idata.posterior_predictive, y_forecasts): y_forecasts = idata.posterior_predictive[y_forecasts] else: warnings.warn("y_hat not found in posterior_predictive", UserWarning) y_forecasts = None # Assign values to y_holdout if isinstance(y_holdout, str): y_holdout = idata.observed_data[y_holdout] # Assign values to x_holdout. if y_holdout is not None or y_forecasts is not None: if x_holdout is None: if holdout_dim is None: if y_holdout is None: x_holdout = y_forecasts.coords[y_forecasts.dims[-1]] else: x_holdout = y_holdout.coords[y_holdout.dims[-1]] elif y_holdout is None: x_holdout = y_forecasts.coords[holdout_dim] else: x_holdout = y_holdout.coords[holdout_dim] elif isinstance(x_holdout, str): x_holdout = idata.constant_data[x_holdout] elif isinstance(x_holdout, tuple): x_holdout_var_names = x_holdout x_holdout = idata.constant_data # Choose dims to generate y plotters if plot_dim is None: skip_dims = list(y.dims) elif isinstance(plot_dim, str): skip_dims = [plot_dim] elif isinstance(plot_dim, tuple): skip_dims = list(plot_dim) # Choose dims to generate y_holdout plotters if holdout_dim is None: if y_holdout is not None: skip_holdout_dims = list(y_holdout.dims) elif y_forecasts is not None: skip_holdout_dims = list(y_forecasts.dims) elif isinstance(holdout_dim, str): skip_holdout_dims = [holdout_dim] elif isinstance(holdout_dim, tuple): skip_holdout_dims = list(holdout_dim) # Compulsory plotters y_plotters = list( xarray_var_iter( y, skip_dims=set(skip_dims), combined=True, ) ) # Compulsory plotters x_plotters = list( xarray_var_iter( x, var_names=x_var_names, skip_dims=set(x.dims), combined=True, ) ) # Necessary when multidim y # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs len_y = len(y_plotters) len_x = len(x_plotters) length_plotters = len_x * len_y # TODO: Incompatible types in assignment (expression has type "ndarray[Any, dtype[Any]]", # TODO: variable has type "List[Any]") [assignment] y_plotters = np.tile(np.array(y_plotters, dtype=object), (len_x, 1)) # type: ignore[assignment] x_plotters = np.tile(np.array(x_plotters, dtype=object), (len_y, 1)) # type: ignore[assignment] # Generate plotters for all the available data y_mean_plotters = None y_hat_plotters = None if y_hat is not None: total_samples = y_hat.sizes["chain"] * y_hat.sizes["draw"] pp_sample_ix = np.random.choice(total_samples, size=num_samples, replace=False) y_hat_satcked = y_hat.stack(__sample__=("chain", "draw"))[..., pp_sample_ix] y_hat_plotters = list( xarray_var_iter( y_hat_satcked, skip_dims=set(skip_dims + ["__sample__"]), combined=True, ) ) y_mean = y_hat.mean(("chain", "draw")) y_mean_plotters = list( xarray_var_iter( y_mean, skip_dims=set(skip_dims), combined=True, ) ) # Necessary when multidim y # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs # TODO: Incompatible types in assignment (expression has type "ndarray[Any, dtype[Any]]", # TODO: variable has type "List[Any]") [assignment] y_hat_plotters = np.tile( np.array(y_hat_plotters, dtype=object), (len_x, 1) ) # type: ignore[assignment] y_mean_plotters = np.tile( np.array(y_mean_plotters, dtype=object), (len_x, 1) ) # type: ignore[assignment] y_holdout_plotters = None x_holdout_plotters = None if y_holdout is not None: y_holdout_plotters = list( xarray_var_iter( y_holdout, skip_dims=set(skip_holdout_dims), combined=True, ) ) x_holdout_plotters = list( xarray_var_iter( x_holdout, var_names=x_holdout_var_names, skip_dims=set(x_holdout.dims), combined=True, ) ) # Necessary when multidim y # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs # TODO: Incompatible types in assignment (expression has type "ndarray[Any, dtype[Any]]", # TODO: variable has type "List[Any]") [assignment] y_holdout_plotters = np.tile( np.array(y_holdout_plotters, dtype=object), (len_x, 1) ) # type: ignore[assignment] x_holdout_plotters = np.tile( np.array(x_holdout_plotters, dtype=object), (len_y, 1) ) # type: ignore[assignment] y_forecasts_plotters = None y_forecasts_mean_plotters = None if y_forecasts is not None: total_samples = y_forecasts.sizes["chain"] * y_forecasts.sizes["draw"] pp_sample_ix = np.random.choice(total_samples, size=num_samples, replace=False) y_forecasts_satcked = y_forecasts.stack(__sample__=("chain", "draw"))[..., pp_sample_ix] y_forecasts_plotters = list( xarray_var_iter( y_forecasts_satcked, skip_dims=set(skip_holdout_dims + ["__sample__"]), combined=True, ) ) y_forecasts_mean = y_forecasts.mean(("chain", "draw")) y_forecasts_mean_plotters = list( xarray_var_iter( y_forecasts_mean, skip_dims=set(skip_holdout_dims), combined=True, ) ) x_holdout_plotters = list( xarray_var_iter( x_holdout, var_names=x_holdout_var_names, skip_dims=set(x_holdout.dims), combined=True, ) ) # Necessary when multidim y # If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs # TODO: Incompatible types in assignment (expression has type "ndarray[Any, dtype[Any]]", # TODO: variable has type "List[Any]") [assignment] y_forecasts_mean_plotters = np.tile( np.array(y_forecasts_mean_plotters, dtype=object), (len_x, 1) ) # type: ignore[assignment] y_forecasts_plotters = np.tile( np.array(y_forecasts_plotters, dtype=object), (len_x, 1) ) # type: ignore[assignment] x_holdout_plotters = np.tile( np.array(x_holdout_plotters, dtype=object), (len_y, 1) ) # type: ignore[assignment] rows, cols = default_grid(length_plotters) tsplot_kwargs = dict( x_plotters=x_plotters, y_plotters=y_plotters, y_mean_plotters=y_mean_plotters, y_hat_plotters=y_hat_plotters, y_holdout_plotters=y_holdout_plotters, x_holdout_plotters=x_holdout_plotters, y_forecasts_plotters=y_forecasts_plotters, y_forecasts_mean_plotters=y_forecasts_mean_plotters, num_samples=num_samples, length_plotters=length_plotters, rows=rows, cols=cols, backend_kwargs=backend_kwargs, y_kwargs=y_kwargs, y_hat_plot_kwargs=y_hat_plot_kwargs, y_mean_plot_kwargs=y_mean_plot_kwargs, vline_kwargs=vline_kwargs, textsize=textsize, figsize=figsize, legend=legend, axes=axes, show=show, ) if backend is None: backend = rcParams["plot.backend"] backend = backend.lower() plot = get_plotting_function("plot_ts", "tsplot", backend) ax = plot(**tsplot_kwargs) return ax