ArviZ Quickstart#

import arviz as az
import numpy as np

ArviZ style sheets#

# ArviZ ships with style sheets!
az.style.use("arviz-darkgrid")

Feel free to check the examples of style sheets here.

Get started with plotting#

ArviZ is designed to be used with libraries like PyStan and PyMC3, but works fine with raw NumPy arrays.

Plotting a dictionary of arrays, ArviZ will interpret each key as the name of a different random variable. Each row of an array is treated as an independent series of draws from the variable, called a chain. Below, we have 10 chains of 50 draws, each for four different distributions.

size = (10, 50)
az.plot_forest(
    {
        "normal": np.random.randn(*size),
        "gumbel": np.random.gumbel(size=size),
        "student t": np.random.standard_t(df=6, size=size),
        "exponential": np.random.exponential(size=size),
    }
);
../_images/5074d11c62d5a654903fa326220691349a71fab90d8982513f370a0b1d66d879.png

ArviZ rcParams#

You may have noticed that for both plot_posterior() and plot_forest(), the Highest Density Interval (HDI) is 94%, which you may find weird at first. This particular value is a friendly reminder of the arbitrary nature of choosing any single value without further justification, including common values like 95%, 50% and even our own default, 94%. ArviZ includes default values for a few parameters, you can access them with az.rcParams. To change the default HDI value to let’s say 90% you can do:

az.rcParams['stats.hdi_prob'] = 0.90
import pymc3 as pm

J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
schools = np.array(
    [
        "Choate",
        "Deerfield",
        "Phillips Andover",
        "Phillips Exeter",
        "Hotchkiss",
        "Lawrenceville",
        "St. Paul's",
        "Mt. Hermon",
    ]
)
with pm.Model() as centered_eight:
    mu = pm.Normal("mu", mu=0, sd=5)
    tau = pm.HalfCauchy("tau", beta=5)
    theta = pm.Normal("theta", mu=mu, sd=tau, shape=J)
    obs = pm.Normal("obs", mu=theta, sd=sigma, observed=y)

    # This pattern is useful in PyMC3
    prior = pm.sample_prior_predictive()
    centered_eight_trace = pm.sample(return_inferencedata=False)
    posterior_predictive = pm.sample_posterior_predictive(centered_eight_trace)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [theta, tau, mu]
100.00% [8000/8000 00:03<00:00 Sampling 4 chains, 192 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 14 seconds.
There were 145 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.44992731488390336, but should be close to 0.8. Try to increase the number of tuning steps.
There were 27 divergences after tuning. Increase `target_accept` or reparameterize.
There were 9 divergences after tuning. Increase `target_accept` or reparameterize.
There were 11 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.7074472966124431, but should be close to 0.8. Try to increase the number of tuning steps.
The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.
The estimated number of effective samples is smaller than 200 for some parameters.
100.00% [4000/4000 00:01<00:00]

Most ArviZ functions work fine with trace objects from PyMC3 while it is recommended to covert to inferenceData before plotting:

az.plot_autocorr(centered_eight_trace, var_names=["mu", "tau"]);
Got error No model on context stack. trying to find log_likelihood in translation.
/Users/yilinxia/opt/miniconda3/envs/arviz/lib/python3.8/site-packages/arviz/data/io_pymc3_3x.py:98: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  warnings.warn(
../_images/e93be7ba19e0042dbf616fc3cf7cf4e49af290c3e1c00c21b7e58c5e2a801b46.png

Convert to InferenceData#

For much more powerful querying, analysis and plotting, we can use built-in ArviZ utilities to convert PyMC3 objects to xarray datasets. Note we are also giving some information about labelling.

ArviZ is built to work with InferenceData. The more groups it has access to, the more powerful analyses it can perform. You can check the InferenceData structure specification here. Given below is a plot of the trace, which is common in PyMC3 workflows. Don’t forget to note the intelligent labels.

data = az.from_pymc3(
    trace=centered_eight_trace,
    prior=prior,
    posterior_predictive=posterior_predictive,
    model=centered_eight,
    coords={"school": schools},
    dims={"theta": ["school"], "obs": ["school"]},
)
data
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 4.333 4.754 -0.6848 ... 1.239 3.973 -0.9172
          theta    (chain, draw, school) float64 6.962 4.182 4.186 ... -9.074 1.109
          tau      (chain, draw) float64 4.318 6.168 5.434 2.579 ... 2.299 6.237 5.914
      Attributes:
          created_at:                 2022-04-15T05:26:48.456950
          arviz_version:              0.12.0
          inference_library:          pymc3
          inference_library_version:  3.11.5
          sampling_time:              13.786029815673828
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (chain, draw, school) float64 3.892 -3.475 -4.494 ... -25.62 -7.858
      Attributes:
          created_at:                 2022-04-15T05:26:48.654469
          arviz_version:              0.12.0
          inference_library:          pymc3
          inference_library_version:  3.11.5

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (chain, draw, school) float64 -4.611 -3.294 ... -6.887 -3.992
      Attributes:
          created_at:                 2022-04-15T05:26:48.652469
          arviz_version:              0.12.0
          inference_library:          pymc3
          inference_library_version:  3.11.5

    • <xarray.Dataset>
      Dimensions:             (chain: 4, draw: 1000)
      Coordinates:
        * chain               (chain) int64 0 1 2 3
        * draw                (draw) int64 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables: (12/13)
          diverging           (chain, draw) bool False False False ... False False
          tree_depth          (chain, draw) int64 4 4 4 4 2 4 4 4 ... 3 2 2 3 4 3 3 5
          acceptance_rate     (chain, draw) float64 0.8917 0.6986 ... 0.3739 0.9883
          energy              (chain, draw) float64 64.75 62.72 66.46 ... 62.89 67.11
          n_steps             (chain, draw) float64 15.0 15.0 15.0 ... 7.0 7.0 31.0
          step_size_bar       (chain, draw) float64 0.3722 0.3722 ... 0.3021 0.3021
          ...                  ...
          max_energy_error    (chain, draw) float64 0.224 51.41 ... 1.013 -0.1603
          lp                  (chain, draw) float64 -55.42 -59.67 ... -59.9 -61.02
          process_time_diff   (chain, draw) float64 0.001564 0.001535 ... 0.002764
          perf_counter_diff   (chain, draw) float64 0.001563 0.001534 ... 0.002763
          perf_counter_start  (chain, draw) float64 11.36 11.36 11.36 ... 13.63 13.64
          step_size           (chain, draw) float64 0.4712 0.4712 ... 0.4006 0.4006
      Attributes:
          created_at:                 2022-04-15T05:26:48.462552
          arviz_version:              0.12.0
          inference_library:          pymc3
          inference_library_version:  3.11.5
          sampling_time:              13.786029815673828
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:    (chain: 1, draw: 500, school: 8)
      Coordinates:
        * chain      (chain) int64 0
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * school     (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          tau_log__  (chain, draw) float64 0.7613 2.303 1.944 ... 0.7547 1.664 3.069
          theta      (chain, draw, school) float64 -5.797 -7.748 ... 26.96 -38.51
          tau        (chain, draw) float64 2.141 10.0 6.988 ... 2.127 5.279 21.52
          mu         (chain, draw) float64 -5.671 -2.992 -5.939 ... 11.19 1.339 1.965
      Attributes:
          created_at:                 2022-04-15T05:26:48.656780
          arviz_version:              0.12.0
          inference_library:          pymc3
          inference_library_version:  3.11.5

    • <xarray.Dataset>
      Dimensions:  (chain: 1, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (chain, draw, school) float64 -12.02 -21.44 -9.397 ... 39.53 -42.67
      Attributes:
          created_at:                 2022-04-15T05:26:48.658653
          arviz_version:              0.12.0
          inference_library:          pymc3
          inference_library_version:  3.11.5

    • <xarray.Dataset>
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
      Attributes:
          created_at:                 2022-04-15T05:26:48.659375
          arviz_version:              0.12.0
          inference_library:          pymc3
          inference_library_version:  3.11.5

az.plot_trace(data,compact=False);
../_images/71b1a9f232a40be1189aa6c1ff3d8a45b7e23fb2835b7734cc92e31abf0b9fb3.png

Plotting with PyStan objects#

ArviZ is built with first class support for PyStan objects, and can plot raw fit objects in a reasonable manner. Here is the same centered eight schools model:

import nest_asyncio
nest_asyncio.apply()
import stan  # pystan version 3.4.0


schools_code = """
data {
  int<lower=0> J;
  array[J] real y;
  array[J] real<lower=0> sigma;
}

parameters {
  real mu;
  real<lower=0> tau;
  array[J] real theta;
}

model {
  mu ~ normal(0, 5);
  tau ~ cauchy(0, 5);
  theta ~ normal(mu, tau);
  y ~ normal(theta, sigma);
}
generated quantities {
    vector[J] log_lik;
    vector[J] y_hat;
    for (j in 1:J) {
        log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
        y_hat[j] = normal_rng(theta[j], sigma[j]);
    }
}
"""

schools_dat = {
    "J": 8,
    "y": [28, 8, -3, 7, -1, 1, 18, 12],
    "sigma": [15, 10, 16, 11, 9, 11, 10, 18],
}

posterior = stan.build(schools_code, data=schools_dat, random_seed=1)
fit = posterior.sample(num_chains=4, num_samples=1000)
Building: found in cache, done.
Sampling:   0%
Sampling:  25% (2000/8000)
Sampling:  50% (4000/8000)
Sampling:  75% (6000/8000)
Sampling: 100% (8000/8000)
Sampling: 100% (8000/8000), done.
Messages received during sampling:
  Gradient evaluation took 2.7e-05 seconds
  1000 transitions using 10 leapfrog steps per transition would take 0.27 seconds.
  Adjust your expectations accordingly!
  Gradient evaluation took 3.3e-05 seconds
  1000 transitions using 10 leapfrog steps per transition would take 0.33 seconds.
  Adjust your expectations accordingly!
  Gradient evaluation took 3.8e-05 seconds
  1000 transitions using 10 leapfrog steps per transition would take 0.38 seconds.
  Adjust your expectations accordingly!
  Gradient evaluation took 3.8e-05 seconds
  1000 transitions using 10 leapfrog steps per transition would take 0.38 seconds.
  Adjust your expectations accordingly!
az.plot_density(fit, var_names=["mu", "tau"]);
../_images/be46ad34c0c620fea51cc0fa0c0eb0e41ec9d8e27edb7e113b70e5682b13dc11.png

Again, converting to InferenceData (a netCDF datastore that loads data into xarray datasets), we can get much richer labelling and mixing of data. Here is a plot showing where the Hamiltonian sampler had divergences:

data = az.from_pystan(
    posterior=fit,
    posterior_predictive="y_hat",
    observed_data=["y"],
    log_likelihood={"y": "log_lik"},
    coords={"school": schools},
    dims={
        "theta": ["school"],
        "y": ["school"],
        "log_lik": ["school"],
        "y_hat": ["school"],
        "theta_tilde": ["school"],
    },
)
data
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 0.599 0.2557 0.7311 ... 2.197 2.982 1.079
          tau      (chain, draw) float64 2.982 1.938 1.214 1.148 ... 3.032 4.728 2.477
          theta    (chain, draw, school) float64 -1.166 -0.1457 ... 0.7122 1.992
      Attributes:
          created_at:                 2022-04-15T05:26:56.712787
          arviz_version:              0.12.0
          inference_library:          stan
          inference_library_version:  3.4.0
          num_chains:                 4
          num_samples:                1000
          num_thin:                   1
          num_warmup:                 1000
          save_warmup:                0

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y_hat    (chain, draw, school) float64 -17.26 12.4 7.823 ... -14.9 -27.98
      Attributes:
          created_at:                 2022-04-15T05:26:56.789583
          arviz_version:              0.12.0
          inference_library:          stan
          inference_library_version:  3.4.0

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (chain, draw, school) float64 -5.517 -3.553 -3.71 ... -4.716 -3.964
      Attributes:
          created_at:                 2022-04-15T05:26:56.767405
          arviz_version:              0.12.0
          inference_library:          stan
          inference_library_version:  3.4.0

    • <xarray.Dataset>
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 0 1 2 3
        * draw             (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (chain, draw) float64 0.9271 0.7419 1.0 ... 1.0 0.9957
          step_size        (chain, draw) float64 0.1575 0.1575 ... 0.1294 0.1294
          tree_depth       (chain, draw) int64 3 3 3 3 3 3 3 2 3 ... 5 5 6 5 5 5 5 4 4
          n_steps          (chain, draw) int64 7 15 7 7 7 7 7 ... 63 31 31 31 31 15 15
          diverging        (chain, draw) bool False True False ... False False False
          energy           (chain, draw) float64 15.08 14.79 11.82 ... 22.14 20.47
          lp               (chain, draw) float64 -13.65 -10.93 -7.736 ... -17.2 -17.62
      Attributes:
          created_at:                 2022-04-15T05:26:56.743010
          arviz_version:              0.12.0
          inference_library:          stan
          inference_library_version:  3.4.0
          num_chains:                 4
          num_samples:                1000
          num_thin:                   1
          num_warmup:                 1000
          save_warmup:                0

az.plot_pair(
    data,
    coords={"school": ["Choate", "Deerfield", "Phillips Andover"]},
    divergences=True,
);
../_images/ef2b40f0ae79e205c5479e20166dc22b302ce6d3c5556066578437fd20930b35.png

See also

working_with_InferenceData