{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Refitting NumPyro models with ArviZ (and xarray)\n", "\n", "ArviZ is backend agnostic and therefore does not sample directly. In order to take advantage of algorithms that require refitting models several times, ArviZ uses `SamplingWrappers` to convert the API of the sampling backend to a common set of functions. Hence, functions like Leave Future Out Cross Validation can be used in ArviZ independently of the sampling backend used." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below there is an example of `SamplingWrapper` usage for [NumPyro](https://pyro.ai/numpyro/)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import numpyro\n", "import numpyro.distributions as dist\n", "import jax.random as random\n", "from numpyro.infer import MCMC, NUTS\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import scipy.stats as stats\n", "import xarray as xr" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "numpyro.set_host_device_count(4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this example, we will use a linear regression model." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "np.random.seed(26)\n", "\n", "xdata = np.linspace(0, 50, 100)\n", "b0, b1, sigma = -2, 1, 3\n", "ydata = np.random.normal(loc=b1 * xdata + b0, scale=sigma)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(xdata, ydata)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will write the NumPyro Code:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def model(N, x, y=None):\n", " b0 = numpyro.sample(\"b0\", dist.Normal(0, 10))\n", " b1 = numpyro.sample(\"b1\", dist.Normal(0, 10))\n", " sigma_e = numpyro.sample(\"sigma_e\", dist.HalfNormal(10))\n", " numpyro.sample(\"y\", dist.Normal(b0 + b1 * x, sigma_e), obs=y)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "data_dict = {\n", " \"N\": len(ydata),\n", " \"y\": ydata,\n", " \"x\": xdata,\n", "}\n", "kernel = NUTS(model)\n", "sample_kwargs = dict(\n", " sampler=kernel, num_warmup=1000, num_samples=1000, num_chains=4, chain_method=\"parallel\"\n", ")\n", "mcmc = MCMC(**sample_kwargs)\n", "mcmc.run(random.PRNGKey(0), **data_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have defined a dictionary `sample_kwargs` that will be passed to the `SamplingWrapper` in order to make sure that all\n", "refits use the same sampler parameters. We follow the same pattern with {func}`az.from_numpyro `." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999\n",
             "Data variables:\n",
             "    b0       (chain, draw) float32 -3.0963688 -3.1254756 ... -2.5883367\n",
             "    b1       (chain, draw) float32 1.0462681 1.0379426 ... 1.038727 1.0135907\n",
             "    sigma_e  (chain, draw) float32 3.047911 2.6600552 ... 3.0927758 3.2862334\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:44:50.997985\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain      (chain) int64 0 1 2 3\n",
             "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
             "Data variables:\n",
             "    diverging  (chain, draw) bool False False False False ... False False False\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:44:50.999466\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (time: 100)\n",
             "Coordinates:\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "Data variables:\n",
             "    y        (time) float64 -1.412 -7.319 1.151 1.502 ... 48.49 48.52 46.03\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:44:51.079386\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (time: 100)\n",
             "Coordinates:\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "Data variables:\n",
             "    x        (time) float64 0.0 0.5051 1.01 1.515 ... 48.48 48.99 49.49 50.0\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:44:51.079921\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> sample_stats\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dims = {\"y\": [\"time\"], \"x\": [\"time\"]}\n", "idata_kwargs = {\"dims\": dims, \"constant_data\": {\"x\": xdata}}\n", "idata = az.from_numpyro(mcmc, **idata_kwargs)\n", "del idata.log_likelihood\n", "idata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now missing the `log_likelihood` group because we have not used the `log_likelihood` argument in `idata_kwargs`. We are doing this to ease the job of the sampling wrapper. Instead of going out of our way to get Stan to calculate the pointwise log likelihood values for each refit and for the excluded observation at every refit, we will compromise and manually write a function to calculate the pointwise log likelihood.\n", "\n", "Even though it is not ideal to lose part of the straight out of the box capabilities of PyStan-ArviZ integration, this should generally not be a problem. We are basically moving the pointwise log likelihood calculation from the Stan Code to the Python code, in both cases, we need to manually write the function to calculate the pointwise log likelihood.\n", "\n", "Moreover, the Python computation could even be written to be compatible with [Dask](https://docs.dask.org/en/latest/). Thus it will work even in cases where the large number of observations makes it impossible to store pointwise log likelihood values (with shape `n_samples * n_observations`) in memory." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def calculate_log_lik(x, y, b0, b1, sigma_e):\n", " mu = b0 + b1 * x\n", " return stats.norm(mu, sigma_e).logpdf(y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function should work for any shape of the input arrays as long as their shapes are compatible and can broadcast. There is no need to loop over each draw in order to calculate the pointwise log likelihood using scalars.\n", "\n", "Therefore, we can use {func}`xr.apply_ufunc ` to handle the broadcasting and preserve the dimension names:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "log_lik = xr.apply_ufunc(\n", " calculate_log_lik,\n", " idata.constant_data[\"x\"],\n", " idata.observed_data[\"y\"],\n", " idata.posterior[\"b0\"],\n", " idata.posterior[\"b1\"],\n", " idata.posterior[\"sigma_e\"],\n", ")\n", "idata.add_groups(log_likelihood=log_lik)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first argument is the function, followed by as many positional arguments as needed by the function, 5 in our case. As this case does not have many different dimensions nor combinations of these, we do not need to use any extra kwargs passed to `xr.apply_ufunc`.\n", "\n", "We are now passing the arguments to `calculate_log_lik` initially as `xr.DataArrays`. What is happening here behind the scenes is that `xr.apply_ufunc` is broadcasting and aligning the dimensions of all the DataArrays involved and afterwards passing NumPy arrays to `calculate_log_lik`. Everything works automagically. \n", "\n", "Now let's see what happens if we were to pass the arrays directly to `calculate_log_lik` instead:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "ename": "ValueError", "evalue": "operands could not be broadcast together with shapes (4,1000) (100,) ", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m calculate_log_lik(\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0midata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconstant_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"x\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0midata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobserved_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"y\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0midata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mposterior\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"b0\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0midata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mposterior\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"b1\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mcalculate_log_lik\u001b[0;34m(x, y, b0, b1, sigma_e)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcalculate_log_lik\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msigma_e\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmu\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mb0\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstats\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msigma_e\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogpdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: operands could not be broadcast together with shapes (4,1000) (100,) " ] } ], "source": [ "calculate_log_lik(\n", " idata.constant_data[\"x\"].values,\n", " idata.observed_data[\"y\"].values,\n", " idata.posterior[\"b0\"].values,\n", " idata.posterior[\"b1\"].values,\n", " idata.posterior[\"sigma_e\"].values,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you are still curious about the magic of xarray and `apply_ufunc`, you can also try to modify the `dims` used to generate the `InferenceData` a couple cells before:\n", "\n", " dims = {\"y\": [\"time\"], \"x\": [\"time\"]}\n", " \n", "What happens to the result if you use a different name for the dimension of `x`?" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999\n",
             "Data variables:\n",
             "    b0       (chain, draw) float32 -3.0963688 -3.1254756 ... -2.5883367\n",
             "    b1       (chain, draw) float32 1.0462681 1.0379426 ... 1.038727 1.0135907\n",
             "    sigma_e  (chain, draw) float32 3.047911 2.6600552 ... 3.0927758 3.2862334\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:44:50.997985\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain      (chain) int64 0 1 2 3\n",
             "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
             "Data variables:\n",
             "    diverging  (chain, draw) bool False False False False ... False False False\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:44:50.999466\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (time: 100)\n",
             "Coordinates:\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "Data variables:\n",
             "    y        (time) float64 -1.412 -7.319 1.151 1.502 ... 48.49 48.52 46.03\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:44:51.079386\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (time: 100)\n",
             "Coordinates:\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "Data variables:\n",
             "    x        (time) float64 0.0 0.5051 1.01 1.515 ... 48.48 48.99 49.49 50.0\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:44:51.079921\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 1000, time: 100)\n",
             "Coordinates:\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999\n",
             "Data variables:\n",
             "    x        (time, chain, draw) float64 -2.186 -2.105 -2.077 ... -2.646 -2.305

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> sample_stats\n", "\t> observed_data\n", "\t> constant_data\n", "\t> log_likelihood" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will create a subclass of {class}`~arviz.SamplingWrapper`. Therefore, instead of having to implement all functions required by {func}`~arviz.reloo` we only have to implement {func}`~arviz.SamplingWrapper.sel_observations` (we are cloning {func}`~arviz.SamplingWrapper.sample` and {func}`~arviz.SamplingWrapper.get_inference_data` from the {class}`~arviz.SamplingWrapper` in order to use `apply_ufunc` instead of assuming the log likelihood is calculated within Stan). \n", "\n", "Let's check the 2 outputs of `sel_observations`.\n", "1. `data__i` is a dictionary because it is an argument of `sample` which will pass it as is to `model.sampling`.\n", "2. `data_ex` is a list because it is an argument to `log_likelihood__i` which will pass it as `*data_ex` to `apply_ufunc`.\n", "\n", "More on `data_ex` and `apply_ufunc` integration is given below." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class NumPyroSamplingWrapper(az.SamplingWrapper):\n", " def __init__(self, model, **kwargs):\n", " self.rng_key = kwargs.pop(\"rng_key\", random.PRNGKey(0))\n", "\n", " super(NumPyroSamplingWrapper, self).__init__(model, **kwargs)\n", "\n", " def sample(self, modified_observed_data):\n", " self.rng_key, subkey = random.split(self.rng_key)\n", " mcmc = MCMC(**self.sample_kwargs)\n", " mcmc.run(subkey, **modified_observed_data)\n", " return mcmc\n", "\n", " def get_inference_data(self, fit):\n", " # Cloned from PyStanSamplingWrapper.\n", " idata = az.from_numpyro(mcmc, **self.idata_kwargs)\n", " return idata\n", "\n", "\n", "class LinRegWrapper(NumPyroSamplingWrapper):\n", " def sel_observations(self, idx):\n", " xdata = self.idata_orig.constant_data[\"x\"]\n", " ydata = self.idata_orig.observed_data[\"y\"]\n", " mask = np.isin(np.arange(len(xdata)), idx)\n", " # data__i is passed to numpyro to sample on it -> dict of numpy array\n", " # data_ex is passed to apply_ufunc -> list of DataArray\n", " data__i = {\"x\": xdata[~mask].values, \"y\": ydata[~mask].values, \"N\": len(ydata[~mask])}\n", " data_ex = [xdata[mask], ydata[mask]]\n", " return data__i, data_ex" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.92 7.20\n", "p_loo 3.11 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 100 100.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 0 0.0%\n", " (1, Inf) (very bad) 0 0.0%" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_orig = az.loo(idata, pointwise=True)\n", "loo_orig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, the Leave-One-Out Cross Validation (LOO-CV) approximation using [Pareto Smoothed Importance Sampling](https://arxiv.org/abs/1507.02646) (PSIS) works for all observations, so we will use modify `loo_orig` in order to make {func}`~arviz.reloo` believe that PSIS failed for some observations. This will also serve as a validation of our wrapper, as the PSIS LOO-CV already returned the correct value." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "loo_orig.pareto_k[[13, 42, 56, 73]] = np.array([0.8, 1.2, 2.6, 0.9])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We initialize our sampling wrapper. Let's stop and analyze each of the arguments. \n", "\n", "We use the `log_lik_fun` and `posterior_vars` argument to tell the wrapper how to call {func}`~xarray:xarray.apply_ufunc`. `log_lik_fun` is the function to be called, which is then called with the following positional arguments:\n", "\n", " log_lik_fun(*data_ex, *[idata__i.posterior[var_name] for var_name in posterior_vars]\n", " \n", "where `data_ex` is the second element returned by `sel_observations` and `idata__i` is the `InferenceData` object result of `get_inference_data` which contains the fit on the subsetted data. We have generated `data_ex` to be a tuple of DataArrays so it plays nicely with this call signature.\n", "\n", "We use `idata_orig` as a starting point, and mostly as a source of observed and constant data which is then subsetted in `sel_observations`.\n", "\n", "Finally, `sample_kwargs` and `idata_kwargs` are used to make sure all refits and corresponding `InferenceData` are generated with the same properties." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "pystan_wrapper = LinRegWrapper(\n", " mcmc,\n", " rng_key=random.PRNGKey(7),\n", " log_lik_fun=calculate_log_lik,\n", " posterior_vars=(\"b0\", \"b1\", \"sigma_e\"),\n", " idata_orig=idata,\n", " sample_kwargs=sample_kwargs,\n", " idata_kwargs=idata_kwargs,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And eventually, we can use this wrapper to call {func}`~arviz.reloo`, and compare the results with the PSIS LOO-CV results." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/oriol/miniconda3/envs/arviz/lib/python3.8/site-packages/arviz/stats/stats_refitting.py:99: UserWarning: reloo is an experimental and untested feature\n", " warnings.warn(\"reloo is an experimental and untested feature\", UserWarning)\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 13\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 13\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 42\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 42\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 56\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 56\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 73\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 73\n" ] } ], "source": [ "loo_relooed = az.reloo(pystan_wrapper, loo_orig=loo_orig)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.89 7.20\n", "p_loo 3.08 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 100 100.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 0 0.0%\n", " (1, Inf) (very bad) 0 0.0%" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_relooed" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.92 7.20\n", "p_loo 3.11 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 96 96.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 2 2.0%\n", " (1, Inf) (very bad) 2 2.0%" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_orig" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 2 }