{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Refitting PyMC models with ArviZ\n", "\n", "ArviZ is backend agnostic and therefore does not sample directly. In order to take advantage of algorithms that require refitting models several times, ArviZ uses {class}`~arviz.SamplingWrapper` to convert the API of the sampling backend to a common set of functions. Hence, functions like Leave Future Out Cross Validation can be used in ArviZ independently of the sampling backend used." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below there is an example of `SamplingWrapper` usage for [PyMC](https://www.pymc.io/projects/docs/en/stable/learn.html)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import pymc as pm\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the example we will use a linear regression model." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "rng = np.random.default_rng(4)\n", "\n", "xdata = np.linspace(0, 50, 100)\n", "b0, b1, sigma = -2, 1, 3\n", "ydata = rng.normal(loc=b1 * xdata + b0, scale=sigma)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(xdata, ydata);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will write the PyMC3 model, keeping in mind that the data must be modifiable (both `x` and `y`)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "with pm.Model() as linreg_model:\n", " # optional: add coords to \"time\" dimension\n", " linreg_model.add_coord(\"time\", np.arange(len(xdata)), mutable=True)\n", "\n", " x = pm.MutableData(\"x\", xdata, dims=\"time\")\n", " y_obs = pm.MutableData(\"y_obs\", ydata, dims=\"time\")\n", "\n", " b0 = pm.Normal(\"b0\", 0, 10)\n", " b1 = pm.Normal(\"b1\", 0, 10)\n", " sigma_e = pm.HalfNormal(\"sigma_e\", 10)\n", "\n", " pm.Normal(\"y\", b0 + b1 * x, sigma_e, observed=y_obs, dims=\"time\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [b0, b1, sigma_e]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [6000/6000 00:04<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 5 seconds.\n" ] } ], "source": [ "sample_kwargs = {\"chains\": 4, \"draws\": 500}\n", "with linreg_model:\n", " idata = pm.sample(**sample_kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have defined a dictionary `sample_kwargs` that will be passed to the `SamplingWrapper` in order to make sure that all refits use the same sampler parameters. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will create a subclass of `az.SamplingWrapper`. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from scipy import stats\n", "from xarray_einstats.stats import XrContinuousRV\n", "\n", "\n", "class PyMCLinRegWrapper(az.PyMCSamplingWrapper):\n", " def sample(self, modified_observed_data):\n", " with self.model:\n", " # if the model had coords the dim needs to be updated before\n", " # modifying the data in the model with set_data\n", " # otherwise, we don't need to overwrite the sample method\n", " n__i = len(modified_observed_data[\"x\"])\n", " self.model.set_dim(\"time\", n__i, coord_values=np.arange(n__i))\n", "\n", " pm.set_data(modified_observed_data)\n", " idata = pm.sample(\n", " **self.sample_kwargs,\n", " )\n", " return idata\n", "\n", " def log_likelihood__i(self, excluded_observed_data, idata__i):\n", " post = idata__i.posterior\n", " dist = XrContinuousRV(\n", " stats.norm,\n", " post[\"b0\"] + post[\"b1\"] * excluded_observed_data[\"x\"],\n", " post[\"sigma_e\"],\n", " )\n", " return dist.logpdf(excluded_observed_data[\"y_obs\"])\n", "\n", " def sel_observations(self, idx):\n", " xdata = self.idata_orig[\"constant_data\"][\"x\"]\n", " ydata = self.idata_orig[\"observed_data\"][\"y\"]\n", " mask = np.isin(np.arange(len(xdata)), idx)\n", " data_dict = {\"x\": xdata, \"y_obs\": ydata}\n", " data__i = {key: value.values[~mask] for key, value in data_dict.items()}\n", " data_ex = {key: value.isel(time=idx) for key, value in data_dict.items()}\n", " return data__i, data_ex" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 2000 posterior samples and 100 observations log-likelihood matrix.\n", "\n", " Estimate SE\n", "elpd_loo -255.84 6.31\n", "p_loo 2.70 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 100 100.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 0 0.0%\n", " (1, Inf) (very bad) 0 0.0%" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_orig = az.loo(idata, pointwise=True)\n", "loo_orig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, the Leave-One-Out Cross Validation (LOO-CV) approximation using [Pareto Smoothed Importance Sampling](https://arxiv.org/abs/1507.02646) (PSIS) works for all observations, so we will use modify `loo_orig` in order to make `az.reloo` believe that PSIS failed for some observations. This will also serve as a validation of our wrapper, as the PSIS LOO-CV already returned the correct value." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "loo_orig.pareto_k[[13, 42, 56, 73]] = np.array([0.8, 1.2, 2.6, 0.9])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We initialize our sampling wrapper" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "pymc_wrapper = PyMCLinRegWrapper(model=linreg_model, idata_orig=idata, sample_kwargs=sample_kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And eventually, we can use this wrapper to call `az.reloo`, and compare the results with the PSIS LOO-CV results." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/oriol/Public/arviz/arviz/stats/stats_refitting.py:99: UserWarning: reloo is an experimental and untested feature\n", " warnings.warn(\"reloo is an experimental and untested feature\", UserWarning)\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [b0, b1, sigma_e]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [6000/6000 00:04<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 5 seconds.\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [b0, b1, sigma_e]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [6000/6000 00:04<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 4 seconds.\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [b0, b1, sigma_e]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [6000/6000 00:03<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 4 seconds.\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [b0, b1, sigma_e]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [6000/6000 00:04<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 4 seconds.\n" ] } ], "source": [ "loo_relooed = az.reloo(pymc_wrapper, loo_orig=loo_orig)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 2000 posterior samples and 100 observations log-likelihood matrix.\n", "\n", " Estimate SE\n", "elpd_loo -255.82 6.30\n", "p_loo 2.69 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 100 100.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 0 0.0%\n", " (1, Inf) (very bad) 0 0.0%" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_relooed" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 2000 posterior samples and 100 observations log-likelihood matrix.\n", "\n", " Estimate SE\n", "elpd_loo -255.84 6.31\n", "p_loo 2.70 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 96 96.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 2 2.0%\n", " (1, Inf) (very bad) 2 2.0%" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_orig" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" } }, "nbformat": 4, "nbformat_minor": 2 }