Source code for moscot.utils.tagged_array

import enum
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Tuple, Union

import numpy as np
import scipy.sparse as sp

from anndata import AnnData

from moscot._logging import logger
from moscot._types import ArrayLike, CostFn_t
from moscot.costs import get_cost

__all__ = ["Tag", "TaggedArray"]


[docs] @enum.unique class Tag(str, enum.Enum): """Tag in the :class:`~moscot.utils.tagged_array.TaggedArray`.""" COST_MATRIX = "cost_matrix" #: Cost matrix. KERNEL = "kernel" #: Kernel matrix. POINT_CLOUD = "point_cloud" #: Point cloud. GRAPH = "graph" #: Graph distances, means [n+m, n+m] transport matrix.
[docs] @dataclass(frozen=False, repr=True) class TaggedArray: """Interface to interpret array-like data for :mod:`moscot.solvers`. It is used to extract array-like data stored in :class:`~anndata.AnnData` and interpret it as either :attr:`cost matrix <is_cost_matrix>`, :attr:`kernel matrix <is_kernel>` or a :attr:`point cloud <is_point_cloud>`, depending on the :attr:`tag`. Parameters ---------- data_src Source data. data_tgt Target data. tag How to interpret :attr:`data_src` and :attr:`data_tgt`. cost Cost function when ``tag = 'point_cloud'``. """ data_src: ArrayLike #: Source data. data_tgt: Optional[ArrayLike] = None #: Target data. tag: Tag = Tag.POINT_CLOUD #: How to interpret :attr:`data_src` and :attr:`data_tgt`. cost: Optional[Union[str, Callable[..., Any]]] = None #: Cost function when ``tag = 'point_cloud'``. @staticmethod def _extract_data( adata: AnnData, *, attr: Literal["X", "obsp", "obsm", "layers", "uns"], key: Optional[str] = None, ) -> ArrayLike: modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" data = getattr(adata, attr) try: if key is not None: data = data[key] except KeyError: raise KeyError(f"Unable to fetch data from `{modifier}`.") from None except IndexError: raise IndexError(f"Unable to fetch data from `{modifier}`.") from None if sp.issparse(data): logger.warning(f"Densifying data in `{modifier}`") data = data.A if data.ndim != 2: raise ValueError(f"Expected `{modifier}` to have `2` dimensions, found `{data.ndim}`.") return data def _set_cost( self, cost: CostFn_t = "sq_euclidean", backend: Literal["ott"] = "ott", **kwargs: Any, ) -> "TaggedArray": if cost == "custom": raise ValueError("Custom cost functions are handled in `TaggedArray.from_adata`.") if cost != "geodesic": cost = get_cost(cost, backend=backend, **kwargs) self.cost = cost return self
[docs] @classmethod def from_adata( cls, adata: AnnData, dist_key: Union[Any, Tuple[Any, Any]], attr: Literal["X", "obsp", "obsm", "layers", "uns"], tag: Tag = Tag.POINT_CLOUD, key: Optional[str] = None, cost: CostFn_t = "sq_euclidean", backend: Literal["ott"] = "ott", **kwargs: Any, ) -> "TaggedArray": """Create tagged array from :class:`~anndata.AnnData`. .. warning:: Sparse arrays will be always densified. Parameters ---------- adata Annotated data object. dist_key Key which determines into which source/target subset ``adata`` belongs. attr Attribute of :class:`~anndata.AnnData` used when extracting/computing the cost. tag Tag used to interpret the extracted data. key Key in the ``attr`` of :class:`~anndata.AnnData` used when extracting/computing the cost. cost Cost function to apply to the extracted array, depending on ``tag``: - if ``tag = 'point_cloud'``, it is extracted from the ``backend``. - if ``tag = 'graph'`` the ``cost`` has to be ``'geodesic'``. - if ``tag = 'cost'`` or ``tag = 'kernel'``, and ``cost = 'custom'``, the extracted array is already assumed to be a cost/kernel matrix. Otherwise, :class:`~moscot.base.cost.BaseCost` is used to compute the cost matrix. backend Which backend to use, see :func:`~moscot.backends.utils.get_available_backends`. kwargs Keyword arguments for the :class:`~moscot.base.cost.BaseCost` or any backend-specific cost. Returns ------- The tagged array. """ if tag == Tag.GRAPH: if cost == "geodesic": dist_key = f"{dist_key[0]}_{dist_key[1]}" if isinstance(dist_key, tuple) else dist_key data = cls._extract_data(adata, attr=attr, key=f"{dist_key}_{key}") return cls(data_src=data, tag=Tag.GRAPH, cost="geodesic") raise ValueError(f"Expected `cost=geodesic`, found `{cost}`.") if tag == Tag.COST_MATRIX: if cost == "custom": # our custom cost functions modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" data = cls._extract_data(adata, attr=attr, key=key) if np.any(data < 0): raise ValueError(f"Cost matrix in `{modifier}` contains negative values.") return cls(data_src=data, tag=Tag.COST_MATRIX, cost=None) cost_fn = get_cost(cost, backend="moscot", adata=adata, attr=attr, key=key, dist_key=dist_key) cost_matrix = cost_fn(**kwargs) return cls(data_src=cost_matrix, tag=Tag.COST_MATRIX, cost=None) # tag is either a point cloud or a kernel data = cls._extract_data(adata, attr=attr, key=key) cost_fn = get_cost(cost, backend=backend, **kwargs) return cls(data_src=data, tag=tag, cost=cost_fn)
@property def shape(self) -> Tuple[int, int]: """Shape of the cost matrix.""" if self.tag == Tag.POINT_CLOUD: x, y = self.data_src, (self.data_src if self.data_tgt is None else self.data_tgt) return x.shape[0], y.shape[0] return self.data_src.shape # type: ignore[return-value] @property def is_cost_matrix(self) -> bool: """Whether :attr:`data_src` is a cost matrix.""" return self.tag == Tag.COST_MATRIX @property def is_kernel(self) -> bool: """Whether :attr:`data_src` is a kernel matrix.""" return self.tag == Tag.KERNEL @property def is_point_cloud(self) -> bool: """Whether :attr:`data_src` (and optionally) :attr:`data_tgt` is a point cloud.""" return self.tag == Tag.POINT_CLOUD @property def is_graph(self) -> bool: """Whether :attr:`data_src` is a graph.""" return self.tag == Tag.GRAPH