Source code for mach3sbitools.diagnostics.sbc

"""
Simulation-Based Calibration (SBC) diagnostics.

Provides three complementary posterior-quality checks:

* **SBC rank plot** — verifies that posterior samples are statistically
  consistent with the prior by checking rank uniformity.
* **Expected coverage** — plots the empirical coverage as a CDF; a
  well-calibrated posterior lies on the diagonal.
* **TARP** (Test of Accuracy with Random Points) — a global calibration
  test that does not require marginalisation over individual parameters.

Background and interpretation guidance is available in the
`sbi diagnostics documentation <https://sbi.readthedocs.io/en/stable/how_to_guide/diagnostics.html>`_.

Typical usage::

    sbc = SBCDiagnostic(simulator, inference_handler, Path("plots/"))
    sbc.create_prior_samples(num_prior_samples=200)
    sbc.rank_plot(num_posterior_samples=1000)
    sbc.expected_coverage(num_posterior_samples=1000)
    sbc.tarp(num_posterior_samples=1000)
"""

from pathlib import Path

import numpy as np
import torch
from matplotlib import pyplot as plt
from sbi.analysis.plot import plot_tarp, sbc_rank_plot
from sbi.diagnostics import check_tarp, run_sbc, run_tarp
from tqdm.auto import tqdm

from mach3sbitools.inference import InferenceHandler
from mach3sbitools.simulator import Simulator
from mach3sbitools.utils import TorchDeviceHandler, get_logger

logger = get_logger()


[docs] class SBCDiagnostic: """ Posterior calibration diagnostics via Simulation-Based Calibration. Wraps the ``sbi`` SBC, expected-coverage, and TARP diagnostics behind a common interface. All three methods share a single pool of prior predictive samples generated by :meth:`create_prior_samples`, which must be called before any plot method. The general workflow is: 1. Construct the object — this calls :meth:`~mach3sbitools.inference.InferenceHandler.build_posterior` on the provided handler. 2. Call :meth:`create_prior_samples` to draw ``θ ~ prior`` and simulate the corresponding observables ``x ~ p(x | θ)``. 3. Call any combination of :meth:`rank_plot`, :meth:`expected_coverage`, and :meth:`tarp`. Each saves a PDF to *plot_dir*. :param simulator: Simulator object used to generate prior predictive samples. Must implement :class:`~mach3sbitools.simulator.SimulatorProtocol`. :param inference_handler: Trained :class:`~mach3sbitools.inference.InferenceHandler` whose posterior is used for all diagnostic evaluations. :param plot_dir: Directory where output PDFs are written. Created automatically if it does not exist. """ def __init__( self, simulator: Simulator, inference_handler: InferenceHandler, plot_dir: Path, ) -> None: self.plot_dir = plot_dir self.plot_dir.mkdir(exist_ok=True, parents=True) self.simulator = simulator self._device_handler = TorchDeviceHandler() self.inference_handler = inference_handler inference_handler.build_posterior() self.posterior = inference_handler.posterior self.prior_samples: torch.Tensor | None = None self.prior_predictives: torch.Tensor | None = None
[docs] def create_prior_samples(self, num_prior_samples: int) -> None: """ Draw samples from the prior and generate corresponding prior predictives. Samples ``θ ~ prior`` using the inference handler's prior, then runs each ``θ`` through the simulator to produce the prior predictive observables ``x ~ p(x | θ)``. Both tensors are stored on the instance and reused by all subsequent diagnostic methods. :param num_prior_samples: Number of ``(θ, x)`` pairs to generate. """ self.prior_samples = self.inference_handler.prior.sample( (num_prior_samples,) ).to(torch.float32) prior_predictives_np = np.array( [ self.simulator.simulator_wrapper.simulate(p) for p in tqdm( self.prior_samples.cpu().numpy(), desc="Running SBC diagnostic" ) ], dtype=np.float32, ) self.prior_predictives = self._device_handler.to_tensor( prior_predictives_np ).to(torch.float32)
def _check_prior_sampled(self) -> None: """ Assert that :meth:`create_prior_samples` has been called. :raises ValueError: If :attr:`prior_predictives` is ``None``. """ if self.prior_predictives is None: raise ValueError("Prior predictives not set")
[docs] def rank_plot( self, num_posterior_samples: int = 1000, num_rank_bins: int = 20, ) -> None: """ Produce an SBC rank-uniformity plot and save it to *plot_dir*. For each prior sample ``θ*``, draws ``num_posterior_samples`` samples from ``p(θ | x*)`` and computes the rank of ``θ*`` within those samples. Under a well-calibrated posterior the ranks are uniformly distributed; systematic deviations indicate over- or under-dispersion. Saves ``rank_plot.pdf`` to :attr:`plot_dir`. :param num_posterior_samples: Number of posterior samples drawn per prior sample when computing ranks. :param num_rank_bins: Number of histogram bins in the rank plot. :raises ValueError: If :meth:`create_prior_samples` has not been called. """ self._check_prior_sampled() ranks, _ = run_sbc( self.prior_samples, self.prior_predictives, self.posterior, num_posterior_samples=num_posterior_samples, use_batched_sampling=True, ) fig, _ = sbc_rank_plot( ranks, num_posterior_samples, num_bins=num_rank_bins, figsize=(20, 20), ) fig.savefig(self.plot_dir / "rank_plot.pdf") if plt.isinteractive(): fig.show() plt.close(fig)
[docs] def expected_coverage( self, num_posterior_samples: int = 1000, num_rank_bins: int = 20, ) -> None: """ Produce an expected-coverage (CDF) plot and save it to *plot_dir*. Uses the negative log-probability of the posterior as a test statistic. The empirical coverage should match the nominal level for a well-calibrated posterior: the CDF curve should lie on the diagonal. Curves above the diagonal indicate over-coverage (conservative posterior); curves below indicate under-coverage (overconfident posterior). Saves ``expected_coverage.pdf`` to :attr:`plot_dir`. :param num_posterior_samples: Number of posterior samples drawn per prior sample when computing the test statistic. :param num_rank_bins: Number of bins used when constructing the empirical CDF. :raises ValueError: If :meth:`create_prior_samples` has not been called, or if :attr:`posterior` is ``None``. """ self._check_prior_sampled() if self.posterior is None: raise ValueError("Posterior predictives not set") ranks, _ = run_sbc( self.prior_samples, self.prior_predictives, self.posterior, num_posterior_samples=num_posterior_samples, reduce_fns=lambda theta, x: -self.posterior.log_prob(theta, x), use_batched_sampling=True, ) fig, _ = sbc_rank_plot( ranks, num_posterior_samples, num_bins=num_rank_bins, plot_type="cdf", figsize=(20, 20), ) fig.savefig(self.plot_dir / "expected_coverage.pdf") if plt.isinteractive(): fig.show() plt.close(fig)
[docs] def tarp(self, num_posterior_samples: int = 1000) -> None: """ Produce a TARP diagnostic plot and save it to *plot_dir*. TARP (Test of Accuracy with Random Points) is a global calibration test that avoids the marginalisation assumptions of rank-based methods. It computes the empirical coverage probability (ECP) as a function of the credibility level ``α`` and reports two summary statistics: * **ATC** (Area To Curve) — should be close to 0 for a well-calibrated posterior; positive values indicate over-coverage, negative values indicate under-coverage. * **KS p-value** - from a Kolmogorov - Smirnov test against the diagonal; a large p-value (> 0.05) is consistent with calibration. Both statistics are logged at INFO level. Saves ``tarp.pdf`` to :attr:`plot_dir`. :param num_posterior_samples: Number of posterior samples drawn per prior sample when estimating the ECP. :raises ValueError: If :meth:`create_prior_samples` has not been called. """ self._check_prior_sampled() ecp, alpha = run_tarp( self.prior_samples, self.prior_predictives, self.posterior, references=None, num_posterior_samples=num_posterior_samples, use_batched_sampling=True, ) atc, ks_pval = check_tarp(ecp.cpu(), alpha.cpu()) logger.info(f"ATC: {atc}, should be close to 0") logger.info(f"KS p-value: {ks_pval}") fig, _ = plot_tarp(ecp.cpu(), alpha.cpu()) fig.savefig(self.plot_dir / "tarp.pdf") if plt.isinteractive(): fig.show() plt.close(fig)