{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(working_with_InferenceData)=\n", "\n", "# Working with InferenceData\n", "\n", "Here we present a collection of common manipulations you can use while working with `InferenceData`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import numpy as np\n", "import xarray as xr\n", "\n", "xr.set_options(display_expand_data=False, display_expand_attrs=False);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`display_expand_data=False` makes the default view for {class}`xarray.DataArray` fold the data values to a single line. To explore the values, click on the {fas}`database` icon on the left of the view, right under the `xarray.DataArray` text. It has no effect on `Dataset` objects that already default to folded views.\n", "\n", "`display_expand_attrs=False` folds the attributes in both `DataArray` and `Dataset` objects to keep the views shorter. In this page we print DataArrays and Datasets several times and they always have the same attributes." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", " \n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> sample_stats\n", "\t> prior\n", "\t> observed_data" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata = az.load_arviz_data(\"centered_eight\")\n", "idata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get the dataset corresponding to a single group" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
       "Coordinates:\n",
       "  * chain    (chain) int64 0 1 2 3\n",
       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
       "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
       "Data variables:\n",
       "    mu       (chain, draw) float64 -3.477 -2.456 -2.826 ... 4.597 5.899 0.1614\n",
       "    theta    (chain, draw, school) float64 1.669 -8.537 -2.623 ... 10.59 4.523\n",
       "    tau      (chain, draw) float64 3.73 2.075 3.703 4.146 ... 8.346 7.711 5.407\n",
       "Attributes: (3)
" ], "text/plain": [ "\n", "Dimensions: (chain: 4, draw: 500, school: 8)\n", "Coordinates:\n", " * chain (chain) int64 0 1 2 3\n", " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", " * school (school) object 'Choate' 'Deerfield' ... \"St. Paul's\" 'Mt. Hermon'\n", "Data variables:\n", " mu (chain, draw) float64 -3.477 -2.456 -2.826 ... 4.597 5.899 0.1614\n", " theta (chain, draw, school) float64 1.669 -8.537 -2.623 ... 10.59 4.523\n", " tau (chain, draw) float64 3.73 2.075 3.703 4.146 ... 8.346 7.711 5.407\n", "Attributes: (3)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "post = idata.posterior\n", "post" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ":::{tip} \n", "You'll have noticed we stored the posterior group in a new variable: `post`. As `.copy()` was not called, now using `idata.posterior` or `post` is equivalent.\n", "\n", "Use this to keep your code short yet easy to read. Store the groups you'll need very often as separate variables to use explicitly, but don't delete the InferenceData parent. You'll need it for many ArviZ functions to work properly. For example: {func}`~arviz.plot_pair` needs data from `sample_stats` group to show divergences, {func}`~arviz.compare` needs data from both `log_likelihood` and `posterior` groups, {func}`~arviz.plot_loo_pit` needs not 2 but 3 groups: `log_likelihood`, `posterior_predictive` and `posterior`.\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Add a new variable\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
       "Coordinates:\n",
       "  * chain    (chain) int64 0 1 2 3\n",
       "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
       "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
       "Data variables:\n",
       "    mu       (chain, draw) float64 -3.477 -2.456 -2.826 ... 4.597 5.899 0.1614\n",
       "    theta    (chain, draw, school) float64 1.669 -8.537 -2.623 ... 10.59 4.523\n",
       "    tau      (chain, draw) float64 3.73 2.075 3.703 4.146 ... 8.346 7.711 5.407\n",
       "    log_tau  (chain, draw) float64 1.316 0.7301 1.309 ... 2.122 2.043 1.688\n",
       "Attributes: (3)
" ], "text/plain": [ "\n", "Dimensions: (chain: 4, draw: 500, school: 8)\n", "Coordinates:\n", " * chain (chain) int64 0 1 2 3\n", " * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", " * school (school) object 'Choate' 'Deerfield' ... \"St. Paul's\" 'Mt. Hermon'\n", "Data variables:\n", " mu (chain, draw) float64 -3.477 -2.456 -2.826 ... 4.597 5.899 0.1614\n", " theta (chain, draw, school) float64 1.669 -8.537 -2.623 ... 10.59 4.523\n", " tau (chain, draw) float64 3.73 2.075 3.703 4.146 ... 8.346 7.711 5.407\n", " log_tau (chain, draw) float64 1.316 0.7301 1.309 ... 2.122 2.043 1.688\n", "Attributes: (3)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "post[\"log_tau\"] = np.log(post[\"tau\"])\n", "idata.posterior" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combine chains and draws" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
       "Dimensions:  (school: 8, sample: 2000)\n",
       "Coordinates:\n",
       "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
       "  * sample   (sample) MultiIndex\n",
       "  - chain    (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 0 3 3 3 3 3 3 3 3 3 3 3 3 3\n",
       "  - draw     (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
       "Data variables:\n",
       "    mu       (sample) float64 -3.477 -2.456 -2.826 -1.996 ... 4.597 5.899 0.1614\n",
       "    theta    (school, sample) float64 1.669 -6.239 2.195 ... -1.095 4.013 4.523\n",
       "    tau      (sample) float64 3.73 2.075 3.703 4.146 ... 8.589 8.346 7.711 5.407\n",
       "    log_tau  (sample) float64 1.316 0.7301 1.309 1.422 ... 2.122 2.043 1.688\n",
       "Attributes: (3)
" ], "text/plain": [ "\n", "Dimensions: (school: 8, sample: 2000)\n", "Coordinates:\n", " * school (school) object 'Choate' 'Deerfield' ... \"St. Paul's\" 'Mt. Hermon'\n", " * sample (sample) MultiIndex\n", " - chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 0 3 3 3 3 3 3 3 3 3 3 3 3 3\n", " - draw (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n", "Data variables:\n", " mu (sample) float64 -3.477 -2.456 -2.826 -1.996 ... 4.597 5.899 0.1614\n", " theta (school, sample) float64 1.669 -6.239 2.195 ... -1.095 4.013 4.523\n", " tau (sample) float64 3.73 2.075 3.703 4.146 ... 8.589 8.346 7.711 5.407\n", " log_tau (sample) float64 1.316 0.7301 1.309 1.422 ... 2.122 2.043 1.688\n", "Attributes: (3)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "stacked = az.extract(idata)\n", "stacked" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also use {meth}`xarray.Dataset.stack` if you only want to combine the chain and draw dimensions. {func}`arviz.extract` is a convenience function aimed at taking care of the most common subsetting operations with MCMC samples. It can:\n", "- Combine chains and draws\n", "- Return a subset of variables (with optional filtering with regular expressions or string matching)\n", "- Return a subset of samples. Moreover by default it returns a random subset to prevent getting non-representative samples due to bad mixing.\n", "- Access any group" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(idata/random_subset)=\n", "## Get a random subset of the samples" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
       "Dimensions:  (school: 8, sample: 100)\n",
       "Coordinates:\n",
       "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
       "  * sample   (sample) MultiIndex\n",
       "  - chain    (sample) int64 0 0 0 3 1 2 2 2 3 2 1 0 ... 3 3 3 1 1 0 1 1 1 1 2 3\n",
       "  - draw     (sample) int64 419 274 161 193 178 203 ... 238 49 212 448 95 412\n",
       "Data variables:\n",
       "    mu       (sample) float64 6.95 7.4 4.131 1.644 ... 5.142 1.917 8.56 3.486\n",
       "    theta    (school, sample) float64 8.734 3.829 19.1 ... 3.086 11.3 3.606\n",
       "    tau      (sample) float64 1.867 1.603 8.83 1.929 ... 1.636 5.707 5.941 1.582\n",
       "    log_tau  (sample) float64 0.6243 0.4717 2.178 0.6572 ... 1.742 1.782 0.4588\n",
       "Attributes: (3)
" ], "text/plain": [ "\n", "Dimensions: (school: 8, sample: 100)\n", "Coordinates:\n", " * school (school) object 'Choate' 'Deerfield' ... \"St. Paul's\" 'Mt. Hermon'\n", " * sample (sample) MultiIndex\n", " - chain (sample) int64 0 0 0 3 1 2 2 2 3 2 1 0 ... 3 3 3 1 1 0 1 1 1 1 2 3\n", " - draw (sample) int64 419 274 161 193 178 203 ... 238 49 212 448 95 412\n", "Data variables:\n", " mu (sample) float64 6.95 7.4 4.131 1.644 ... 5.142 1.917 8.56 3.486\n", " theta (school, sample) float64 8.734 3.829 19.1 ... 3.086 11.3 3.606\n", " tau (sample) float64 1.867 1.603 8.83 1.929 ... 1.636 5.707 5.941 1.582\n", " log_tau (sample) float64 0.6243 0.4717 2.178 0.6572 ... 1.742 1.782 0.4588\n", "Attributes: (3)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "az.extract(idata, num_samples=100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ":::{tip}\n", "Use a random seed to get the same subset from multiple groups: `az.extract(idata, num_samples=100, rng=3)` and `az.extract(idata, group=\"log_likelihood\", num_samples=100, rng=3)` will continue to have matching samples\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Obtain a NumPy array for a given parameter\n", "\n", "Let's say we want to get the values for `mu` as a NumPy array." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([-3.47698606, -2.45587061, -2.82625433, ..., 4.59705819,\n", " 5.89850592, 0.16138927])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "stacked.mu.values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get the dimension lengths\n", "\n", "Let's check how many groups are in our hierarchical model." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "8" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(idata.observed_data.school)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get coordinate values\n", "\n", "What are the names of the groups in our hierarchical model? You can access them from the coordinate name `school` in this case" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'school' (school: 8)>\n",
       "'Choate' 'Deerfield' 'Phillips Andover' ... "St. Paul's" 'Mt. Hermon'\n",
       "Coordinates:\n",
       "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
" ], "text/plain": [ "\n", "'Choate' 'Deerfield' 'Phillips Andover' ... \"St. Paul's\" 'Mt. Hermon'\n", "Coordinates:\n", " * school (school) object 'Choate' 'Deerfield' ... \"St. Paul's\" 'Mt. Hermon'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata.observed_data.school" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get a subset of chains\n", "\n", "Let's keep only chain 0 and 2 here. For the subset to take effect on all relevant InferenceData groups: posterior, sample_stats, log_likelihood, posterior_predictive we will use the {meth}`arviz.InferenceData.sel`, the method of InferenceData instead of {meth}`xarray.Dataset.sel`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 2, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 2\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu       (chain, draw) float64 -3.477 -2.456 -2.826 ... -1.571 -4.435 9.763\n",
             "    theta    (chain, draw, school) float64 1.669 -8.537 -2.623 ... 12.01 16.67\n",
             "    tau      (chain, draw) float64 3.73 2.075 3.703 4.146 ... 2.812 12.18 4.453\n",
             "    log_tau  (chain, draw) float64 1.316 0.7301 1.309 1.422 ... 1.034 2.5 1.494\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 2, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 2\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 7.85 -19.03 -22.5 ... 9.892 17.29\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:           (chain: 2, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain             (chain) int64 0 2\n",
             "  * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499\n",
             "  * school            (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tune              (chain, draw) bool True False False ... False False False\n",
             "    depth             (chain, draw) int64 5 3 3 4 5 5 4 4 5 ... 4 4 4 5 4 4 4 5\n",
             "    tree_size         (chain, draw) float64 31.0 7.0 7.0 15.0 ... 15.0 15.0 31.0\n",
             "    lp                (chain, draw) float64 -59.05 -56.19 ... -63.1 -61.91\n",
             "    energy_error      (chain, draw) float64 0.07387 -0.1841 ... 1.118 -0.5052\n",
             "    step_size_bar     (chain, draw) float64 0.2417 0.2417 ... 0.2501 0.2501\n",
             "    max_energy_error  (chain, draw) float64 0.131 -0.2067 ... 4.38 -0.5052\n",
             "    energy            (chain, draw) float64 60.76 62.76 64.4 ... 68.89 67.32\n",
             "    mean_tree_accept  (chain, draw) float64 0.9506 0.9906 ... 0.1054 0.9791\n",
             "    step_size         (chain, draw) float64 0.1275 0.1275 ... 0.2075 0.2075\n",
             "    diverging         (chain, draw) bool False False False ... False False False\n",
             "    log_likelihood    (chain, draw, school) float64 -5.168 -4.589 ... -3.843\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain      (chain) int64 0\n",
             "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
             "  * school     (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau        (chain, draw) float64 6.561 1.016 68.91 ... 1.56 5.949 0.7631\n",
             "    tau_log__  (chain, draw) float64 1.881 0.01593 4.233 ... 1.783 -0.2704\n",
             "    mu         (chain, draw) float64 5.293 0.8137 0.7122 ... -1.658 -3.273\n",
             "    theta      (chain, draw, school) float64 2.357 7.371 7.251 ... -3.775 -3.555\n",
             "    obs        (chain, draw, school) float64 -3.54 6.769 19.68 ... -21.16 -6.071\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> sample_stats\n", "\t> prior\n", "\t> observed_data" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata.sel(chain=[0, 2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Remove the first n draws (burn-in)\n", "\n", "Let's say we want to remove the first 100 samples, from all the chains and all `InferenceData` groups with draws." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 100 101 102 103 104 105 ... 494 495 496 497 498 499\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu       (chain, draw) float64 4.271 4.517 0.3265 ... 4.597 5.899 0.1614\n",
             "    theta    (chain, draw, school) float64 32.74 1.796 2.199 ... 10.59 4.523\n",
             "    tau      (chain, draw) float64 11.98 9.164 11.72 6.183 ... 8.346 7.711 5.407\n",
             "    log_tau  (chain, draw) float64 2.483 2.215 2.462 1.822 ... 2.122 2.043 1.688\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 100 101 102 103 104 105 ... 494 495 496 497 498 499\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 24.5 11.84 28.08 ... 4.698 -15.07\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:           (chain: 4, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain             (chain) int64 0 1 2 3\n",
             "  * draw              (draw) int64 100 101 102 103 104 ... 495 496 497 498 499\n",
             "  * school            (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tune              (chain, draw) bool False False False ... False False False\n",
             "    depth             (chain, draw) int64 5 5 5 5 5 5 4 4 4 ... 4 4 4 5 5 5 5 5\n",
             "    tree_size         (chain, draw) float64 31.0 31.0 31.0 ... 31.0 31.0 31.0\n",
             "    lp                (chain, draw) float64 -67.62 -66.08 ... -63.62 -58.35\n",
             "    energy_error      (chain, draw) float64 0.003801 -0.0119 ... -0.003652\n",
             "    step_size_bar     (chain, draw) float64 0.2417 0.2417 ... 0.1502 0.1502\n",
             "    max_energy_error  (chain, draw) float64 -0.03831 -0.02486 ... -0.101 -0.1757\n",
             "    energy            (chain, draw) float64 72.68 74.16 73.41 ... 67.77 67.21\n",
             "    mean_tree_accept  (chain, draw) float64 0.9998 1.0 0.8716 ... 0.9875 0.9967\n",
             "    step_size         (chain, draw) float64 0.1275 0.1275 ... 0.1064 0.1064\n",
             "    diverging         (chain, draw) bool False False False ... False False False\n",
             "    log_likelihood    (chain, draw, school) float64 -3.677 -3.414 ... -3.896\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (chain: 1, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain      (chain) int64 0\n",
             "  * draw       (draw) int64 100 101 102 103 104 105 ... 494 495 496 497 498 499\n",
             "  * school     (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau        (chain, draw) float64 1.588 0.4472 1.197 ... 1.56 5.949 0.7631\n",
             "    tau_log__  (chain, draw) float64 0.4625 -0.8048 0.1801 ... 1.783 -0.2704\n",
             "    mu         (chain, draw) float64 -1.087 -8.631 -0.7139 ... -1.658 -3.273\n",
             "    theta      (chain, draw, school) float64 1.556 1.323 2.802 ... -3.775 -3.555\n",
             "    obs        (chain, draw, school) float64 18.6 12.49 7.67 ... -21.16 -6.071\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> sample_stats\n", "\t> prior\n", "\t> observed_data" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata.sel(draw=slice(100, None))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you check the `burnin` object you will see that the groups `posterior`, `posterior_predictive`, `prior` and `sample_stats` have 400 draws compared to `idata` that has 500. The group `observed_data` has not been affected because it does not have the `draw` dimension. Alternatively, you can specify which group or groups you want to change." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 100 101 102 103 104 105 ... 494 495 496 497 498 499\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu       (chain, draw) float64 4.271 4.517 0.3265 ... 4.597 5.899 0.1614\n",
             "    theta    (chain, draw, school) float64 32.74 1.796 2.199 ... 10.59 4.523\n",
             "    tau      (chain, draw) float64 11.98 9.164 11.72 6.183 ... 8.346 7.711 5.407\n",
             "    log_tau  (chain, draw) float64 2.483 2.215 2.462 1.822 ... 2.122 2.043 1.688\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 7.85 -19.03 -22.5 ... 4.698 -15.07\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:           (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain             (chain) int64 0 1 2 3\n",
             "  * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499\n",
             "  * school            (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tune              (chain, draw) bool True False False ... False False False\n",
             "    depth             (chain, draw) int64 5 3 3 4 5 5 4 4 5 ... 4 4 4 5 5 5 5 5\n",
             "    tree_size         (chain, draw) float64 31.0 7.0 7.0 15.0 ... 31.0 31.0 31.0\n",
             "    lp                (chain, draw) float64 -59.05 -56.19 ... -63.62 -58.35\n",
             "    energy_error      (chain, draw) float64 0.07387 -0.1841 ... -0.087 -0.003652\n",
             "    step_size_bar     (chain, draw) float64 0.2417 0.2417 ... 0.1502 0.1502\n",
             "    max_energy_error  (chain, draw) float64 0.131 -0.2067 ... -0.101 -0.1757\n",
             "    energy            (chain, draw) float64 60.76 62.76 64.4 ... 67.77 67.21\n",
             "    mean_tree_accept  (chain, draw) float64 0.9506 0.9906 ... 0.9875 0.9967\n",
             "    step_size         (chain, draw) float64 0.1275 0.1275 ... 0.1064 0.1064\n",
             "    diverging         (chain, draw) bool False False False ... False False False\n",
             "    log_likelihood    (chain, draw, school) float64 -5.168 -4.589 ... -3.896\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain      (chain) int64 0\n",
             "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
             "  * school     (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau        (chain, draw) float64 6.561 1.016 68.91 ... 1.56 5.949 0.7631\n",
             "    tau_log__  (chain, draw) float64 1.881 0.01593 4.233 ... 1.783 -0.2704\n",
             "    mu         (chain, draw) float64 5.293 0.8137 0.7122 ... -1.658 -3.273\n",
             "    theta      (chain, draw, school) float64 2.357 7.371 7.251 ... -3.775 -3.555\n",
             "    obs        (chain, draw, school) float64 -3.54 6.769 19.68 ... -21.16 -6.071\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> sample_stats\n", "\t> prior\n", "\t> observed_data" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata.sel(draw=slice(100, None), groups=\"posterior\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute posterior mean values along `draw` and `chain` dimensions\n", "\n", "To compute the mean value of the posterior samples, do the following:\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
       "Dimensions:  ()\n",
       "Data variables:\n",
       "    mu       float64 4.093\n",
       "    theta    float64 4.56\n",
       "    tau      float64 4.089\n",
       "    log_tau  float64 1.15
" ], "text/plain": [ "\n", "Dimensions: ()\n", "Data variables:\n", " mu float64 4.093\n", " theta float64 4.56\n", " tau float64 4.089\n", " log_tau float64 1.15" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "post.mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This computes the mean along all dimensions. This is probably what you want for `mu` and `tau`, which have two dimensions (`chain` and `draw`), but maybe not what you expected for `theta`, which has one more dimension `school`. \n", "\n", "You can specify along which dimension you want to compute the mean (or other functions)." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
       "Dimensions:  (school: 8)\n",
       "Coordinates:\n",
       "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
       "Data variables:\n",
       "    mu       float64 4.093\n",
       "    theta    (school) float64 6.026 4.724 3.576 4.478 3.064 3.821 6.25 4.544\n",
       "    tau      float64 4.089\n",
       "    log_tau  float64 1.15
" ], "text/plain": [ "\n", "Dimensions: (school: 8)\n", "Coordinates:\n", " * school (school) object 'Choate' 'Deerfield' ... \"St. Paul's\" 'Mt. Hermon'\n", "Data variables:\n", " mu float64 4.093\n", " theta (school) float64 6.026 4.724 3.576 4.478 3.064 3.821 6.25 4.544\n", " tau float64 4.089\n", " log_tau float64 1.15" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "post.mean(dim=[\"chain\", \"draw\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute and store posterior pushforward quantities" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use \"posterior pushfoward quantities\" to refer to quantities that are not variables in the posterior but deterministic computations using posterior variables. \n", "\n", "You can use xarray for these pushforward operations and store them as a new variable in the posterior group. You'll then be able to plot them with ArviZ functions, calculate stats and diagnostics on them (like the {func}`~arviz.mcse`) or save and share the inferencedata object with the pushforward quantities included. \n", "\n", "Compute the rolling mean of $\\log(\\tau)$ with {meth}`xarray.DataArray.rolling`, storing the result in the posterior" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "post[\"mlogtau\"] = post[\"log_tau\"].rolling({\"draw\": 50}).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using xarray for pusforward calculations has all the advantages of working with xarray. It also inherits the disadvantages of working with xarray, but we believe those to be outweighed by the advantages, and we have already shown how to extract the data as NumPy arrays. Working with InferenceData is working mainly with xarray objects and this is what is shown in this guide.\n", "\n", "Some examples of these advantages are specifying operations with named dimensions instead of positional ones (as seen in some previous sections), \n", "automatic alignment and broadcasting of arrays (as we'll see now),\n", "or integration with Dask (as shown in the {ref}`dask_for_arviz` guide).\n", "\n", "In this cell you will compute pairwise differences between schools on their mean effects (variable `theta`).\n", "To do so, substract the variable theta after renaming the school dimension to the original variable. \n", "Xarray then aligns and broadcasts the two variables because they have different dimensions, and\n", "the result is a 4d variable with all the pointwise differences.\n", "\n", "Eventually, store the result in the `theta_school_diff` variable:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "post[\"theta_school_diff\"] = post.theta - post.theta.rename(school=\"school_bis\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ":::{note}\n", ":class: dropdown\n", "\n", "This same operation using NumPy would require manual alignment of the two arrays to make sure they broadcast correctly. The could would be something like:\n", "\n", "```python\n", "theta_school_diff = theta[:, :, :, None] - theta[:, :, None, :]\n", "```\n", ":::\n", "\n", "The `theta_shool_diff` variable in the posterior has kept the named dimensions and coordinates:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
       "Dimensions:            (chain: 4, draw: 500, school: 8, school_bis: 8)\n",
       "Coordinates:\n",
       "  * chain              (chain) int64 0 1 2 3\n",
       "  * draw               (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
       "  * school             (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
       "  * school_bis         (school_bis) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
       "Data variables:\n",
       "    mu                 (chain, draw) float64 -3.477 -2.456 ... 5.899 0.1614\n",
       "    theta              (chain, draw, school) float64 1.669 -8.537 ... 4.523\n",
       "    tau                (chain, draw) float64 3.73 2.075 3.703 ... 7.711 5.407\n",
       "    log_tau            (chain, draw) float64 1.316 0.7301 1.309 ... 2.043 1.688\n",
       "    mlogtau            (chain, draw) float64 nan nan nan ... 0.9753 1.004 1.034\n",
       "    theta_school_diff  (chain, draw, school, school_bis) float64 0.0 ... 0.0\n",
       "Attributes: (3)
" ], "text/plain": [ "\n", "Dimensions: (chain: 4, draw: 500, school: 8, school_bis: 8)\n", "Coordinates:\n", " * chain (chain) int64 0 1 2 3\n", " * draw (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", " * school (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n", " * school_bis (school_bis) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n", "Data variables:\n", " mu (chain, draw) float64 -3.477 -2.456 ... 5.899 0.1614\n", " theta (chain, draw, school) float64 1.669 -8.537 ... 4.523\n", " tau (chain, draw) float64 3.73 2.075 3.703 ... 7.711 5.407\n", " log_tau (chain, draw) float64 1.316 0.7301 1.309 ... 2.043 1.688\n", " mlogtau (chain, draw) float64 nan nan nan ... 0.9753 1.004 1.034\n", " theta_school_diff (chain, draw, school, school_bis) float64 0.0 ... 0.0\n", "Attributes: (3)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "post" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Advanced subsetting\n", "To select the value corresponding to the difference between the Choate and Deerfield schools do:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500)>\n",
       "10.21 -7.311 5.116 2.606 -1.116 24.96 ... 3.128 -4.62 4.288 2.424 2.613 -0.1137\n",
       "Coordinates:\n",
       "  * chain       (chain) int64 0 1 2 3\n",
       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
       "    school      <U6 'Choate'\n",
       "    school_bis  <U9 'Deerfield'
" ], "text/plain": [ "\n", "10.21 -7.311 5.116 2.606 -1.116 24.96 ... 3.128 -4.62 4.288 2.424 2.613 -0.1137\n", "Coordinates:\n", " * chain (chain) int64 0 1 2 3\n", " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n", " school \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500,\n",
       "                                       pairwise_school_diff: 3)>\n",
       "10.21 -5.673 2.356 -7.311 2.817 -1.51 ... 2.613 8.154 8.915 -0.1137 2.805 5.63\n",
       "Coordinates:\n",
       "  * chain       (chain) int64 0 1 2 3\n",
       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
       "    school      (pairwise_school_diff) object 'Choate' 'Hotchkiss' 'Mt. Hermon'\n",
       "    school_bis  (pairwise_school_diff) object 'Deerfield' ... 'Lawrenceville'\n",
       "Dimensions without coordinates: pairwise_school_diff
" ], "text/plain": [ "\n", "10.21 -5.673 2.356 -7.311 2.817 -1.51 ... 2.613 8.154 8.915 -0.1137 2.805 5.63\n", "Coordinates:\n", " * chain (chain) int64 0 1 2 3\n", " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n", " school (pairwise_school_diff) object 'Choate' 'Hotchkiss' 'Mt. Hermon'\n", " school_bis (pairwise_school_diff) object 'Deerfield' ... 'Lawrenceville'\n", "Dimensions without coordinates: pairwise_school_diff" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "school_idx = xr.DataArray([\"Choate\", \"Hotchkiss\", \"Mt. Hermon\"], dims=[\"pairwise_school_diff\"])\n", "school_bis_idx = xr.DataArray(\n", " [\"Deerfield\", \"Choate\", \"Lawrenceville\"], dims=[\"pairwise_school_diff\"]\n", ")\n", "post[\"theta_school_diff\"].sel(school=school_idx, school_bis=school_bis_idx)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using lists or NumPy arrays instead of DataArrays does colum/row based indexing. As you can see, the result has 9 values of `theta_shool_diff` instead of the 3 pairs of difference we selected in the previous cell:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500, school: 3,\n",
       "                                       school_bis: 3)>\n",
       "10.21 0.0 10.84 4.533 -5.673 5.169 1.719 ... 2.691 2.805 3.861 4.46 4.574 5.63\n",
       "Coordinates:\n",
       "  * chain       (chain) int64 0 1 2 3\n",
       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
       "  * school      (school) object 'Choate' 'Hotchkiss' 'Mt. Hermon'\n",
       "  * school_bis  (school_bis) object 'Deerfield' 'Choate' 'Lawrenceville'
" ], "text/plain": [ "\n", "10.21 0.0 10.84 4.533 -5.673 5.169 1.719 ... 2.691 2.805 3.861 4.46 4.574 5.63\n", "Coordinates:\n", " * chain (chain) int64 0 1 2 3\n", " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n", " * school (school) object 'Choate' 'Hotchkiss' 'Mt. Hermon'\n", " * school_bis (school_bis) object 'Deerfield' 'Choate' 'Lawrenceville'" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "post[\"theta_school_diff\"].sel(\n", " school=[\"Choate\", \"Hotchkiss\", \"Mt. Hermon\"],\n", " school_bis=[\"Deerfield\", \"Choate\", \"Lawrenceville\"],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Add new chains using concat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After checking the {func}`~arviz.mcse` and realizing you need more samples, you rerun the model with two chains\n", "and obtain an `idata_rerun` object." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "idata_rerun = (\n", " idata.sel(chain=[0, 1])\n", " .copy()\n", " .assign_coords(coords={\"chain\": [4, 5]}, groups=\"posterior_groups\")\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can combine the two into a single InferenceData object using {func}`arviz.concat`:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata_complete = az.concat(idata, idata_rerun, dim=\"chain\")\n", "idata_complete.posterior.dims[\"chain\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Add groups to InferenceData objects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also add new groups to InferenceData objects with the {meth}`~arviz.InferenceData.extend` (if the new groups are already in an InferenceData object) or with {meth}`~arviz.InferenceData.add_groups` (if the new groups are dictionaries or `xarray.Dataset` objects)." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:            (chain: 4, draw: 500, school: 8, school_bis: 8)\n",
             "Coordinates:\n",
             "  * chain              (chain) int64 0 1 2 3\n",
             "  * draw               (draw) int64 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
             "  * school             (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "  * school_bis         (school_bis) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu                 (chain, draw) float64 -3.477 -2.456 ... 5.899 0.1614\n",
             "    theta              (chain, draw, school) float64 1.669 -8.537 ... 4.523\n",
             "    tau                (chain, draw) float64 3.73 2.075 3.703 ... 7.711 5.407\n",
             "    log_tau            (chain, draw) float64 1.316 0.7301 1.309 ... 2.043 1.688\n",
             "    mlogtau            (chain, draw) float64 nan nan nan ... 0.9753 1.004 1.034\n",
             "    theta_school_diff  (chain, draw, school, school_bis) float64 0.0 ... 0.0\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 7.85 -19.03 -22.5 ... 4.698 -15.07\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (chain: 4, draw: 500, new_school: 2)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 0 1 2 3\n",
             "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
             "  * new_school  (new_school) <U13 'Essex College' 'Moordale'\n",
             "Data variables:\n",
             "    obs         (chain, draw, new_school) float64 2.041 -2.556 ... -0.2822\n",
             "Attributes: (2)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:           (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain             (chain) int64 0 1 2 3\n",
             "  * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499\n",
             "  * school            (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tune              (chain, draw) bool True False False ... False False False\n",
             "    depth             (chain, draw) int64 5 3 3 4 5 5 4 4 5 ... 4 4 4 5 5 5 5 5\n",
             "    tree_size         (chain, draw) float64 31.0 7.0 7.0 15.0 ... 31.0 31.0 31.0\n",
             "    lp                (chain, draw) float64 -59.05 -56.19 ... -63.62 -58.35\n",
             "    energy_error      (chain, draw) float64 0.07387 -0.1841 ... -0.087 -0.003652\n",
             "    step_size_bar     (chain, draw) float64 0.2417 0.2417 ... 0.1502 0.1502\n",
             "    max_energy_error  (chain, draw) float64 0.131 -0.2067 ... -0.101 -0.1757\n",
             "    energy            (chain, draw) float64 60.76 62.76 64.4 ... 67.77 67.21\n",
             "    mean_tree_accept  (chain, draw) float64 0.9506 0.9906 ... 0.9875 0.9967\n",
             "    step_size         (chain, draw) float64 0.1275 0.1275 ... 0.1064 0.1064\n",
             "    diverging         (chain, draw) bool False False False ... False False False\n",
             "    log_likelihood    (chain, draw, school) float64 -5.168 -4.589 ... -3.896\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain      (chain) int64 0\n",
             "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
             "  * school     (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau        (chain, draw) float64 6.561 1.016 68.91 ... 1.56 5.949 0.7631\n",
             "    tau_log__  (chain, draw) float64 1.881 0.01593 4.233 ... 1.783 -0.2704\n",
             "    mu         (chain, draw) float64 5.293 0.8137 0.7122 ... -1.658 -3.273\n",
             "    theta      (chain, draw, school) float64 2.357 7.371 7.251 ... -3.775 -3.555\n",
             "    obs        (chain, draw, school) float64 -3.54 6.769 19.68 ... -21.16 -6.071\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n",
             "Attributes: (3)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> predictions\n", "\t> sample_stats\n", "\t> prior\n", "\t> observed_data" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rng = np.random.default_rng(3)\n", "idata.add_groups(\n", " {\"predictions\": {\"obs\": rng.normal(size=(4, 500, 2))}},\n", " dims={\"obs\": [\"new_school\"]},\n", " coords={\"new_school\": [\"Essex College\", \"Moordale\"]},\n", ")\n", "idata" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }