Converting emcee objects to InferenceData#

InferenceData is the central data format for ArviZ. InferenceData itself is just a container that maintains references to one or more xarray.Dataset.

Below are various ways to generate an InferenceData from emcee objects.

See also

We will start by importing the required packages and defining the model. The famous 8 school model.

import arviz as az
import numpy as np
import emcee
az.style.use("arviz-darkgrid")
J = 8
y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def log_prior_8school(theta):
    mu, tau, eta = theta[0], theta[1], theta[2:]
    # Half-cauchy prior, hwhm=25
    if tau < 0:
        return -np.inf
    prior_tau = -np.log(tau**2 + 25**2)
    prior_mu = -((mu / 10) ** 2)  # normal prior, loc=0, scale=10
    prior_eta = -np.sum(eta**2)  # normal prior, loc=0, scale=1
    return prior_mu + prior_tau + prior_eta


def log_likelihood_8school(theta, y, s):
    mu, tau, eta = theta[0], theta[1], theta[2:]
    return -(((mu + tau * eta - y) / s) ** 2)


def lnprob_8school(theta, y, s):
    prior = log_prior_8school(theta)
    like_vect = log_likelihood_8school(theta, y, s)
    like = np.sum(like_vect)
    return like + prior
nwalkers = 40  # called chains in ArviZ
ndim = J + 2
draws = 1500
pos = np.random.normal(size=(nwalkers, ndim))
pos[:, 1] = np.absolute(pos[:, 1])
sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob_8school, args=(y_obs, sigma))
sampler.run_mcmc(pos, draws);

Manually set variable names#

This first example will show how to convert manually setting the variable names only, leaving everything else to ArviZ defaults.

# define variable names, it cannot be inferred from emcee
var_names = ["mu", "tau"] + ["eta{}".format(i) for i in range(J)]
idata1 = az.from_emcee(sampler, var_names=var_names)
idata1
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 40, draw: 1500)
      Coordinates:
        * chain    (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499
      Data variables:
          mu       (chain, draw) float64 0.6982 0.7962 0.8433 ... 5.763 5.763 5.029
          tau      (chain, draw) float64 0.6679 0.7259 0.8075 ... 2.051 2.051 3.239
          eta0     (chain, draw) float64 0.08153 0.008519 0.007711 ... 0.4684 0.6057
          eta1     (chain, draw) float64 -0.5837 -0.6358 -0.828 ... 1.431 1.431 1.608
          eta2     (chain, draw) float64 0.104 -0.003427 0.08645 ... -1.056 -0.8344
          eta3     (chain, draw) float64 0.8693 1.196 1.423 ... -1.621 -1.621 -0.8859
          eta4     (chain, draw) float64 0.8211 1.27 1.324 ... -1.509 -1.509 -0.9923
          eta5     (chain, draw) float64 0.04491 0.2302 0.1735 ... -0.8137 -0.5359
          eta6     (chain, draw) float64 0.2983 0.1357 0.1385 ... -0.2085 0.0377
          eta7     (chain, draw) float64 -0.5895 -0.5165 -0.6091 ... 0.1594 -0.03057
      Attributes:
          created_at:                 2021-08-30T18:14:53.861857
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:  (chain: 40, draw: 1500)
      Coordinates:
        * chain    (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499
      Data variables:
          lp       (chain, draw) float64 -16.3 -17.83 -18.92 ... -20.11 -20.11 -16.68
      Attributes:
          created_at:                 2021-08-30T18:14:53.851570
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:      (arg_0_dim_0: 8, arg_1_dim_0: 8)
      Coordinates:
        * arg_0_dim_0  (arg_0_dim_0) int64 0 1 2 3 4 5 6 7
        * arg_1_dim_0  (arg_1_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          arg_0        (arg_0_dim_0) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
          arg_1        (arg_1_dim_0) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0
      Attributes:
          created_at:                 2021-08-30T18:14:53.853598
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

ArviZ has stored the posterior variables with the provided names as expected, but it has also included other useful information in the InferenceData object. The log probability of each sample is stored in the sample_stats group under the name lp and all the arguments passed to the sampler as args have been saved in the observed_data group.

It can also be useful to perform a burn in cut to the MCMC samples (see :meth:arviz.InferenceData.sel for more details)

idata1.sel(draw=slice(100, None))
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 40, draw: 1400)
      Coordinates:
        * chain    (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
        * draw     (draw) int64 100 101 102 103 104 105 ... 1495 1496 1497 1498 1499
      Data variables:
          mu       (chain, draw) float64 6.588 6.588 6.588 7.358 ... 5.763 5.763 5.029
          tau      (chain, draw) float64 1.122 1.122 1.122 1.328 ... 2.051 2.051 3.239
          eta0     (chain, draw) float64 -0.4995 -0.4995 -0.4995 ... 0.4684 0.6057
          eta1     (chain, draw) float64 0.2038 0.2038 0.2038 ... 1.431 1.431 1.608
          eta2     (chain, draw) float64 0.1563 0.1563 0.1563 ... -1.056 -0.8344
          eta3     (chain, draw) float64 0.04793 0.04793 0.04793 ... -1.621 -0.8859
          eta4     (chain, draw) float64 -1.467 -1.467 -1.467 ... -1.509 -0.9923
          eta5     (chain, draw) float64 0.4489 0.4489 0.4489 ... -0.8137 -0.5359
          eta6     (chain, draw) float64 0.1747 0.1747 0.1747 ... -0.2085 0.0377
          eta7     (chain, draw) float64 0.4413 0.4413 0.4413 ... 0.1594 -0.03057
      Attributes:
          created_at:                 2021-08-30T18:14:53.861857
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:  (chain: 40, draw: 1400)
      Coordinates:
        * chain    (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
        * draw     (draw) int64 100 101 102 103 104 105 ... 1495 1496 1497 1498 1499
      Data variables:
          lp       (chain, draw) float64 -14.38 -14.38 -14.38 ... -20.11 -20.11 -16.68
      Attributes:
          created_at:                 2021-08-30T18:14:53.851570
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:      (arg_0_dim_0: 8, arg_1_dim_0: 8)
      Coordinates:
        * arg_0_dim_0  (arg_0_dim_0) int64 0 1 2 3 4 5 6 7
        * arg_1_dim_0  (arg_1_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          arg_0        (arg_0_dim_0) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
          arg_1        (arg_1_dim_0) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0
      Attributes:
          created_at:                 2021-08-30T18:14:53.853598
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

From an InferenceData object, ArviZ’s native data structure, the posterior plot of a few variables can be done in one line:

az.plot_posterior(idata1, var_names=["mu", "tau", "eta4"])
array([<AxesSubplot:title={'center':'mu'}>,
       <AxesSubplot:title={'center':'tau'}>,
       <AxesSubplot:title={'center':'eta4'}>], dtype=object)
../_images/1c516223ef8c74d0643a20c0b446e7427f50abcf053342bff9298f2ac7e2de1d.png

Structuring the posterior as multidimensional variables#

This way of calling from_emcee stores each eta as a different variable, called eta#, however, they are in fact different dimensions of the same variable. This can be seen in the code of the likelihood and prior functions, where theta is unpacked as:

mu, tau, eta = theta[0], theta[1], theta[2:]

ArviZ has support for multidimensional variables, and there is a way to tell it how to split the variables like it was done in the likelihood and prior functions:

idata2 = az.from_emcee(sampler, slices=[0, 1, slice(2, None)])
idata2
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:      (chain: 40, draw: 1500, var_2_dim_0: 8)
      Coordinates:
        * chain        (chain) int64 0 1 2 3 4 5 6 7 8 ... 31 32 33 34 35 36 37 38 39
        * draw         (draw) int64 0 1 2 3 4 5 6 ... 1494 1495 1496 1497 1498 1499
        * var_2_dim_0  (var_2_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          var_0        (chain, draw) float64 0.6982 0.7962 0.8433 ... 5.763 5.029
          var_1        (chain, draw) float64 0.6679 0.7259 0.8075 ... 2.051 3.239
          var_2        (chain, draw, var_2_dim_0) float64 0.08153 -0.5837 ... -0.03057
      Attributes:
          created_at:                 2021-08-30T18:14:54.508656
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:  (chain: 40, draw: 1500)
      Coordinates:
        * chain    (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499
      Data variables:
          lp       (chain, draw) float64 -16.3 -17.83 -18.92 ... -20.11 -20.11 -16.68
      Attributes:
          created_at:                 2021-08-30T18:14:54.505081
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:      (arg_0_dim_0: 8, arg_1_dim_0: 8)
      Coordinates:
        * arg_0_dim_0  (arg_0_dim_0) int64 0 1 2 3 4 5 6 7
        * arg_1_dim_0  (arg_1_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          arg_0        (arg_0_dim_0) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
          arg_1        (arg_1_dim_0) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0
      Attributes:
          created_at:                 2021-08-30T18:14:54.506402
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

After checking the default variable names, the trace of one dimension of eta can be plotted using ArviZ syntax:

az.plot_trace(idata2, var_names=["var_2"], coords={"var_2_dim_0": 4});
../_images/11ef22582252befc37b8f2b7850a91714a01254ebd6db0f210a94e0e6d41bbf2.png

blobs: unlock sample stats, posterior predictive and miscellanea#

Emcee does not store per-draw sample stats, however, it has a functionality called blobs that allows to store any variable on a per-draw basis. It can be used to store some sample_stats or even posterior_predictive data.

You can modify the probability function to use this blobs functionality and store the pointwise log likelihood, then rerun the sampler using the new function:

def lnprob_8school_blobs(theta, y, s):
    prior = log_prior_8school(theta)
    like_vect = log_likelihood_8school(theta, y, s)
    like = np.sum(like_vect)
    return like + prior, like_vect


sampler_blobs = emcee.EnsembleSampler(
    nwalkers,
    ndim,
    lnprob_8school_blobs,
    args=(y_obs, sigma),
)
sampler_blobs.run_mcmc(pos, draws);

You can now use the blob_names argument to indicate how to store this blob-defined variable. As the group is not specified, it will go to sample_stats. Note that the argument blob_names is added to the arguments covered in the previous examples and we are also introducing coords and dims arguments to show the power and flexibility of the converter. For more on coords and dims see page_in_construction.

dims = {"eta": ["school"], "log_likelihood": ["school"]}
idata3 = az.from_emcee(
    sampler_blobs,
    var_names=["mu", "tau", "eta"],
    slices=[0, 1, slice(2, None)],
    blob_names=["log_likelihood"],
    dims=dims,
    coords={"school": range(8)},
)
idata3
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 40, draw: 1500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499
        * school   (school) int64 0 1 2 3 4 5 6 7
      Data variables:
          mu       (chain, draw) float64 0.6982 0.7962 0.8433 ... 5.763 5.763 5.029
          tau      (chain, draw) float64 0.6679 0.7259 0.8075 ... 2.051 2.051 3.239
          eta      (chain, draw, school) float64 0.08153 -0.5837 ... 0.0377 -0.03057
      Attributes:
          created_at:                 2021-08-30T18:14:57.535401
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:         (chain: 40, draw: 1500, school: 8)
      Coordinates:
        * chain           (chain) int64 0 1 2 3 4 5 6 7 8 ... 32 33 34 35 36 37 38 39
        * draw            (draw) int64 0 1 2 3 4 5 6 ... 1494 1495 1496 1497 1498 1499
        * school          (school) int64 0 1 2 3 4 5 6 7
      Data variables:
          log_likelihood  (chain, draw, school) float64 -3.3 -0.5916 ... -0.1543
      Attributes:
          created_at:                 2021-08-30T18:14:57.531408
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:  (chain: 40, draw: 1500)
      Coordinates:
        * chain    (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499
      Data variables:
          lp       (chain, draw) float64 -16.3 -17.83 -18.92 ... -20.11 -20.11 -16.68
      Attributes:
          created_at:                 2021-08-30T18:14:57.532516
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:      (arg_0_dim_0: 8, arg_1_dim_0: 8)
      Coordinates:
        * arg_0_dim_0  (arg_0_dim_0) int64 0 1 2 3 4 5 6 7
        * arg_1_dim_0  (arg_1_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          arg_0        (arg_0_dim_0) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
          arg_1        (arg_1_dim_0) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0
      Attributes:
          created_at:                 2021-08-30T18:14:57.533391
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

Multi-group blobs#

You might even have more complicated blobs, each corresponding to a different group of the InferenceData object. Moreover, you can store the variables passed to the EnsembleSampler via the args argument in observed or constant data groups. This is shown in the example below:

sampler_blobs.blobs[0, 1]
array([-3.41541659e+00, -6.35964292e-01, -4.38888593e-02, -4.32716958e-01,
       -3.27623973e-03, -1.24760573e-02, -3.05547121e+00, -4.60634226e-01])
def lnprob_8school_blobs(theta, y, sigma):
    mu, tau, eta = theta[0], theta[1], theta[2:]
    prior = log_prior_8school(theta)
    like_vect = log_likelihood_8school(theta, y, sigma)
    like = np.sum(like_vect)
    # store pointwise log likelihood, useful for model comparison with az.loo or az.waic
    # and posterior predictive samples as blobs
    return like + prior, (like_vect, np.random.normal((mu + tau * eta), sigma))


sampler_blobs = emcee.EnsembleSampler(
    nwalkers,
    ndim,
    lnprob_8school_blobs,
    args=(y_obs, sigma),
)
sampler_blobs.run_mcmc(pos, draws)

dims = {"eta": ["school"], "log_likelihood": ["school"], "y": ["school"]}
idata4 = az.from_emcee(
    sampler_blobs,
    var_names=["mu", "tau", "eta"],
    slices=[0, 1, slice(2, None)],
    arg_names=["y", "sigma"],
    arg_groups=["observed_data", "constant_data"],
    blob_names=["log_likelihood", "y"],
    blob_groups=["log_likelihood", "posterior_predictive"],
    dims=dims,
    coords={"school": range(8)},
)
idata4
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 40, draw: 1500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499
        * school   (school) int64 0 1 2 3 4 5 6 7
      Data variables:
          mu       (chain, draw) float64 0.6982 0.7962 0.8433 ... 5.763 5.763 5.029
          tau      (chain, draw) float64 0.6679 0.7259 0.8075 ... 2.051 2.051 3.239
          eta      (chain, draw, school) float64 0.08153 -0.5837 ... 0.0377 -0.03057
      Attributes:
          created_at:                 2021-08-30T18:15:01.274909
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:  (chain: 40, draw: 1500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499
        * school   (school) int64 0 1 2 3 4 5 6 7
      Data variables:
          y        (chain, draw, school) float64 2.55 -6.472 -27.52 ... -0.198 23.99
      Attributes:
          created_at:                 2021-08-30T18:15:01.270461
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:         (chain: 40, draw: 1500, school: 8)
      Coordinates:
        * chain           (chain) int64 0 1 2 3 4 5 6 7 8 ... 32 33 34 35 36 37 38 39
        * draw            (draw) int64 0 1 2 3 4 5 6 ... 1494 1495 1496 1497 1498 1499
        * school          (school) int64 0 1 2 3 4 5 6 7
      Data variables:
          log_likelihood  (chain, draw, school) float64 -3.3 -0.5916 ... -0.1543
      Attributes:
          created_at:                 2021-08-30T18:15:01.271500
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:  (chain: 40, draw: 1500)
      Coordinates:
        * chain    (chain) int64 0 1 2 3 4 5 6 7 8 9 ... 30 31 32 33 34 35 36 37 38 39
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 1493 1494 1495 1496 1497 1498 1499
      Data variables:
          lp       (chain, draw) float64 -16.3 -17.83 -18.92 ... -20.11 -20.11 -16.68
      Attributes:
          created_at:                 2021-08-30T18:15:01.272272
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) int64 0 1 2 3 4 5 6 7
      Data variables:
          y        (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
      Attributes:
          created_at:                 2021-08-30T18:15:01.273600
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

    • <xarray.Dataset>
      Dimensions:      (sigma_dim_0: 8)
      Coordinates:
        * sigma_dim_0  (sigma_dim_0) int64 0 1 2 3 4 5 6 7
      Data variables:
          sigma        (sigma_dim_0) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0
      Attributes:
          created_at:                 2021-08-30T18:15:01.273208
          arviz_version:              0.11.2
          inference_library:          emcee
          inference_library_version:  3.1.1

This last version, which contains both observed data and posterior predictive can be used to plot posterior predictive checks:

az.plot_ppc(idata4, var_names=["y"], alpha=0.3, num_pp_samples=200);
../_images/1b52307147844d506c5d24bbde6cea12afdbf5f27e9798274d30357818fc7242.png
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Mon Aug 30 2021

Python implementation: CPython
Python version       : 3.8.6
IPython version      : 7.27.0

emcee: 3.1.1
arviz: 0.11.2
numpy: None

Watermark: 2.1.0