{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Refitting NumPyro models with ArviZ\n", "\n", "ArviZ is backend agnostic and therefore does not sample directly. In order to take advantage of algorithms that require refitting models several times, ArviZ uses {class}`~arviz.SamplingWrapper` to convert the API of the sampling backend to a common set of functions. Hence, functions like Leave Future Out Cross Validation can be used in ArviZ independently of the sampling backend used." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below there is an example of `SamplingWrapper` usage for [NumPyro](https://pyro.ai/numpyro/)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import numpyro\n", "import numpyro.distributions as dist\n", "import jax.random as random\n", "from numpyro.infer import MCMC, NUTS\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import scipy.stats as stats\n", "import xarray as xr" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "numpyro.set_host_device_count(4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the example, we will use a linear regression model." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "np.random.seed(26)\n", "\n", "xdata = np.linspace(0, 50, 100)\n", "b0, b1, sigma = -2, 1, 3\n", "ydata = np.random.normal(loc=b1 * xdata + b0, scale=sigma)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(xdata, ydata)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will write the NumPyro code:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def model(N, x, y=None):\n", " b0 = numpyro.sample(\"b0\", dist.Normal(0, 10))\n", " b1 = numpyro.sample(\"b1\", dist.Normal(0, 10))\n", " sigma_e = numpyro.sample(\"sigma_e\", dist.HalfNormal(10))\n", " numpyro.sample(\"y\", dist.Normal(b0 + b1 * x, sigma_e), obs=y)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "data_dict = {\n", " \"N\": len(ydata),\n", " \"y\": ydata,\n", " \"x\": xdata,\n", "}\n", "kernel = NUTS(model)\n", "sample_kwargs = dict(\n", " sampler=kernel, num_warmup=1000, num_samples=1000, num_chains=4, chain_method=\"parallel\"\n", ")\n", "mcmc = MCMC(**sample_kwargs)\n", "mcmc.run(random.PRNGKey(0), **data_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have defined a dictionary `sample_kwargs` that will be passed to the `SamplingWrapper` in order to make sure that all refits use the same sampler parameters. We follow the same pattern with {func}`az.from_numpyro `." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999\n",
             "Data variables:\n",
             "    b0       (chain, draw) float32 -3.0963688 -3.1254756 ... -2.5883367\n",
             "    b1       (chain, draw) float32 1.0462681 1.0379426 ... 1.038727 1.0135907\n",
             "    sigma_e  (chain, draw) float32 3.047911 2.6600552 ... 3.0927758 3.2862334\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:36:51.467097\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 1000, time: 100)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "Data variables:\n",
             "    y        (chain, draw, time) float32 -2.1860917 -3.248132 ... -2.305284\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:36:51.544419\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain      (chain) int64 0 1 2 3\n",
             "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
             "Data variables:\n",
             "    diverging  (chain, draw) bool False False False False ... False False False\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:36:51.468495\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (time: 100)\n",
             "Coordinates:\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "Data variables:\n",
             "    y        (time) float64 -1.412 -7.319 1.151 1.502 ... 48.49 48.52 46.03\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:36:51.545286\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (time: 100)\n",
             "Coordinates:\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "Data variables:\n",
             "    x        (time) float64 0.0 0.5051 1.01 1.515 ... 48.48 48.99 49.49 50.0\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:36:51.545865\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dims = {\"y\": [\"time\"], \"x\": [\"time\"]}\n", "idata_kwargs = {\"dims\": dims, \"constant_data\": {\"x\": xdata}}\n", "idata = az.from_numpyro(mcmc, **idata_kwargs)\n", "idata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will create a subclass of `az.SamplingWrapper`. Therefore, instead of having to implement all functions required by {func}`~arviz.reloo` we only have to implement {func}`~arviz.SamplingWrapper.sel_observations` (we are cloning {func}`~arviz.SamplingWrapper.sample` and {func}`~arviz.SamplingWrapper.get_inference_data` from the {class}`~arviz.PyStanSamplingWrapper` in order to use {func}`~xarray:xarray.apply_ufunc` instead of assuming the log likelihood is calculated within Stan). \n", "\n", "Let's check the 2 outputs of `sel_observations`.\n", "1. `data__i` is a dictionary because it is an argument of `sample` which will pass it as is to `model.sampling`.\n", "2. `data_ex` is a list because it is an argument to `log_likelihood__i` which will pass it as `*data_ex` to `apply_ufunc`.\n", "\n", "More on `data_ex` and `apply_ufunc` integration is given below." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class NumPyroSamplingWrapper(az.SamplingWrapper):\n", " def __init__(self, model, **kwargs):\n", " self.model_fun = model.sampler.model\n", " self.rng_key = kwargs.pop(\"rng_key\", random.PRNGKey(0))\n", "\n", " super(NumPyroSamplingWrapper, self).__init__(model, **kwargs)\n", "\n", " def log_likelihood__i(self, excluded_obs, idata__i):\n", " samples = {\n", " key: values.values.reshape((-1, *values.values.shape[2:]))\n", " for key, values in idata__i.posterior.items()\n", " }\n", " log_likelihood_dict = numpyro.infer.log_likelihood(self.model_fun, samples, **excluded_obs)\n", " if len(log_likelihood_dict) > 1:\n", " raise ValueError(\"multiple likelihoods found\")\n", " data = {}\n", " nchains = idata__i.posterior.dims[\"chain\"]\n", " ndraws = idata__i.posterior.dims[\"draw\"]\n", " for obs_name, log_like in log_likelihood_dict.items():\n", " shape = (nchains, ndraws) + log_like.shape[1:]\n", " data[obs_name] = np.reshape(log_like.copy(), shape)\n", " return az.dict_to_dataset(data)[obs_name]\n", "\n", " def sample(self, modified_observed_data):\n", " self.rng_key, subkey = random.split(self.rng_key)\n", " mcmc = MCMC(**self.sample_kwargs)\n", " mcmc.run(subkey, **modified_observed_data)\n", " return mcmc\n", "\n", " def get_inference_data(self, fit):\n", " # Cloned from PyStanSamplingWrapper.\n", " idata = az.from_numpyro(mcmc, **self.idata_kwargs)\n", " return idata\n", "\n", "\n", "class LinRegWrapper(NumPyroSamplingWrapper):\n", " def sel_observations(self, idx):\n", " xdata = self.idata_orig.constant_data[\"x\"].values\n", " ydata = self.idata_orig.observed_data[\"y\"].values\n", " mask = np.isin(np.arange(len(xdata)), idx)\n", " data__i = {\"x\": xdata[~mask], \"y\": ydata[~mask], \"N\": len(ydata[~mask])}\n", " data_ex = {\"x\": xdata[mask], \"y\": ydata[mask], \"N\": len(ydata[mask])}\n", " return data__i, data_ex" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.92 7.20\n", "p_loo 3.11 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 100 100.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 0 0.0%\n", " (1, Inf) (very bad) 0 0.0%" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_orig = az.loo(idata, pointwise=True)\n", "loo_orig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, the Leave-One-Out Cross Validation (LOO-CV) approximation using [Pareto Smoothed Importance Sampling](https://arxiv.org/abs/1507.02646) (PSIS) works for all observations, so we will use modify `loo_orig` in order to make `az.reloo` believe that PSIS failed for some observations. This will also serve as a validation of our wrapper, as the PSIS LOO-CV already returned the correct value." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "loo_orig.pareto_k[[13, 42, 56, 73]] = np.array([0.8, 1.2, 2.6, 0.9])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We initialize our sampling wrapper. Let's stop and analyze each of the arguments. \n", "\n", "* We use `idata_orig` as a starting point, and mostly as a source of observed and constant data which is then subsetted in `sel_observations`.\n", "\n", "* We also use `model` to get automatic log likelihood computation and we have the option to set the `rng_key`. Even if the data for each fit is different the `rng_key` is split with every fit.\n", "\n", "* Finally, `sample_kwargs` and `idata_kwargs` are used to make sure all refits and corresponding `InferenceData` are generated with the same properties." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "numpyro_wrapper = LinRegWrapper(\n", " mcmc,\n", " rng_key=random.PRNGKey(5),\n", " idata_orig=idata,\n", " sample_kwargs=sample_kwargs,\n", " idata_kwargs=idata_kwargs,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And eventually, we can use this wrapper to call `az.reloo`, and compare the results with the PSIS LOO-CV results." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/oriol/miniconda3/envs/arviz/lib/python3.8/site-packages/arviz/stats/stats_refitting.py:99: UserWarning: reloo is an experimental and untested feature\n", " warnings.warn(\"reloo is an experimental and untested feature\", UserWarning)\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 13\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 13\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 42\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 42\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 56\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 56\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 73\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 73\n" ] } ], "source": [ "loo_relooed = az.reloo(numpyro_wrapper, loo_orig=loo_orig)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.89 7.20\n", "p_loo 3.08 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 100 100.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 0 0.0%\n", " (1, Inf) (very bad) 0 0.0%" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_relooed" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.92 7.20\n", "p_loo 3.11 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 96 96.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 2 2.0%\n", " (1, Inf) (very bad) 2 2.0%" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_orig" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 2 }