{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(numba_for_arviz)=\n", "# Numba - an overview" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Numba](https://numba.pydata.org/numba-doc/latest/index.html) is a just-in-time compiler for Python that works best on code that uses NumPy arrays and functions, and loops.\n", "ArviZ includes {ref}`Numba as an optional dependency ` and a number of functions have been included in `utils.py` for systems in which Numba is pre-installed. Additional functionality, {class}`arviz.Numba`, of disabling/re-enabling numba for systems that have Numba installed has also been included." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A simple example to display the effectiveness of Numba" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import numpy as np\n", "import timeit\n", "\n", "from arviz.utils import conditional_jit, Numba\n", "from arviz.stats.diagnostics import ks_summary" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data = np.random.randn(1000000)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def variance(data, ddof=0): # Method to calculate variance without using numba\n", " a_a, b_b = 0, 0\n", " for i in data:\n", " a_a = a_a + i\n", " b_b = b_b + i * i\n", " var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)\n", " var = var * (len(data) / (len(data) - ddof))\n", " return var" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "140 ms ± 2.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ "%timeit variance(data, ddof=1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "@conditional_jit\n", "def variance_jit(data, ddof=0): # Calculating variance with numba\n", " a_a, b_b = 0, 0\n", " for i in data:\n", " a_a = a_a + i\n", " b_b = b_b + i * i\n", " var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)\n", " var = var * (len(data) / (len(data) - ddof))\n", " return var" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.03 ms ± 44.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%timeit variance_jit(data, ddof=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That is almost 150 times faster!! Let's compare this to NumPy" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.79 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%timeit np.var(data, ddof=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In certain scenarios, Numba can even outperform NumPy! " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Numba within ArviZ\n", "Let's see Numba's effect on a few of ArviZ functions" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "summary_data = np.random.randn(1000, 100, 10)\n", "school = az.load_arviz_data(\"centered_eight\").posterior[\"mu\"].values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The methods of the {class}`~arviz.Numba` class can be used to enable or disable numba. The attribute `numba_flag` indicates whether numba is enabled within ArviZ or not." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Numba.disable_numba()\n", "Numba.numba_flag" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "57.8 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ "%timeit ks_summary(summary_data)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "462 µs ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "%timeit ks_summary(school)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Numba.enable_numba()\n", "Numba.numba_flag" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "7.18 ms ± 359 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%timeit ks_summary(summary_data)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "359 µs ± 62.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "%timeit ks_summary(school)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Numba has provided a substantial speedup once again." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }