Utilities

Configuration

class TrainingConfig(save_path=None, batch_size=2048, learning_rate=0.0005, max_epochs=500, stop_after_epochs=100, scheduler_patience=20, validation_fraction=0.1, num_workers=1, autosave_every=10, resume_checkpoint=None, use_amp=False, print_interval=10, show_progress=False, tensorboard_dir=None, warmup_epochs=50, ema_alpha=0.05, compile=False)[source]

Configuration for the SBI training loop.

Parameters:
  • save_path (Path | None) – Directory to write model checkpoints. None disables checkpointing.

  • batch_size (int) – Number of samples per training batch.

  • learning_rate (float) – Initial learning rate for the Adam optimiser.

  • max_epochs (int) – Hard upper limit on training epochs.

  • stop_after_epochs (int) – Stop if the EMA validation loss has not improved for this many consecutive epochs.

  • scheduler_patience (int) – Epochs without improvement before ReduceLROnPlateau halves the LR.

  • validation_fraction (float) – Fraction of data held out for validation.

  • num_workers (int) – Number of DataLoader worker processes.

  • autosave_every (int) – Save a periodic checkpoint every N epochs.

  • resume_checkpoint (Path | None) – Path to a checkpoint to resume from.

  • use_amp (bool) – Enable automatic mixed precision.

  • print_interval (int) – Log a training summary every N epochs.

  • show_progress (bool) – Show the two-level fit/epoch progress bars. Works correctly in both CLI terminals and Jupyter notebooks. Set to False for non-interactive / CI environments.

  • tensorboard_dir (Path | None) – Directory for TensorBoard event files. None disables TensorBoard logging.

  • warmup_epochs (int) – Epochs for linear LR warm-up from 1% to 100%.

  • ema_alpha (float) – EMA smoothing factor for validation loss used in early stopping. Smaller values are smoother.

  • compile (bool) – Compile the model with torch.compile.

class PosteriorConfig(model='maf', hidden_features=128, num_transforms=6, dropout_probability=0.1, num_blocks=2, num_bins=10)[source]

Configuration for the NPE density estimator architecture.

Parameters:
  • model (str) – "maf" (Masked Autoregressive Flow) or "nse" (Neural Spline Flow).

  • hidden_features (int) – Number of hidden units per layer.

  • num_transforms (int) – Number of autoregressive transforms (MAF only).

  • dropout_probability (float) – Dropout probability during training.

  • num_blocks (int) – Number of residual blocks.

  • num_bins (int) – Number of spline bins (NSF only).

Logging

class MaCh3Logger(name='mach3sbi', level='INFO', log_file=None, file_level=None, show_path=False)[source]

Rich-backed logger for mach3sbitools.

Wraps logging.Logger with a RichHandler for coloured console output and an optional plain-text file handler.

Usage in application code:

logger = MaCh3Logger("mach3sbi", log_file="run.log", level="INFO")

Usage in submodules:

from mach3sbitools.utils.logger import get_logger
logger = get_logger(__name__)
logger.info("Loaded [bold]50k[/] pairs")
Parameters:
  • name (str)

  • level (str)

  • log_file (Path | None)

  • file_level (str | None)

  • show_path (bool)

debug(msg, *args, **kwargs)[source]

Log msg at DEBUG level.

Return type:

None

Parameters:

msg (str)

info(msg, *args, **kwargs)[source]

Log msg at INFO level.

Return type:

None

Parameters:

msg (str)

warning(msg, *args, **kwargs)[source]

Log msg at WARNING level.

Return type:

None

Parameters:

msg (str)

error(msg, *args, **kwargs)[source]

Log msg at ERROR level.

Return type:

None

Parameters:

msg (str)

critical(msg, *args, **kwargs)[source]

Log msg at CRITICAL level.

Return type:

None

Parameters:

msg (str)

set_level(level)[source]

Adjust the console handler log level at runtime.

Parameters:

level (str) – New log level string.

Return type:

None

property logger: Logger

The underlying logging.Logger for stdlib compatibility.

get_logger(name='mach3sbi')[source]

Return a named child logger for use in submodules.

Parameters:

name (str) – Logger name, typically __name__.

Return type:

Logger

Returns:

A logging.Logger instance.

Device Handling

class TorchDeviceHandler[source]

Detects the best available PyTorch device and provides tensor conversion.

The device is detected once at construction time and cached.

property device: str

The detected device string, e.g. "cuda" or "cpu".

to_tensor(data)[source]

Convert an array-like object to a torch.Tensor on the active device.

Handles DataFrame, ndarray, and any object accepted by torch.tensor().

Parameters:

data – Input data to convert.

Return type:

Tensor

Returns:

Float tensor on device.

Raises:

TensorConversionError – If conversion fails.