Source code for mach3sbitools.simulator.simulator_injector
"""
Simulator injection utilities.
Simulators are expected to follow the :class:`SimulatorProtocol` contract and
be configurable via an input file (e.g. a MaCh3 fitter YAML). This module
handles dynamic import, protocol validation, and instantiation.
"""
import importlib
import inspect
import pkgutil
from collections.abc import Callable
from difflib import get_close_matches
from importlib.util import find_spec
from pathlib import Path
from typing import Protocol, cast, runtime_checkable
import numpy as np
from mach3sbitools.types import BoundaryConditions
from mach3sbitools.utils.logger import get_logger
logger = get_logger()
# ---------------------------------------------------------------------------
# Exceptions
# ---------------------------------------------------------------------------
class SimulatorException(Exception):
"""Base exception for all simulator errors."""
class SimulatorImportError(SimulatorException):
"""Raised when the simulator module or class cannot be imported."""
class SimulatorImplementationError(SimulatorException):
"""Raised when a simulator class does not implement :class:`SimulatorProtocol`."""
class SimulatorSetupError(SimulatorException):
"""Raised when the simulator configuration file cannot be found."""
# ---------------------------------------------------------------------------
# Protocol
# ---------------------------------------------------------------------------
[docs]
@runtime_checkable
class SimulatorProtocol(Protocol):
"""
Interface that every simulator must implement.
Simulators are configured via a single file path passed to ``__init__``.
For MaCh3 this is the fitter YAML config. All parameter-level methods
operate over the full (un-filtered) parameter vector.
"""
def __init__(self, simulator_config: Path | str) -> None:
"""
Initialise and configure the simulator from a file.
:param simulator_config: Path to the simulator configuration file.
"""
...
[docs]
def simulate(self, theta: list[float]) -> list[float]:
"""
Run a single forward simulation.
:param theta: Input parameter vector.
:returns: Predicted observable vector *x*.
"""
...
[docs]
def get_parameter_names(self) -> list[str]:
"""
Return the name of each parameter in *theta*.
:returns: Ordered list of parameter name strings.
"""
...
[docs]
def get_parameter_bounds(self) -> BoundaryConditions:
"""
Return hard lower and upper bounds for each parameter.
:returns: Tuple of ``(lower_bounds, upper_bounds)``, each a list of
floats with one entry per parameter.
"""
...
[docs]
def get_is_flat(self, i: int) -> bool:
"""
Return whether parameter *i* should use a flat (uniform) prior.
:param i: Zero-based parameter index.
:returns: ``True`` if the parameter is flat, ``False`` for Gaussian.
"""
...
[docs]
def get_data_bins(self) -> list[float]:
"""
Return the observed data bin values *x_o*.
:returns: Observed data vector.
"""
...
[docs]
def get_parameter_nominals(self) -> list[float]:
"""
Return the nominal (mean) value for each parameter.
:returns: Ordered list of nominal values.
"""
...
[docs]
def get_parameter_errors(self) -> list[float]:
"""
Return the 1σ error for each parameter.
:returns: Ordered list of parameter errors.
"""
...
[docs]
def get_covariance_matrix(self) -> np.ndarray:
"""
Return the full parameter covariance matrix.
:returns: Square numpy array of shape ``(n_params, n_params)``.
"""
...
def _implements(proto: type) -> Callable[[type], type]:
"""
Class decorator that asserts the decorated class satisfies *proto* at
decoration time.
:param proto: A :func:`runtime_checkable` Protocol class.
:returns: Decorator that returns the class unchanged or raises
:exc:`SimulatorImplementationError`.
"""
def _deco(cls_def):
if issubclass(cls_def, proto):
return cls_def
raise SimulatorImplementationError(
f"{cls_def} does not implement protocol {proto}. "
f"Please see {__file__} for the required interface."
)
return _deco
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _closest_match(name: str, candidates: list[str]) -> str | None:
"""
Return the closest fuzzy match for *name* from *candidates*, or ``None``.
:param name: The name to search for.
:param candidates: List of candidate strings.
:returns: Best match string, or ``None`` if no match above threshold.
"""
matches = get_close_matches(name, candidates, n=1, cutoff=0.6)
return matches[0] if matches else None
def _hint(name: str, candidates: list[str]) -> str:
"""
Build a "did you mean?" hint string for error messages.
:param name: The name that was not found.
:param candidates: List of valid names to search.
:returns: A hint string, or an empty string if no close match exists.
"""
match = _closest_match(name, candidates)
return f" Did you mean: {match}?" if match else ""
# ---------------------------------------------------------------------------
# Loader
# ---------------------------------------------------------------------------
[docs]
def get_simulator(module_name: str, class_name: str, config: Path) -> SimulatorProtocol:
"""
Dynamically import, validate, and instantiate a simulator.
The class is checked against :class:`SimulatorProtocol` before
instantiation. Equivalent to::
from <module_name> import <class_name>
return class_name(config)
.. code-block:: console
# Example — loading a MaCh3 simulator
get_simulator("mypackage.simulator", "MySimulator", Path("fitter.yaml"))
:param module_name: Dotted Python module path (e.g. ``'mypackage.simulator'``).
:param class_name: Name of the simulator class within the module.
:param config: Path to the simulator configuration file.
:returns: An instantiated, protocol-validated simulator object.
:raises SimulatorImportError: If the module or class cannot be found.
:raises SimulatorImplementationError: If the class does not satisfy
:class:`SimulatorProtocol`.
:raises SimulatorSetupError: If *config* does not exist on disk.
"""
if find_spec(module_name) is None:
installed = [m.name for m in pkgutil.iter_modules()]
raise SimulatorImportError(
f"Module '{module_name}' not found.{_hint(module_name, installed)}"
)
module = importlib.import_module(module_name)
logger.info("Found simulator '%s'", module_name)
if not hasattr(module, class_name):
all_classes = [n for n, _ in inspect.getmembers(module, inspect.isclass)]
raise SimulatorImportError(
f"Class '{class_name}' not found in '{module_name}'."
f"{_hint(class_name, all_classes)}"
)
simulator_cls = getattr(module, class_name)
simulator_cls = _implements(SimulatorProtocol)(simulator_cls)
logger.info("Imported simulator '%s' from '%s'", class_name, module_name)
if not config.exists():
raise SimulatorSetupError(f"Config file not found: {config}")
logger.info("Found simulator config '%s'", config)
return cast(SimulatorProtocol, simulator_cls(str(config)))