Source code for mach3sbitools.utils.file_utils
"""
Feather file I/O utilities for simulation data.
"""
from fnmatch import fnmatch
from pathlib import Path
from typing import TypedDict
import numpy as np
from pyarrow import Table, feather
from mach3sbitools.types import SimulatorData, SimulatorDataGrouped
class FeatherOutput(TypedDict):
"""Schema for feather files written by :func:`to_feather`."""
x: SimulatorData
theta: SimulatorData
[docs]
def filter_nuisance(
parameter_names: list[str], nuisance_pars: list[str], theta: SimulatorData
) -> SimulatorData:
"""
Remove nuisance parameters from a theta array by name pattern.
:param parameter_names: Ordered parameter names, length must match
``theta.shape[1]``.
:param nuisance_pars: fnmatch patterns for parameters to exclude
(e.g. ``["syst_*"]``).
:param theta: Parameter array of shape ``(n_samples, n_params)``.
:returns: Filtered array with nuisance columns removed.
:raises ValueError: If ``len(parameter_names) != theta.shape[1]``.
"""
if len(theta[0]) != len(parameter_names):
raise ValueError("Parameter names and theta must have same length")
if nuisance_pars is None:
return theta
param_filter = np.array(
[
not any(fnmatch(param, nuis) for nuis in nuisance_pars)
for param in parameter_names
],
dtype=bool,
)
return theta[:, param_filter].copy()
[docs]
def from_feather(
file_name: Path,
parameter_names: list[str],
nuisance_pars: list[str] | None = None,
) -> SimulatorDataGrouped:
"""
Load a ``(theta, x)`` pair from a feather file.
:param file_name: Path to the ``.feather`` file.
:param parameter_names: Ordered parameter names used for nuisance filtering.
:param nuisance_pars: fnmatch patterns for parameters to exclude from
*theta*. ``None`` returns all parameters.
:returns: Tuple of ``(theta, x)`` as ``float32`` numpy arrays.
:raises FileNotFoundError: If *file_name* does not exist.
"""
if not file_name.exists():
raise FileNotFoundError(file_name)
table = feather.read_feather(str(file_name))
theta = np.array(table["theta"].to_list(), dtype=np.float32)
x = np.array(table["x"].to_list(), dtype=np.float32)
if nuisance_pars is not None:
theta = filter_nuisance(parameter_names, nuisance_pars, theta)
return theta, x
[docs]
def to_feather(
file_name: Path,
theta_values: SimulatorData,
x_values: SimulatorData,
) -> None:
"""
Write a ``(theta, x)`` pair to a feather file.
:param file_name: Destination path. Must end in ``.feather``.
:param theta_values: Parameter array of shape ``(n_samples, n_params)``.
:param x_values: Observable array of shape ``(n_samples, n_bins)``.
:raises ValueError: If *file_name* does not have a ``.feather`` suffix.
"""
if file_name.suffix != ".feather":
raise ValueError("Must store outputs files with the *.feather extension")
param_dict: FeatherOutput = {
"x": x_values.tolist(),
"theta": theta_values.tolist(),
}
param_table = Table.from_pydict(param_dict)
file_name.parent.mkdir(parents=True, exist_ok=True)
feather.write_feather(param_table, str(file_name))