{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(pystan_refitting)=\n", "# Refitting PyStan (3.0+) models with ArviZ\n", "\n", "ArviZ is backend agnostic and therefore does not sample directly. In order to take advantage of algorithms that require refitting models several times, ArviZ uses {class}`~arviz.SamplingWrapper` to convert the API of the sampling backend to a common set of functions. Hence, functions like Leave Future Out Cross Validation can be used in ArviZ independently of the sampling backend used." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below there is one example of `SamplingWrapper` usage for PyStan exteding {class}`arviz.PyStanSamplingWrapper` which already implements some default methods targeted to PyStan.\n", "\n", "Before starting, it is important to note that PyStan cannot call the C++ functions it uses. Therefore, the **code** of the model must be slightly modified in order to be compatible with the cross validation refitting functions." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import stan\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# enable PyStan on Jupyter IDE\n", "import nest_asyncio\n", "\n", "nest_asyncio.apply()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the example we will use a linear regression." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "np.random.seed(26)\n", "\n", "xdata = np.linspace(0, 50, 100)\n", "b0, b1, sigma = -2, 1, 3\n", "ydata = np.random.normal(loc=b1 * xdata + b0, scale=sigma)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8+yak3AAAACXBIWXMAAAsTAAALEwEAmpwYAAA3tElEQVR4nO3dd3xc5Zno8d87oy6N+kiyiq3qblwxNrYxzQRILi2QQAokkJDsZnfZLNkQNjf3brKbTbmbsiWbhACJEwglgQRIQjGmGnDvcpVk9d41ozKamff+cWZGGhWrzaiMnu/n448055yZc84H8ejVc573eZXWGiGEEKHJNNMXIIQQIngkyAshRAiTIC+EECFMgrwQQoQwCfJCCBHCwmb6AgZLTU3Vubm5M30ZQggxpxw6dKhZa20dad+sCvK5ubkcPHhwpi9DCCHmFKVUxWj7JF0jhBAhTIK8EEKEMAnyQggRwiTICyFECJMgL4QQIUyCvBBChDAJ8kIIEcIkyAshxBTZ+5w8c6ASt3v2tW6XIC+EEFP00rFaHnruBHtKmmf6UoaRIC+EEFN0rsEGwOunG/y2t9j6uPvx/ZQ322fisgAJ8kIIMWXnG7sAeP1UA4NX2/v9oWreOdfE0weqZurSAhPklVLlSqkTSqmjSqmDnm3JSqldSqnznq9JgTiXEELMNucbbMRFhlHb0UtxbScAWmueO1wNwMsn65ippVYDOZK/Smu9Rmu9wfP6a8BurXURsNvzWgghQkpHTz/1nb3ctTEHpWDXKSNlU1zbybkGG2tyEqlo6eZUXeeMXF8w0zU3Azs93+8EbgniuYQQYkaUNBr5+E35KaxfmOQL8s8dribCbOJHH1+D2aR4+UT9jFxfoIK8Bl5TSh1SSt3v2Zauta7zfF8PpI/0RqXU/Uqpg0qpg01NTQG6HCGEmB7nG4x8fFGahR3L0zlV10lFi50Xj9Zy7fI08lJj2Zyfwl9OzEzKJlBBfqvWeh1wA/AlpdQVg3dq485GvDut9SNa6w1a6w1W64g974UQYtY632gjKtxEdlI01y43xrLfeKGYFruD29ZmA3DDqgzKmu2+KpzpFJAgr7Wu8XxtBP4AbAQalFILADxfGwNxLiGEmE3ONXRRmBaHyaQosMaRb43lnXNNpMRGsH2JMXC9bnkGJgV/OVE3xqcF3pSDvFIqVill8X4PXAecBF4E7vEcdg/wwlTPJYQQs01Jo43FaRbf6x3LjNH8TWsyCTcbIdZqiWRjXjIvn5yDQR4j175HKXUM2A/8WWv9CvBdYIdS6jxwree1EELMWXvLWvj2n0/52hd09vZT19FLYXqc75hb1mZhtURy18aFfu+9cdUCzjXY+O7LZ/jeK2f41z+doqIl+JOkprzGq9a6DFg9wvYW4Jqpfr4QQswG9R29/NUTh2jr7ufKJWlsKUz1VdYMHskvWxDPga9fO+z916/M4PuvnOVnb5cSYTbR73bT1t3PDz42LHwG1KxayFsIIWYjl1vz5WeO0ud0Y4kK46n9lWwpTPVV1ixOt4zxCZBmieLI/9mBAsLMJh76/XFeOl7Lv9yygpiI4IViaWsghBBj+NnbpXxQ1sI3b1rB7euzebW4nhZbH+cbBiprxiPcbCLMk6f/6Ppsuh0uXjkZ3Pp5CfJCCHERhyvb+OGuc9y0OpPb12dz18aF9LuMlgXnGm2+ypqJ2rAoiZzkaJ4/XBOEqx4gQV4IIS7iNx9UkBAdzrdvXYlSisXpFtYvSuLp/VWcb+iiKG3sVM1ITCbFbWuzea+0mdr2ngBf9aDzBO2ThRAiBByrbmf9oiQsUeG+bXdtXEhZs526jl6KBlXWTNRt67LQGv54NHijeQnyQggxis7efsqa7KzOTvDb/uFVC7BEGQ9LJzuSB1iUEsuluUk8d6g6aC0PJMgLIcQoTlZ3AHBJdqLf9ugIM7euzQJg8RRG8gC3rcumtMnOcc+5Ak2CvBBCjOKYL8gnDNv399cu5v/dfgmLUmKndI4bVy0gIszE857e84EmdfJCCDGK49XtLEqJITEmYti+5NgI7tiQM+VzJESH84u7NwxLCQWKBHkhhBjF8eoO1i0K/qJ22xcHrwOvpGuEEGIETV191LT3BG2EPV0kyAsh5q0z9Z0cq2ofcd/xamP70Ieuc42ka4QQ89aDzx6juLaTj1yygIdvXEZW4kB7gmPVHZgUrMyKn8ErnDoZyQsh5iWtNRUt3eRbY9l1qoFrfvAWv3rvgm//8ep2itIsQW0eNh0kyAsh5qWOnn5sfU4+sXEhux/czub8FP75pVO8fqoBrTXHqztYNcfz8SBBXggxT1W3Gf1ispNiyE6K4aefWs/KrHi+/OxRPihtodXumPMPXUGCvBBinqpu6wbwtQmOCjfz00+uRwGf//VBYO4/dAUJ8kKIEPLC0Ro6uvvHdax3JJ+TFOPblpMcw4/vXIPd4SLcrFi6YPJ9aWYLCfJCiJBQ1drNA08f5akDleM+3hIZRny0/4PVq5em842PLOfuzblEhpmDcanTam4/NhZChCxbn5Pdpxu4aXUmSo29KEdJk7Heqnfd1bFUt/WQlRQ94mfftzVvYhc7i8lIXggxKz21r5IHnj7K+XEG7dLGiQf5nOSYsQ+c4yTICyFmpSNVbQCcrusc1/FlzXYASptsw3qzP7bnAocr23yvtdZUt3WPe23WuUyCvBBiVjpa2Q7A2fqucR1f5knXdPU6aerq82239zn51z+f4tF3y3zb2rv7sTtcZCfJSF4IIaZdY2cvtR29wPiDfGmTncyEKGAgPw9wqq4TreFY1cCiHFVDyidDmQR5IcSsc8TTNCwnOZoz4wjyXb39NHX1cd2KDGAgPw9wssYI7jXtPb4R/sBEKAny46aUMiuljiil/uR5naeU2qeUKlFKPaOUGt51XwghRnC0qp0wk+LWtdnUtPfQ1Xvx2veyJiMfv7kghdgIM6We1wAnagZG8N7OkgMToSRdMxEPAKcHvf4e8COtdSHQBtwXwHMJIULY0cp2li2I97UVONdw8dF8qSc9U2CNoyAtzq/Cprimk8vykjGpgeX8qtt6iI8KIyE6PEh3MHsEJMgrpbKBDwOPel4r4Grg955DdgK3BOJcQojZ7f5fH+R7r5yZ9Ptdbs3x6nbW5CSyJMOYcTpWyqasyY7ZpFiYHEOhNc4X9HscLs43dnFZXjKL0y2+3vHVbT3zYhQPgRvJ/xj4KuD2vE4B2rXWTs/raiBrpDcqpe5XSh1USh1samoK0OUIIWaC0+XmrbNNfFDaMunPKGm0YXe4WJOTSFZiNJbIsDEfvpY121iYHENEmImCtDjqOnqx9Tk5Xd+JW8OKrARWZydyvLodrTVVrfOjfBICEOSVUh8BGrXWhybzfq31I1rrDVrrDVZr8NY5FEIE34VmOw6X25fznoyjnvr4NQsTUUqxOMMy5ki+tNFOgTUWMFI2YJRUFnvy8SuzErgkJ4G27n6qWntkJD9BW4CblFLlwNMYaZr/ABKVUt62CdlATQDOJYSYxU55Ji412xx0O5xjHD2yo1XtJESHk5diBO0lGRbO1ncNm+Dk5XJrLrTYyfcE98I0430ljTZO1HSQHBtBZkIUqz0dJd8820hPv4ucZBnJj4vW+mGtdbbWOhe4E3hDa/1J4E3gds9h9wAvTPVcQojZbfCI21umOFFHKttZnZOIyWT0lFmaYaGjp5+Gzr4Rj69t78HhdJOfagT3RSmxhJkUpU02TtZ0siIzHqUUSzIsRIaZ+PPxOmB+VNZAcOvkHwL+QSlVgpGjfyyI5xJCBIjT5eZnb5di75v4SPx0XSdhnuBc1TrxlI29z8m5hi7W5CT6ti1J9z58Hbm9ga+yJs0YyYebTSxMieFUbSfnGrpYlZXg274iM54DFa3A/KiRhwAHea31W1rrj3i+L9Nab9RaF2qt79Baj/xrWAgxqxwob+O7L5/h1eL6Cb/3TF0XmwtSgIsHeYfTTecIte8najpwa1g7KMgvzTAW0h7t4au3Jt47kgcotMaxp6QZp1uzMmtgdafVOYl4sz5ZEuSFEPORd8r/uYbxdXP0arM7qO/sZWthKlHhJqoukq75pz+cYMt33uDIoKZhDqebR98tw6SMYOyVEBNORnzUqEG+rMlGQnQ4ybED8y0L0+LodxnRfNXgIO/JyydEhxMfFfo18iBBXggxhDeXXtI4vp4xXqc96ZRlC+LJTooZdSTf2NnLH4/UYHM4ufux/RytasfhdPM3vz3M66cb+eebVvgFbDAevp4eNcjbybfG+vWF91bYJESH+6VlvL885stDV5AgL4QYonqSI/kzdUYQXrYgnpyk6FFH8k/tr8Lp1jx532UkxUbw6Uf38dlf7ee1Uw186+YV3L05d9h7lmZYKG200e9yD9tX2mTzBXWvQk9+fmVWvF/wz02JMQJ/4vx46AoS5IUQQ1S3GsG5qq2bHodr3O87XddJalwEVkskOckxVLd2Dyt77He5eXJfBdsXW7m8MJWn799EUmwE75W0jBrgwfjF4XC5+afnT1DRYuTgtdacqO6gsauPfGus3/H51lhMClZlJfptV0rx359Yy5d3LB73fc11svyfECHiXEMXzV19XF6YOqXPqW4z1j7t6nNS2mTze3B5MWfqu3wPSXOSYujqc9LZ4yQhZiD3/WpxPY1dfXz3o4sAyEyM5vm/vpwLzXYuzU0e9bNvWJXBoYpFPHOwiucOV7OtyEpJo42a9h5MCtYtTPI73hIVzhP3XcayBfHDPmtb0fyadCkjeSFCxA9eO8tXnzs+pc9wON3Ud/aybbHxi+L8OPPyTpebcw1dLFtglDt6c95VQ2a+/vr9ChYmx7B9cZpvW2pc5EUDPEBkmJl/uWUle756FZ/flk9Jo40VmfF876Or2PvwNWzKTxn2nssLU0mKlea3MpIXIkTUtPfQbOtDaz2uha9HUtfRg1vD1kIru041jDsvX95ip8/p9o3kvRONqlq7fX8JnKrtZH95K1+/cRlm0+SuLy0+iodvXMbDNy6b1PvnIxnJCxEiatt76e130z2BPPpQ3sqa3NQY8lJjOT/OIH960ENXwLdA9uCR/BP7KogKN3HHhuxJX5+YOAnyQoSAHoeLVrsDgBabY9Kf462syUmKoSjNMu50zZl6Y6ZrgadvjFGHHkaV5yGuy6159WQ9O5ZnkBgjKZTpJEFeiBBQ2zFQrthiH9/k8mNV7ez44du+Xw5gjOTNJsWChCiK0uOobO2mt3/svwxO13VRmBZHZJjZty0nOcY3kj9Y3kqL3cGHVqSP95ZEgEiQFyIE1LYPCvLjHMm/WlzP+UYbe8sGer9XtXaTER9FmNlEUZoFrfFbZWkkRyrbOHChdVglS86gCVGvFjcQEWbiyiVpI32ECCIJ8kKEgMFBfvDI/GKOVLYDcLhioLWA0WfdqIxZnG5MKLpYkH/haA0ff2QvSbER/N01RX77cpKjqW7rwe3WvFpcz9bCVOIipdZjukmQFyIE1LT3+r5vHke6xuXWHPMsan240j/Iex+a5qYaLXtHW1/1R7vO8cDTR1mbk8gfv7SFvFT/CUk5yTH0Od28fb6JmvYeSdXMEPm1KkQIqG3vISM+is7e/nGla841dNHtcJGZEMXJmk76nEbevaGr1zeSDzebjAqbEUbyb5xp4D92n+f29dn8262riAgbPl7M8ZRRPvbuBUwKrl0mQX4myEheiBBQ19FDZmIUKXER40rXeFM1n9mSi8Pl5mRNJ7XtvWg9EJwBFqdbOD9kJN/V28/X/3CSJemWUQM8DEyI2lPSzIbcZFLiIid5d2IqJMgLEQJq23vJTIwmOTaSZtvY6ZojlW0kx0Zwy9oswMjLe8snB3dtLEwbXmHz/VfOUt/Zy3c/OnqANz5n4JfFh1ZkTPieRGBIkBdijtNaU9PeQ1ZiNKmxEeNK1xypamdtTiJplihykqM5XNnmq2nPTvYfybv1wOpL+y+08pu9FXz28jzWDukXM1RUuBmrxRi9X7dcUjUzRXLyQsxxLXYHDqebzMRoWu0OimtHXibPq6Onn5JGG7esyQSM5l57y1p8D1oz4qN8xy7JMCpsbvnJe6RZorA7nGQnRfOVD42vi2OhNY4FCVG+h7li+kmQF2KO85ZPZiZGU9fRS4v94v1rjlW1A/hG4usXJfHC0Vr2X2glMzHar69MgTWOH9yxmpImGw0dvTTbHfzNVYXERIwvdPzo42smf2MiICTICzHHDQT5KCpaIuh3abr6nKMub3eksh2l4JJso3GYt03voYo2Li/w7+aolOKj6yffayYjIWrsg0RQSU5eiDnOWyOflRjtWzbvYnn5I1VtLE6zYPH8EliaYSE63GhHMLiyRoQGCfJCzHG17T3ERJhJiA73lSm2jFJho7XmSGU7axcm+raFmU2szjFG9YMra0RokCAvxBxX295DZmI0SilSvCP5UWrlLzTb6ejp9wvyMJCyyZ5HC1zPFxLkhZjjvEEeICXu4umaA+WtAMPKHy8vMFaCKkqzBOsyxQyZcpBXSkUppfYrpY4ppYqVUt/0bM9TSu1TSpUopZ5RSkkTaSGCoKa9l6xE4wGnNyffOkL/Gq01O9+vIN8aS6E1zm/f1qJUdj+4fdzruYq5IxAj+T7gaq31amANcL1SahPwPeBHWutCoA24LwDnEkIM0tvvotnWx4IEYyQfGWbGEhlG8wgj+XfON3OqrpMvXlGAaYTl9wqGBH4RGqYc5LXB28Eo3PNPA1cDv/ds3wncMtVzCSH81XcYlTXedA0YKZuRcvI/fauEjPgobl6bOW3XJ2ZeQHLySimzUuoo0AjsAkqBdq2103NINZA1ynvvV0odVEodbGpqCsTlCDFvDK6R90qOjRiWrjlS2cbeslY+ty3Pb/UmEfoCEuS11i6t9RogG9gILJ3Aex/RWm/QWm+wWq2BuBwh5o0aT5DP8hvJRw578Pqzt0tJiA7nzo0Lp/X6xMwLaHWN1rodeBPYDCQqpbwzarOBmkCeSwhhdJ8E/5mlqUPSNSWNXbxa3MA9mxfJykzzUCCqa6xKqUTP99HADuA0RrC/3XPYPcALUz2XEMJfXUcPVkukXwrGSNc4cLs1AM8cqCLCbOKey3Nn6CrFTArESH4B8KZS6jhwANiltf4T8BDwD0qpEiAFeCwA5xJCDFLWZB82SzUlNhKXW9PR0w/AO+eauTQvSRbtmKem/Leb1vo4sHaE7WUY+XkhRBDY+pwcrmzj81fk+233TYiyO3C43Jxt6OK2deN+TCZCjCTohJhFdr5fTphZcdWSNL+yyJG8X9KM063Zvti/YCEldqB/zVFPW+FtRVLUMF9JkBdiGtR19PCNPxbz5R1FrMgceVZpR3c///fFYt/rpRkWvnnTCi7LTxnx+LfPNREXGebrO+M1eCT/7vkmUuMiWZoh7QrmK+ldI8Q0+NV75bx+uoF7f3WAuo6eEY+p8qyx+vANS3n4hqV09PTzzZdOobUedqzWmrfPNXF5QcqwdVa9TcqabX3sOd/MtqLUEWe4ivlBgrwQQdbndPG7Q9WsyUnE3ufis788QFdv/7DjqtuM4L+lMJUvbC/gb68u4lRdJ4cq2oYdW9Zsp7qth+1LhqdhkjxB/t3zzbTYHWwrSg3wHYm5RIK8EEH2anEDrXYHX96xmJ9+ah0ljTa+9Nsj9LvcfscNndh0y9pM4qPC+NX75cM+8+2zxuzwK0bItYebTSREh/PW2UYAthZKkJ/PJMgLEWS/3VdBTnI02wpT2VZk5Vs3r+Sdc028fqrB77jqtm5iI8wkxhgrNsVEhPGxDTm8crKehs5ev2PfPtdEvjV21AWyU+KMZQCXZlhIi5cl+OYzCfJCBFFpk429Za3ceelCX178tnVZKAXnGmx+x1a39ZCdFOO3APenNy/CpTVP7qv0bevtd7G3rGVYVc1gqZ4KmysucoyYHyTICxFET+2rJMykuGPDwGLYUeFmshKjudDsH+Rr2nrIGjKxaVFKLFctSeO3+ypxOI30zr4LrfQ53RcN8t6+8pKPFxLkhQiS3n4Xvz9czXUr0kmz+KdM8lJjKWu2+22rbusecY3Vey7PpdnWx3dePs1zh6p5al8lkWEmNo1SWglGL5vocDOX5iYH5mbEnCV18kJMUb/LTX1H77D8+BtnGmnv7ueuETo/5qfG8vzhGrTWKKXo7O2ns9fp103Sa1thKssWxPPL98p9265dlkZU+Ogtg790VSEfXZd90WPE/CBBXogp+u83Svjp26W899DVWC0D/WFeP9VAYkw4m0cYceelxtLV56TZ5sBqiaTGUz6ZnTT8QarJpHjpb7bQ3tNPd58LW5+ThSkjP3D1sloi/a5FzF+SrhFiClxuzTMHqnA43bx2qt5v+5tnG7lqSRph5uH/m+V7ltq74EnZVPuC/MitDMLMJlLjIlmYEsPyzHhpGSzGTYK8EFPw7vkm6jt7CTMpXjk5EOSPVrXR1t3P1UvTRnxfXmosAGVNxsPXGs9s16EPXoWYKgnyQkzB7w5WkxwbwWe35PJ+aQttnsU6dp9uxGxSo5YwZiZGExFm8hvJR4WbfC0JhAgUCfJCTFKb3cGuUw3cvCaTm9dk4XJrdnkmOO0+3ciluUkkRIeP+F6zSZGbEuOrsKlu6yErMdqvRl6IQJAgL8QkvXC0BofLzR3rc1iRGU9OcjR/OVlHVWs3Zxu6uHZZ+kXfn5ca6xvJ17T3jPjQVYipkiAvxCQ9e7CaVVkJLM+MRynFjSsX8F5JMy8cNZYzHi0f75VvjaOixY7T5R61Rl6IqZIgL8QknKzp4FRdp99M1utXZtDv0vzkzVLyUmN9FTSjyUuNpd+lOddgo627Xx66iqCQOiwhhujtdwEMm0jU73Lz1tkmXjxWy+unGogKN3HT6kzf/jU5iWQmRFHb0TvmKB6MCVEAe0qMjpKSrhHBIEFeiCH+6olD9Dnd/Pbzm/y2/58XinlqfyVJMeHcui6LT2xcSGLMQDWMUorrVy7g8fcucM2ysYO8t4zy3fPNACPOdhViqiTIi5DW73Lz1d8f5zOX57I6J3HM491uzf4LrdgdLk7WdLAyy1iqr6mrj+cOVXPH+mz+7bZVhI8wwQngc9vyiI00s3EcPWOSYyNIiA5n34VWAHIkXSOCQHLyIqQduNDKH47U8NtBrXovprzFjt1hpGsGL9bx1P5KHC43X7yyYNQAD0b9+4PXLRlxlutQSinyUmNxON1EhBkzWoUINAnyIqS95qlb31PSPGyt1IefP86Lx2r9thXXdgKwYVESLx6rpcXWh8Pp5om9FWxfbKVgjIepE+XNy2clRss6rCIoJMiLkKW15vXTDUSYTdS09/hq0sHoGfPU/iqe2Fvh956TtR1EmE186+aVOJxunj5Qxcsn62js6uMzW3IDfo3evLyUT4pgmXKQV0rlKKXeVEqdUkoVK6Ue8GxPVkrtUkqd93xNmvrlCjF+Zxu6qG7r4bNbcwFjNO/18sk6AI5WtfuqaQCKazpZkmFheWY8WwtTeWJvBY+/V05eaizbR1hPdaq8ZZby0FUESyBG8k7gQa31cmAT8CWl1HLga8BurXURsNvzWohp411D9b4teWQnRfuqWABeOVlPRJgJh9PN0ap2wBj5F9d2sDIrHjAW66jr6OVYVTt3b14UlHSKjORFsE05yGut67TWhz3fdwGngSzgZmCn57CdwC1TPZcIDfsvtPLcoWqcLveUP2tonn2wXacbWZ2TSFp8FNuKUtlb2uKbXXq8uoN7t+ShFOwrM6pbajt6aevuZ3mmUVFz9dI0cpKjiY0wc/v67FHPMxWFaXH8r9WZXDNGCwQhJiugOXmlVC6wFtgHpGut6zy76gH5KRYA/PurZ3nwd8f4yH/t4YPSlkl/zjMHKtn8nTc4Xdc5bF9jpzEC3+GpV99aaKWrz8mx6nZfS+A7L81haUY8+8uNazhZ0wHAykxjJG82KX788bX85JPrsESN3GhsqiLCTPzXXWtZtiA+KJ8vRMCCvFIqDngO+Huttd//ddoYbo045FJK3a+UOqiUOtjU1BSoyxGzWLOtj8XpcXT1OrnrF3v55xeLJ/wZj+25wEPPnaC+s5f/3H1+2P7dZxoB2LE8A4DLC1JQyph49MrJepZmWMhNjeWyvGQOVbThcLopru3EpGBpxkDAXb8oiSuXjD2xSYjZKiBBXikVjhHgn9RaP+/Z3KCUWuDZvwBoHOm9WutHtNYbtNYbrNbAP9gSs0+zrY/N+SnsfnA7t6/PZucH5VS1do/rvVpr/mv3ef7lT6e4YWUGX9iez8sn6znX0OV33OunGshJjmZxuvFgMyk2glVZCbx4rJZDlW3csHIBAJvyk+ntd3Oipp3img4K0+KIjpB1UUXoCER1jQIeA05rrX84aNeLwD2e7+8BXpjqucTc53C66ex1khIXSVS4mS/vWIwCnj1YNa73v366kR/sOsdt67L4r7vW8sUrCoiJMPOTN0t8xzTb+thT0sy1y9L9+rNvLUylrMmO1nDDKmOEvzHPWH91b1krJ2s7WOnJxwsRKgIxkt8CfBq4Wil11PPvRuC7wA6l1HngWs9rMc+1elZOSokzer5kJUazfbGVZw9WjetB7BtnGrBEhfH9j15CmNlEUmwEn960iJeO1XKh2U5jVy93PbIXpeCO9Tl+791alApAvjWWojRjhJ8cG8Hi9Dj+cqKOhs4+lmdKblyElkBU1+zRWiut9SVa6zWef3/RWrdora/RWhdpra/VWrcG4oLF3NZs6wMgJXZgCv9dGxfS0NnHm2fHfibzQWkLl+Ul+7UNuG9bHuFmE9/5y2nufGQvNe09/PIzG4cF7PWLkkiNi+Cj67L9RviX5aX4Zrp6e9UIESpkxquYVi2ekXxq3ED3xquXppFmieSp/RfvL1Pb3kN5SzebC1L9tqdZorhr40JeO9VAQ0cvO+/dyOaClGHvjwwz885Xr+Kvthf4bb8sf6CZmIzkRaiRLpRiWrV4R/KDmnGFmU18bEMO//NWCbXtPWSOMvvTW255+QgB/K+vLKC2vYcvbC9g/aLRJ1fHRAz/kd+YZwT5RSkxxAepVFKImSIjeTGtWmz+OXmvj1+ag+biD2DfL20hKSacJemWYfvS4qN45O4NFw3wo0mzRLEiM35c7YGFmGtkJC+mVbO9jwizCUuk/49eTnIM24qs/OaDCj5ySSaFaf7dHrXW7C1rYXNBSlDaCzx9/6aLthAWYq6Sn2oxrVpsDlLiIvwefHp948PLUErx8Z9/QHFth9++ytZuatp72Jw/PFUTCJao8GHL/QkRCiTIi2nVYusblqrxKkq38OwXNhEZZuKuR/ZyuLLNt8+bjx/60FUIcXES5MWE9fa7+M7Lp8c9S3WwFrvDr3xyqHxrHM9+cTPJsRF86tF97Cszgvv7pS1YLZEUWGMnfd1CzEcS5MWEaK35+h9O8vO3y3jpeO3YbxjCm665mOykGJ79wmYWJETxmV8e4IPSFj4oa/H0n5HVk4SYCAnyYkKe3FfJc4erAShrso9xtD+tNS32vnGtZZoWH8VT928iOymaux/fR1NXX9Dy8UKEMgnyYtwOV7bxzZeKuXKJlY25yZQ12Sb0/m6Hi95+NymxFx/Je6VZjECfn2pU2lwu+XghJkyCvBiXjp5+/vqJw2QkRPHjj6+hIC3Ob83U8RiokR97JO+VGhfJs1/YzO++uJmFKTETOp8QQoL8rNTvcnP34/s5WD572v08ua/C6N1+51oSYyIosMbS1t1Pm6dNwXg0272zXcc3kvdKiAnnUpmoJMSkSJCfheo7ennnXBOvnx6xBf+kvVZcz8PPn7joknkj6XO6+NV75WwtTGXtQmNGab6nyqWsefwpG+9IPvUi1TVCiMCSID8LNXn6u1yYQAAdS2+/i2+8cJKn9leyt2xifyG8eLSWxq4+Pn9Fvm+bN09eOoGHrwN9ayY2khdCTJ4E+VnIO+IdKef96w/KOVTRNmz7WH79QTkNnX1Eh5v55XsXxv0+rTWPvnuBJekWrigaePCZnRRNuFlNqMLG24EyeZwPXoUQUydBfhbyjnjLW7pxuwdSKw6nm2+9dGpCQRqgs7ef/3mrlO2Lrdy7NZddpxvGPZHpnfPNnG3o4vNX5PvVqIeZTSxMjrnoXxuvFtdzpn5gud9mWx9xkWHSPkCIaSRBfhbyLqzhcLqp7ejxbS9rtuF0a843TCyN8+i7F2jv7ucfP7SET2/KxawUO98vH9d7f/FOGenxkdy0OnPYvnxr3Kgj+fdKmvniE4f4/itnfdvGMxFKCBFYEuRnoWbbQMXK4JTN2XpjseqyZtu4lsoD46+Cx94t48OrFrAyK4GMhChuWLWAZw5WYe9z4nS5+eFrZ7n639/ytRDwevFYLXtKmvnM5XlEhA3/Ucm3xlLR0o3L7f8gt9Xu4B+ePYrWcKSyzfegt8XeN+4aeSFEYEiQn4W8aQ3wD/JnPEG+36WpGGe65fH3LtDT7+Ifrlvs2/bZLbl09Tr5n7dKuOsXe/nPN0po7Xbwqcf28buDVWit+fHr5/i7p46wYVESn968aMTPLkiNw+FyU902cC1aax567jitdgef3rSItu5+3z0YI3mprBFiOkmQn4VabA6WZFiIiTD7pUPO1Xf5RtTnG7rG9VnHqztYlZVAgXWgP/u6hUmszknkJ2+Wcqq2kx99fDVvf+UqLstL4R9/f5wP/+cefvz6eW5bl8WTn7/M9wtnKF8Z5aBrfHJfJbtONfDQ9Uv51Cbjl8PhynbA+AslVdI1QkwrCfKzUIu9D2tcJLkpsZS3+I/kryiyAow7L1/eYic3dXjnxq9dv5Trlqfzp7/bxq1rs0mICeeXn72UT21ayOn6Tr52w1J+cMdqIsNGf0ia7/nFUeYZqbd3O/i3v5xmW1Eq927JoygtDktkGIcr23C7Na32vot2oBRCBJ6sDDULNdscXJobgdmsOFljLJ7R1dtPTXsPn7hsIWfqOznXOHaQdzjd1LT1cOva7GH7NhekDFvsOtxs4l9vWcVD1y/FMo61TpNiwkmIDvf1sHlyXyXdDhcP37DMt3rTmoWJHK5oo72nH7eWGnkhppuM5GcZp8tNW7eRu85PjaWqtRuH0805z8h9SbqForS4caVrqtu6cWvInWDPl/EEeAClFPnWWMqa7Dicbna+b8yKXZ4Z7ztm3cIkzjV0+f4ikZy8ENNLgvws09rtQGuwxkWQlxqLW0NVW7evsmZJhoXF6RbKmu1jVthUtBgPRBcFsbFXfmocZc02XjpmzIr93LY8v/3rFiXh1vCGp0VDqlTXCDGtJMjPMoM7NXpz6Rea7Jxr6CI2wkxWYjSFaXE4nG4qx6iw8Y6eF6UEbzWlfGssDZ19/OStEorS4ti+2Oq3f01OIgC7TjUAMpIXYroFJMgrpR5XSjUqpU4O2paslNqllDrv+ZoUiHOFOl8TL0+6BowyyjP1nSzOsGAyKRanWwA4P0ZevqKlm7jIsKDWphcMqrD53La8YSs3JUSHU5QWx1lPekly8kJMr0CN5H8FXD9k29eA3VrrImC357UYQ/OgJl6JMREkxYRT1mznbH0XSzzBvSDNqGoZKy9f3mJnUUpMUJfM81bYpMZFcPOarBGPWefpXKkUJMVIkBdiOgUkyGut3wGGtja8Gdjp+X4ncEsgzhXqvEHe2443LzWWA+WttHX3syTDCPJxkWFkJUaPaySfG8RUDRj5/sSYcD6/LX/UnjTrFxlBPjkmArNJ1mgVYjoFMyefrrWu83xfD6SPdJBS6n6l1EGl1MGmpqYgXs7c0GxzEG5WxEcb1a25qbGUNA5U1ngVpcddtFbe6XJT1dod1IeuAJFhZj742jXcP6gN8VDrFiUCkqoRYiZMy4NXbTQvGXGlCq31I1rrDVrrDVardaRD5pUWmzFhyJtiyR80kck7kgcoSoujtMmGy63RWvOtl07x07dKfftr23txunXQR/IA0RHmi6aE8lPjiI8Kk4lQQsyAYE6GalBKLdBa1ymlFgCBXeYoRLXYHaRaBka8eanenHekX2VKUZqFPqcxWv/ziToef+8CSTHh3H9FPmaTGlRZM/ProppMii/vWCx95IWYAcEcyb8I3OP5/h7ghSCeK2Q02/yn/ud5RvJLMuL8jitKN17/7O1S/v21syxMjqGtu58Tnhmy3gZmI7U0mAmf3ZI36oNZIUTwBKqE8ingA2CJUqpaKXUf8F1gh1LqPHCt57UYQ4vNQeqgEXtuagxKwdKMeL/jCj0VNk8fqGJlZgLPfGETSsFbZ40/mCqa7USFm0izSIpEiPksIOkarfVdo+y6JhCfP19orWmy9fl1aoyJCOPRuzewMivB71hLVDhZidE4XG5+cfcGMhKiWJ2dyFtnm/j7axdT7qmsCWb5pBBi9pMGZbOIrc+Jw+keVoVyzbIRC5P46afWER8VTkZCFABXLrHyH7vP02p3UNFi97UCFkLMX9LWYBYZPNt1PC7JTvTLuV+5JA2t4e1zjVS0Br9GXggx+0mQn0UGZrtOLo9+SVYCybERPL2/CofTHdSeNUKIuUGCfJC02R28crJu7AMHafaN5CdXamgyKbYVpbLvgjH5eKIthoUQoUeCfJA8daCSLz5xmJr2nnG/x9fSYAqdGq9cMjChbNEsKZ8UQswcCfJBUuWpU/eu7DQe3pz8VCYNXVFkRSmIMJvIiI+a9OcIIUKDBPkgqW4zRvDFtZ3jfk+LvY/EmHDCzZP/z5ISF8kl2YksSomRZmBCCCmhDBbvSL54AiN5Y7br1Kf+//vtl9DT75ry5wgh5j4J8kHgdmtfLn4iI/nmIbNdJ6toULdKIcT8JumaIGjs6qPfpclLjaW+s9f3QHUszba+gAR5IYTwkiAfBNVtRqrmQysygPGP5ltsDum5LoQIKAnyk2C0xx+d96HrdSuMdgTjqbBxON109PTLSF4IEVAS5CfhOy+f4bofve0bsQ/l3b58QTwLk2Morh07yDcNWttVCCECRYL8JLxxppFzDTY+/vO9lDfbh+2vau0hNS6SqHAzKzLjx5Wu+d3BKmBg0WshhAgECfIT1O1wUtpk48ZVGfT0u/jYzz+gpLHL75jq9m6yk6IBWJmVQEVLN529/aN+ZkdPP4/vucCO5eksWxA/6nFCCDFREuQn6HRdF1rDrWuzefr+TWjgvp0H/fL01W095CQbfWOWZxpB+9RFRvO/eq+czl4nD1xTFNRrF0LMPxLkJ8ibX1+RGc/idAsPXFNERUs3Va3Gw1aXW1Pb3jMwks80FvsY7eFrZ28/j+0p49pl6cMWBhFCiKmSID9BxTWdJMWEs8CzUMeluckAHCg3Oj82dvXS79K+IG+1RJJmiRx1JL9TRvFCiCCSID9BJ2s7WJmV4FtWrygtjvioMA5WGEHeWz6ZnTTQ5ndlVgInB1XY9Pa7OFrVzm/2VvDongtcuyyNVdkyihdCBJ60NZgAh9PNuYYu7t2a59tmMik25CZzoLwNGOhZ4x3Jg5HaeetsI3c9spfq9m5q23txuY0cfkZ8FF/50JJpvAshxHwiQX4CzjV00e/Svjy714bcJN4400ir3eEbyWclDgT5q5em8eKxWvqcLtbmJHHLmhhWZMazMiuBrMRoWWxbCBE08ybI9zldHK5oZ/2iJCLCJpel8ubVV2T6lzl68/KHKtqobusmzWLUyHutXZjE2/941SSvXAghJi/kg/zpuk6eOVDFH4/W0N7dz4M7FvO3gx5yaq15+PkTxEWGcffmXBZeZMm8k7UdxEaYhy2QvSorgQiziYPlrVS39filaoQQYiaFdJC/0Gznxv98l3CzieuWp1PV1sMT+yr44pUFvoU53jrXxNMHjNmmj713gWuXpfO/P7xsxEWwi2s7WZ4Zj2nIYhxR4WYuyU7gQHkrzTYHa3ISg35vQggxHkGvrlFKXa+UOquUKlFKfS3Y5xusvNmO1vCbezfy359YxwPXFNLQ2cerxfW+Yx55u4yM+Cje/epVfOnKQt4raeb7r5wd9lkut+ZUbScrMkeugtmQm8yJmg6/GnkhhJhpQQ3ySikz8BPgBmA5cJdSankwzzlYXUcvgG/26fbFaSxMjmHn++UAHK9u54OyFu7dmktOcgxf+dASrlqSxvGa9mGfdaHZTk+/a1g+3mtjXhL9Lo3Trf3KJ4UQYiYFeyS/ESjRWpdprR3A08DNQT6nT31HDyZlTEgCMJsUd29exIHyNoprO/j5O2VYIsO4a+NC33tWZMVT1dpDR7d/rxnvTNfRZqWuX5js+15G8kKI2SLYQT4LqBr0utqzzUcpdb9S6qBS6mBTU1NAT17f2YvVEum3MPYd63OICjfx3ZfP8PKJOj6xaSGWqHDffm955ND2wMW1nUSEmShMixvxXAkx4SzxLLvn/ctBCCFm2ozPeNVaP6K13qC13mC1WgP62XUdvWQk+I+qE2LCuXVtFu+eb8ZsUty7Jc9vv3ekfnJIkD9W1c7SDIvfL4yhNuQmYVL4Wh4IIcRMC3aQrwFyBr3O9mybFvUdvSyIHx5w77k8F4Cb12SRPmR/cmwEWYnRnKwZ6DXT7XByuLKNy/KSuZi/vbqIRz69wa9GXgghZlKwSygPAEVKqTyM4H4n8Ikgn9OnvqOXLYWpw7YvzYhn570bWT1Kv5gVmfF+XSP3XWil36XZVnTxvzQyEqLIkFG8EGIWCepIXmvtBP4GeBU4DTyrtS4O9HmOV7fz0O+P+z0stfU56epzjhp0ty+2khgz8lJ7K7MSKGu20+VZ6GPP+WYiwkxsHGMkL4QQs03Qc/Ja679orRdrrQu01t8OxjlabA6eOVhFSdPACk31nvLJyeTHV2YZZZKn64zP23O+mUtzkyQNI4SYc2b8wWsg5FuN2amljQPrrXqDfMYIOfmxDF7oo6Gzl7MNXWwtDOxDYSGEmA4h0dYgOymGCLOJ0iabb1tdh9ENcjI58rT4KNIskZys7SAh2iiv3FY0PLcvhBCzXUgEebNJkZca6xfkvSP5odUz47UyK4GTNR1oDSmxESyXBbaFEHNQSAR5gIK0WM7UDcrJd/aSHBsx6Tz6Ss9CH802B1sKU4c1JRNCiLkgJHLyAAXWOCpau3E43YAxkp9MPt5rRVYCbg2tdgfbRijDFEKIuSCkgrzLralsNR6+1nX0Tmnm6apBPWq2Sj5eCDFHhUyQ91bYlHgqbOo7e0mfQpBfkBBFcmwEBdZYMhOl4ZgQYm4KmZx8vtVoHFbaZKO330Wr3TFiS4PxUkrx4HWLSRplwpQQQswFIRPk4yLDyIiPorTJRmNnHzC58snBPnnZokBcmhBCzJiQSdeAUWFT2mT31cgvSJA0ixBifgutIG+No6zJ5lsRKiMhcoavSAghZlbIBfmuXicnPB0kh/aSF0KI+Sakgry3wua9kmYskWHERYbMIwchhJiUkAryBZ4KmzP1XdLXXQghCLEgnxEfRUyE0cZAgrwQQoRYkDeZlC9lM5WWBkIIESpCKsjDQMpGFtMWQogQDPL5qUaQl8oaIYQIwSBfkOZJ10iNvBBChF6Qv3JJGp/bmsem/JSZvhQhhJhxIVdIHhcZxv/+yPKZvgwhhJgVQm4kL4QQYoAEeSGECGES5IUQIoRJkBdCiBA2pSCvlLpDKVWslHIrpTYM2fewUqpEKXVWKfWhqV2mEEKIyZhqdc1J4Dbg54M3KqWWA3cCK4BM4HWl1GKttWuK5xNCCDEBUxrJa61Pa63PjrDrZuBprXWf1voCUAJsnMq5hBBCTFywcvJZQNWg19WebcMope5XSh1USh1samoK0uUIIcT8NGa6Rin1OpAxwq6va61fmOoFaK0fAR7xnKtJKVUxyY9KBZqnej1zjNzz/CD3PD9M5Z4XjbZjzCCvtb52EiesAXIGvc72bBvrXNZJnAsApdRBrfWGsY8MHXLP84Pc8/wQrHsOVrrmReBOpVSkUioPKAL2B+lcQgghRjHVEspblVLVwGbgz0qpVwG01sXAs8Ap4BXgS1JZI4QQ029KJZRa6z8Afxhl37eBb0/l8yfokWk812wh9zw/yD3PD0G5Z6W1DsbnCiGEmAWkrYEQQoQwCfJCCBHCQiLIK6Wu9/TIKVFKfW2mrycYlFKPK6UalVInB21LVkrtUkqd93xNmslrDDSlVI5S6k2l1ClPj6QHPNtD9r6VUlFKqf1KqWOee/6mZ3ueUmqf52f8GaVUxExfayAppcxKqSNKqT95Xof6/ZYrpU4opY4qpQ56tgXl53rOB3mllBn4CXADsBy4y9M7J9T8Crh+yLavAbu11kXAbs/rUOIEHtRaLwc2AV/y/LcN5fvuA67WWq8G1gDXK6U2Ad8DfqS1LgTagPtm7hKD4gHg9KDXoX6/AFdprdcMqo0Pys/1nA/yGD1xSrTWZVprB/A0Ru+ckKK1fgdoHbL5ZmCn5/udwC3TeU3BprWu01of9nzfhREEsgjh+9YGm+dluOefBq4Gfu/ZHlL3rJTKBj4MPOp5rQjh+72IoPxch0KQH3efnBCUrrWu83xfD6TP5MUEk1IqF1gL7CPE79uTujgKNAK7gFKgXWvt9BwSaj/jPwa+Crg9r1MI7fsF4xf3a0qpQ0qp+z3bgvJzHXILec9XWmutlArJelilVBzwHPD3WutOY6BnCMX79kwcXKOUSsSYh7J0Zq8oeJRSHwEatdaHlFJXzvDlTKetWusapVQasEspdWbwzkD+XIfCSH5SfXJCRINSagGA52vjDF9PwCmlwjEC/JNa6+c9m0P+vgG01u3AmxgzyhOVUt5BWSj9jG8BblJKlWOkWq8G/oPQvV8AtNY1nq+NGL/INxKkn+tQCPIHgCLP0/gIjMVKXpzha5ouLwL3eL6/B5hyV9DZxJObfQw4rbX+4aBdIXvfSimrZwSPUioa2IHxLOJN4HbPYSFzz1rrh7XW2VrrXIz/d9/QWn+SEL1fAKVUrFLK4v0euA5jAaag/FyHxIxXpdSNGHk9M/C4p6VCSFFKPQVcidGOtAH4v8AfMXoELQQqgI9prYc+nJ2zlFJbgXeBEwzka/8JIy8fkvetlLoE46GbGWMQ9qzW+ltKqXyMkW4ycAT4lNa6b+auNPA86ZqvaK0/Esr367k3bzuYMOC3WutvK6VSCMLPdUgEeSGEECMLhXSNEEKIUUiQF0KIECZBXgghQpgEeSGECGES5IUQIoRJkBdCiBAmQV4IIULY/wfp3iPkf5jG9QAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(xdata, ydata)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will write the Stan code, keeping in mind that it must be able to compute the pointwise log likelihood on excluded data, that is, data which is not used to fit the model. Thus, the backbone of the code must look like:\n", "\n", "```\n", "data {\n", " data_for_fitting\n", " excluded_data\n", " ...\n", "}\n", "model {\n", " // fit against data_for_fitting\n", " ...\n", "}\n", "generated quantities {\n", " ....\n", " log_lik for data_for_fitting\n", " log_lik_excluded for excluded_data\n", "}\n", "```" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "refit_lr_code = \"\"\"\n", "data {\n", " // Define data for fitting\n", " int N;\n", " vector[N] x;\n", " vector[N] y;\n", " // Define excluded data. It will not be used when fitting.\n", " int N_ex;\n", " vector[N_ex] x_ex;\n", " vector[N_ex] y_ex;\n", "}\n", "\n", "parameters {\n", " real b0;\n", " real b1;\n", " real sigma_e;\n", "}\n", "\n", "model {\n", " b0 ~ normal(0, 10);\n", " b1 ~ normal(0, 10);\n", " sigma_e ~ normal(0, 10);\n", " for (i in 1:N) {\n", " y[i] ~ normal(b0 + b1 * x[i], sigma_e); // use only data for fitting\n", " }\n", " \n", "}\n", "\n", "generated quantities {\n", " vector[N] log_lik;\n", " vector[N_ex] log_lik_ex;\n", " vector[N] y_hat;\n", " \n", " for (i in 1:N) {\n", " // calculate log likelihood and posterior predictive, there are \n", " // no restrictions on adding more generated quantities\n", " log_lik[i] = normal_lpdf(y[i] | b0 + b1 * x[i], sigma_e);\n", " y_hat[i] = normal_rng(b0 + b1 * x[i], sigma_e);\n", " }\n", " for (j in 1:N_ex) {\n", " // calculate the log likelihood of the excluded data given data_for_fitting\n", " log_lik_ex[j] = normal_lpdf(y_ex[j] | b0 + b1 * x_ex[j], sigma_e);\n", " }\n", "}\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Building... This may take some time.\n", "Done.\n", "Sampling...\n", " 0/8000 [>---------------------------] 0% 1 sec/0 \n", " 8000/8000 [============================] 100% 1 sec/0 \n", " 8000/8000 [============================] 100% 1 sec/0 Messages received during sampling:\n", " Gradient evaluation took 4.2e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.42 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 4.2e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.42 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 0.00014 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 1.4 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 3.9e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.39 seconds.\n", " Adjust your expectations accordingly!\n", "\n", " 8000/8000 [============================] 100% 1 sec/0 \n", "Done.\n" ] } ], "source": [ "data_dict = {\n", " \"N\": len(ydata),\n", " \"y\": ydata,\n", " \"x\": xdata,\n", " # No excluded data in initial fit\n", " \"N_ex\": 0,\n", " \"x_ex\": [],\n", " \"y_ex\": [],\n", "}\n", "sm = stan.build(program_code=refit_lr_code, data=data_dict)\n", "sample_kwargs = {\"num_samples\": 1000, \"num_chains\": 4}\n", "fit = sm.sample(**sample_kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have defined a dictionary `sample_kwargs` that will be passed to the `SamplingWrapper` in order to make sure that all\n", "refits use the same sampler parameters. We follow the same pattern with `az.from_pystan`." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "dims = {\"y\": [\"time\"], \"x\": [\"time\"], \"log_likelihood\": [\"time\"], \"y_hat\": [\"time\"]}\n", "idata_kwargs = {\n", " \"posterior_predictive\": [\"y_hat\"],\n", " \"observed_data\": \"y\",\n", " \"constant_data\": \"x\",\n", " \"log_likelihood\": [\"log_lik\", \"log_lik_ex\"],\n", " \"dims\": dims,\n", "}\n", "idata = az.from_pystan(posterior=fit, posterior_model=sm, **idata_kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will create a subclass of {class}`~arviz.PyStanSamplingWrapper`. Therefore, instead of having to implement all functions required by {func}`~arviz.reloo` we only have to implement `sel_observations`. As explained in its docs, it takes one argument which are the indices of the data to be excluded and returns `modified_observed_data` which is passed as `data` to `sampling` function of PyStan model and `excluded_observed_data` which is used to retrieve the log likelihood of the excluded data (as passing the excluded data would make no sense)." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class LinearRegressionWrapper(az.PyStanSamplingWrapper):\n", " def sel_observations(self, idx):\n", " xdata = self.idata_orig.constant_data.x.values\n", " ydata = self.idata_orig.observed_data.y.values\n", " mask = np.full_like(xdata, True, dtype=bool)\n", " mask[idx] = False\n", " N_obs = len(mask)\n", " N_ex = np.sum(~mask)\n", " observations = {\n", " \"N\": int(N_obs - N_ex),\n", " \"x\": xdata[mask],\n", " \"y\": ydata[mask],\n", " \"N_ex\": int(N_ex),\n", " \"x_ex\": xdata[~mask],\n", " \"y_ex\": ydata[~mask],\n", " }\n", " return observations, \"log_lik_ex\"" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.85 7.20\n", "p_loo 3.05 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 100 100.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 0 0.0%\n", " (1, Inf) (very bad) 0 0.0%" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_orig = az.loo(idata, pointwise=True)\n", "loo_orig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, the Leave-One-Out Cross Validation (LOO-CV) approximation using Pareto Smoothed Importance Sampling (PSIS) works for all observations, so we will use modify `loo_orig` in order to make {func}`~arviz.reloo` believe that PSIS failed for some observations. This will also serve as a validation of our wrapper, as the PSIS LOO-CV already returned the correct value." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "loo_orig.pareto_k[[13, 42, 56, 73]] = np.array([0.8, 1.2, 2.6, 0.9])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We initialize our sampling wrapper" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "pystan_wrapper = LinearRegressionWrapper(\n", " refit_lr_code, idata_orig=idata, sample_kwargs=sample_kwargs, idata_kwargs=idata_kwargs\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And eventually, we can use this wrapper to call `az.reloo`, and compare the results with the PSIS LOO-CV results." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Building...\n", "/home/ahartikainen/github_ubuntu/arviz/arviz/stats/stats_refitting.py:99: UserWarning: reloo is an experimental and untested feature\n", " warnings.warn(\"reloo is an experimental and untested feature\", UserWarning)\n", "Building...\n", "Found model in cache. Done.\n", "Sampling...\n", " 0/8000 [>---------------------------] 0% 1 sec/0 \n", " 8000/8000 [============================] 100% 1 sec/0 \n", " 8000/8000 [============================] 100% 1 sec/0 Messages received during sampling:\n", " Gradient evaluation took 4e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.4 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 5.9e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.59 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 4.4e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.44 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 3.8e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.38 seconds.\n", " Adjust your expectations accordingly!\n", "\n", " 8000/8000 [============================] 100% 1 sec/0 \n", "Done.\n", "Building...\n", "Found model in cache. Done.\n", "Sampling...\n", " 0/8000 [>---------------------------] 0% 1 sec/0 \n", " 8000/8000 [============================] 100% 1 sec/0 \n", " 8000/8000 [============================] 100% 1 sec/0 Messages received during sampling:\n", " Gradient evaluation took 2e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.2 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 1.9e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.19 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 2.2e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.22 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 2.4e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.24 seconds.\n", " Adjust your expectations accordingly!\n", "\n", " 8000/8000 [============================] 100% 1 sec/0 \n", "Done.\n", "Building...\n", "Found model in cache. Done.\n", "Sampling...\n", " 0/8000 [>---------------------------] 0% 1 sec/0 \n", " 8000/8000 [============================] 100% 1 sec/0 \n", " 8000/8000 [============================] 100% 1 sec/0 Messages received during sampling:\n", " Gradient evaluation took 1.9e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.19 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 1.6e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.16 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 1.8e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.18 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 1.3e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.13 seconds.\n", " Adjust your expectations accordingly!\n", "\n", " 8000/8000 [============================] 100% 1 sec/0 \n", "Done.\n", "Building...\n", "Found model in cache. Done.\n", "Sampling...\n", " 0/8000 [>---------------------------] 0% 1 sec/0 \n", " 8000/8000 [============================] 100% 1 sec/0 \n", " 8000/8000 [============================] 100% 1 sec/0 Messages received during sampling:\n", " Gradient evaluation took 1.7e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.17 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 1.8e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.18 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 2.2e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.22 seconds.\n", " Adjust your expectations accordingly!\n", " Gradient evaluation took 2.6e-05 seconds\n", " 1000 transitions using 10 leapfrog steps per transition would take 0.26 seconds.\n", " Adjust your expectations accordingly!\n", "\n", " 8000/8000 [============================] 100% 1 sec/0 \n", "Done.\n" ] } ], "source": [ "loo_relooed = az.reloo(pystan_wrapper, loo_orig=loo_orig)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.85 7.20\n", "p_loo 3.05 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 100 100.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 0 0.0%\n", " (1, Inf) (very bad) 0 0.0%" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_relooed" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Computed from 4000 by 100 log-likelihood matrix\n", "\n", " Estimate SE\n", "elpd_loo -250.85 7.20\n", "p_loo 3.05 -\n", "------\n", "\n", "Pareto k diagnostic values:\n", " Count Pct.\n", "(-Inf, 0.5] (good) 96 96.0%\n", " (0.5, 0.7] (ok) 0 0.0%\n", " (0.7, 1] (bad) 2 2.0%\n", " (1, Inf) (very bad) 2 2.0%" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loo_orig" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }