Source code for mach3sbitools.simulator.simulator_injector

"""
Simulator injection utilities.

Simulators are expected to follow the :class:`SimulatorProtocol` contract and
be configurable via an input file (e.g. a MaCh3 fitter YAML). This module
handles dynamic import, protocol validation, and instantiation.
"""

import importlib
import inspect
import pkgutil
from collections.abc import Callable
from difflib import get_close_matches
from importlib.util import find_spec
from pathlib import Path
from typing import Protocol, cast, runtime_checkable

import numpy as np

from mach3sbitools.types import BoundaryConditions
from mach3sbitools.utils.logger import get_logger

logger = get_logger()


# ---------------------------------------------------------------------------
# Exceptions
# ---------------------------------------------------------------------------


class SimulatorException(Exception):
    """Base exception for all simulator errors."""


class SimulatorImportError(SimulatorException):
    """Raised when the simulator module or class cannot be imported."""


class SimulatorImplementationError(SimulatorException):
    """Raised when a simulator class does not implement :class:`SimulatorProtocol`."""


class SimulatorSetupError(SimulatorException):
    """Raised when the simulator configuration file cannot be found."""


# ---------------------------------------------------------------------------
# Protocol
# ---------------------------------------------------------------------------


[docs] @runtime_checkable class SimulatorProtocol(Protocol): """ Interface that every simulator must implement. Simulators are configured via a single file path passed to ``__init__``. For MaCh3 this is the fitter YAML config. All parameter-level methods operate over the full (un-filtered) parameter vector. """ def __init__(self, simulator_config: Path | str) -> None: """ Initialise and configure the simulator from a file. :param simulator_config: Path to the simulator configuration file. """ ...
[docs] def simulate(self, theta: list[float]) -> list[float]: """ Run a single forward simulation. :param theta: Input parameter vector. :returns: Predicted observable vector *x*. """ ...
[docs] def get_parameter_names(self) -> list[str]: """ Return the name of each parameter in *theta*. :returns: Ordered list of parameter name strings. """ ...
[docs] def get_parameter_bounds(self) -> BoundaryConditions: """ Return hard lower and upper bounds for each parameter. :returns: Tuple of ``(lower_bounds, upper_bounds)``, each a list of floats with one entry per parameter. """ ...
[docs] def get_is_flat(self, i: int) -> bool: """ Return whether parameter *i* should use a flat (uniform) prior. :param i: Zero-based parameter index. :returns: ``True`` if the parameter is flat, ``False`` for Gaussian. """ ...
[docs] def get_data_bins(self) -> list[float]: """ Return the observed data bin values *x_o*. :returns: Observed data vector. """ ...
[docs] def get_parameter_nominals(self) -> list[float]: """ Return the nominal (mean) value for each parameter. :returns: Ordered list of nominal values. """ ...
[docs] def get_parameter_errors(self) -> list[float]: """ Return the 1σ error for each parameter. :returns: Ordered list of parameter errors. """ ...
[docs] def get_covariance_matrix(self) -> np.ndarray: """ Return the full parameter covariance matrix. :returns: Square numpy array of shape ``(n_params, n_params)``. """ ...
def _implements(proto: type) -> Callable[[type], type]: """ Class decorator that asserts the decorated class satisfies *proto* at decoration time. :param proto: A :func:`runtime_checkable` Protocol class. :returns: Decorator that returns the class unchanged or raises :exc:`SimulatorImplementationError`. """ def _deco(cls_def): if issubclass(cls_def, proto): return cls_def raise SimulatorImplementationError( f"{cls_def} does not implement protocol {proto}. " f"Please see {__file__} for the required interface." ) return _deco # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _closest_match(name: str, candidates: list[str]) -> str | None: """ Return the closest fuzzy match for *name* from *candidates*, or ``None``. :param name: The name to search for. :param candidates: List of candidate strings. :returns: Best match string, or ``None`` if no match above threshold. """ matches = get_close_matches(name, candidates, n=1, cutoff=0.6) return matches[0] if matches else None def _hint(name: str, candidates: list[str]) -> str: """ Build a "did you mean?" hint string for error messages. :param name: The name that was not found. :param candidates: List of valid names to search. :returns: A hint string, or an empty string if no close match exists. """ match = _closest_match(name, candidates) return f" Did you mean: {match}?" if match else "" # --------------------------------------------------------------------------- # Loader # ---------------------------------------------------------------------------
[docs] def get_simulator(module_name: str, class_name: str, config: Path) -> SimulatorProtocol: """ Dynamically import, validate, and instantiate a simulator. The class is checked against :class:`SimulatorProtocol` before instantiation. Equivalent to:: from <module_name> import <class_name> return class_name(config) .. code-block:: console # Example — loading a MaCh3 simulator get_simulator("mypackage.simulator", "MySimulator", Path("fitter.yaml")) :param module_name: Dotted Python module path (e.g. ``'mypackage.simulator'``). :param class_name: Name of the simulator class within the module. :param config: Path to the simulator configuration file. :returns: An instantiated, protocol-validated simulator object. :raises SimulatorImportError: If the module or class cannot be found. :raises SimulatorImplementationError: If the class does not satisfy :class:`SimulatorProtocol`. :raises SimulatorSetupError: If *config* does not exist on disk. """ if find_spec(module_name) is None: installed = [m.name for m in pkgutil.iter_modules()] raise SimulatorImportError( f"Module '{module_name}' not found.{_hint(module_name, installed)}" ) module = importlib.import_module(module_name) logger.info("Found simulator '%s'", module_name) if not hasattr(module, class_name): all_classes = [n for n, _ in inspect.getmembers(module, inspect.isclass)] raise SimulatorImportError( f"Class '{class_name}' not found in '{module_name}'." f"{_hint(class_name, all_classes)}" ) simulator_cls = getattr(module, class_name) simulator_cls = _implements(SimulatorProtocol)(simulator_cls) logger.info("Imported simulator '%s' from '%s'", class_name, module_name) if not config.exists(): raise SimulatorSetupError(f"Config file not found: {config}") logger.info("Found simulator config '%s'", config) return cast(SimulatorProtocol, simulator_cls(str(config)))