Source code for mach3sbitools.utils.device_handler

"""
PyTorch device detection and tensor conversion utilities.
"""

import numpy as np
import pandas as pd
import torch


class TensorConversionError(Exception):
    """Raised when an object cannot be converted to a :class:`torch.Tensor`."""


[docs] class TorchDeviceHandler: """ Detects the best available PyTorch device and provides tensor conversion. The device is detected once at construction time and cached. """ def __init__(self): self._device: str = self._find_device() @property def device(self) -> str: """The detected device string, e.g. ``"cuda"`` or ``"cpu"``.""" return self._device @staticmethod def _find_device() -> str: """Return ``"cuda"`` if available, otherwise ``"cpu"``.""" if torch.cuda.is_available(): return "cuda" return "cpu"
[docs] def to_tensor(self, data) -> torch.Tensor: """ Convert an array-like object to a :class:`torch.Tensor` on the active device. Handles :class:`~pandas.DataFrame`, :class:`~numpy.ndarray`, and any object accepted by :func:`torch.tensor`. :param data: Input data to convert. :returns: Float tensor on :attr:`device`. :raises TensorConversionError: If conversion fails. """ if isinstance(data, pd.DataFrame): return torch.tensor(data.values.astype(np.float32), device=self.device) if isinstance(data, np.ndarray): return torch.tensor(data.astype(np.float32), device=self.device) try: return torch.tensor(data, device=self.device) except Exception as e: raise TensorConversionError( f"Cannot convert object of type {type(data)} to torch tensor" ) from e