Source code for moscot.costs._utils

import collections
from typing import Any, Dict, List, Optional, Tuple

from moscot import _registry

__all__ = ["get_cost", "get_available_costs", "register_cost"]


_REGISTRY = _registry.Registry()
_SEP = "-"


[docs] def get_cost(name: str, *, backend: str = "moscot", **kwargs: Any) -> Any: """Get cost function for a specific backend.""" key = f"{backend}{_SEP}{name}" if key not in _REGISTRY: raise ValueError(f"Cost `{name!r}` is not available for backend `{backend!r}`.") return _REGISTRY[key](**kwargs)
[docs] def get_available_costs(backend: Optional[str] = None) -> Dict[str, Tuple[str, ...]]: """Return available costs. Parameters ---------- backend Select cost specific to a backend. If :obj:`None`, return the costs for each backend. Returns ------- Dictionary with keys as backend names and values as registered cost functions. """ groups: Dict[str, List[str]] = _get_available_backends_and_costs() if backend is None: return {k: tuple(v) for k, v in groups.items()} if backend not in groups: raise KeyError(f"No backend named `{backend!r}`.") return {backend: tuple(groups[backend])}
[docs] def register_cost(name: str, *, backend: str) -> Any: """Register cost function for a specific backend.""" return _REGISTRY.register(f"{backend}{_SEP}{name}")
def _get_available_backends_and_costs(): """Return a dictionary of available backends with their corresponding list of costs. Returns ------- Default dictionary with keys as backend names and values as registered cost functions. """ groups: Dict[str, List[str]] = collections.defaultdict(list) for key in _REGISTRY: back, *name = key.split(_SEP) groups[back].append(_SEP.join(name)) return groups