Source code for mach3sbitools.data_loaders.paraket_dataloader
"""
Dataset implementation for feather-based simulation files.
"""
from pathlib import Path
import torch
from torch.utils.data import Dataset, TensorDataset
from tqdm import tqdm
from mach3sbitools.utils import from_feather
[docs]
class ParaketDataset(Dataset):
"""
File-level PyTorch dataset over a folder of ``.feather`` simulation files.
Each ``__getitem__`` call loads one feather file and returns a
``(theta, x)`` pair. Call :meth:`to_tensor_dataset` before training to
pre-load everything into RAM as a flat :class:`~torch.utils.data.TensorDataset`.
"""
def __init__(
self,
data_folder: Path,
parameter_names: list[str],
nuisance_params: list[str] | None = None,
):
"""
:param data_folder: Directory containing ``.feather`` files.
:param parameter_names: Ordered list of parameter names in each file's
``theta`` column.
:param nuisance_params: fnmatch patterns for parameters to filter out
of *theta* on load.
"""
self.data_folder = data_folder
self.files = sorted(data_folder.glob("*.feather"))
self.nuisance_params = nuisance_params or None
self.parameter_names = parameter_names
def __len__(self) -> int:
"""Number of feather files in the dataset folder."""
return len(self.files)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Load one feather file and return ``(theta, x)`` as tensors.
:param idx: File index.
:returns: Tuple of ``(theta, x)`` float tensors.
"""
theta, x = from_feather(
self.files[idx], self.parameter_names, self.nuisance_params
)
return torch.from_numpy(theta), torch.from_numpy(x)
[docs]
def to_tensor_dataset(self, device: str = "cpu") -> TensorDataset:
"""
Pre-load all feather files into a flat :class:`~torch.utils.data.TensorDataset`.
Concatenates all ``(theta, x)`` pairs along the sample dimension.
This avoids repeated disk reads per epoch during training.
:param device: Target device for the output tensors.
:returns: A :class:`~torch.utils.data.TensorDataset` of
``(theta_tensor, x_tensor)`` with shapes
``(n_total_samples, n_params)`` and ``(n_total_samples, n_bins)``.
"""
all_theta, all_x = [], []
for idx in tqdm(range(len(self)), desc="Pre-loading dataset"):
theta, x = self[idx]
all_theta.append(theta)
all_x.append(x)
theta_tensor = torch.cat(all_theta, dim=0).to(device)
x_tensor = torch.cat(all_x, dim=0).to(device)
print(
f"Loaded {theta_tensor.shape[0]:,} simulations | "
f"θ: {theta_tensor.shape[1]}D x: {x_tensor.shape[1]}D | "
f"RAM: {(theta_tensor.nbytes + x_tensor.nbytes) / 1e9:.2f} GB"
)
return TensorDataset(theta_tensor, x_tensor)