Source code for mach3sbitools.simulator.priors.prior

"""
Prior distribution for MaCh3 SBI.

Constructs a composite prior from three distribution types, checked in order:

1. **Cyclical** — parameters matching *cyclical_parameters* patterns, forced
   to bounds of ``[-2π, 2π]``.
2. **Flat (Uniform)** — parameters flagged via *flat_msk* and not cyclical.
3. **Gaussian** — all remaining parameters, modelled as a
   :class:`~torch.distributions.MultivariateNormal`.
"""

import fnmatch
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import TypeAlias

import numpy as np
import torch
from torch.distributions import MultivariateNormal, Uniform, constraints

from mach3sbitools.utils import TorchDeviceHandler, get_logger

from ..simulator_injector import SimulatorProtocol
from .cyclical_distribution import CyclicalDistribution
from .dataclasses import PriorData

size_: TypeAlias = torch.Size | list[int] | tuple[int, ...]


class PriorNotFound(Exception):
    """Raised when a prior file cannot be found or deserialised."""


logger = get_logger()


@dataclass(frozen=True)
class MaskDistributionMap:
    """
    Associates a boolean parameter mask with its distribution.

    :param mask: Boolean tensor of shape ``(n_params,)`` selecting the
        parameters governed by *distribution*.
    :param distribution: The :class:`torch.distributions.Distribution`
        for the selected parameters.
    """

    mask: torch.Tensor
    distribution: torch.distributions.Distribution

    def to(self, device: torch.device) -> "MaskDistributionMap":
        """
        Move *mask* to *device* (distribution tensors are not moved).

        :param device: Target PyTorch device.
        :returns: New :class:`MaskDistributionMap` with mask on *device*.
        """
        return MaskDistributionMap(
            mask=self.mask.to(device), distribution=self.distribution
        )


[docs] class Prior(torch.distributions.Distribution): """ Composite MaCh3 prior combining cyclical, flat, and Gaussian components. Designed to replicate MaCh3's prior construction (https://github.com/mach3-software/MaCh3) and satisfy the ``sbi`` :class:`~torch.distributions.Distribution` interface. Parameters are assigned to distributions in the following order: - **Cyclical** — matched by *cyclical_parameters* (fnmatch patterns). - **Flat** — flagged by *flat_msk* and not cyclical. - **Gaussian** — everything else. Nuisance parameters can be filtered at construction time or later via :meth:`set_nuisance_filter`, without modifying the underlying data. """ nuisance_filter: torch.Tensor def __init__( self, prior_data: PriorData, flat_msk: list[bool] | None = None, cyclical_parameters: list[str] | None = None, nuisance_parameters: list[str] | None = None, ): """ Construct the composite prior. :param prior_data: Raw prior arrays (names, nominals, bounds, covariance). :param flat_msk: Per-parameter flat flags (index-aligned with *prior_data*). Requires exact name matches. :param cyclical_parameters: fnmatch patterns selecting cyclical parameters. Matched parameters receive bounds of ``±2π``. :param nuisance_parameters: fnmatch patterns selecting parameters to exclude. Can be updated later with :meth:`set_nuisance_filter`. """ self.device_handler = TorchDeviceHandler() self._prior_data = prior_data self.set_nuisance_filter(nuisance_parameters) self._priors: list[MaskDistributionMap] = [] cyclical_mask = torch.zeros(len(self.prior_data.nominals), dtype=torch.bool) if cyclical_parameters: cyclical_mask_ = [ any(fnmatch.fnmatch(p, c) for c in cyclical_parameters) for p in self.prior_data.parameter_names ] cyclical_mask = self.device_handler.to_tensor(cyclical_mask_) if any(cyclical_mask): self._priors.append(self._get_cyclical_map(cyclical_mask)) flat_mask = self.device_handler.to_tensor(flat_msk) & ~cyclical_mask if any(flat_mask): self._priors.append(self._get_flat_map(flat_mask)) gaussian_mask = ~cyclical_mask & ~flat_mask if any(gaussian_mask): self._priors.append(self._get_gaussian_map(gaussian_mask)) super().__init__( batch_shape=torch.Size(), event_shape=torch.Size([len(self.prior_data.nominals)]), validate_args=False, ) # ── Private distribution builders ───────────────────────────────────────── def _get_cyclical_map(self, cyclical_mask: torch.Tensor) -> MaskDistributionMap: self.prior_data.lower_bounds[cyclical_mask] = -2 * torch.pi self.prior_data.upper_bounds[cyclical_mask] = 2 * torch.pi cyclical_data = self.prior_data[cyclical_mask] cyclical_dist = CyclicalDistribution( cyclical_data.nominals, cyclical_data.lower_bounds, cyclical_data.upper_bounds, ) return MaskDistributionMap(cyclical_mask, cyclical_dist) def _get_flat_map(self, flat_mask: torch.Tensor) -> MaskDistributionMap: flat_data = self.prior_data[flat_mask] flat_dist = Uniform(flat_data.lower_bounds, flat_data.upper_bounds) return MaskDistributionMap(flat_mask, flat_dist) def _get_gaussian_map(self, gaussian_mask: torch.Tensor) -> MaskDistributionMap: gaussian_data = self.prior_data[gaussian_mask] gaussian_dist = MultivariateNormal( gaussian_data.nominals, covariance_matrix=gaussian_data.covariance_matrix ) return MaskDistributionMap(gaussian_mask, gaussian_dist) # ── Properties ──────────────────────────────────────────────────────────── @property def prior_data(self) -> PriorData: """Active :class:`PriorData` after applying the nuisance filter.""" return self._prior_data[self.nuisance_filter] @property def mean(self) -> torch.Tensor: """Prior mean — the nominal parameter values.""" return self.device_handler.to_tensor(self.prior_data.nominals) @property def n_params(self) -> int: """Number of active (non-nuisance) parameters.""" return len(self.prior_data.nominals) @property def variance(self) -> torch.Tensor: """Per-parameter prior variance, assembled from all sub-distributions.""" variance = torch.zeros(len(self.prior_data.nominals)) for mask_map in self._priors: variance[mask_map.mask] = mask_map.distribution.variance return variance @property def support(self): """Independent interval support defined by the parameter bounds.""" return constraints.independent( constraints.interval( self.prior_data.lower_bounds, self.prior_data.upper_bounds ), 1, ) # ── Nuisance filtering ────────────────────────────────────────────────────
[docs] def set_nuisance_filter(self, nuisance_patterns: list[str] | None = None) -> None: """ Update the nuisance parameter filter. Parameters matching any of *nuisance_patterns* (fnmatch-style) are excluded from sampling and density evaluation. Call with ``None`` to reset and include all parameters. :param nuisance_patterns: fnmatch patterns (e.g. ``['syst_*']``), or ``None`` to clear the filter. """ if nuisance_patterns is None: n_pars = len(self._prior_data.parameter_names) self.nuisance_filter = torch.ones(n_pars, dtype=torch.bool) return nuisance_filter_ = [ not any(fnmatch.fnmatch(p, n) for n in nuisance_patterns) for p in self._prior_data.parameter_names ] self.nuisance_filter = self.device_handler.to_tensor(nuisance_filter_)
# ── Distribution interface ────────────────────────────────────────────────
[docs] def sample(self, sample_shape=torch.Size([])) -> torch.Tensor: """ Draw samples from the composite prior. :param sample_shape: Batch shape, e.g. ``torch.Size([1000])``. :returns: Tensor of shape ``(*sample_shape, n_params)``. """ sample_shape = torch.Size(sample_shape) samples = torch.empty(*sample_shape, self.n_params, dtype=torch.double) for mask_map in self._priors: samples[..., mask_map.mask] = mask_map.distribution.sample(sample_shape).to( torch.double ) return samples
[docs] def rsample(self, sample_shape=torch.Size([])) -> torch.Tensor: """ Draw reparameterised samples (where supported by sub-distributions). :param sample_shape: Batch shape. :returns: Tensor of shape ``(*sample_shape, n_params)``. """ sample_shape = torch.Size(sample_shape) samples = torch.empty(*sample_shape, self.n_params, dtype=torch.double) for mask_map in self._priors: samples[..., mask_map.mask] = mask_map.distribution.rsample( sample_shape ).to(torch.double) return samples
[docs] def check_bounds(self, params: torch.Tensor) -> torch.Tensor: """ Check whether each sample in *params* lies within the prior bounds. :param params: Tensor of shape ``(n_samples, n_params)``. :returns: Boolean tensor of shape ``(n_samples,)``. """ in_bounds = (params >= self.prior_data.lower_bounds) & ( params <= self.prior_data.upper_bounds ) return self.device_handler.to_tensor(in_bounds.all(dim=-1))
# ── Persistence ───────────────────────────────────────────────────────────
[docs] def save(self, output_path: Path) -> None: """ Pickle the prior to *output_path*. :param output_path: Destination file path. Parent directories are created automatically. """ output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("wb") as f: pickle.dump(self, f)
[docs] def to(self, device: torch.device) -> "Prior": """ Move all tensors to *device* in-place. :param device: Target PyTorch device. :returns: ``self``, for chaining. """ self._prior_data = self._prior_data.to(device) for i, mask_map in enumerate(self._priors): self._priors[i] = mask_map.to(device) self.nuisance_filter = self.nuisance_filter.to(device) return self
# --------------------------------------------------------------------------- # Module-level helpers # --------------------------------------------------------------------------- def _check_boundary( nominal: torch.Tensor, error: torch.Tensor, lower_bound: torch.Tensor, upper_bound: torch.Tensor, parameter_names: np.ndarray, ) -> None: """ Warn if any parameter has bounds further than 10σ from its nominal. :param nominal: Nominal values, shape ``(n_params,)``. :param error: 1σ errors, shape ``(n_params,)``. :param lower_bound: Lower bounds, shape ``(n_params,)``. :param upper_bound: Upper bounds, shape ``(n_params,)``. :param parameter_names: Parameter name strings, shape ``(n_params,)``. """ warning_thresh = 10 warning_ub = nominal + error * warning_thresh warning_lb = nominal - error * warning_thresh mask = (lower_bound < warning_lb) | (upper_bound > warning_ub) if not any(mask): return logger.warning( f"The following parameters have boundaries > {warning_thresh:d}σ from their prior nominal" ) for param_info in zip( parameter_names[mask.cpu().numpy()], nominal[mask], error[mask], lower_bound[mask], upper_bound[mask], ): logger.warning( " '{:s}' | Nominal: {:4f}, Error {:4f} | Lower Bnd {:4f}, Upper Bnd {:4f}".format( *param_info ) )
[docs] def create_prior( simulator_instance: SimulatorProtocol, nuisance_pars: list[str] | None = None, cyclical_pars: list[str] | None = None, ) -> Prior: """ Convenience function to build a :class:`Prior` from a simulator instance. Reads all parameter metadata from *simulator_instance* and constructs the appropriate composite prior. Warns about parameters with unusually wide bounds (>10σ). .. code-block:: console prior = create_prior(simulator, nuisance_pars=["syst_*"], cyclical_pars=["angle"]) prior.save(Path("prior.pkl")) :param simulator_instance: An object implementing :class:`SimulatorProtocol`. :param nuisance_pars: fnmatch patterns for parameters to exclude from the prior (e.g. ``['syst_*']``). :param cyclical_pars: fnmatch patterns for parameters that should use a cyclical sinusoidal prior over ``[-2π, 2π]``. :returns: Configured :class:`Prior` ready for use with ``sbi``. """ logger.info("Creating Prior") dh = TorchDeviceHandler() nominals = dh.to_tensor(simulator_instance.get_parameter_nominals()) errors = dh.to_tensor(simulator_instance.get_parameter_errors()) lower_arr, upper_arr = simulator_instance.get_parameter_bounds() lower = dh.to_tensor(lower_arr) upper = dh.to_tensor(upper_arr) names = np.array(simulator_instance.get_parameter_names(), dtype=str) _check_boundary(nominals, errors, lower, upper, names) covariance = dh.to_tensor(simulator_instance.get_covariance_matrix()) flat_pars = [simulator_instance.get_is_flat(i) for i in range(len(names))] data = PriorData( parameter_names=names, nominals=nominals, lower_bounds=lower, upper_bounds=upper, covariance_matrix=covariance, ) return Prior( prior_data=data, flat_msk=flat_pars, nuisance_parameters=nuisance_pars, cyclical_parameters=cyclical_pars, )
[docs] def load_prior(prior_path: Path, device=torch.device("cpu")) -> Prior: """ Load a pickled :class:`Prior` from disk. .. code-block:: console prior = load_prior(Path("prior.pkl")) :param prior_path: Path to a ``.pkl`` file produced by :meth:`Prior.save`. :param device: Device to move the prior to after loading. Defaults to CPU. :returns: The loaded :class:`Prior`. :raises PriorNotFound: If *prior_path* does not exist or does not contain a valid :class:`Prior`. """ if not prior_path.is_file() or not prior_path.exists(): raise PriorNotFound("Could not find prior %s", prior_path) with prior_path.open("rb") as f: prior = pickle.load(f) if not isinstance(prior, Prior): raise PriorNotFound( "No valid prior in %s. Instead found %s", prior_path, type(prior) ) return prior.to(device)