Refitting PyStan (2.x) models with ArviZ#

ArviZ is backend agnostic and therefore does not sample directly. In order to take advantage of algorithms that require refitting models several times, ArviZ uses SamplingWrapper to convert the API of the sampling backend to a common set of functions. Hence, functions like Leave Future Out Cross Validation can be used in ArviZ independently of the sampling backend used.

Below there is an example of SamplingWrapper usage for PyStan extending arviz.PyStan2SamplingWrapper that already implements some default methods targeted to PyStan.

Before starting, it is important to note that PyStan cannot call the C++ functions it uses. Therefore, the code of the model must be slightly modified in order to be compatible with the cross validation refitting functions.

import arviz as az
import pystan
import numpy as np
import matplotlib.pyplot as plt

For the example, we will use a linear regression model.

np.random.seed(26)

xdata = np.linspace(0, 50, 100)
b0, b1, sigma = -2, 1, 3
ydata = np.random.normal(loc=b1 * xdata + b0, scale=sigma)
plt.plot(xdata, ydata)
[<matplotlib.lines.Line2D at 0x1258dee10>]
../_images/d69721bc5a4bc07a5b656206d762a4e7f7468708f68cfcaff6b379209054fcca.png

Now we will write the Stan code, keeping in mind that it must be able to compute the pointwise log likelihood on excluded data, i.e., data that is not used to fit the model. Thus, the backbone of the code must look like the following:

data {
    data_for_fitting
    excluded_data
    ...
}
model {
    // fit against data_for_fitting
   ...
}
generated quantities {
    ....
    log_lik for data_for_fitting
    log_lik_excluded for excluded_data
}
refit_lr_code = """
data {
  // Define data for fitting
  int<lower=0> N;
  vector[N] x;
  vector[N] y;
  // Define excluded data. It will not be used when fitting.
  int<lower=0> N_ex;
  vector[N_ex] x_ex;
  vector[N_ex] y_ex;
}

parameters {
  real b0;
  real b1;
  real<lower=0> sigma_e;
}

model {
  b0 ~ normal(0, 10);
  b1 ~ normal(0, 10);
  sigma_e ~ normal(0, 10);
  for (i in 1:N) {
    y[i] ~ normal(b0 + b1 * x[i], sigma_e);  // use only data for fitting
  }
  
}

generated quantities {
    vector[N] log_lik;
    vector[N_ex] log_lik_ex;
    vector[N] y_hat;
    
    for (i in 1:N) {
        // calculate log likelihood and posterior predictive, there are 
        // no restrictions on adding more generated quantities
        log_lik[i] = normal_lpdf(y[i] | b0 + b1 * x[i], sigma_e);
        y_hat[i] = normal_rng(b0 + b1 * x[i], sigma_e);
    }
    for (j in 1:N_ex) {
        // calculate the log likelihood of the excluded data given data_for_fitting
        log_lik_ex[j] = normal_lpdf(y_ex[j] | b0 + b1 * x_ex[j], sigma_e);
    }
}
"""
sm = pystan.StanModel(model_code=refit_lr_code)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_4275bea8cf61cb4b45f01fa01c73d194 NOW.
data_dict = {
    "N": len(ydata),
    "y": ydata,
    "x": xdata,
    # No excluded data in initial fit
    "N_ex": 0,
    "x_ex": [],
    "y_ex": [],
}
sample_kwargs = {"iter": 1000, "chains": 4}
fit = sm.sampling(data=data_dict, **sample_kwargs)

We have defined a dictionary sample_kwargs that will be passed to the SamplingWrapper in order to make sure that all refits use the same sampler parameters. We will follow the same pattern with az.from_pystan.

dims = {"y": ["time"], "x": ["time"], "log_likelihood": ["time"], "y_hat": ["time"]}
idata_kwargs = {
    "posterior_predictive": ["y_hat"],
    "observed_data": "y",
    "constant_data": "x",
    "log_likelihood": ["log_lik", "log_lik_ex"],
    "dims": dims,
}
idata = az.from_pystan(posterior=fit, **idata_kwargs)

We will create a subclass of PyStan2SamplingWrapper. Therefore, instead of having to implement all functions required by reloo() we only have to implement sel_observations(). As explained in its docs, it takes one argument which is the indices of the data to be excluded and returns modified_observed_data which is passed as data to sampling function of PyStan model and excluded_observed_data which is used to retrieve the log likelihood of the excluded data (as passing the excluded data would make no sense).

class LinearRegressionWrapper(az.PyStan2SamplingWrapper):
    def sel_observations(self, idx):
        xdata = self.idata_orig.constant_data.x.values
        ydata = self.idata_orig.observed_data.y.values
        mask = np.full_like(xdata, True, dtype=bool)
        mask[idx] = False
        N_obs = len(mask)
        N_ex = np.sum(~mask)
        observations = {
            "N": N_obs - N_ex,
            "x": xdata[mask],
            "y": ydata[mask],
            "N_ex": N_ex,
            "x_ex": xdata[~mask],
            "y_ex": ydata[~mask],
        }
        return observations, "log_lik_ex"
loo_orig = az.loo(idata, pointwise=True)
loo_orig
Computed from 2000 by 100 log-likelihood matrix

         Estimate       SE
elpd_loo  -250.66     7.17
p_loo        2.85        -
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.5]   (good)      100  100.0%
 (0.5, 0.7]   (ok)          0    0.0%
   (0.7, 1]   (bad)         0    0.0%
   (1, Inf)   (very bad)    0    0.0%


The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if
you rely on a specific value.
A higher log-score (or a lower deviance) indicates a model with better predictive
accuracy.

In this case, the Leave-One-Out Cross Validation (LOO-CV) approximation using Pareto Smoothed Importance Sampling (PSIS) works for all observations, so we will use modify loo_orig in order to make reloo() believe that PSIS failed for some observations. This will also serve as a validation of our wrapper, as the PSIS LOO-CV already returned the correct value.

loo_orig.pareto_k[[13, 42, 56, 73]] = np.array([0.8, 1.2, 2.6, 0.9])

We initialize our sampling wrapper

pystan_wrapper = LinearRegressionWrapper(
    sm, idata_orig=idata, sample_kwargs=sample_kwargs, idata_kwargs=idata_kwargs
)

And eventually, we can use this wrapper to call az.reloo, and compare the results with the PSIS LOO-CV results.

loo_relooed = az.reloo(pystan_wrapper, loo_orig=loo_orig)
/Users/percy/anaconda3/envs/arviz/lib/python3.6/site-packages/arviz/stats/stats_refitting.py:98: UserWarning: reloo is an experimental and untested feature
  warnings.warn("reloo is an experimental and untested feature", UserWarning)
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 13
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 13
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 42
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 42
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 56
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 56
arviz.stats.stats_refitting - INFO - Refitting model excluding observation 73
INFO:arviz.stats.stats_refitting:Refitting model excluding observation 73
loo_relooed
Computed from 2000 by 100 log-likelihood matrix

         Estimate       SE
elpd_loo  -250.67     7.17
p_loo        2.86        -
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.5]   (good)      100  100.0%
 (0.5, 0.7]   (ok)          0    0.0%
   (0.7, 1]   (bad)         0    0.0%
   (1, Inf)   (very bad)    0    0.0%


The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if
you rely on a specific value.
A higher log-score (or a lower deviance) indicates a model with better predictive
accuracy.
loo_orig
Computed from 2000 by 100 log-likelihood matrix

         Estimate       SE
elpd_loo  -250.66     7.17
p_loo        2.85        -
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.5]   (good)       96   96.0%
 (0.5, 0.7]   (ok)          0    0.0%
   (0.7, 1]   (bad)         2    2.0%
   (1, Inf)   (very bad)    2    2.0%


The scale is now log by default. Use 'scale' argument or 'stats.ic_scale' rcParam if
you rely on a specific value.
A higher log-score (or a lower deviance) indicates a model with better predictive
accuracy.