Source code for mach3sbitools.inference.tensorboard_writer

"""
TensorBoard logging wrapper for the SBI training loop.
"""

import torch
from torch.utils.tensorboard import SummaryWriter


[docs] class TensorBoardWriter: """ Thin wrapper around :class:`~torch.utils.tensorboard.SummaryWriter`. Writes training scalars (losses, learning rates, throughput, GPU stats) to a TensorBoard event file each epoch. """ def __init__(self, log_dir: str, device_type: str): """ :param log_dir: Directory for TensorBoard event files. :param device_type: PyTorch device type string, e.g. ``"cuda"`` or ``"cpu"``. """ self.writer = SummaryWriter(log_dir=log_dir) self.device_type = device_type
[docs] def add_to_writer( self, epoch: int, train_loss: float, val_loss: float, ema_val_loss: float, best_val_loss: float, optimizer: torch.optim.Optimizer, elapsed: float, epochs_no_improve: int, total_samples: int, ) -> None: """ Write all training scalars for one epoch. Records loss curves, per-group learning rates, throughput metrics, early-stopping state, and GPU memory statistics. :param epoch: Current epoch number (x-axis value). :param train_loss: Mean training loss for the epoch. :param val_loss: Mean validation loss for the epoch. :param ema_val_loss: EMA-smoothed validation loss. :param best_val_loss: Best EMA validation loss seen so far. :param optimizer: Current optimiser (used to read learning rates). :param elapsed: Wall-clock seconds for the epoch. :param epochs_no_improve: Current early-stopping counter. :param total_samples: Number of training steps in the epoch (used for throughput calculation). """ self.writer.add_scalar("loss/train", train_loss, epoch) self.writer.add_scalar("loss/val", val_loss, epoch) self.writer.add_scalar("loss/val_ema", ema_val_loss, epoch) self.writer.add_scalar("loss/best_val_ema", best_val_loss, epoch) self.writer.add_scalar("loss/train_val_gap", train_loss - val_loss, epoch) for i, pg in enumerate(optimizer.param_groups): self.writer.add_scalar(f"lr/group_{i}", pg["lr"], epoch) self.writer.add_scalar( "throughput/samples_per_sec", total_samples / elapsed, epoch ) self.writer.add_scalar("throughput/epoch_seconds", elapsed, epoch) self.writer.add_scalar( "early_stopping/epochs_no_improve", epochs_no_improve, epoch ) gpu_stats = self.get_gpu_stats() self.writer.add_scalar("gpu/allocated_mb", gpu_stats["allocated_mb"], epoch) self.writer.add_scalar("gpu/reserved_mb", gpu_stats["reserved_mb"], epoch) self.writer.add_scalar( "gpu/max_reserved_mb", gpu_stats["max_reserved_mb"], epoch ) self.writer.add_scalar( "gpu/memory_utilization_pct", gpu_stats["memory_utilization_pct"], epoch ) if "sm_utilization_pct" in gpu_stats: self.writer.add_scalar( "gpu/sm_utilization_pct", gpu_stats["sm_utilization_pct"], epoch ) self.writer.add_scalar( "gpu/memory_bandwidth_utilization_pct", gpu_stats["memory_bandwidth_utilization_pct"], epoch, )
[docs] def get_gpu_stats(self) -> dict: """ Return GPU memory and utilisation statistics. Returns zeros for all fields when running on CPU or when CUDA is unavailable. :returns: Dict with keys ``allocated_mb``, ``reserved_mb``, ``max_reserved_mb``, ``memory_utilization_pct``, and optionally ``sm_utilization_pct`` and ``memory_bandwidth_utilization_pct`` when CUDA is active. """ if not torch.cuda.is_available() or self.device_type == "cpu": return { "allocated_mb": 0, "reserved_mb": 0, "max_reserved_mb": 0, "memory_utilization_pct": 0, } allocated = torch.cuda.memory_allocated(self.device_type) / 1024**2 reserved = torch.cuda.memory_reserved(self.device_type) / 1024**2 max_reserved = torch.cuda.max_memory_reserved(self.device_type) / 1024**2 mem_util = (allocated / reserved * 100) if reserved > 0 else 0 return { "allocated_mb": allocated, "reserved_mb": reserved, "max_reserved_mb": max_reserved, "memory_utilization_pct": mem_util, }
[docs] def close(self) -> None: """Flush and close the underlying :class:`~torch.utils.tensorboard.SummaryWriter`.""" self.writer.flush() self.writer.close()