Source code for mach3sbitools.utils.config

"""
Configuration dataclasses for model architecture and training.
"""

from dataclasses import dataclass
from pathlib import Path


[docs] @dataclass class TrainingConfig: """ Configuration for the SBI training loop. :param save_path: Directory to write model checkpoints. ``None`` disables checkpointing. :param batch_size: Number of samples per training batch. :param learning_rate: Initial learning rate for the Adam optimiser. :param max_epochs: Hard upper limit on training epochs. :param stop_after_epochs: Stop if the EMA validation loss has not improved for this many consecutive epochs. :param scheduler_patience: Epochs without improvement before :class:`~torch.optim.lr_scheduler.ReduceLROnPlateau` halves the LR. :param validation_fraction: Fraction of data held out for validation. :param num_workers: Number of DataLoader worker processes. :param autosave_every: Save a periodic checkpoint every *N* epochs. :param resume_checkpoint: Path to a checkpoint to resume from. :param use_amp: Enable automatic mixed precision. :param print_interval: Log a training summary every *N* epochs. :param show_progress: Show the two-level fit/epoch progress bars. Works correctly in both CLI terminals and Jupyter notebooks. Set to ``False`` for non-interactive / CI environments. :param tensorboard_dir: Directory for TensorBoard event files. ``None`` disables TensorBoard logging. :param warmup_epochs: Epochs for linear LR warm-up from 1% to 100%. :param ema_alpha: EMA smoothing factor for validation loss used in early stopping. Smaller values are smoother. :param compile: Compile the model with ``torch.compile``. """ save_path: Path | None = None batch_size: int = 2048 learning_rate: float = 5e-4 max_epochs: int = 500 stop_after_epochs: int = 100 scheduler_patience: int = 20 validation_fraction: float = 0.1 num_workers: int = 1 autosave_every: int = 10 resume_checkpoint: Path | None = None use_amp: bool = False print_interval: int = 10 show_progress: bool = False tensorboard_dir: Path | None = None warmup_epochs: int = 50 ema_alpha: float = 0.05 compile: bool = False
[docs] @dataclass class PosteriorConfig: """ Configuration for the NPE density estimator architecture. :param model: ``"maf"`` (Masked Autoregressive Flow) or ``"nse"`` (Neural Spline Flow). :param hidden_features: Number of hidden units per layer. :param num_transforms: Number of autoregressive transforms (MAF only). :param dropout_probability: Dropout probability during training. :param num_blocks: Number of residual blocks. :param num_bins: Number of spline bins (NSF only). """ model: str = "maf" hidden_features: int = 128 num_transforms: int = 6 dropout_probability: float = 0.1 num_blocks: int = 2 num_bins: int = 10