"""
Simulation-Based Calibration (SBC) diagnostics.
Provides three complementary posterior-quality checks:
* **SBC rank plot** — verifies that posterior samples are statistically
consistent with the prior by checking rank uniformity.
* **Expected coverage** — plots the empirical coverage as a CDF; a
well-calibrated posterior lies on the diagonal.
* **TARP** (Test of Accuracy with Random Points) — a global calibration
test that does not require marginalisation over individual parameters.
Background and interpretation guidance is available in the
`sbi diagnostics documentation <https://sbi.readthedocs.io/en/stable/how_to_guide/diagnostics.html>`_.
Typical usage::
sbc = SBCDiagnostic(simulator, inference_handler, Path("plots/"))
sbc.create_prior_samples(num_prior_samples=200)
sbc.rank_plot(num_posterior_samples=1000)
sbc.expected_coverage(num_posterior_samples=1000)
sbc.tarp(num_posterior_samples=1000)
"""
from pathlib import Path
import numpy as np
import torch
from matplotlib import pyplot as plt
from sbi.analysis.plot import plot_tarp, sbc_rank_plot
from sbi.diagnostics import check_tarp, run_sbc, run_tarp
from tqdm.auto import tqdm
from mach3sbitools.inference import InferenceHandler
from mach3sbitools.simulator import Simulator
from mach3sbitools.utils import TorchDeviceHandler, get_logger
logger = get_logger()
[docs]
class SBCDiagnostic:
"""
Posterior calibration diagnostics via Simulation-Based Calibration.
Wraps the ``sbi`` SBC, expected-coverage, and TARP diagnostics behind a
common interface. All three methods share a single pool of prior predictive
samples generated by :meth:`create_prior_samples`, which must be called
before any plot method.
The general workflow is:
1. Construct the object — this calls
:meth:`~mach3sbitools.inference.InferenceHandler.build_posterior`
on the provided handler.
2. Call :meth:`create_prior_samples` to draw ``θ ~ prior`` and simulate
the corresponding observables ``x ~ p(x | θ)``.
3. Call any combination of :meth:`rank_plot`, :meth:`expected_coverage`,
and :meth:`tarp`. Each saves a PDF to *plot_dir*.
:param simulator: Simulator object used to generate prior predictive
samples. Must implement
:class:`~mach3sbitools.simulator.SimulatorProtocol`.
:param inference_handler: Trained
:class:`~mach3sbitools.inference.InferenceHandler` whose posterior
is used for all diagnostic evaluations.
:param plot_dir: Directory where output PDFs are written. Created
automatically if it does not exist.
"""
def __init__(
self,
simulator: Simulator,
inference_handler: InferenceHandler,
plot_dir: Path,
) -> None:
self.plot_dir = plot_dir
self.plot_dir.mkdir(exist_ok=True, parents=True)
self.simulator = simulator
self._device_handler = TorchDeviceHandler()
self.inference_handler = inference_handler
inference_handler.build_posterior()
self.posterior = inference_handler.posterior
self.prior_samples: torch.Tensor | None = None
self.prior_predictives: torch.Tensor | None = None
[docs]
def create_prior_samples(self, num_prior_samples: int) -> None:
"""
Draw samples from the prior and generate corresponding prior predictives.
Samples ``θ ~ prior`` using the inference handler's prior, then runs
each ``θ`` through the simulator to produce the prior predictive
observables ``x ~ p(x | θ)``. Both tensors are stored on the instance
and reused by all subsequent diagnostic methods.
:param num_prior_samples: Number of ``(θ, x)`` pairs to generate.
"""
self.prior_samples = self.inference_handler.prior.sample(
(num_prior_samples,)
).to(torch.float32)
prior_predictives_np = np.array(
[
self.simulator.simulator_wrapper.simulate(p)
for p in tqdm(
self.prior_samples.cpu().numpy(), desc="Running SBC diagnostic"
)
],
dtype=np.float32,
)
self.prior_predictives = self._device_handler.to_tensor(
prior_predictives_np
).to(torch.float32)
def _check_prior_sampled(self) -> None:
"""
Assert that :meth:`create_prior_samples` has been called.
:raises ValueError: If :attr:`prior_predictives` is ``None``.
"""
if self.prior_predictives is None:
raise ValueError("Prior predictives not set")
[docs]
def rank_plot(
self,
num_posterior_samples: int = 1000,
num_rank_bins: int = 20,
) -> None:
"""
Produce an SBC rank-uniformity plot and save it to *plot_dir*.
For each prior sample ``θ*``, draws ``num_posterior_samples`` samples
from ``p(θ | x*)`` and computes the rank of ``θ*`` within those
samples. Under a well-calibrated posterior the ranks are uniformly
distributed; systematic deviations indicate over- or
under-dispersion.
Saves ``rank_plot.pdf`` to :attr:`plot_dir`.
:param num_posterior_samples: Number of posterior samples drawn per
prior sample when computing ranks.
:param num_rank_bins: Number of histogram bins in the rank plot.
:raises ValueError: If :meth:`create_prior_samples` has not been
called.
"""
self._check_prior_sampled()
ranks, _ = run_sbc(
self.prior_samples,
self.prior_predictives,
self.posterior,
num_posterior_samples=num_posterior_samples,
use_batched_sampling=True,
)
fig, _ = sbc_rank_plot(
ranks,
num_posterior_samples,
num_bins=num_rank_bins,
figsize=(20, 20),
)
fig.savefig(self.plot_dir / "rank_plot.pdf")
if plt.isinteractive():
fig.show()
plt.close(fig)
[docs]
def expected_coverage(
self,
num_posterior_samples: int = 1000,
num_rank_bins: int = 20,
) -> None:
"""
Produce an expected-coverage (CDF) plot and save it to *plot_dir*.
Uses the negative log-probability of the posterior as a test
statistic. The empirical coverage should match the nominal level for a
well-calibrated posterior: the CDF curve should lie on the diagonal.
Curves above the diagonal indicate over-coverage (conservative
posterior); curves below indicate under-coverage (overconfident
posterior).
Saves ``expected_coverage.pdf`` to :attr:`plot_dir`.
:param num_posterior_samples: Number of posterior samples drawn per
prior sample when computing the test statistic.
:param num_rank_bins: Number of bins used when constructing the
empirical CDF.
:raises ValueError: If :meth:`create_prior_samples` has not been
called, or if :attr:`posterior` is ``None``.
"""
self._check_prior_sampled()
if self.posterior is None:
raise ValueError("Posterior predictives not set")
ranks, _ = run_sbc(
self.prior_samples,
self.prior_predictives,
self.posterior,
num_posterior_samples=num_posterior_samples,
reduce_fns=lambda theta, x: -self.posterior.log_prob(theta, x),
use_batched_sampling=True,
)
fig, _ = sbc_rank_plot(
ranks,
num_posterior_samples,
num_bins=num_rank_bins,
plot_type="cdf",
figsize=(20, 20),
)
fig.savefig(self.plot_dir / "expected_coverage.pdf")
if plt.isinteractive():
fig.show()
plt.close(fig)
[docs]
def tarp(self, num_posterior_samples: int = 1000) -> None:
"""
Produce a TARP diagnostic plot and save it to *plot_dir*.
TARP (Test of Accuracy with Random Points) is a global calibration
test that avoids the marginalisation assumptions of rank-based methods.
It computes the empirical coverage probability (ECP) as a function of
the credibility level ``α`` and reports two summary statistics:
* **ATC** (Area To Curve) — should be close to 0 for a well-calibrated
posterior; positive values indicate over-coverage, negative values
indicate under-coverage.
* **KS p-value** - from a Kolmogorov - Smirnov test against the
diagonal; a large p-value (> 0.05) is consistent with calibration.
Both statistics are logged at INFO level. Saves ``tarp.pdf`` to
:attr:`plot_dir`.
:param num_posterior_samples: Number of posterior samples drawn per
prior sample when estimating the ECP.
:raises ValueError: If :meth:`create_prior_samples` has not been
called.
"""
self._check_prior_sampled()
ecp, alpha = run_tarp(
self.prior_samples,
self.prior_predictives,
self.posterior,
references=None,
num_posterior_samples=num_posterior_samples,
use_batched_sampling=True,
)
atc, ks_pval = check_tarp(ecp.cpu(), alpha.cpu())
logger.info(f"ATC: {atc}, should be close to 0")
logger.info(f"KS p-value: {ks_pval}")
fig, _ = plot_tarp(ecp.cpu(), alpha.cpu())
fig.savefig(self.plot_dir / "tarp.pdf")
if plt.isinteractive():
fig.show()
plt.close(fig)