brycewang-stanford

bayesian-workflow

Opinionated Bayesian modeling workflow with PyMC and ArviZ. Contains critical guardrails (nutpie sampler, prior/posterior predictive checks, LOO-PIT calibration, prior sensitivity checks, 94% HDI, non-centered parameterizations, reproducible seeds) that agents won't apply unprompted — always consult before writing Bayesian model code. Trigger on: building probabilistic/Bayesian models, prior elicitation, MCMC inference, convergence diagnostics (divergences, R-hat, ESS), model comparison (LOO-CV, ELPD, stacking weights), hierarchical/multilevel models, count regressions, logistic regression with uncertainty, prior sensitivity analysis, reporting Bayesian results, or mentions of PyMC, ArviZ, InferenceData, credible intervals, posterior distributions, shrinkage, uncertainty quantification. Also trigger for model comparison, diagnosing sampling problems, choosing priors, or presenting stats to non-technical audiences.

brycewang-stanford 2,621 355 Updated 4w ago

Resources

5
GitHub

Install

npx skillscat add brycewang-stanford/auto-empirical-research-skills/bayesian-workflow

Install via the SkillsCat registry.

SKILL.md

Bayesian Workflow

Workflow overview

Every Bayesian analysis follows this sequence. Do not skip steps -- especially model criticism.

  1. Formulate — Define the generative story. What underlying process, that we're precisely trying to model, created the data?
  2. Specify priors — See references/priors.md
  3. Implement in PyMC — Write the model. Prefer PyMC 5+ syntax. Use the latest version possible.
  4. Run prior predictive checkspm.sample_prior_predictive(). Verify priors produce plausible data ranges before fitting
  5. Inferencepm.sample(nuts_sampler="nutpie"). Always use nutpie for speed (the nutpie python package provides cutting-edge sampling). Don't hardcode the number of chains — let the sampler pick the best default for the platform.
  6. Diagnose convergence — Use arviz_stats.diagnose(idata) as the first check (requires arviz-stats >= 1.0.0). It covers R-hat, ESS, divergences, tree depth, and E-BFMI in one call. See references/diagnostics.md
  7. Criticize the model — See references/model-criticism.md
  8. Check prior sensitivity — Run psense_summary(idata) to verify conclusions are robust to prior choices. Visualize with plot_psense_dist(idata) from arviz_plots. Requires log_likelihood and log_prior in the InferenceData — compute them after sampling if needed. See references/sensitivity.md
  9. Compare models (if applicable) — See references/model-comparison.md
  10. Report results — See references/reporting.md. When the user asks for a report or mentions a non-technical audience, generate a standalone markdown report file (not just code comments) using the template in reporting.md. Adapt the language to the audience — if they're new to Bayesian stats, include a glossary and plain-language explanations of key concepts.

Installation

Prefer conda-forge / mamba-forge to install PyMC and its dependencies — pip can cause issues with
compiled backends (nutpie, JAX). Example:

mamba install -c conda-forge pymc nutpie arviz arviz-stats preliz

PyMC model template

import pymc as pm
import arviz as az
import numpy as np

RANDOM_SEED = sum(map(ord, "churn-logistic-v1"))
rng = np.random.default_rng(RANDOM_SEED)

# always use dimensions and coordinates in PyMC models
with pm.Model(coords=coords) as model:
    # use Data containers when working on a PyMC model
    data = pm.Data("data", df["y"].to_numpy(), dims="obs")

    # --- Priors ---
    # Always document WHY each prior was chosen
    mu = pm.Normal("mu", mu=0, sigma=10)  # Weakly informative: allows wide range

    # --- Data model ---
    pm.Normal("obs", mu=mu, sigma=1, observed=data, dims="obs")

    # --- Prior predictive check ---
    prior_pred = pm.sample_prior_predictive(random_seed=rng)

    # --- Inference ---
    idata = pm.sample(nuts_sampler="nutpie", random_seed=rng)
    idata.extend(prior_pred)

    # --- Posterior predictive check ---
    idata.extend(pm.sample_posterior_predictive(idata, random_seed=rng))

    # --- Compute log-likelihood and log-prior for sensitivity checks & LOO ---
    pm.compute_log_likelihood(idata, model=model)
    pm.compute_log_prior(idata, model=model)

    # --- Save immediately after sampling ---
    # Late crashes can destroy valid results. Save to disk before any post-processing.
    idata.to_netcdf("model_output.nc")

Critical rules

  • Always run prior predictive checks before sampling. If prior predictions span implausible ranges, fix priors first. If you have issues or doubts for some parameters, use the PreliZ package to elicit priors from the user.
  • Always check convergence before interpreting results. R-hat > 1.01 or ESS < 100 * nbr_chains means the results are unreliable.
  • Always run posterior predictive checks. A model that fits well numerically but cannot reproduce the data is useless.
  • Always run calibration checks (PIT / coverage). Use ArviZ's plot_ppc_pit for this — it handles all data types (continuous, binary, count) correctly. See references/model-criticism.md.
  • Document every prior choice with a brief justification in a code comment.
  • Never report point estimates alone. Always include credible intervals (default: 94% HDI).
  • Use arviz_stats.diagnose(idata) as the first diagnostic on every model (arviz-stats >= 1.0.0). It checks R-hat, ESS, divergences, tree depth saturation, and E-BFMI in one call. Follow up with az.plot_trace(idata, kind="rank_vlines") for visual inspection.
  • Don't hardcode number of chains. Let PyMC / nutpie choose the optimal default for the user's platform. Just call pm.sample() without specifying chains.
  • Use reproducible, descriptive seeds. Never use magic numbers like 42. Instead, derive a seed from the analysis name: RANDOM_SEED = sum(map(ord, "my-analysis-name")). Pass it to pm.sample(random_seed=rng), pm.sample_prior_predictive(random_seed=rng), and numpy via rng = np.random.default_rng(RANDOM_SEED).
  • Save InferenceData immediately after sampling with idata.to_netcdf("model_output.nc"). Late crashes or kernel restarts can destroy valid MCMC results — save before any post-processing.
  • Use ArviZ for all plots and calibration. Don't write custom plotting code when ArviZ already handles it — including for binary data, count data, and calibration. ArviZ developers have thought through edge cases so you don't have to.
  • Prefer xarray over numpy for InferenceData operations. InferenceData and DataTree objects are backed by xarray — use xarray's labeled indexing (.sel(), .mean(dim=...), etc.) instead of converting to numpy arrays. This preserves dimension labels, avoids shape bugs, and makes code more readable. Fall back to numpy only when xarray can't do what you need.
  • Always generate analysis notes alongside code. When producing a model script, also produce a companion markdown file (analysis_notes.md or similar) that interprets the results — what the diagnostics mean, what the posteriors tell us, what the calibration plots show. Code without interpretation is incomplete.
  • Always use the posterior mean (not median) for predictive probabilities. The proper Bayesian predictive distribution averages over the posterior: P(Y=k|x) = (1/S) Σ P(Y=k|x,θₛ). This is the mean, not the median. The median does not correspond to the posterior predictive distribution, can violate probability coherence (probabilities may not sum to 1), and biases calibration due to Jensen's inequality. In code: use np.mean(probs, axis=sample_axis), never np.median(...).
  • Use pm.set_data() + pm.sample_posterior_predictive() for out-of-sample predictions. Don't manually extract posterior samples and recompute predictions — let PyMC propagate uncertainty properly. Define predictors as pm.Data(...) during model building, then swap in new data:
# After fitting the model:
with model:
    pm.set_data({"X": X_new, "group_idx": group_idx_new})
    oos_preds = pm.sample_posterior_predictive(idata, predictions=True, random_seed=rng)
  • Check model identifiability before interpreting components. If two model components always appear together in the likelihood (e.g., a league intercept and a home advantage term when every observation is from home perspective), their individual posteriors reflect prior assumptions, not data signal — only their sum is identified. Use az.plot_pair() to check for strong posterior correlations between components. If correlation is near ±1, the components are not separately identifiable — either merge them or restructure the data.

Common model families

Problem Data model Typical priors Reference
Continuous outcome Normal / StudentT Normal, Gamma avoiding 0 for positive-constrained parameters references/priors.md
Binary outcome Bernoulli or Binomial if aggregated, with logit inverse-link Normal(0, 1.5) on coeffs references/priors.md
Count data Poisson / NegBinomial Gamma on rate, avoiding 0 references/priors.md
Count data with excess zeros ZeroInflatedPoisson / ZeroInflatedNegBinomial Gamma on rate; Beta or Normal+logit on zero-inflation prob references/priors.md
Positive count data (no zeros) Hurdle Poisson / Hurdle NegBinomial Separate zero-gate (Bernoulli) and count (Truncated) components references/priors.md
Ordinal outcome OrderedLogistic (cumulative link) Normal on coeffs; Normal with ordered transform on cutpoints references/priors.md
Censored data (survival, limits of detection) pm.Censored(dist, lower, upper) Same as uncensored, applied to underlying distribution references/priors.md
Truncated data pm.Truncated(dist, lower, upper) Same as underlying distribution references/priors.md
High-dimensional / sparse regression Normal / StudentT with sparsity prior on coefficients Regularized Horseshoe or R2-D2 on coeffs references/priors.md
Hierarchical / multilevel Varies See partial pooling pattern references/hierarchical.md
Time series state space models / Gaussian Processes Problem-specific references/priors.md

Utility scripts

Run diagnose_model.py after sampling to get a structured convergence + diagnostics report:

python scripts/diagnose_model.py --idata path/to/inference_data.nc

Run calibration_check.py to generate calibration plots:

python scripts/calibration_check.py --idata path/to/inference_data.nc

See scripts/ for all available utilities.

Common gotchas

These are battle-tested lessons that save hours of debugging:

  • nutpie silently ignores idata_kwargs for log_likelihood and log_prior. Always compute them explicitly after sampling: pm.compute_log_likelihood(idata, model=model) (needed for LOO-CV) and pm.compute_log_prior(idata, model=model) (needed for prior sensitivity checks). Don't assume they're stored automatically.
  • az.plot_khat() requires the LOO object, not InferenceData. Pass the output of az.loo(idata, pointwise=True) to it.
  • Flat priors on scale parameters (HalfCauchy, HalfFlat) cause funnels in hierarchical models. Use Gamma(2, ...) or Exponential — these avoid the near-zero region that creates sampling problems. If there's no group-level variation to detect, you don't need the hierarchy.
  • Python conditionals in models (if x > 0) don't work inside PyMC. Use pm.math.switch or pytensor.tensor.where instead.
  • Forgetting to standardize predictors makes shared priors inappropriate and slows sampling. Always standardize before fitting, then back-transform for interpretation.
  • Horseshoe priors create a double-funnel geometry that standard NUTS can struggle with. Always use the regularized (Finnish) horseshoe (Piironen & Vehtari, 2017), which adds a slab component that smooths the geometry. Set target_accept=0.95 or higher. If you see divergences with a horseshoe model, this is almost certainly the cause.
  • np.median on posterior predictive probabilities is a silent bug. It does not produce the Bayesian predictive distribution and can yield probabilities that don't sum to 1 across categories. Always use np.mean over the posterior samples dimension.

When things go wrong

Symptom Likely cause Fix
Divergences Posterior geometry issue Reparameterize (non-centered), increase target_accept to 0.95-0.99
Low ESS High autocorrelation More tuning steps, reparameterize, reduce correlations
R-hat > 1.01 Chains haven't mixed More draws, better initialization, check for multimodality
Prior pred. looks wrong Bad priors Tighten or shift priors, use domain knowledge
Post. pred. misses data Model misspecification Add complexity (varying slopes, different data model, interaction terms)
log_likelihood missing nutpie doesn't auto-store it Call pm.compute_log_likelihood(idata, model=model) after sampling
Slow model Large Deterministics or recompilation Profile with model.profile(model.logp()), avoid large Deterministic arrays
Slow to initialize / poor warmup Bad starting point Try init="adapt_diag_grad" in pm.sample(), or run pmx.fit(method="pathfinder") first (import pymc_extras as pmx) and pass its estimates as initvals
Prior sensitivity flag Prior-data conflict or strong prior Check psense_summary(idata) — see references/sensitivity.md. Justify or revise the flagged prior