Source code for moscot.utils.tagged_array

import enum
from dataclasses import dataclass
from typing import Any, Callable, Hashable, Literal, Optional, Tuple, TypeVar, Union

import numpy as np
import scipy.sparse as sp

from anndata import AnnData

from moscot._logging import logger
from moscot._types import ArrayLike, CostFn_t, OttCostFn_t
from moscot.costs import get_cost

K = TypeVar("K", bound=Hashable)

__all__ = ["Tag", "TaggedArray", "DistributionContainer", "DistributionCollection"]


[docs] @enum.unique class Tag(str, enum.Enum): """Tag in the :class:`~moscot.utils.tagged_array.TaggedArray`.""" COST_MATRIX = "cost_matrix" #: Cost matrix. KERNEL = "kernel" #: Kernel matrix. POINT_CLOUD = "point_cloud" #: Point cloud. GRAPH = "graph" #: Graph distances, means [n+m, n+m] transport matrix.
[docs] @dataclass(frozen=False, repr=True) class TaggedArray: """Interface to interpret array-like data for :mod:`moscot.solvers`. It is used to extract array-like data stored in :class:`~anndata.AnnData` and interpret it as either :attr:`cost matrix <is_cost_matrix>`, :attr:`kernel matrix <is_kernel>` or a :attr:`point cloud <is_point_cloud>`, depending on the :attr:`tag`. Parameters ---------- data_src Source data. data_tgt Target data. tag How to interpret :attr:`data_src` and :attr:`data_tgt`. cost Cost function when ``tag = 'point_cloud'``. """ data_src: ArrayLike #: Source data. data_tgt: Optional[ArrayLike] = None #: Target data. tag: Tag = Tag.POINT_CLOUD #: How to interpret :attr:`data_src` and :attr:`data_tgt`. cost: Optional[Union[str, Callable[..., Any]]] = None #: Cost function when ``tag = 'point_cloud'``. @staticmethod def _extract_data( adata: AnnData, *, attr: Literal["X", "obsp", "obsm", "layers", "uns"], key: Optional[str] = None, densify: bool = False, ) -> ArrayLike: modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" data = getattr(adata, attr) try: if key is not None: data = data[key] except KeyError: raise KeyError(f"Unable to fetch data from `{modifier}`.") from None except IndexError: raise IndexError(f"Unable to fetch data from `{modifier}`.") from None if sp.issparse(data) and densify: logger.warning(f"Densifying data in `{modifier}`") data = data.toarray() if data.ndim != 2: raise ValueError(f"Expected `{modifier}` to have `2` dimensions, found `{data.ndim}`.") return data def _set_cost( self, cost: CostFn_t = "sq_euclidean", backend: Literal["ott"] = "ott", **kwargs: Any, ) -> "TaggedArray": if cost == "custom": raise ValueError("Custom cost functions are handled in `TaggedArray.from_adata`.") if cost != "geodesic": cost = get_cost(cost, backend=backend, **kwargs) self.cost = cost return self
[docs] @classmethod def from_adata( cls, adata: AnnData, dist_key: Union[Any, Tuple[Any, Any]], attr: Literal["X", "obsp", "obsm", "layers", "uns"], tag: Tag = Tag.POINT_CLOUD, key: Optional[str] = None, cost: CostFn_t = "sq_euclidean", backend: Literal["ott"] = "ott", **kwargs: Any, ) -> "TaggedArray": """Create tagged array from :class:`~anndata.AnnData`. .. warning:: Sparse arrays will be densified except when ``tag = 'graph'``. Parameters ---------- adata Annotated data object. dist_key Key which determines into which source/target subset ``adata`` belongs. attr Attribute of :class:`~anndata.AnnData` used when extracting/computing the cost. tag Tag used to interpret the extracted data. key Key in the ``attr`` of :class:`~anndata.AnnData` used when extracting/computing the cost. cost Cost function to apply to the extracted array, depending on ``tag``: - if ``tag = 'point_cloud'``, it is extracted from the ``backend``. - if ``tag = 'graph'`` the ``cost`` has to be ``'geodesic'``. - if ``tag = 'cost'`` or ``tag = 'kernel'``, and ``cost = 'custom'``, the extracted array is already assumed to be a cost/kernel matrix. Otherwise, :class:`~moscot.base.cost.BaseCost` is used to compute the cost matrix. backend Which backend to use, see :func:`~moscot.backends.utils.get_available_backends`. kwargs Keyword arguments for the :class:`~moscot.base.cost.BaseCost` or any backend-specific cost. Returns ------- The tagged array. """ if tag == Tag.GRAPH: if cost == "geodesic": dist_key = f"{dist_key[0]}_{dist_key[1]}" if isinstance(dist_key, tuple) else dist_key data = cls._extract_data(adata, attr=attr, key=f"{dist_key}_{key}", densify=False) return cls(data_src=data, tag=Tag.GRAPH, cost="geodesic") raise ValueError(f"Expected `cost=geodesic`, found `{cost}`.") if tag == Tag.COST_MATRIX: if cost == "custom": # our custom cost functions modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" data = cls._extract_data(adata, attr=attr, key=key, densify=True) if np.any(data < 0): raise ValueError(f"Cost matrix in `{modifier}` contains negative values.") return cls(data_src=data, tag=Tag.COST_MATRIX, cost=None) cost_fn = get_cost(cost, backend="moscot", adata=adata, attr=attr, key=key, dist_key=dist_key) cost_matrix = cost_fn(**kwargs) return cls(data_src=cost_matrix, tag=Tag.COST_MATRIX, cost=None) # tag is either a point cloud or a kernel data = cls._extract_data(adata, attr=attr, key=key, densify=True) cost_fn = get_cost(cost, backend=backend, **kwargs) return cls(data_src=data, tag=tag, cost=cost_fn)
@property def shape(self) -> Tuple[int, ...]: """Shape of the cost matrix.""" if self.tag == Tag.POINT_CLOUD: x, y = self.data_src, (self.data_src if self.data_tgt is None else self.data_tgt) return x.shape[0], y.shape[0] return self.data_src.shape @property def is_cost_matrix(self) -> bool: """Whether :attr:`data_src` is a cost matrix.""" return self.tag == Tag.COST_MATRIX @property def is_kernel(self) -> bool: """Whether :attr:`data_src` is a kernel matrix.""" return self.tag == Tag.KERNEL @property def is_point_cloud(self) -> bool: """Whether :attr:`data_src` (and optionally) :attr:`data_tgt` is a point cloud.""" return self.tag == Tag.POINT_CLOUD @property def is_graph(self) -> bool: """Whether :attr:`data_src` is a graph.""" return self.tag == Tag.GRAPH
[docs] @dataclass(frozen=True, repr=True) class DistributionContainer: """Data container for OT problems involving more than two distributions. TODO Parameters ---------- xy Distribution living in a shared space. xx Distribution living in an incomparable space. a Marginals when used as source distribution. b Marginals when used as target distribution. conditions Conditions for the distributions. cost_xy Cost function when in the shared space. cost_xx Cost function in the incomparable space. """ xy: Optional[ArrayLike] xx: Optional[ArrayLike] a: ArrayLike b: ArrayLike conditions: Optional[ArrayLike] cost_xy: OttCostFn_t cost_xx: OttCostFn_t @property def contains_linear(self) -> bool: """Whether the distribution contains data corresponding to the linear term.""" return self.xy is not None @property def contains_quadratic(self) -> bool: """Whether the distribution contains data corresponding to the quadratic term.""" return self.xx is not None @property def contains_condition(self) -> bool: """Whether the distribution contains data corresponding to the condition.""" return self.conditions is not None @staticmethod def _extract_data( adata: AnnData, *, attr: Literal["X", "obs", "obsp", "obsm", "var", "varm", "layers", "uns"], key: Optional[str] = None, ) -> ArrayLike: modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" data = getattr(adata, attr) try: if key is not None: data = data[key] except KeyError: raise KeyError(f"Unable to fetch data from `{modifier}`.") from None except IndexError: raise IndexError(f"Unable to fetch data from `{modifier}`.") from None if attr == "obs": data = np.asarray(data)[:, None] if sp.issparse(data): logger.warning(f"Densifying data in `{modifier}`") data = data.A if data.ndim != 2: raise ValueError(f"Expected `{modifier}` to have `2` dimensions, found `{data.ndim}`.") return data @staticmethod def _verify_input( xy_attr: Optional[Literal["X", "obsp", "obsm", "layers", "uns"]], xy_key: Optional[str], xx_attr: Optional[Literal["X", "obsp", "obsm", "layers", "uns"]], xx_key: Optional[str], conditions_attr: Optional[Literal["obs", "var", "obsm", "varm", "layers", "uns"]], conditions_key: Optional[str], ) -> Tuple[bool, bool, bool]: if (xy_attr is None and xy_key is not None) or (xy_attr is not None and xy_key is None): raise ValueError(r"Either both `xy_attr` and `xy_key` must be `None` or none of them.") if (xx_attr is None and xx_key is not None) or (xx_attr is not None and xx_key is None): raise ValueError(r"Either both `xy_attr` and `xy_key` must be `None` or none of them.") if (conditions_attr is None and conditions_key is not None) or ( conditions_attr is not None and conditions_key is None ): raise ValueError(r"Either both `conditions_attr` and `conditions_key` must be `None` or none of them.") return xy_attr is not None, xx_attr is not None, conditions_attr is not None
[docs] @classmethod def from_adata( cls, adata: AnnData, a: ArrayLike, b: ArrayLike, xy_attr: Literal["X", "obsp", "obsm", "layers", "uns"] = None, xy_key: Optional[str] = None, xy_cost: CostFn_t = "sq_euclidean", xx_attr: Literal["X", "obsp", "obsm", "layers", "uns"] = None, xx_key: Optional[str] = None, xx_cost: CostFn_t = "sq_euclidean", conditions_attr: Optional[Literal["obs", "var", "obsm", "varm", "layers", "uns"]] = None, conditions_key: Optional[str] = None, backend: Literal["ott"] = "ott", **kwargs: Any, ) -> "DistributionContainer": """Create distribution container from :class:`~anndata.AnnData`. .. warning:: Sparse arrays will be always densified. Parameters ---------- adata Annotated data object. a Marginals when used as source distribution. b Marginals when used as target distribution. xy_attr Attribute of `adata` containing the data for the shared space. xy_key Key of `xy_attr` containing the data for the shared space. xy_cost Cost function when in the shared space. xx_attr Attribute of `adata` containing the data for the incomparable space. xx_key Key of `xx_attr` containing the data for the incomparable space. xx_cost Cost function in the incomparable space. conditions_attr Attribute of `adata` containing the conditions. conditions_key Key of `conditions_attr` containing the conditions. backend Backend to use. kwargs Keyword arguments to pass to the cost functions. Returns ------- The distribution container. """ contains_linear, contains_quadratic, contains_condition = cls._verify_input( xy_attr, xy_key, xx_attr, xx_key, conditions_attr, conditions_key ) if contains_linear: xy_data = cls._extract_data(adata, attr=xy_attr, key=xy_key) xy_cost_fn = get_cost(xy_cost, backend=backend, **kwargs) else: xy_data = None xy_cost_fn = None if contains_quadratic: xx_data = cls._extract_data(adata, attr=xx_attr, key=xx_key) xx_cost_fn = get_cost(xx_cost, backend=backend, **kwargs) else: xx_data = None xx_cost_fn = None conditions_data = ( cls._extract_data(adata, attr=conditions_attr, key=conditions_key) if contains_condition else None # type: ignore[arg-type] # noqa:E501 ) return cls(xy=xy_data, xx=xx_data, a=a, b=b, conditions=conditions_data, cost_xy=xy_cost_fn, cost_xx=xx_cost_fn)
[docs] class DistributionCollection(dict[K, DistributionContainer]): """Collection of distributions.""" def __repr__(self) -> str: return f"{self.__class__.__name__}{list(self.keys())}" def __str__(self) -> str: return repr(self)