Utilities¶
Configuration¶
- class TrainingConfig(save_path=None, batch_size=2048, learning_rate=0.0005, max_epochs=500, stop_after_epochs=100, scheduler_patience=20, validation_fraction=0.1, num_workers=1, autosave_every=10, resume_checkpoint=None, use_amp=False, print_interval=10, show_progress=False, tensorboard_dir=None, warmup_epochs=50, ema_alpha=0.05, compile=False)[source]¶
Configuration for the SBI training loop.
- Parameters:
save_path (
Path|None) – Directory to write model checkpoints.Nonedisables checkpointing.batch_size (
int) – Number of samples per training batch.learning_rate (
float) – Initial learning rate for the Adam optimiser.max_epochs (
int) – Hard upper limit on training epochs.stop_after_epochs (
int) – Stop if the EMA validation loss has not improved for this many consecutive epochs.scheduler_patience (
int) – Epochs without improvement beforeReduceLROnPlateauhalves the LR.validation_fraction (
float) – Fraction of data held out for validation.num_workers (
int) – Number of DataLoader worker processes.autosave_every (
int) – Save a periodic checkpoint every N epochs.resume_checkpoint (
Path|None) – Path to a checkpoint to resume from.use_amp (
bool) – Enable automatic mixed precision.print_interval (
int) – Log a training summary every N epochs.show_progress (
bool) – Show the two-level fit/epoch progress bars. Works correctly in both CLI terminals and Jupyter notebooks. Set toFalsefor non-interactive / CI environments.tensorboard_dir (
Path|None) – Directory for TensorBoard event files.Nonedisables TensorBoard logging.warmup_epochs (
int) – Epochs for linear LR warm-up from 1% to 100%.ema_alpha (
float) – EMA smoothing factor for validation loss used in early stopping. Smaller values are smoother.compile (
bool) – Compile the model withtorch.compile.
- class PosteriorConfig(model='maf', hidden_features=128, num_transforms=6, dropout_probability=0.1, num_blocks=2, num_bins=10)[source]¶
Configuration for the NPE density estimator architecture.
- Parameters:
model (
str) –"maf"(Masked Autoregressive Flow) or"nse"(Neural Spline Flow).hidden_features (
int) – Number of hidden units per layer.num_transforms (
int) – Number of autoregressive transforms (MAF only).dropout_probability (
float) – Dropout probability during training.num_blocks (
int) – Number of residual blocks.num_bins (
int) – Number of spline bins (NSF only).
Logging¶
- class MaCh3Logger(name='mach3sbi', level='INFO', log_file=None, file_level=None, show_path=False)[source]¶
Rich-backed logger for mach3sbitools.
Wraps
logging.Loggerwith aRichHandlerfor coloured console output and an optional plain-text file handler.Usage in application code:
logger = MaCh3Logger("mach3sbi", log_file="run.log", level="INFO")
Usage in submodules:
from mach3sbitools.utils.logger import get_logger logger = get_logger(__name__) logger.info("Loaded [bold]50k[/] pairs")
- Parameters:
name (str)
level (str)
log_file (Path | None)
file_level (str | None)
show_path (bool)
- debug(msg, *args, **kwargs)[source]¶
Log msg at DEBUG level.
- Return type:
None- Parameters:
msg (str)
- warning(msg, *args, **kwargs)[source]¶
Log msg at WARNING level.
- Return type:
None- Parameters:
msg (str)
- error(msg, *args, **kwargs)[source]¶
Log msg at ERROR level.
- Return type:
None- Parameters:
msg (str)
- critical(msg, *args, **kwargs)[source]¶
Log msg at CRITICAL level.
- Return type:
None- Parameters:
msg (str)
- set_level(level)[source]¶
Adjust the console handler log level at runtime.
- Parameters:
level (
str) – New log level string.- Return type:
None
- property logger: Logger¶
The underlying
logging.Loggerfor stdlib compatibility.
Device Handling¶
- class TorchDeviceHandler[source]¶
Detects the best available PyTorch device and provides tensor conversion.
The device is detected once at construction time and cached.
- property device: str¶
The detected device string, e.g.
"cuda"or"cpu".
- to_tensor(data)[source]¶
Convert an array-like object to a
torch.Tensoron the active device.Handles
DataFrame,ndarray, and any object accepted bytorch.tensor().- Parameters:
data – Input data to convert.
- Return type:
Tensor- Returns:
Float tensor on
device.- Raises:
TensorConversionError – If conversion fails.