import abc
import contextlib
import itertools
import operator
from typing import (
Any,
Dict,
Generic,
Hashable,
Iterable,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
import networkx as nx
import numpy as np
import pandas as pd
from anndata import AnnData
from moscot import _constants
from moscot._types import ArrayLike, Policy_t
__all__ = [
"SubsetPolicy",
"OrderedPolicy",
"StarPolicy",
"ExternalStarPolicy",
"SequentialPolicy",
"TriangularPolicy",
"ExplicitPolicy",
"DummyPolicy",
"FormatterMixin",
"create_policy",
]
K = TypeVar("K", bound=Hashable)
class FormatterMixin(abc.ABC):
@abc.abstractmethod
def _format(self, value: Any, *, is_source: bool) -> Any:
pass
[docs]
class SubsetPolicy(Generic[K], abc.ABC):
r"""Base policy class.
Parameters
----------
adata
Annotated data object or a categorical data.
key
Key in :attr:`~anndata.AnnData.obs` where the categorical data is stored.
verify_integrity
Whether to check that the data has :math:`\ge 2` categories.
Examples
--------
- See :doc:`../../notebooks/examples/problems/400_subset_policy` on how to use different policies.
"""
def __init__(
self,
adata: Union[AnnData, pd.Series, pd.Categorical],
key: Optional[str] = None,
verify_integrity: bool = True,
):
try:
self._data = pd.Series(adata.obs[key]) if isinstance(adata, AnnData) else adata
except KeyError:
raise KeyError(f"Unable to find data in `adata.obs[{key!r}]`.") from None
self._data = self._data.astype("category") # TODO(@MUCDK): catch conversion error
self._graph: Set[Tuple[K, K]] = set()
self._cat = tuple(self._data.cat.categories)
self._subset_key: Optional[str] = key
if verify_integrity and len(self._cat) < 2:
raise ValueError(
f"Policy must contain at least `2` different values, found `{len(self._cat)}`.\n"
"Is it possible that there is only one `batch` in `batch_key`?"
)
@abc.abstractmethod
def _create_graph(self, **kwargs: Any) -> Set[Tuple[K, K]]:
"""Create a policy graph."""
@abc.abstractmethod
def _plan(self, **kwargs: Any) -> Sequence[Tuple[K, K]]:
"""Compute a sequence of steps based on the policy graph."""
[docs]
def create_graph(self, **kwargs: Any) -> "SubsetPolicy[K]":
"""Create a policy graph.
Parameters
----------
kwargs
Keyword arguments.
Returns
-------
Return self.
"""
graph = self._create_graph(**kwargs)
if not len(graph):
raise ValueError("The policy graph is empty.")
self._graph = graph
return self
[docs]
def plan(
self,
filter: Optional[Sequence[Tuple[K, K]]] = None, # noqa: A002
explicit_steps: Optional[Sequence[Tuple[K, K]]] = None,
**kwargs: Any,
) -> Sequence[Tuple[K, K]]:
"""Compute a sequence of steps based on the policy graph.
Useful when calling :meth:`create_masks`.
Parameters
----------
filter
Steps to exclude. If :obj:`None`, keep all the steps.
explicit_steps
Precomputed sequence of steps to use.
kwargs
Additional keyword arguments.
Returns
-------
Sequence of steps.
"""
if explicit_steps is not None:
G = nx.DiGraph()
G.add_edges_from(explicit_steps)
if not set(G.nodes).issubset(self._cat):
raise ValueError(
f"Explicitly specified steps `{set(explicit_steps)}` must be a subset of `{self._cat}`."
)
src = explicit_steps[0][0]
tgt = explicit_steps[-1][1]
if not nx.has_path(G, src, tgt):
raise ValueError(f"Explicitly specified steps do not form a connected path from `{src}` to `{tgt}`.")
return explicit_steps
plan = self._plan(**kwargs)
# TODO(michalk8): ensure unique
if filter is not None:
plan = self._filter_plan(plan, filter=filter)
if not len(plan):
raise ValueError("Unable to create a plan, no steps were selected after filtering.")
return plan
def _filter_plan(
self, plan: Sequence[Tuple[K, K]], filter: Sequence[Tuple[K, K]] # noqa: A002
) -> Sequence[Tuple[K, K]]:
return [step for step in plan if step in filter]
[docs]
def create_mask(self, value: Union[K, Sequence[K]], *, allow_empty: bool = False) -> ArrayLike:
"""Create a mask used to subset the data.
Parameters
----------
value
Values in the data which determine the mask.
allow_empty
Whether to allow empty mask.
Returns
-------
Boolean mask of the same shape as the data.
"""
if isinstance(value, str) or not isinstance(value, Iterable):
mask = self._data == value
else:
mask = self._data.isin(value)
if not allow_empty and not np.sum(mask):
raise ValueError("Unable to construct an empty mask, use `allow_empty=True` to override.")
return np.asarray(mask)
[docs]
def create_masks(self, discard_empty: bool = True) -> Dict[Tuple[K, K], Tuple[ArrayLike, ArrayLike]]:
"""Create masks based on the policy graph.
Parameters
----------
discard_empty
Whether to remove empty masks.
Returns
-------
Masks for each edge in the policy graph.
"""
res = {}
for a, b in self._graph:
try:
mask_a = self.create_mask(a, allow_empty=not discard_empty)
mask_b = self.create_mask(b, allow_empty=not discard_empty)
res[a, b] = mask_a, mask_b
except ValueError as e:
if "Unable to construct an empty mask" not in str(e):
raise
if not res:
# can only happen when `discard_empty=True`
raise ValueError("All empty masks were discarded.")
return res
[docs]
def add_node(self, node: Tuple[K, K], only_existing: bool = False) -> "SubsetPolicy[K]":
"""Add a node to the policy graph.
Parameters
----------
node
Node to add.
only_existing
Whether to allow creating new nodes or only connect existing ones.
Returns
-------
Remove the ``node``, if present and return self.
"""
src, tgt = node
if src == tgt:
raise ValueError(f"Unable to add `{src, tgt}` node, self connections are disallowed.")
if only_existing and (src not in self._cat or tgt not in self._cat):
raise ValueError(
f"Unable to add `{src}` or `{tgt}` node(s) that are not already present "
f"in the policy graph, use `only_existing=False` to override."
)
self._graph.add(node)
return self
[docs]
def remove_node(self, node: Tuple[K, K]) -> "SubsetPolicy[K]":
"""Remove a node from the policy graph.
Parameters
----------
node
Node to remove.
Returns
-------
Remove the ``node``, if present and return self.
"""
with contextlib.suppress(KeyError):
self._graph.remove(node)
return self
@property
def categories(self) -> Sequence[K]:
"""Categories in the policy."""
return self._cat
@property
def key(self) -> Optional[str]:
"""Key in :attr:`~anndata.AnnData.obs` defining the policy."""
return self._subset_key
[docs]
class OrderedPolicy(SubsetPolicy[K], abc.ABC):
"""Base ordered policy.
Parameters
----------
adata
Annotated data object or an ordered categorical data.
kwargs
Additional keyword arguments.
"""
def _plan(
self, forward: bool = True, start: Optional[K] = None, end: Optional[K] = None, **_: Any
) -> Sequence[Tuple[K, K]]:
if self._graph is None:
raise RuntimeError("Construct the policy graph first.")
if start is None and end is None:
start, end = self._cat[0], self._cat[-1]
if start is None:
start = self._cat[0]
if end is None:
end = self._cat[-1]
# TODO: add Graph for undirected
G = nx.DiGraph()
G.add_edges_from(self._graph)
if start == end:
raise ValueError(f"Start node `{start}` is the same as the end node `{end}`.")
if start is None or end is None:
raise ValueError("Both start and end node are `None`.")
path = nx.shortest_path(G, start, end)
path = list(zip(path[:-1], path[1:]))
return path if forward else path[::-1]
[docs]
def reverse(self) -> "OrderedPolicy[K]":
"""Reverse the policy."""
cats = self._data.cat.categories
data = self._data.cat.reorder_categories(list(reversed(cats)))
return type(self)(data, key=self.key, verify_integrity=False)
class SimplePlanPolicy(SubsetPolicy[K], abc.ABC):
"""Policy whose plan is just the underlying policy graph."""
def _plan(self, **_: Any) -> Sequence[Tuple[K, K]]:
return list(self._graph)
[docs]
class StarPolicy(OrderedPolicy[K]):
r"""Policy with a star topology.
Parameters
----------
adata
Annotated data object or a categorical data.
key
Key in :attr:`~anndata.AnnData.obs` where the categorical data is stored.
verify_integrity
Whether to check that the data has :math:`\ge 2` categories.
"""
def _create_graph(self, reference: K, **kwargs: Any) -> Set[Tuple[K, K]]: # type: ignore[override]
if reference not in self._cat:
raise ValueError(f"Reference `{reference}` is not in valid nodes: `{self._cat}`.")
return {(c, reference) for c in self._cat if c != reference}
def _filter_plan(
self, plan: Sequence[Tuple[K, K]], filter: Sequence[Union[K, Tuple[K, K]]] # noqa: A002
) -> Sequence[Tuple[K, K]]:
filter = [src[0] if isinstance(src, tuple) else src for src in filter] # noqa: A001
return [(src, ref) for src, ref in plan if src in filter]
[docs]
def add_node(self, node: Union[K, Tuple[K, K]], only_existing: bool = False) -> "StarPolicy[K]":
if not isinstance(node, tuple):
node = (node, self.reference)
return super().add_node(node, only_existing=only_existing) # type: ignore[return-value]
[docs]
def remove_node(self, node: Union[K, Tuple[K, K]]) -> "StarPolicy[K]":
if not isinstance(node, tuple):
node = (node, self.reference)
return super().remove_node(node) # type: ignore[return-value]
@property
def reference(self) -> K:
"""Central node."""
for _, ref in self._graph:
return ref
raise ValueError("Graph is empty.")
[docs]
class ExternalStarPolicy(FormatterMixin, StarPolicy[K]):
"""Policy with star topology and external central node.
Parameters
----------
adata
Annotated data object.
tgt_name
Name of the central node.
kwargs
Additional keyword arguments.
"""
_SENTINEL = object()
def __init__(self, adata: Union[AnnData, pd.Series, pd.Categorical], tgt_name: K = "ref", **kwargs: Any):
super().__init__(adata, **kwargs)
self._tgt_name = tgt_name
def _format(self, value: K, *, is_source: bool) -> K:
if is_source:
return value
if value is self._SENTINEL:
return self._tgt_name
raise ValueError(f"Expected value to be `{self._SENTINEL}`, found `{value}`.")
def _create_graph(self, **_: Any) -> Set[Tuple[K, object]]: # type: ignore[override]
return {(c, self._SENTINEL) for c in self._cat if c != self._SENTINEL}
def _plan(self, **_: Any) -> Sequence[Tuple[K, K]]:
return [(src, self._format(tgt, is_source=False)) for (src, tgt) in self._graph]
[docs]
def add_node(self, node: Union[K, Tuple[K, K]], only_existing: bool = False) -> "ExternalStarPolicy[K]":
if isinstance(node, tuple):
_, tgt = node
# TODO(michalk8): tgt can be undefined
if tgt is self._tgt_name:
return self
return super().add_node(node, only_existing=only_existing) # type: ignore[return-value]
[docs]
def create_masks(self, discard_empty: bool = True) -> Dict[Tuple[K, K], Tuple[ArrayLike, ArrayLike]]:
del discard_empty
return super().create_masks(discard_empty=False)
[docs]
class SequentialPolicy(OrderedPolicy[K]):
"""Policy which connects immediate successors.
Parameters
----------
adata
Annotated data object.
upper
Whether to use subsequent nodes instead of the preceding ones.
kwargs
Additional keyword arguments.
"""
def _create_graph(self, *_: Any, **__: Any) -> Set[Tuple[K, K]]:
return set(zip(self._cat[:-1], self._cat[1:]))
[docs]
class TriangularPolicy(OrderedPolicy[K]):
"""Policy which connects all preceding/subsequent nodes.
Parameters
----------
adata
Annotated data object.
upper
Whether to use subsequent nodes instead of the preceding ones.
kwargs
Additional keyword arguments.
"""
def __init__(self, adata: Union[AnnData, pd.Series, pd.Categorical], upper: bool = True, **kwargs: Any):
super().__init__(adata, **kwargs)
self._comparator = operator.lt if upper else operator.gt
def _create_graph(self, **__: Any) -> Set[Tuple[K, K]]:
return {(a, b) for a, b in itertools.product(self._cat, self._cat) if self._comparator(a, b)}
[docs]
class ExplicitPolicy(SimplePlanPolicy[K]):
r"""Explicitly specified policy.
The policy graph is passed directly in :meth:`create_graph`.
Parameters
----------
adata
Annotated data object or a categorical data.
key
Key in :attr:`~anndata.AnnData.obs` where the categorical data is stored.
verify_integrity
Whether to check that the data has :math:`\ge 2` categories.
"""
def _create_graph(self, subset: Sequence[Tuple[K, K]], **_: Any) -> Set[Tuple[K, K]]: # type: ignore[override]
if subset is None:
raise ValueError("No steps specifying the explicit policy.")
return set(subset) # pass-through, all checks are done later
class DummyPolicy(FormatterMixin, SubsetPolicy[str]):
"""Policy TODO.
Parameters
----------
adata
Annotated data object or a categorical data.
src_name
TODO.
tgt_name
TODO.
kwargs
Additional keyword arguments.
"""
_SENTINEL = object()
def __init__(
self,
adata: Union[AnnData, pd.Series, pd.Categorical],
src_name: Literal["src"] = "src",
tgt_name: Literal["tgt"] = "tgt",
**kwargs: Any,
):
super().__init__(pd.Series([self._SENTINEL] * len(adata)), verify_integrity=False, **kwargs)
self._cat = (src_name, tgt_name)
self._src_name = src_name
self._tgt_name = tgt_name
def _create_graph(self, **__: Any) -> Set[Tuple[object, object]]: # type: ignore[override]
return {(self._SENTINEL, self._SENTINEL)}
def _plan(self, **_: Any) -> List[Tuple[str, str]]:
return [(self._src_name, self._tgt_name)]
def _format(self, _: Any, *, is_source: bool) -> str:
return self._src_name if is_source else self._tgt_name
def _filter_plan(
self, plan: Sequence[Tuple[K, K]], filter: Sequence[Tuple[K, K]] # noqa: A002
) -> Sequence[Tuple[K, K]]:
return plan
# TODO(michalk8): in the future, use Registry
def create_policy(
kind: Policy_t,
adata: Union[AnnData, pd.Series, pd.Categorical],
**kwargs: Any,
) -> SubsetPolicy[K]:
"""Create a policy.
Parameters
----------
kind
What policy to create.
adata
Annotated data object.
kwargs
Additional keyword arguments.
Returns
-------
The policy.
Notes
-----
- See :doc:`../../notebooks/examples/problems/400_subset_policy` on how to use different policies.
"""
if kind == _constants.SEQUENTIAL:
return SequentialPolicy(adata, **kwargs)
if kind == _constants.STAR:
return StarPolicy(adata, **kwargs)
if kind == _constants.EXTERNAL_STAR:
return ExternalStarPolicy(adata, **kwargs)
if kind == _constants.TRIU:
return TriangularPolicy(adata, **kwargs, upper=True)
if kind == _constants.TRIL:
return TriangularPolicy(adata, **kwargs, upper=False)
if kind == _constants.EXPLICIT:
return ExplicitPolicy(adata, **kwargs)
raise NotImplementedError(f"Policy `{kind}` is not yet implemented.")