"""
Prior distribution for MaCh3 SBI.
Constructs a composite prior from three distribution types, checked in order:
1. **Cyclical** — parameters matching *cyclical_parameters* patterns, forced
to bounds of ``[-2π, 2π]``.
2. **Flat (Uniform)** — parameters flagged via *flat_msk* and not cyclical.
3. **Gaussian** — all remaining parameters, modelled as a
:class:`~torch.distributions.MultivariateNormal`.
"""
import fnmatch
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import TypeAlias
import numpy as np
import torch
from torch.distributions import MultivariateNormal, Uniform, constraints
from mach3sbitools.utils import TorchDeviceHandler, get_logger
from ..simulator_injector import SimulatorProtocol
from .cyclical_distribution import CyclicalDistribution
from .dataclasses import PriorData
size_: TypeAlias = torch.Size | list[int] | tuple[int, ...]
class PriorNotFound(Exception):
"""Raised when a prior file cannot be found or deserialised."""
logger = get_logger()
@dataclass(frozen=True)
class MaskDistributionMap:
"""
Associates a boolean parameter mask with its distribution.
:param mask: Boolean tensor of shape ``(n_params,)`` selecting the
parameters governed by *distribution*.
:param distribution: The :class:`torch.distributions.Distribution`
for the selected parameters.
"""
mask: torch.Tensor
distribution: torch.distributions.Distribution
def to(self, device: torch.device) -> "MaskDistributionMap":
"""
Move *mask* to *device* (distribution tensors are not moved).
:param device: Target PyTorch device.
:returns: New :class:`MaskDistributionMap` with mask on *device*.
"""
return MaskDistributionMap(
mask=self.mask.to(device), distribution=self.distribution
)
[docs]
class Prior(torch.distributions.Distribution):
"""
Composite MaCh3 prior combining cyclical, flat, and Gaussian components.
Designed to replicate MaCh3's prior construction
(https://github.com/mach3-software/MaCh3) and satisfy the ``sbi``
:class:`~torch.distributions.Distribution` interface.
Parameters are assigned to distributions in the following order:
- **Cyclical** — matched by *cyclical_parameters* (fnmatch patterns).
- **Flat** — flagged by *flat_msk* and not cyclical.
- **Gaussian** — everything else.
Nuisance parameters can be filtered at construction time or later via
:meth:`set_nuisance_filter`, without modifying the underlying data.
"""
nuisance_filter: torch.Tensor
def __init__(
self,
prior_data: PriorData,
flat_msk: list[bool] | None = None,
cyclical_parameters: list[str] | None = None,
nuisance_parameters: list[str] | None = None,
):
"""
Construct the composite prior.
:param prior_data: Raw prior arrays (names, nominals, bounds, covariance).
:param flat_msk: Per-parameter flat flags (index-aligned with
*prior_data*). Requires exact name matches.
:param cyclical_parameters: fnmatch patterns selecting cyclical
parameters. Matched parameters receive bounds of ``±2π``.
:param nuisance_parameters: fnmatch patterns selecting parameters to
exclude. Can be updated later with :meth:`set_nuisance_filter`.
"""
self.device_handler = TorchDeviceHandler()
self._prior_data = prior_data
self.set_nuisance_filter(nuisance_parameters)
self._priors: list[MaskDistributionMap] = []
cyclical_mask = torch.zeros(len(self.prior_data.nominals), dtype=torch.bool)
if cyclical_parameters:
cyclical_mask_ = [
any(fnmatch.fnmatch(p, c) for c in cyclical_parameters)
for p in self.prior_data.parameter_names
]
cyclical_mask = self.device_handler.to_tensor(cyclical_mask_)
if any(cyclical_mask):
self._priors.append(self._get_cyclical_map(cyclical_mask))
flat_mask = self.device_handler.to_tensor(flat_msk) & ~cyclical_mask
if any(flat_mask):
self._priors.append(self._get_flat_map(flat_mask))
gaussian_mask = ~cyclical_mask & ~flat_mask
if any(gaussian_mask):
self._priors.append(self._get_gaussian_map(gaussian_mask))
super().__init__(
batch_shape=torch.Size(),
event_shape=torch.Size([len(self.prior_data.nominals)]),
validate_args=False,
)
# ── Private distribution builders ─────────────────────────────────────────
def _get_cyclical_map(self, cyclical_mask: torch.Tensor) -> MaskDistributionMap:
self.prior_data.lower_bounds[cyclical_mask] = -2 * torch.pi
self.prior_data.upper_bounds[cyclical_mask] = 2 * torch.pi
cyclical_data = self.prior_data[cyclical_mask]
cyclical_dist = CyclicalDistribution(
cyclical_data.nominals,
cyclical_data.lower_bounds,
cyclical_data.upper_bounds,
)
return MaskDistributionMap(cyclical_mask, cyclical_dist)
def _get_flat_map(self, flat_mask: torch.Tensor) -> MaskDistributionMap:
flat_data = self.prior_data[flat_mask]
flat_dist = Uniform(flat_data.lower_bounds, flat_data.upper_bounds)
return MaskDistributionMap(flat_mask, flat_dist)
def _get_gaussian_map(self, gaussian_mask: torch.Tensor) -> MaskDistributionMap:
gaussian_data = self.prior_data[gaussian_mask]
gaussian_dist = MultivariateNormal(
gaussian_data.nominals, covariance_matrix=gaussian_data.covariance_matrix
)
return MaskDistributionMap(gaussian_mask, gaussian_dist)
# ── Properties ────────────────────────────────────────────────────────────
@property
def prior_data(self) -> PriorData:
"""Active :class:`PriorData` after applying the nuisance filter."""
return self._prior_data[self.nuisance_filter]
@property
def mean(self) -> torch.Tensor:
"""Prior mean — the nominal parameter values."""
return self.device_handler.to_tensor(self.prior_data.nominals)
@property
def n_params(self) -> int:
"""Number of active (non-nuisance) parameters."""
return len(self.prior_data.nominals)
@property
def variance(self) -> torch.Tensor:
"""Per-parameter prior variance, assembled from all sub-distributions."""
variance = torch.zeros(len(self.prior_data.nominals))
for mask_map in self._priors:
variance[mask_map.mask] = mask_map.distribution.variance
return variance
@property
def support(self):
"""Independent interval support defined by the parameter bounds."""
return constraints.independent(
constraints.interval(
self.prior_data.lower_bounds, self.prior_data.upper_bounds
),
1,
)
# ── Nuisance filtering ────────────────────────────────────────────────────
[docs]
def set_nuisance_filter(self, nuisance_patterns: list[str] | None = None) -> None:
"""
Update the nuisance parameter filter.
Parameters matching any of *nuisance_patterns* (fnmatch-style) are
excluded from sampling and density evaluation. Call with ``None`` to
reset and include all parameters.
:param nuisance_patterns: fnmatch patterns (e.g. ``['syst_*']``), or
``None`` to clear the filter.
"""
if nuisance_patterns is None:
n_pars = len(self._prior_data.parameter_names)
self.nuisance_filter = torch.ones(n_pars, dtype=torch.bool)
return
nuisance_filter_ = [
not any(fnmatch.fnmatch(p, n) for n in nuisance_patterns)
for p in self._prior_data.parameter_names
]
self.nuisance_filter = self.device_handler.to_tensor(nuisance_filter_)
# ── Distribution interface ────────────────────────────────────────────────
[docs]
def sample(self, sample_shape=torch.Size([])) -> torch.Tensor:
"""
Draw samples from the composite prior.
:param sample_shape: Batch shape, e.g. ``torch.Size([1000])``.
:returns: Tensor of shape ``(*sample_shape, n_params)``.
"""
sample_shape = torch.Size(sample_shape)
samples = torch.empty(*sample_shape, self.n_params, dtype=torch.double)
for mask_map in self._priors:
samples[..., mask_map.mask] = mask_map.distribution.sample(sample_shape).to(
torch.double
)
return samples
[docs]
def rsample(self, sample_shape=torch.Size([])) -> torch.Tensor:
"""
Draw reparameterised samples (where supported by sub-distributions).
:param sample_shape: Batch shape.
:returns: Tensor of shape ``(*sample_shape, n_params)``.
"""
sample_shape = torch.Size(sample_shape)
samples = torch.empty(*sample_shape, self.n_params, dtype=torch.double)
for mask_map in self._priors:
samples[..., mask_map.mask] = mask_map.distribution.rsample(
sample_shape
).to(torch.double)
return samples
[docs]
def check_bounds(self, params: torch.Tensor) -> torch.Tensor:
"""
Check whether each sample in *params* lies within the prior bounds.
:param params: Tensor of shape ``(n_samples, n_params)``.
:returns: Boolean tensor of shape ``(n_samples,)``.
"""
in_bounds = (params >= self.prior_data.lower_bounds) & (
params <= self.prior_data.upper_bounds
)
return self.device_handler.to_tensor(in_bounds.all(dim=-1))
# ── Persistence ───────────────────────────────────────────────────────────
[docs]
def save(self, output_path: Path) -> None:
"""
Pickle the prior to *output_path*.
:param output_path: Destination file path. Parent directories are
created automatically.
"""
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as f:
pickle.dump(self, f)
[docs]
def to(self, device: torch.device) -> "Prior":
"""
Move all tensors to *device* in-place.
:param device: Target PyTorch device.
:returns: ``self``, for chaining.
"""
self._prior_data = self._prior_data.to(device)
for i, mask_map in enumerate(self._priors):
self._priors[i] = mask_map.to(device)
self.nuisance_filter = self.nuisance_filter.to(device)
return self
# ---------------------------------------------------------------------------
# Module-level helpers
# ---------------------------------------------------------------------------
def _check_boundary(
nominal: torch.Tensor,
error: torch.Tensor,
lower_bound: torch.Tensor,
upper_bound: torch.Tensor,
parameter_names: np.ndarray,
) -> None:
"""
Warn if any parameter has bounds further than 10σ from its nominal.
:param nominal: Nominal values, shape ``(n_params,)``.
:param error: 1σ errors, shape ``(n_params,)``.
:param lower_bound: Lower bounds, shape ``(n_params,)``.
:param upper_bound: Upper bounds, shape ``(n_params,)``.
:param parameter_names: Parameter name strings, shape ``(n_params,)``.
"""
warning_thresh = 10
warning_ub = nominal + error * warning_thresh
warning_lb = nominal - error * warning_thresh
mask = (lower_bound < warning_lb) | (upper_bound > warning_ub)
if not any(mask):
return
logger.warning(
f"The following parameters have boundaries > {warning_thresh:d}σ from their prior nominal"
)
for param_info in zip(
parameter_names[mask.cpu().numpy()],
nominal[mask],
error[mask],
lower_bound[mask],
upper_bound[mask],
):
logger.warning(
" '{:s}' | Nominal: {:4f}, Error {:4f} | Lower Bnd {:4f}, Upper Bnd {:4f}".format(
*param_info
)
)
[docs]
def create_prior(
simulator_instance: SimulatorProtocol,
nuisance_pars: list[str] | None = None,
cyclical_pars: list[str] | None = None,
) -> Prior:
"""
Convenience function to build a :class:`Prior` from a simulator instance.
Reads all parameter metadata from *simulator_instance* and constructs the
appropriate composite prior. Warns about parameters with unusually wide
bounds (>10σ).
.. code-block:: console
prior = create_prior(simulator, nuisance_pars=["syst_*"], cyclical_pars=["angle"])
prior.save(Path("prior.pkl"))
:param simulator_instance: An object implementing :class:`SimulatorProtocol`.
:param nuisance_pars: fnmatch patterns for parameters to exclude from the
prior (e.g. ``['syst_*']``).
:param cyclical_pars: fnmatch patterns for parameters that should use a
cyclical sinusoidal prior over ``[-2π, 2π]``.
:returns: Configured :class:`Prior` ready for use with ``sbi``.
"""
logger.info("Creating Prior")
dh = TorchDeviceHandler()
nominals = dh.to_tensor(simulator_instance.get_parameter_nominals())
errors = dh.to_tensor(simulator_instance.get_parameter_errors())
lower_arr, upper_arr = simulator_instance.get_parameter_bounds()
lower = dh.to_tensor(lower_arr)
upper = dh.to_tensor(upper_arr)
names = np.array(simulator_instance.get_parameter_names(), dtype=str)
_check_boundary(nominals, errors, lower, upper, names)
covariance = dh.to_tensor(simulator_instance.get_covariance_matrix())
flat_pars = [simulator_instance.get_is_flat(i) for i in range(len(names))]
data = PriorData(
parameter_names=names,
nominals=nominals,
lower_bounds=lower,
upper_bounds=upper,
covariance_matrix=covariance,
)
return Prior(
prior_data=data,
flat_msk=flat_pars,
nuisance_parameters=nuisance_pars,
cyclical_parameters=cyclical_pars,
)
[docs]
def load_prior(prior_path: Path, device=torch.device("cpu")) -> Prior:
"""
Load a pickled :class:`Prior` from disk.
.. code-block:: console
prior = load_prior(Path("prior.pkl"))
:param prior_path: Path to a ``.pkl`` file produced by :meth:`Prior.save`.
:param device: Device to move the prior to after loading. Defaults to CPU.
:returns: The loaded :class:`Prior`.
:raises PriorNotFound: If *prior_path* does not exist or does not contain
a valid :class:`Prior`.
"""
if not prior_path.is_file() or not prior_path.exists():
raise PriorNotFound("Could not find prior %s", prior_path)
with prior_path.open("rb") as f:
prior = pickle.load(f)
if not isinstance(prior, Prior):
raise PriorNotFound(
"No valid prior in %s. Instead found %s", prior_path, type(prior)
)
return prior.to(device)