{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Refitting NumPyro models with ArviZ\n", "\n", "ArviZ is backend agnostic and therefore does not sample directly. In order to take advantage of algorithms that require refitting models several times, ArviZ uses {class}`~arviz.SamplingWrapper` to convert the API of the sampling backend to a common set of functions. Hence, functions like Leave Future Out Cross Validation can be used in ArviZ independently of the sampling backend used." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below there is an example of `SamplingWrapper` usage for [NumPyro](https://pyro.ai/numpyro/)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import numpyro\n", "import numpyro.distributions as dist\n", "import jax.random as random\n", "from numpyro.infer import MCMC, NUTS\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import scipy.stats as stats\n", "import xarray as xr" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "numpyro.set_host_device_count(4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the example, we will use a linear regression model." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "np.random.seed(26)\n", "\n", "xdata = np.linspace(0, 50, 100)\n", "b0, b1, sigma = -2, 1, 3\n", "ydata = np.random.normal(loc=b1 * xdata + b0, scale=sigma)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAA30ElEQVR4nO3dd3wb15Xo8d8FQLAA7GDvlKhuNUuWZLm3uG3sOM1O7ChtnR7vvs3mOS/ZbMrL2+zu2/Ts5jmJW+ISb1zjJE5kyZa71S1REk1SFMVexQYWgADu+2MAkBCLWACSAs/38/GHxGAGM2N/fHB55txzldYaIYQQ0ck03xcghBAiciTICyFEFJMgL4QQUUyCvBBCRDEJ8kIIEcUs830BozkcDl1cXDzflyGEEOeVAwcOdGitM8Z7b0EF+eLiYvbv3z/flyGEEOcVpdTpid6TdI0QQkQxCfJCCBHFJMgLIUQUkyAvhBBRTIK8EEJEMQnyQggRxSTICyFEFJMgL4QQs+T1aR7bW8fQsHe+L2UMCfJCCDFLb57s5GtPHeWpg43zfSljSJAXQohZqmztA+DFE60h24eGvfztw/s5cPrMfFwWsMDaGgghxPmoqs0I8q9Vd9Dv8mCLNULrrhNt7DzeSqzFxIVFafNybWEZySulapVSR5VSh5VS+/3b0pRSO5VSVf6fqeE4lxBCLDRVrU7ssRbcHh+vVnUEtz91sAGA3RVt85avD2e65kqt9Xqt9Sb/63uBXVrrMmCX/7UQQkQVrTWVrX3cvDaH5PgYdh43UjYdThcvV7azviCFAbeXPZXt83J9kczJ3wI85P/9IeDWCJ5LCCHmRXufi94hDytzkrhyeQa7K1rxeH08e7gJr0/zL7ddQGpCDH8+2jwv1xeuIK+BvyqlDiil7vZvy9JaNwP4f2aOd6BS6m6l1H6l1P729vn5phNCiJmqbHUCUJZp59pV2XQNDHPgdBdPHWxgbX4yK3OSeM/qbF480YbLM/cpm3AF+e1a643ADcAXlFKXTfVArfV9WutNWutNGRnj9rwXQogFK/DQtSwrkcuWOYgxK37+8kmONfVy24Y8AG64IAeny8Nro/L1cyUsQV5r3eT/2QY8DVwEtCqlcgD8P9vCcS4hhFhIKludpCTE4LBbSYyLYdsSB69UtmMxKd673gjyFy9JJzk+hj8dbZnz65t1kFdK2ZRSiYHfgeuAcuA5YId/tx3As7M9lxBCLDTVbX0sy0xEKQXAtSuNzPSVKzJJs1kBiDGbuHZVFjuPt+D2+Ob0+sIxks8CXlNKvQPsBf6otX4B+D5wrVKqCrjW/1oIIc5bFS29fOOZowy6jdy6UVnjZGmWPbjPe1Znk50Ux45txSHH3nhBNr1DHr71h2P82wsVfPf54xxt6In4Nc96MpTWugZYN872TuDq2X6+EEIsBE6Xh8/+5gC1nQNckJfMhzcX0u500TM4zLLMkSCfmRTHW/9rbOjbvtRBTnIcj75dh8Wk8GlNeWMPv/vMtohet8x4FUKIKfjnZ49Rd2aAjMRYHttbz4c3F1Llr6xZlpV4zuNjLWZe/eqVaIz0zU93VfEfOyupPzNAQVpCxK5betcIIcQ5PHOokScPNvClq8r47OVLOFzfzfGmXqr8PWtGp2smYzGbiDEbYfd9G42Hsk8fimxTMwnyQggxifozA3zjmXIuKk7jS1ct5bYNeVgtJh7fV0dlm1FZk2GPnfbn5qcmsLU0jacONqC1jsCVGyTICyHEJJ7YX8/gsJcf3r4ei9lEqs3KDWuyefpQI0cbeijLtAcra6br/Rvzqe0c4GBdV5iveoQEeSGEmMQ7DT0sy0okLyU+uO2OiwrpG/JwtLGHsink4ydywwU5xMeYeTKCfeglyAshxAS01hxp6GZdfnLI9i0laZQ6bIDRzmCm7LEWrl+TzfPvNEWsS6UEeSGEmED9mUG6B4ZZm58Ssl0pxe0XFQBTq6yZzG0b8+gd8rDrRGSaAkgJpRBCTOCdhm4A1p41kge4a2sxtlgLW0vTZ3WOi5c4yE6K46mDDdy0NmdWnzUeCfJCCDGBIw3dxFpMLM8eO1qPt5r56JaiWZ/DbFL86Pb1lPjTP+EmQV4IISbwTkMPq3KTgrXtkTLbvwYmIzl5IYQYh9dntB1Yd1Y+/nwjQV4IsWjVdQ7w5snOcScjVbc5GXB7x83Hn08kyAshFq3v/ek4d/zyLT52/14q/S0KAkYeuqbM/YWFkQR5IcSidbpzgLyUeN6p7+aGH7/Kv75QERzVH2noxh5rCdbDn68kyAshFiWtNQ1dg1y7KouX//FKbl2fx3+9fJLfvHUagCMNPazJS8JkmlnLgoVCqmuEEItSz+AwTpeH/NR40mxW/v0Da+kacPPd54+zPCuRE829fHJ7yXxf5qzJSF4IsSg1dA0CRjdIAJNJ8cMPrScrKY5PPriPYa8+7/PxIEFeCBFFdp1opaFrYEr7BvbLTx1pPJacEMMv7ryQYZ+Rlz/fK2tAgrwQIkq4PT4+85sD/GLPySntX3/GGMkXpIauyrQmL5n/+OA6/mZdbsgXwPlKgrwQYkHyeH389/56PF7flPavO9OPx6epbnNOaf+GrgESYy0kxY99NPk363L56R0bZtwnfiGRIC+EWJB2Hm/lH39/hFeq2qe0f3Vbf8jPc2noGiQ/LSEqAvlkJMgLIRakQ/XdAJxo7pt8R7+aDmME3+F00TMwHPLeUwcb2F3RGrKtoWswKtIx5yJBXgixIB2u6wbg3ZYpBvn2kRF8dftIykZrzff+eIIfvVgVsq2ha0CCvBBCzAeP18fRxh5g6kH+ZLuT3OS44O8BLb1DdPa7OdHci8tjrL7UNTBMv9sbLJ+MZhLkhRALzrutfQwOeylIi+dkuxO3Z/KHr1pratr7uXx5BlaziZOjHr6WN/YCMOzVwdTPeOWT0SpsQV4pZVZKHVJKPe9/naaU2qmUqvL/TA3XuYQQ0e2wPx//4U0FeHw6mG+fyJl+Nz2Dw5RlJlLisIWM5AN/EQC84//cwESos8sno1E4R/L3ACdGvb4X2KW1LgN2+V8LIcQ5Ha7rJs1m5dpV2cC5UzYn/fn40gwbSzJtIWWUxxp7KMu047DHBjtLBkbyeTKSnxqlVD5wE/CrUZtvAR7y//4QcGs4ziWEWNi+9dwxvvzYoVl9xuH6btblJ1OaYSPGrKg4R5Cv8Y/cl2TYWZphp+7MQDD/Xt7UwwV5yawvSA4ZySfFWUiOj5nVdZ4PwtWg7EfAV4HRCyFmaa2bAbTWzUqpzPEOVErdDdwNUFhYGKbLEULMl5febaPf5Z3x8X1Dw1S3O7l5bS4xZhNLMuznHMnXdPRjtZjITYlnSaYdn4bajgFSbTG09rpYnZfMgMvDroo2eoeGqT8zsCgeukIYRvJKqZuBNq31gZkcr7W+T2u9SWu9KSMjY7aXI4SYR/0uD6c7B+hwuhh0zyzQH2noQWtYX5gCwPLsxHOna9qclDpsmE2KJRl2Y1u7k2P+h65rcpNYW5CC1lDe0LNoauQhPOma7cB7lVK1wOPAVUqp3wKtSqkcAP/PtjCcSwixgI1Oq0y1UdjZAg9d1/s7QC7PTqSxe5DeoeEJj6np6Kc0w1jcI/Czus0ZfOi6Oi+Zdf5mY4cbumnoGqQgTUbyU6K1/prWOl9rXQzcDuzWWt8JPAfs8O+2A3h2tucSQixsFS29wd/rZxjkD9V1U5phIznByJevyDaywJUTjObdHh91ZwYodRgj+ASrhbwUo/SyvLGHUocNe6yFlAQrxekJvFzRzuCwV0byYfB94FqlVBVwrf+1EGKB01rz4OunaOsdmvaxJ5p7sfhXUgp0eZzuuQ/Xd7O+ICW4bXl2EsCED1/rzgzg9WmWZI4s07ck0051mxHk1+SNtAtem5/CvtNnACQnPxNa65e11jf7f+/UWl+ttS7z/zwTznMJISKjoWuQb/3hOI/vq5/2sRXNfawvSCHWYqL+zMQjeZ9P09XvHrO9sXuQDqeLDaOCfG5yHIlxlgnz8oGa+MBIHmBphp3K1j6aeoZYk5cU3L7On5eHxTERCmTGqxDiLIE0S2Xr1NoJBGitqWjpY1VuEvmp8ZOma36yu4qL/s+LvHi8NeT4/3rZ6AW/sWhk7qRSiuVZEz98rRlVIx+wNNPOsNeI5qNH8utGLQIiQV4IsSgFZoNOtS/76OOcLg8rspMoSEuYMF0zNOzl4TdP4/FpPvfIAV483orWmn9+7hiPvF3H3ZeVsjo3dEWm5dmJnGjpRQeG4aPUtDvJSIwlMW6k5n3JqIA/+rNW5yZjNilSEmJC9o9mEuSFECECQb6mvX/KC3aAkY8HWJmTSEFqwoQj+T8eaeZMv5v//MhGVuUk8blHDvDph/bz8Jun+cxlpXzthhVjjlmRnUjfkIfmnrHPCU62O0OCOhgjeYCi9ISQCU/xVjPLsxIXzSgeJMgLIc7S4M+lu70+Tk+SVz/bieY+lIJlWYkUpMXTN+QZ09cd4OE3a1mSYeP6Ndk8/KktrMpJYldFG5+5rJR7b1gx7iIeK3OMvPo3nz3G8aaRCp7ajn6q25yUZthD9k+zWUm3Wbkgb+ward973xq+/d41U76v8124ZrwKIeZZY/cg5Y09vGd19qw+p6FrkMRYC30uD1WtfcHJRedS0dJLUVoCtlhLsPFXfdcAyQkjgfZwfTfvNPTwnVtWo5QiOT6G3356C4fru7lkqWPCVZouLErly1eXcf9rp7jxJ69y8ZJ0WnqHgvn4zcWh/Q+VUvz645vJSIwd81kbChdXr0QZyQsRJe5/7RSff+QgXt/YvPV0NHQNsH2pA4Cq1qnn5Sta+oIj7sBEo7MnRD38Ri32WAu3bcwPbkuMi+HSsoxJl+FTSvE/rl3G6/dexVeuW+afsZrAt9+7mle/eiXv25A/5pj1BSnkpSyetMxEZCQvRJRo7BrE69N0Dbhx2MeOYKfC7fHR3DvEB7ITKW/qoXKKD18H3B5qO/u5dX0eMNLCd/TD1w6ni+ePNHPHRQXYY2cWepLjY/jiVWV88aqyGR2/GMlIXogo0dRjBNRO59j686lq7hlEa6O8cFlWIlVTLKN8t6UPrY2HrgDJCTEkxllCHr4+sb8et9fHXduKZ3x9YvokyAsRJZq6A0HeNePPCFTW5KfGU5Zpn3KFTWA2aiBdA8ZofvSEqD8dbWZDYUqw8kXMDQnyQkSBoWEvHf4RfOc4M0nH09A1wFX/8XLIKkqBHHpBagJlWYm4vUZfmHM50dyLPdYSUppYkBZPvf9Lo6FrgPLG3lk/FBbTJ0FeiCgwun58qiP5Vyo7qGnvZ9eJkVmn9WcGMZsUOclxlPlH3JXnePha0+7kxeOtrMxJDHl4WpCaQEPXAFpr/nrMOIcE+bknQV6IKBBI1YCx3ulUHKrrAuDg6e7gtoauAbKT4rCYTcG0SnXbxHn516o6uPXnrzPk8fG1G1eGvFeQlsDQsI92p4u/HGthWZadEodtgk8SkSJBXogo0DgqyHdMNcj7+7YfqOsKtgsw+qwbKRebP/0y0Uj+8b117HhgLznJ8Tz7he1sPKv+PPA5R+p72Fd7Rkbx80SCvBBRoKl7EKWgOD1hSumansFhqtuc5CbH0d7nCj5wre8KXRavLNNO1ThllMeaevj6M+VsX+rgyc9fPO4CHIEyygffqMWnJVUzXyTICxEFmruHyLDHkpMcP6V0TWBB649vLwbgYF0XLo+X1l5XMDiD0aLgZLszZIKVx+vjfz55hNQEKz+9fcOENe+BL4vXqjvIS4lndW7SuPuJyJIgL0QUaOoZJDclnjS7dUp18ofqulEKPrypkASrmYOnu2jqNh7ejq6QWZppD668FPDr105R3tjLd25ZHVy9aTzxVnNwUtZ1q7MmndEqIkeCvBBRoLF7kLyUeBw2Kx1TSNccqu9iWWYiyQkxrMtP4WBdd7CmfXSQX5blX3rPPymqtqOfH+ys5LpVWdyw5tzpl0BeXlI180faGghxntNa09Q9yNUrMrHHxtA75MHt8WG1jD+G01pzqK47GKQ3FqXwiz01wUA+Or++NNOOScHnHzlIhj0Wj09jtZj47q1rpjQyX5php/7MIJuKFldTsIVEgrwQ57mugWGGhn3kpsQTYzb5t7nJSoobd/9THf30DA6zoTAFMDo8en2aP5e3YDGpkONssRZ+ceeFHGnooaV3iLY+F3dsLpjws8/2tRtX8sWrlmIxS9JgvkiQF+I8F6iRz02JD5ZCdjonDvKH6rqBkZa7GwqMnwdOd1GYloDZFDpCv251NtfNMN2SZrOSZrPO6FgRHvL1KsR5LlAjn5cST5rNeNDZ2T9xXv5QfReJsRaW+vvEp9qswfVRAzl0ET0kyAtxnhs9kk+3G6PmySpsDtV1s64gBdOoEXtgIlN+yth6d3F+kyAvxHmuqXuQuBgTqQkxpPtTIxM1KRtwe6ho6Qvm4wOCQX4RrX26WEiQF+I819Q9RG5KPEopkuJisJjUhLNeD9V14/XpMUF+S2kaSkGZv2RSRI9ZP3hVSsUBrwCx/s/7vdb6n5VSacDvgGKgFviQ1rprtucTQoQK1MgDmEyKNJt1wlmvD71RS3J8DFtK0kO2L8mws/sfrqBonPYE4vwWjpG8C7hKa70OWA9cr5TaCtwL7NJalwG7/K+FEGHW1D1ITvJIJU2azRrsLT9adVsffz3eyo5tRdjGaUVQ4rCF5OlFdJh1kNeGQAejGP8/GrgFeMi//SHg1tmeSwgRyuXx0tbnInfUgtUOe+y41TW/2FNDXIyJHRcXz+EVivkWlpy8UsqslDoMtAE7tdZvA1la62YA/8/MCY69Wym1Xym1v729PRyXI8Si0dpjBPPRQX68dE1T9yDPHGrk9s2FpM9wkW9xfgpLkNdae7XW64F84CKl1JppHHuf1nqT1npTRkZGOC5HiEVjdI18QPo4Tcp+/dopNPDpS0vm8vLEAhDW6hqtdTfwMnA90KqUygHw/2wL57mEEKE18gEOeyxOl4ehYS8AXf1uHttbxy3rckN6xYvFYdZBXimVoZRK8f8eD1wDVADPATv8u+0Anp3tuYQQoZp7jCB/9oNXGFkG8Ll3mhhwe/nby0rn/gLFvAtH75oc4CGllBnjS+MJrfXzSqk3gSeUUp8C6oAPhuFcQohRatr7yUiMJS7GHNwWnBDldJObEs8rle0UpSewMkcW7ViMZh3ktdZHgA3jbO8Erp7t5wshxqe15tXqDraWhta8B1sb9Ltwe3y8WdPJ+zfmz8cligVAulAKsYA8c6iRzn43Vy7PoNTfQGwiJ5r7aO9zcfmy0IKF9ECTMqebg3VdDLi9XFrmiNg1i4VNgrwQc8Dp8vCP//0Od24tYvvS8QOu1pp/eqacPpeH7z5vLMr91etXcOMFOePuv6fSKDm+7KwAPnokX9PhxGxSbFuSPuZ4sThI7xoh5sDTBxv4c3kLn/nNAY439Y67T++ghz6Xh7+9tITv3LIai9nEN58tx+Xxjrv/nso2VuUkkXlW33h7rAWr2URnv5tXqzrYWJhCYtzEa7GK6CZBXogI01rzyNt1LM20Y4+18MkH99HSMzRmv/ouY43VC4tS+di2Yr558yo6nG7+fLRlzL5Ol4f9tV1cvnzs3BKlFOl2KyfbnBxt7OHSMpl/sphJkBciwg7Vd1PR0scnthfzwCc243R5+OSD+3C6PCH7jUxsMmrZL1nqoNRh48E3asd85hvVHXh8mssmCOBpNit7KtvRGsnHL3IS5IWIsEffrsNmNXPL+jxW5iTx849u5HhzL4+9XReyX0OXEeQDPd1NJsXHthVxuL6bd+q7Q/bdU9mOzWrmwgkWyE63xzLs1STFWVibnxL2exLnDwnyQkRQz+Awzx9p4r3r87D7Oz9eviyDzMRYKlv7QvZt6BrAZjWTkjCSP3//hfnYrGYeerM2uE1rzZ7Kdi5e6sBqGf9/YYe/Vv6SMseYNVvF4iJBXogIevpgA0PDPj66pTBke2mGjVMd/SHbGrsGyUs1Fv8ISIyL4f0X5vP8O83BhUBqOvpp6BocUzo5WmDWq+TjhQR5ISJEa82je+tYm5/MmrzkkPdKHHZqzgryDV2D4/aW+di2YtxeH999/ji/P9DAr187BTBpkM/2tzmQfLyQOnkhZklrTW3nACUOW8j2o409VLY6+ZfbLhhzTKnDxpl+N90DblISjFF3Q9fAuDn2pZl2rlmZxTOHm3jmcBMAK7ITKZhkFacPby5gbX6KNCQTEuSFmK3fH2jgH39/hD9++RJW546M2F883opJwfWrs8ccE/hCONXRz4ZCK71Dw/QOeSZcSPu+uy6ka8DNgNtLv9tDVmLcuPsFJMbFcFFJ2izuSkQLSdcIMUuP7TWqZJ4/0hyyfVdFGxcWpZLqz4+PVpoxEuTByMcDE468TSZFuj2WgrQEVmQnjfuZQoxHgrwQs1Dd1sfBum4sJsWfjzajtQagpWeIY029XLUia9zjCtISMJsUNe2hQT5vgpG8EDMlQV6IWfjv/Q1YTIovX11GbecAFS1GWeTuCmONnKtXjrvqJTFmE4VpCcGRfIN/tutE6RohZkqCvBAzNOz18eTBRq5ckclHthRiUvDno0bKZteJVvJT4ynLnLiTZInDFqywaegaJC7GFOwFL0S4SJAXYob2vNtOh9PFhzYV4LDHsqUknT+VtzDo9vJadQfXrMwKqXk/W4nDRm1HPz6fprF7kLyU+En3F2ImJMgLMUNP7K/HYY/lCn+TsBsvyKa6zcnDb9bi8vi4asX4qZqA0gwbg8NeWnqHJqyRF2K2JMgLMQPtfS52V7Rx28Y8YszG/0bvWZ2NUvCjF6tIsJrZUjp5CePoMsqGrgF56CoiQurkhTjLsNfHsNdHgjX0fw+tNW+c7OS5w028cKwFr9Z88MKRZfUyk+LYVJTKvtou3rM6i1iL+eyPDlHqMPL15Y09dA0My0NXEREykhfiLN/5w3Fu+slreLy+kO3/+fJJPvqrt3n+SBNXrcjk0U9vpSwrMWSfG9YYqzhdvXL80snRspJiSbCaea26A4C8FAnyIvxkJC+i3rf/cIxLljqmFHgB3j7VyamOfl480cb1a4zZqi6PlwdeP8WlZQ7uu2sT8dbxR+kf2JRPa98QN02wZN9oSilKHDbePnUGmHgilBCzISN5EdVOd/bzwOu13P/6qSntP+j2Ut3mBODBN0aO+eORZjqcbu6+rHTCAA+QFBfD125YiS12auOnEocNt8f4i6FA0jUiAiTIi6i283grAPtquxgaDl0r9Qc7K/nVqzUh2ypaevFp2FSUyls1Z6ho6UVrzQOv17I0084lEyzCPVOl/oevVrMJhz02rJ8tBEiQF1HuxROtWM0m3B4fe/1pETDWSP3FnpM88HptyP7l/kW2v33LamItJh564zQH67o52tjDjouLw17HXuLvYZOXGo9JFvcQETDrIK+UKlBKvaSUOqGUOqaUuse/PU0ptVMpVeX/Of46ZUJESPeAm321XXx0ayFWsyn4gBOMtgNuj4/G7kHqzwwEtx9r7CE1IYZVOUm8b0MeTx9q4Ce7qkiMs3DbhrywX2OgwkYeuopICcdI3gP8g9Z6JbAV+IJSahVwL7BLa10G7PK/FmLOvPxuO16f5r3rctlYlMKrVSNB/oXy5uDSeW+PGuEfa+plTV4ySil2XFzM0LCPPZXtfGhTwZTz7NNR7E/XSPmkiJRZB3mtdbPW+qD/9z7gBJAH3AI85N/tIeDW2Z5LRIcTzb385q3TY3LkMxHo+jienSdacdhjWZefwqVlGZxo7qW9z8Wg28tLFe28f2M+KQkxvF3TCYDb4+Pdlj5W5SYBsDIniS0laSgFH9tWNOtrHU9yfAx3bS3iprXnrsYRYibCmpNXShUDG4C3gSytdTMYXwTA5HO8xaLxy1dq+KdnyrnmB3tC2vNO10sVbWz9l128WtU+5j23x8eed9u5ZmUmJpMKPjB942QHeyrbGRz2ctMFOWwuTmNvrTGSr2rrw+31sWbUwh/fe98afnL7BorSbWPOES7fvXWNrMUqIiZsQV4pZQeeBP5Oa907jePuVkrtV0rtb28f+z+riD7tThd5KfHYYy187pGD/O3DB6Yd6P90tJm7f7Of1l4X//bCu2OOf/tUJ06Xh2tXGbXxa/KSSY6P4dWqDl4obyYlIYYtpWlsKUnjdOdAsP87wGr/SB5gaWYif7Mud5Z3LMT8CUuQV0rFYAT4R7TWT/k3tyqlcvzv5wBt4x2rtb5Pa71Ja70pI0NGM4tBp9PNiuxEnv/SJXzxyqW8eKKVfbVdUz7+9wca+OKjB1mXn8I3blrJ0cYe9lSGDhBePN5KXIyJ7f4RvNmk2L40nVer2tl1oo1rV2YRYzaxtTQdML4UjjX2YLOaKY7gqF2IuRaO6hoF/Bo4obX+wai3ngN2+H/fATw723OJ6NDZ7yLdbsViNvGFK5eSGGcJLqF3LseaevjKf7/D9qUOHv7URXxsWzF5KfH8dHd1cDQ/4Pbw1+OtXFqWQVzMyMSlS5Zm0Nrros/l4YYLjJmsK3OSSIyz8FbNGcqbelmdmyyljCKqhGMkvx24C7hKKXXY/8+NwPeBa5VSVcC1/tdikdNa0+l0k+6f+BNvNXPr+jz+eLSZ7gH3OY9/+V1jxP6jD68nwWrBajHx2ctLOXC6izdrOul3efj4A/to7R3io1sKQ469tMwY1SfGWkJG+JuL03irppPjTb3Bh65CRItZ14RprV8DJhr6XD3bzxfRpXfQg8enQ1ZAuuOiQn7z1mmePtTIJ7aXTHr8myc7WZGdGPySAPjgpgJ+uruaH+6sRGs4VN/Nj27fwBXLQ5/1F6QlsConiXUFySEdIreUpAWX61uTl4wQ0URmvIo51dHvAgiZwr8qN4l1BSk8trdu0gewLo+XfbVn2LYkPWR7XIyZuy8rZV9tF4fru/npHRt47wQPS5/6/MV855Y1Idu2lI583po8GcmL6CJBXsypTqeRkkm3h65lesfmAipbnRys657w2MN13bg8Pi5eMrZ/zEe2FHLT2hz+684LuXGSDpBxMebgIh8Ba3KTSLCasVpMLMmYeE1WIc5HEuTFnOp0GiP5dFtoM66/WZeLzWqe9AHsGyc7MSm4qGTsiksJVgs//8jGYMnkdFjMJi4tc3BhYeqYLwAhznfST17MqY5+YyTvOGskb4u1cMuGPJ480MDtmwvYVDw2kL9Z0xmsdw+3H354Pb6ZzckSYkGTYYuYU4GRfKrNOua9e64uIy8lnrt+vZfXRvWZAaPP+6G6LraVpo85LhwSrBbsEehNI8R8kyAv5lSn001KQsy4aZGspDh+95ltFKUn8MkH9/HXYy3B9w6c7mLYq8c8dBVCTE6CvJg2n0/zoxcrOdrQM+1jO/tdIeWTZ8tIjOXxu7eyMjeJzz1ykOePNAFGzxmLv6ZdCDF18vepmLYf76rix7uqaOtzcUH+BdM6dvREqImkJFh55NNb+MQDe7nn8cP4tJGPX1eQEpF2v0JEMxnJi2nZXdHKj3dVAVDT7pz28Z397jEPXcdjj7Xw4Ccu4sKiVP7u8UO8U98dsXy8ENFMgryYstOd/fzd44dZlZPETWtzqGnvn/ZndDpdY8onJ2KLtfDgJzazpSQdn4aLl0qQF2K6JMiLKfF4fXzutwdRSvGLOy9kVU4SbX0unC7PtD6ja2B4zESoySRYLdz/8c389lNbZCQvxAxIkF+g/v53h/njkeb5voygP5W3cLy5l/996xoK0xNY4l+A+tQ0RvNnBgKzXac2kg+It5q5pMwR9kW0hVgMJMgvQB6vj2cON/Ln8vAG+YN1XXzh0YPTXnZPa819r5ykxGHjJn/LgFL/9P+ajqnn5QMtDRyTVNcIIcJLgvwCdGbAjdZwqmP6Oe+J+Hyabz5bzh+PNPOHd5qmdexbNWcob+zl05eWBHutF6UnYFJwchoj+ZG+NdMbyQshZk6C/AIUCIanOvrHdGV87p0mdp1onfZnvnCshfLGXuJjzDzweu20ltv75as1pNmsvH9jfnBbrMVMfmrCtCpsOv0dKKeTkxdCzI4E+QUoEOQH3F7a+lwh7/3rnyv46e7qaX2ex+vj//71XZZl2fnGzSs53tzL3lNnpnRsdVsfuyva+Ni2opBVlgBKHLZJ/9p442QH+2pHztMRGMlLukaIOSNBfgHqcI4E9tFBtG9omMbuQarbnNMaiT91qJGa9n7+4brl3LYhn5SEGB54vXZKx/7q1VPEWkzctbVozHulGbZx/9oAqGrt4xMP7OPrTx8Nbut0urCYFElx4W8wJoQYnwT5BWiiIF/Z2geA0+WhuWdoSp/l8nj58YtVrCtI4bpVWcRbzdxxUSF/Pd5C/ZkBtNY8/GYtl/3bS/zpaOiD3jdPdvLUoUY+cGH+uHn00gw7A24vLb2h1zI07OVLjx3C5fFR1eakd2gYMP5CSbNZZQ1VIeaQBPkFqMPpJsassFpMIUG+oqUv+HtV29Ry4U8dbKSxe5Cvvmd5sATxrq1FKKX4z5er+exvD/DNZ4/ROzTM5x85yM92V6G15vG9ddz167cpSI3ny1eXjfvZSxxGGeXZk6L+9YUKKlr6+OT2ErQ2FvuAwALe8tBViLkkQX4BCswKLU5PCAmglS19WC3Gf7Kq1r6JDg9xpKGbdJs1uHA1QG5KPNevyeaxvfXsrmjjGzet5K2vXc2t63P5v3+t5Oafvsa9Tx3l4qUOnv7CdrKS4sb97GAZ5aiHry9VtPHA67V8/OJi/v7aMpQySjfB+PKaSksDIUT4SLenBaiz340j0Upucjw1Z43k1+QmUXdmgKrWqY3kazsGKPaPuEe75+oyBt1e7rm6jHUFKYCxcMbSTDv/sbOSj19czDduWollkpWSspJiSbCag9c47PXxjWfKWZ6VyL03rCAuxsyyzMTgkn6d/S6K0xOm+G9BCBEOEuQXoA7/SL4kw8ZL77bh9WlMysjJX78mG6vFRGXb1Ebypzv72TpOD/ZlWYnc//HNIduUUnzxqjJ2XFxM4hQejiqlKHHYgn9t/OloM43dg/zyY5uClTgbi1J5/kgTPp+eUgdKIUR4SbpmATKCoZVSh41hr6axa5D2PhddA8Msz0qkLDOR6tZzV9gMDXtp6hmiOH3sSH4yUwnwAaUZdmo6jGv59WunKHXYuHpFZvD9jYUp9A15ONrYw4DbKzXyQswxCfILjNaadqeLDHssJQ4j532qsz/40HVZdiLLsuz0uTy09rom+yjqzwwAxuzUSCl12GjoGuS16g6ONPTwyUtKQqpnNhalArDzuDGByzHFDpRCiPCQdM0C43R5cHt8pNutFDuM4Hyq3YnHv8r0iuwkFEYQrWztIzt5/IeiALWdgSA/vZH8dJRm2NAavvXcMVITYkJmxYLxJZCSEBMM8jKSF2JuhWUkr5S6XynVppQqH7UtTSm1UylV5f+ZGo5zRbtgEy97LBn2WOyxFk51GCP5jMRY0mxWlmUZI/xzlVGe7jRy5ZF82LnEX2Fzsr2fO7cWEW8NnRWrlGJDQQrv+quBJCcvxNwKV7rmQeD6s7bdC+zSWpcBu/yvxTkEJkKl22NHHmx29FPZ2sfyrMTge2k26znLKGs7+0mOjyElIXKj5xJ/5Y7VbOKubWNnxQJsLBz5fpeWBkLMrbAEea31K8DZzVBuAR7y//4QcGs4zhXtzu7vEqheqWztY3l2YnC/pZn2KYzkByJesmiLtbA8K5EPby4gM3H81NGFRaOCvKRrhJhTkczJZ2mtmwG01s1KqczxdlJK3Q3cDVBYWBjByzk/BEbyGYlGWqPYYeM5f2vgwEgeYFmWnecON6G1nnAxjdrOfjYURD5L9tyXtmMxTTxeWFeQgklBXIyZBKs8BhJiLs17dY3W+j6t9Sat9aaMjIz5vpx5F8jJp/lH8qWjJjKNHsmXZSbSO+QJdqn8xZ6TfOu5Y8H33R4fjV2DczL5KNZixjxJPxpbrIXl2UkyihdiHkRyWNWqlMrxj+JzgLYInitqdPa7SEmIIcY/0zSQ81YKyvwPXAHKMv0PX1udvFXTyff/XIFJwd9fs4zkhBgaugbw6chW1kzHZy8vDaaihBBzJ5Ij+eeAHf7fdwDPRvBcUcOY7Toy4g20JChMSwhJdZT5UzdPHWzgq78/QmFaAj4Nr1a3A3DaXyMfKMOcb7esz+NTl5TM92UIseiEq4TyMeBNYLlSqkEp9Sng+8C1Sqkq4Fr/a3EORhOvkTLD5PgYHPZYVoxK1QA47FZSEmJ46lAjDnssv//cNpLjY3j5XX+Q9/eTWSgjeSHE/AhLukZrfccEb10djs9fTDqcLlZmJ4Vs+9lHNgQfxAYopVielcjRxh5+tWMTmYlxXFrmYE9lOz6fprZzAHusRUoWhVjkpNRhgQn0rRlta+nYBmMA33vfBbg8XlbmGF8KVyzP5PkjzRxv7uV0Zz9F6QkTVt4IIRYHCfILiNvjo2dwOCRdM5mlmfaQ15cvM6qT9lS2c7pzIBj8hRCL17yXUIoRZ/r9E6FmWGqYkRjLmrwkdp1opb5rIKKNyYQQ5wcJ8hHi8nh5+lDDtBbcDkyEmupIfjyXL8vgYF03w1497RbDQojoI0E+Qv5yrJW//9077D11dreHiY0E+Zk/LL1i+cjEYhnJCyEkyEdIoJd7eVPvlI/pDPatmflIfkNBCklxxqMWKZ8UQkiQj5CGLiPIH2vqmfIxnf3+kXzizIO8xWzi0mUZJFjNZM7ic4QQ0UGqayKk/swgAMcapz6S73C6ibWYsJ3Vk326vn7jSu7cUhSyQpMQYnGSIB8hgZF8dbuToWFvcGHryXQ4XTj8feRnIzclntyU+Fl9hhAiOki6JgJ8Pk1j9yAlDhtenw6uz3ouRksDmaEqhAgfCfIR0NbnYtirec/qbGDqeflOp0uWxxNChJUE+RnQWk9a/x5I1WwpTSM5PobyKeblO2UkL4QIMwnyM/Cbt05z8fd3TzhCb+gyHroWpCawOjdpSiN5j9dHZ7+M5IUQ4SVBfgZ2V7TR3DPEHfe9xeH67jHvB2rk81PjWZ2bREVLH8Ne36Sf+czhJoa9mgsLI79cnxBi8ZAgPwPljb1cWuYgJcHKnb96e8ys1oauQRz2WOJizKzJS8bt8VE9yaLbHq+Pn+2uYlVOElevHHcpXCGEmBEJ8tPU1jtEh9PFVSsyeeIz28hKiuWTD+5jwO0J7tPQPUBBmlHCuDrX6AR5bJKZr88ebqK2c4B7rimT1sBCiLCSID9N5f78+urcZLKT4/inm1fhdHk4XNcd3Keha5D8VKNvTInDTnyMmfLG8fPyHq+Pn71UzcqcJK5blRXx6xdCLC4S5KcpMIN1ZY6xHN/GolSUgn21XQB4fZqm7kHyU42RvNmkWJmTyPEJRvJ/ONLEqY5+7rl6qYzihRBhJzNep6m8qYcSh43EuBgAkuJiWJGdxP7TRl6+rW+IYa8OBnmANXnJPHmgAZ9PYzIphr0+qlqdHG3s5mcvVbMiO5HrVmXPy/0IIaKbBPlpOtbUy7qClJBtm4tTefJAAx6vL9izJpCuASMv//CbXm7/5Vu097lo7BrE7a+2SY6P4fu3rZU+M0KIiJAgPw3dA24augb56JaikO2bitN4+M3TVLT0BSdCjR7Jb1/qoNRhY2jYy6qcJN6zOpuVOYmszU+hKC1BArwQImIWTZD3+jQH67pYkZ0YTLVMVyCvHqiYCdhcbNS276s9Q9+QUWWTN6pBWH5qAru/csWMzimEELMR9UH+dGc/v9tXz1MHG2npHeLDmwr41w+sDdnn3/9SwZn+YXZcXMSK7IkXvx6prAndJyc5nryUePbXdmGLNfq4T6XrpBBCRFpUB/l+l4frfvgKw14fVyzPpCzLzjOHG7n3hhWk2oweMVWtffz8pZMAPLa3jm2l6Xz9ppWsyUse83nHmnrJSY4bt/XA5uJU3jjZyZIMe0iqRggh5lPESyiVUtcrpd5VSlUrpe6N9PlGa+oexOXx8e8fWMf9H9/MN25ahcvj43f764P73PdKDXExJl76yhXce8MK3m3t4+tPHx3388obe1idOzb4g5GXb+tzcbi+O+ShqxBCzKeIBnmllBn4OXADsAq4Qym1KpLnHK25ZwiAgjQj6C7PTmRraRq/efM0Xp+mtXeIZw438qFNBZQ4bHz28iV88MJ8TjT34faE9poZcHuo6egfk6oJuKgkDYDBYa+M5IUQC0akR/IXAdVa6xqttRt4HLglwucMavEH+ZzkuOC2j19cTGP3ILtOtHL/66fw+jSfvqQ0+P7qvGTcXh9VbaELfZxo7kVrxk3jACzNsJMcbzzQlZG8EGKhiHSQzwPqR71u8G8LUkrdrZTar5Ta397eHtaTt/QaQT4zaSSHfs3KLHKS4/jFnpM8+lYdN1yQQ2H6SFBeE+g1c1YP+GMTVNYEmEyKTUVGlU2gb40QQsy3SAf58QrAQ1bb0Frfp7XepLXelJGREdaTN/cM4bBbibWMVLpYzCbu3FrEwbpu+lwePnNZacgxxek27LGWYCVNwDv1PaTZrCF/FZxtU7GRsimQkbwQYoGIdHVNA1Aw6nU+0BThcwa19AySPU5Qvn1zAT/eVcXGwhTW5qeEvGcyKVblJoU0FNNa83p1B1tK0ibtL3Pn1kJyU+IodtjCdg9CCDEbkQ7y+4AypVQJ0AjcDnwkwucMau4ZGvchaLo9lkc/vYWclPHTKmtyk3l072k8Xh8Ws4mT7U5aeoe4tGzyvzQS42K4ZX3epPsIIcRcimi6RmvtAb4I/AU4ATyhtT4W7vOc7uzn608f5WR76MIcLb1D447kwUit5E0U5POSGBr2UdPRD8CrVR0AXFrmCONVCyFE5EW8Tl5r/Set9TKt9RKt9fcicQ63x8cjb9eFpFiGhr10DwyTkzz9h6CBCprA571W1UFRekKwFFMIIc4XUdFPvjA9AZOCk6OW2AuUT2YnTfygdCKlDhtxMSbKG3txe3y8WdPJJUtlFC+EOP9ERZCPtZgpTEvgZHt/cFtgItRE6ZrJWMwmVuUkUd7Uw6G6Lgbc3nPm44UQYiGKiiAPsCTDHpKTb+k1+rrPJMiDkbI53tTLK1XtmBRsW5IelusUQoi5FD1BPtPOqY5+vD6jDL+lxwXMLF0DRoWN0+Xhif0NrCtICc5mFUKI80n0BPkMGy6Pj6ZuYwTf0jNIUpwFW+zMqkRX5xkzW9v7XFwq+XghxHkqioK8HYBqf8qmuWdoRpU1AcuyErGajX89l0g+XghxnoqaIF/qD/KBCpuW3iGyZpiPB4gxm1iRk4jNamZDYUo4LlEIIeZc1CwakmazkpoQE6ywae4ZYuUkqzxNxecuX0JHv5sYc9R8FwohFpmoCfIwUmEz7PXR4XTNuLIm4IYLcsJ0ZUIIMT+iaoi6JMNOTbuTtj4XWjNpx0ghhFgMoivIZ9rocLp5t8Xo/T6bnLwQQkSD6Ary/oevr1d3AjKSF0KIqArypcEgb3SNzEmSFZqEEItbVAX5gtR4YsyKipY+4mPMJMVH1XNlIYSYtqgK8hazieJ0Y1WmnOS4SVdxEkKIxSCqgjyM5OWzZtizRgghokn0BfnMkZG8EEIsdlEX5Esdxkh+thOhhBAiGkRdkF+SKUFeCCECoi7Ir8lN4rOXL+H61dnzfSlCCDHvoq7G0GI2ce8NK+b7MoQQYkGIupG8EEKIERLkhRAiikmQF0KIKCZBXgghotisgrxS6oNKqWNKKZ9SatNZ731NKVWtlHpXKfWe2V2mEEKImZhtdU05cBvw/0ZvVEqtAm4HVgO5wItKqWVaa+8szyeEEGIaZjWS11qf0Fq/O85btwCPa61dWutTQDVw0WzOJYQQYvoilZPPA+pHvW7wbxtDKXW3Umq/Ump/e3t7hC5HCCEWp3Oma5RSLwLjTR/9utb62YkOG2ebHm9HrfV9wH3+c7UrpU6f65om4QA6ZnH8+Wax3S/IPS8Wcs/TUzTRG+cM8lrra2ZwwgagYNTrfKBpCufKmMG5gpRS+7XWm869Z3RYbPcLcs+Lhdxz+EQqXfMccLtSKlYpVQKUAXsjdC4hhBATmG0J5fuUUg3ANuCPSqm/AGitjwFPAMeBF4AvSGWNEELMvVmVUGqtnwaenuC97wHfm83nz8B9c3y++bbY7hfknhcLuecwUVqP+zxUCCFEFJC2BkIIEcUkyAshRBSLiiCvlLre3yOnWil173xfTyQope5XSrUppcpHbUtTSu1USlX5f6bO5zWGm1KqQCn1klLqhL9H0j3+7VF730qpOKXUXqXUO/57/rZ/e9TeM4BSyqyUOqSUet7/Otrvt1YpdVQpdVgptd+/LSL3fN4HeaWUGfg5cAOwCrjD3zsn2jwIXH/WtnuBXVrrMmCX/3U08QD/oLVeCWwFvuD/bxvN9+0CrtJarwPWA9crpbYS3fcMcA9wYtTraL9fgCu11utH1cZH5J7P+yCP0ROnWmtdo7V2A49j9M6JKlrrV4AzZ22+BXjI//tDwK1zeU2RprVu1lof9P/ehxEE8oji+9YGp/9ljP8fTRTfs1IqH7gJ+NWozVF7v5OIyD1HQ5Cfcp+cKJSltW4GIyACmfN8PRGjlCoGNgBvE+X37U9dHAbagJ1a62i/5x8BXwV8o7ZF8/2C8cX9V6XUAaXU3f5tEbnnaFjIe8p9csT5SSllB54E/k5r3avUeP/Jo4d/4uB6pVQK8LRSas08X1LEKKVuBtq01geUUlfM8+XMpe1a6yalVCawUylVEakTRcNIfkZ9cqJEq1IqB8D/s22eryfslFIxGAH+Ea31U/7NUX/fAFrrbuBljGcx0XrP24H3KqVqMVKtVymlfkv03i8AWusm/882jAmlFxGhe46GIL8PKFNKlSilrBiLlTw3z9c0V54Ddvh/3wFM1BX0vKSMIfuvgRNa6x+Meitq71spleEfwaOUigeuASqI0nvWWn9Na52vtS7G+H93t9b6TqL0fgGUUjalVGLgd+A6jAWYInLPUTHjVSl1I0Zezwzc72+pEFWUUo8BV2C0I20F/hl4BqNHUCFQB3xQa332w9nzllLqEuBV4Cgj+dr/hZGXj8r7VkqtxXjoZsYYhD2htf6OUiqdKL3nAH+65ita65uj+X6VUqWMtIOxAI9qrb8XqXuOiiAvhBBifNGQrhFCCDEBCfJCCBHFJMgLIUQUkyAvhBBRTIK8EEJEMQnyQggRxSTICyFEFPv/U2NDJ5TaV8EAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(xdata, ydata)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will write the NumPyro code:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def model(N, x, y=None):\n", " b0 = numpyro.sample(\"b0\", dist.Normal(0, 10))\n", " b1 = numpyro.sample(\"b1\", dist.Normal(0, 10))\n", " sigma_e = numpyro.sample(\"sigma_e\", dist.HalfNormal(10))\n", " numpyro.sample(\"y\", dist.Normal(b0 + b1 * x, sigma_e), obs=y)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "data_dict = {\n", " \"N\": len(ydata),\n", " \"y\": ydata,\n", " \"x\": xdata,\n", "}\n", "kernel = NUTS(model)\n", "sample_kwargs = dict(\n", " sampler=kernel, num_warmup=1000, num_samples=1000, num_chains=4, chain_method=\"parallel\"\n", ")\n", "mcmc = MCMC(**sample_kwargs)\n", "mcmc.run(random.PRNGKey(0), **data_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have defined a dictionary `sample_kwargs` that will be passed to the `SamplingWrapper` in order to make sure that all refits use the same sampler parameters. We follow the same pattern with {func}`az.from_numpyro `." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999\n",
             "Data variables:\n",
             "    b0       (chain, draw) float32 -3.0963688 -3.1254756 ... -2.5883367\n",
             "    b1       (chain, draw) float32 1.0462681 1.0379426 ... 1.038727 1.0135907\n",
             "    sigma_e  (chain, draw) float32 3.047911 2.6600552 ... 3.0927758 3.2862334\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:36:51.467097\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 1000, time: 100)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "Data variables:\n",
             "    y        (chain, draw, time) float32 -2.1860917 -3.248132 ... -2.305284\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:36:51.544419\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain      (chain) int64 0 1 2 3\n",
             "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
             "Data variables:\n",
             "    diverging  (chain, draw) bool False False False False ... False False False\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:36:51.468495\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (time: 100)\n",
             "Coordinates:\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "Data variables:\n",
             "    y        (time) float64 -1.412 -7.319 1.151 1.502 ... 48.49 48.52 46.03\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:36:51.545286\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (time: 100)\n",
             "Coordinates:\n",
             "  * time     (time) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99\n",
             "Data variables:\n",
             "    x        (time) float64 0.0 0.5051 1.01 1.515 ... 48.48 48.99 49.49 50.0\n",
             "Attributes:\n",
             "    created_at:                 2020-10-06T03:36:51.545865\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          numpyro\n",
             "    inference_library_version:  0.4.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dims = {\"y\": [\"time\"], \"x\": [\"time\"]}\n", "idata_kwargs = {\"dims\": dims, \"constant_data\": {\"x\": xdata}}\n", "idata = az.from_numpyro(mcmc, **idata_kwargs)\n", "idata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will create a subclass of `az.SamplingWrapper`. Therefore, instead of having to implement all functions required by {func}`~arviz.reloo` we only have to implement {func}`~arviz.SamplingWrapper.sel_observations` (we are cloning {func}`~arviz.SamplingWrapper.sample` and {func}`~arviz.SamplingWrapper.get_inference_data` from the {class}`~arviz.PyStanSamplingWrapper` in order to use {func}`~xarray:xarray.apply_ufunc` instead of assuming the log likelihood is calculated within Stan). \n", "\n", "Let's check the 2 outputs of `sel_observations`.\n", "1. `data__i` is a dictionary because it is an argument of `sample` which will pass it as is to `model.sampling`.\n", "2. `data_ex` is a list because it is an argument to `log_likelihood__i` which will pass it as `*data_ex` to `apply_ufunc`.\n", "\n", "More on `data_ex` and `apply_ufunc` integration is given below." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class NumPyroSamplingWrapper(az.SamplingWrapper):\n", " def __init__(self, model, **kwargs):\n", " self.model_fun = model.sampler.model\n", " self.rng_key = kwargs.pop(\"rng_key\", random.PRNGKey(0))\n", "\n", " super(NumPyroSamplingWrapper, self).__init__(model, **kwargs)\n", "\n", " def log_likelihood__i(self, excluded_obs, idata__i):\n", " samples = {\n", " key: values.values.reshape((-1, *values.values.shape[2:]))\n", " for key, values in idata__i.posterior.items()\n", " }\n", " log_likelihood_dict = numpyro.infer.log_likelihood(self.model_fun, samples, **excluded_obs)\n", " if len(log_likelihood_dict) > 1:\n", " raise ValueError(\"multiple likelihoods found\")\n", " data = {}\n", " nchains = idata__i.posterior.dims[\"chain\"]\n", " ndraws = idata__i.posterior.dims[\"draw\"]\n", " for obs_name, log_like in log_likelihood_dict.items():\n", " shape = (nchains, ndraws) + log_like.shape[1:]\n", " data[obs_name] = np.reshape(log_like.copy(), shape)\n", " return az.dict_to_dataset(data)[obs_name]\n", "\n", " def sample(self, modified_observed_data):\n", " self.rng_key, subkey = random.split(self.rng_key)\n", " mcmc = MCMC(**self.sample_kwargs)\n", " mcmc.run(subkey, **modified_observed_data)\n", " return mcmc\n", "\n", " def get_inference_data(self, fit):\n", " # Cloned from PyStanSamplingWrapper.\n", " idata = az.from_numpyro(mcmc, **self.idata_kwargs)\n", " return idata\n", "\n", "\n", "class LinRegWrapper(NumPyroSamplingWrapper):\n", " def sel_observations(self, idx):\n", " xdata = self.idata_orig.constant_data[\"x\"].values\n", " ydata = self.idata_orig.observed_data[\"y\"].values\n", " mask = np.isin(np.arange(len(xdata)), idx)\n", " data__i = {\"x\": xdata[~mask], \"y\": ydata[~mask], \"N\": len(ydata[~mask])}\n", " data_ex = {\"x\": xdata[mask], \"y\": ydata[mask], \"N\": len(ydata[mask])}\n", " return data__i, data_ex" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.92 7.20\n", "p_loo 3.11 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 100 100.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 0 0.0%\n", " (1, Inf) (very bad) 0 0.0%" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_orig = az.loo(idata, pointwise=True)\n", "loo_orig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, the Leave-One-Out Cross Validation (LOO-CV) approximation using [Pareto Smoothed Importance Sampling](https://arxiv.org/abs/1507.02646) (PSIS) works for all observations, so we will use modify `loo_orig` in order to make `az.reloo` believe that PSIS failed for some observations. This will also serve as a validation of our wrapper, as the PSIS LOO-CV already returned the correct value." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "loo_orig.pareto_k[[13, 42, 56, 73]] = np.array([0.8, 1.2, 2.6, 0.9])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We initialize our sampling wrapper. Let's stop and analyze each of the arguments. \n", "\n", "* We use `idata_orig` as a starting point, and mostly as a source of observed and constant data which is then subsetted in `sel_observations`.\n", "\n", "* We also use `model` to get automatic log likelihood computation and we have the option to set the `rng_key`. Even if the data for each fit is different the `rng_key` is split with every fit.\n", "\n", "* Finally, `sample_kwargs` and `idata_kwargs` are used to make sure all refits and corresponding `InferenceData` are generated with the same properties." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "numpyro_wrapper = LinRegWrapper(\n", " mcmc,\n", " rng_key=random.PRNGKey(5),\n", " idata_orig=idata,\n", " sample_kwargs=sample_kwargs,\n", " idata_kwargs=idata_kwargs,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And eventually, we can use this wrapper to call `az.reloo`, and compare the results with the PSIS LOO-CV results." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/oriol/miniconda3/envs/arviz/lib/python3.8/site-packages/arviz/stats/stats_refitting.py:99: UserWarning: reloo is an experimental and untested feature\n", " warnings.warn(\"reloo is an experimental and untested feature\", UserWarning)\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 13\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 13\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 42\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 42\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 56\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 56\n", "arviz.stats.stats_refitting - INFO - Refitting model excluding observation 73\n", "INFO:arviz.stats.stats_refitting:Refitting model excluding observation 73\n" ] } ], "source": [ "loo_relooed = az.reloo(numpyro_wrapper, loo_orig=loo_orig)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.89 7.20\n", "p_loo 3.08 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 100 100.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 0 0.0%\n", " (1, Inf) (very bad) 0 0.0%" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_relooed" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.92 7.20\n", "p_loo 3.11 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 96 96.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 2 2.0%\n", " (1, Inf) (very bad) 2 2.0%" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_orig" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 2 }