Source code for moscot.backends.utils

from typing import TYPE_CHECKING, Any, Callable, Literal, Tuple, Union

from moscot import _registry
from moscot._types import ProblemKind_t

if TYPE_CHECKING:
    from moscot.backends import ott

__all__ = ["get_solver", "register_solver", "get_available_backends"]

register_solver_t = Callable[
    [Literal["linear", "quadratic"]],
    Union["ott.SinkhornSolver", "ott.GWSolver"],
]


_REGISTRY = _registry.Registry()


[docs] def get_solver(problem_kind: ProblemKind_t, *, backend: str = "ott", return_class: bool = False, **kwargs: Any) -> Any: """TODO.""" if backend not in _REGISTRY: raise ValueError(f"Backend `{backend!r}` is not available.") solver_class = _REGISTRY[backend](problem_kind) return solver_class if return_class else solver_class(**kwargs)
def register_solver( backend: str, ) -> Union["ott.SinkhornSolver", "ott.GWSolver"]: """Register a solver for a specific backend. Parameters ---------- backend Name of the backend. Returns ------- The decorated function which returns the type of the solver. """ return _REGISTRY.register(backend) # type: ignore[return-value] @register_solver("ott") def _( problem_kind: Literal["linear", "quadratic"], ) -> Union["ott.SinkhornSolver", "ott.GWSolver"]: from moscot.backends import ott if problem_kind == "linear": return ott.SinkhornSolver # type: ignore[return-value] if problem_kind == "quadratic": return ott.GWSolver # type: ignore[return-value] raise NotImplementedError(f"Unable to create solver for `{problem_kind!r}`.")
[docs] def get_available_backends() -> Tuple[str, ...]: """Return all available backends.""" return tuple(backend for backend in _REGISTRY)