Inference

class InferenceHandler(prior_path, nuisance_pars=None)[source]

High-level interface for NPE training and posterior sampling.

Manages the full inference pipeline: loading simulations from disk, building and training an NPE density estimator, and drawing posterior samples conditioned on observed data.

Parameters:
  • prior_path (Path)

  • nuisance_pars (list[str] | None)

set_dataset(data_folder)[source]

Point the handler at a folder of .feather simulation files.

Parameters:

data_folder (Path) – Directory containing .feather files produced by save().

Return type:

None

load_training_data()[source]

Pre-load all feather files into RAM as a flat TensorDataset.

Call once before train_posterior(). Keeps data on CPU; the DataLoader handles GPU transfers via pinned memory.

Raises:

ValueError – If set_dataset() has not been called.

Return type:

None

create_posterior(config)[source]

Build the NPE inference object and density estimator network.

Parameters:

config (PosteriorConfig) – Architecture and hyperparameter settings. See PosteriorConfig.

Return type:

None

train_posterior(config)[source]

Train the density estimator using the custom SBITrainer.

Parameters:

config (TrainingConfig) – Training loop settings. See TrainingConfig.

Raises:

ValueError – If load_training_data() or create_posterior() has not been called.

Return type:

None

build_posterior()[source]

Wrap the trained density estimator in an sbi posterior object.

Called automatically by sample_posterior().

Raises:

ValueError – If no density estimator has been trained or loaded.

Return type:

None

sample_posterior(num_samples, x, **kwargs)[source]

Draw samples from the posterior conditioned on x.

Parameters:
  • num_samples (int) – Number of posterior samples to draw.

  • x (list[float]) – Observed data vector x_o.

  • kwargs – Additional keyword arguments forwarded to sbi.posterior.sample.

Return type:

Tensor

Returns:

Tensor of shape (num_samples, n_params).

Raises:

ValueError – If no density estimator is available.

load_posterior(checkpoint_path, config)[source]

Load a trained density estimator from a checkpoint file.

Supports both best-model state dicts (plain state_dict) and autosave checkpoints (dicts with a "model_state" key).

Parameter and observable dimensions are inferred from the prior.

Parameters:
  • checkpoint_path (Path) – Path to a .pt checkpoint file.

  • config (PosteriorConfig) – Architecture config — must match the settings used during training.

Raises:
  • FileNotFoundError – If checkpoint_path does not exist.

  • ValueError – If the inference object is unavailable after create_posterior().

Return type:

None

Trainer

class SBITrainer(dataset, config, device)[source]

Training loop for sbi density estimators.

Handles data splitting, DataLoader construction, optimiser and scheduler setup, AMP, gradient clipping, EMA-based early stopping, TensorBoard logging, and checkpointing.

Typical usage via InferenceHandler:

trainer = SBITrainer(dataset, config, device)
best_model = trainer.train(density_estimator)
Parameters:
train(density_estimator, optimizer=None, resume_checkpoint=None)[source]

Run the full training loop.

Parameters:
  • density_estimator (Module) – The network to train.

  • optimizer (Optimizer | None) – Optional pre-built optimiser (defaults to Adam).

  • resume_checkpoint (Path | None) – Path to a checkpoint to resume from.

Return type:

Module

Returns:

The density estimator restored to its best validation loss.

Raises:

DensityEstimatorError – If density_estimator is None.

save_checkpoint(epoch, density_estimator, optimizer, warmup_scheduler, plateau_scheduler, scaler, best_val_loss, epochs_no_improve, save_path, training_config=None, use_unique_path=True)[source]

Serialise a full training checkpoint to disk.

Parameters:
  • epoch (int) – Current epoch number.

  • density_estimator (Module) – The density estimator module to save.

  • optimizer (Optimizer) – Current optimiser.

  • warmup_scheduler (LinearLR) – Linear warm-up scheduler.

  • plateau_scheduler (ReduceLROnPlateau) – ReduceLROnPlateau scheduler.

  • scaler (GradScaler) – AMP gradient scaler.

  • best_val_loss (float) – Best EMA validation loss seen so far.

  • epochs_no_improve (int) – Current early-stopping counter.

  • save_path (Path) – Base file path.

  • training_config (TrainingConfig | None) – Optionally embed the config for provenance.

  • use_unique_path (bool) – Append _epoch{N} to the stem when True.

Return type:

None

TensorBoard

class TensorBoardWriter(log_dir, device_type)[source]

Thin wrapper around SummaryWriter.

Writes training scalars (losses, learning rates, throughput, GPU stats) to a TensorBoard event file each epoch.

Parameters:
  • log_dir (str)

  • device_type (str)

add_to_writer(epoch, train_loss, val_loss, ema_val_loss, best_val_loss, optimizer, elapsed, epochs_no_improve, total_samples)[source]

Write all training scalars for one epoch.

Records loss curves, per-group learning rates, throughput metrics, early-stopping state, and GPU memory statistics.

Parameters:
  • epoch (int) – Current epoch number (x-axis value).

  • train_loss (float) – Mean training loss for the epoch.

  • val_loss (float) – Mean validation loss for the epoch.

  • ema_val_loss (float) – EMA-smoothed validation loss.

  • best_val_loss (float) – Best EMA validation loss seen so far.

  • optimizer (Optimizer) – Current optimiser (used to read learning rates).

  • elapsed (float) – Wall-clock seconds for the epoch.

  • epochs_no_improve (int) – Current early-stopping counter.

  • total_samples (int) – Number of training steps in the epoch (used for throughput calculation).

Return type:

None

get_gpu_stats()[source]

Return GPU memory and utilisation statistics.

Returns zeros for all fields when running on CPU or when CUDA is unavailable.

Return type:

dict

Returns:

Dict with keys allocated_mb, reserved_mb, max_reserved_mb, memory_utilization_pct, and optionally sm_utilization_pct and memory_bandwidth_utilization_pct when CUDA is active.

close()[source]

Flush and close the underlying SummaryWriter.

Return type:

None