Source code for moscot.base.problems.compound_problem

import abc
import types
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Generic,
    Hashable,
    Iterator,
    Literal,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
)

import scipy.sparse as sp

from anndata import AnnData

from moscot._logging import logger
from moscot._types import ArrayLike, Policy_t, ProblemStage_t
from moscot.base.output import BaseSolverOutput
from moscot.base.problems._utils import attributedispatch, require_prepare
from moscot.base.problems.manager import ProblemManager
from moscot.base.problems.problem import BaseProblem, OTProblem
from moscot.utils.subset_policy import (
    DummyPolicy,
    ExplicitPolicy,
    FormatterMixin,
    OrderedPolicy,
    StarPolicy,
    SubsetPolicy,
    create_policy,
)
from moscot.utils.tagged_array import Tag, TaggedArray

__all__ = ["BaseCompoundProblem", "CompoundProblem"]

K = TypeVar("K", bound=Hashable)
B = TypeVar("B", bound=OTProblem)
Callback_t = Callable[[Literal["xy", "x", "y"], AnnData, Optional[AnnData]], Optional[TaggedArray]]
ApplyOutput_t = Union[ArrayLike, Dict[K, ArrayLike]]


[docs] class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]): """Base class for all biological problems. This class translates a biological problem to multiple :term:`OT` problems. Parameters ---------- adata Annotated data object. kwargs Keyword arguments for :class:`~moscot.base.problems.BaseProblem`. """ def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(**kwargs) self._adata = adata self._problem_manager: Optional[ProblemManager[K, B]] = None @abc.abstractmethod def _create_problem(self, src: K, tgt: K, src_mask: ArrayLike, tgt_mask: ArrayLike, **kwargs: Any) -> B: """Create an :term:`OT` subproblem. Parameters ---------- src Source key identifying the subproblem. tgt Target key identifying the subproblem. src_mask Source mask used to subset :attr:`adata`. tgt_mask Target mask used to subset :attr:`adata`. kwargs Additional keyword arguments. Returns ------- The subproblem. """ @abc.abstractmethod def _create_policy( self, policy: Policy_t, **kwargs: Any, ) -> SubsetPolicy[K]: """Create a policy used to split :attr:`adata`. Only policies specified by :attr:`_valid_policies` will be passed to this function. Parameters ---------- policy Name of the policy. kwargs Keyword arguments for :class:`~moscot.utils.subset_policy.SubsetPolicy`. Returns ------- The policy. """ @property @abc.abstractmethod def _valid_policies(self) -> Tuple[Policy_t, ...]: """Valid policies for this problem.""" def _callback_handler( self, term: Literal["xy", "x", "y"], key_1: K, key_2: K, problem: B, *, callback: Optional[Union[Literal["local-pca", "spatial-norm", "graph-construction"], Callback_t]] = None, **kwargs: Any, ) -> Optional[TaggedArray]: if callback is None: return None if callback == "local-pca": callback = problem._local_pca_callback if callback == "spatial-norm": callback = problem._spatial_norm_callback if callback == "graph-construction": callback = problem._graph_construction_callback if not callable(callback): raise TypeError("Callback is not a function.") return callback(term, problem.adata_src, problem.adata_tgt, **kwargs) # TODO(michalk8): refactor me def _create_problems( self, xy: Mapping[str, Any] = types.MappingProxyType({}), x: Mapping[str, Any] = types.MappingProxyType({}), y: Mapping[str, Any] = types.MappingProxyType({}), xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Any, ) -> Dict[Tuple[K, K], B]: from moscot.base.problems.birth_death import BirthDeathProblem if TYPE_CHECKING: assert isinstance(self._policy, SubsetPolicy) problems: Dict[Tuple[K, K], B] = {} for (src, tgt), (src_mask, tgt_mask) in self._policy.create_masks().items(): if isinstance(self._policy, FormatterMixin): src_name = self._policy._format(src, is_source=True) tgt_name = self._policy._format(tgt, is_source=False) else: src_name = src tgt_name = tgt problem = self._create_problem(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) xy_data = self._callback_handler( term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs ) x_data = self._callback_handler( term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs ) y_data = self._callback_handler( term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs ) if xy_data: xy = dict(xy) xy["tagged_array"] = xy_data if x_data: x = dict(x) x["tagged_array"] = x_data if y_data: y = dict(y) y["tagged_array"] = y_data if isinstance(problem, BirthDeathProblem): kwargs["proliferation_key"] = self.proliferation_key # type: ignore[attr-defined] kwargs["apoptosis_key"] = self.apoptosis_key # type: ignore[attr-defined] problems[src_name, tgt_name] = problem.prepare(xy=xy, x=x, y=y, **kwargs) return problems
[docs] def prepare( self, policy: Policy_t, key: Optional[str], subset: Optional[Sequence[Tuple[K, K]]] = None, reference: Optional[Any] = None, xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Any, ) -> "BaseCompoundProblem[K, B]": """Prepare the individual :term:`OT` subproblems. .. seealso:: - See :doc:`../notebooks/examples/problems/400_subset_policy` on how to use different policies. Parameters ---------- policy Rule which defines how to construct the subproblems. key Key in :attr:`~anndata.AnnData.obs` for the :class:`~moscot.utils.subset_policy.SubsetPolicy`. subset Subset of :attr:`obs['{key}'] <anndata.AnnData.obs>` for the :class:`~moscot.utils.subset_policy.ExplicitPolicy`. Only used when ``policy = 'explicit'``. reference Reference for the :class:`~moscot.utils.subset_policy.SubsetPolicy`. Only used when ``policy = 'star'``. xy_callback Callback function used to prepare the data in the :term:`linear term`. x_callback Callback function used to prepare the data in the source :term:`quadratic term`. y_callback Callback function used to prepare the data in the target :term:`quadratic term`. xy_callback_kwargs Keyword arguments for the ``xy_callback``. x_callback_kwargs Keyword arguments for the ``x_callback``. y_callback_kwargs Keyword arguments for the ``y_callback``. kwargs Keyword arguments for the subproblems' :meth:`~moscot.base.problems.OTProblem.prepare` method. Returns ------- Returns self and updates the following fields: - :attr:`problems` - the prepared subproblems. - :attr:`solutions` - set to an empty :class:`dict`. - :attr:`stage` - set to ``'prepared'``. - :attr:`problem_kind` - kind of the :term:`OT` problem. """ self._ensure_valid_policy(policy) policy = self._create_policy(policy=policy, key=key) if TYPE_CHECKING: assert isinstance(policy, SubsetPolicy) if isinstance(policy, ExplicitPolicy): policy = policy.create_graph(subset=subset) elif isinstance(policy, StarPolicy): policy = policy.create_graph(reference=reference) else: policy = policy.create_graph() # TODO(michalk8): manager must be currently instantiated first, since `_create_problems` accesses the policy # when refactoring the callback, consider changing this self._problem_manager = ProblemManager(self, policy=policy) problems = self._create_problems( xy_callback=xy_callback, x_callback=x_callback, y_callback=y_callback, xy_callback_kwargs=xy_callback_kwargs, x_callback_kwargs=x_callback_kwargs, y_callback_kwargs=y_callback_kwargs, **kwargs, ) self._problem_manager.add_problems(problems) # we assume that all subproblems are of the same kind for p in self.problems.values(): self._problem_kind = p._problem_kind self._stage = "prepared" break return self
[docs] def solve( self, stage: Union[ProblemStage_t, Tuple[ProblemStage_t, ...]] = ("prepared", "solved"), **kwargs: Any, ) -> "BaseCompoundProblem[K, B]": """Solve the individual :term:`OT` subproblems. .. seealso: - See :doc:`../notebooks/examples/solvers/100_linear_problems_basic` for an introduction on how to solve linear problems. - See :doc:`../notebooks/examples/solvers/300_quad_problems_basic` for an introduction on how to solve quadratic problems. Parameters ---------- stage Stage by which to filter the :attr:`problems` to be solved. kwargs Keyword arguments for the subproblems' :meth:`~moscot.base.problems.OTProblem.solve` method. Returns ------- Returns self and updates the following fields: - :attr:`solutions` - the :term:`OT` solutions for each subproblem. - :attr:`stage` - set to ``'solved'``. """ if TYPE_CHECKING: assert isinstance(self._problem_manager, ProblemManager) problems = self._problem_manager.get_problems(stage=stage) logger.info(f"Solving `{len(problems)}` problems") for problem in problems.values(): logger.info(f"Solving problem {problem}.") _ = problem.solve(**kwargs) self._stage = "solved" return self
@attributedispatch(attr="_policy") def _apply(self, *_args: Any, **_kwargs: Any) -> ApplyOutput_t[K]: raise NotImplementedError(type(self._policy)) @_apply.register(DummyPolicy) @_apply.register(StarPolicy) def _( self, source: Optional[K] = None, target: Optional[K] = None, data: Optional[Union[str, ArrayLike]] = None, forward: bool = True, scale_by_marginals: bool = False, return_all: bool = False, **kwargs: Any, ) -> ApplyOutput_t[K]: del target if TYPE_CHECKING: assert isinstance(self._policy, StarPolicy) res = {} source = source if isinstance(source, list) else [source] for src, tgt in self._policy.plan( explicit_steps=kwargs.pop("explicit_steps", None), filter=source, # type: ignore [arg-type] ): problem = self.problems[src, tgt] fun = problem.push if forward else problem.pull res[src] = fun(data=data, scale_by_marginals=scale_by_marginals, **kwargs) return res if return_all else res[src] @_apply.register(ExplicitPolicy) @_apply.register(OrderedPolicy) def _( self, source: Optional[K] = None, target: Optional[K] = None, data: Optional[Union[str, ArrayLike]] = None, forward: bool = True, scale_by_marginals: bool = False, return_all: bool = False, **kwargs: Any, ) -> ApplyOutput_t[K]: explicit_steps = kwargs.pop( "explicit_steps", [[source, target]] if isinstance(self._policy, ExplicitPolicy) else None ) if TYPE_CHECKING: assert isinstance(self._policy, OrderedPolicy) (src, tgt), *rest = self._policy.plan( forward=forward, start=source, end=target, explicit_steps=explicit_steps, ) problem = self.problems[src, tgt] adata = problem.adata_src if forward else problem.adata_tgt current_mass = problem._get_mass(adata, data=data, **kwargs) res = {src if forward else tgt: current_mass} for _src, _tgt in [(src, tgt)] + rest: problem = self.problems[_src, _tgt] fun = problem.push if forward else problem.pull res[_tgt if forward else _src] = current_mass = fun( current_mass, scale_by_marginals=scale_by_marginals, **kwargs ) return res if return_all else current_mass # TODO(michalk8): better description of `source/target` (also in other places).
[docs] def push(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]: """Push mass from source to target. TODO. """ _ = kwargs.pop("return_data", None) _ = kwargs.pop("key_added", None) # this should be handled by overriding method return self._apply(*args, forward=True, **kwargs)
[docs] def pull( self, *args: Any, **kwargs: Any, ) -> ApplyOutput_t[K]: """Pull mass from target to source. TODO """ _ = kwargs.pop("return_data", None) _ = kwargs.pop("key_added", None) # this should be handled by overriding method return self._apply(*args, forward=False, **kwargs)
@property def problems(self) -> Dict[Tuple[K, K], B]: """:term:`OT` subproblems that define the biological problem.""" if self._problem_manager is None: return {} return self._problem_manager.problems
[docs] @require_prepare def add_problem( self, key: Tuple[K, K], problem: B, *, overwrite: bool = False, **kwargs: Any, ) -> "BaseCompoundProblem[K, B]": """Add a subproblem. .. seealso:: - See :doc:`../notebooks/examples/problems/300_adding_and_removing_problems` on how to add subproblems. Parameters ---------- key Key in :attr:`problems` where to add the subproblem. problem Subproblem to add. overwrite Whether ot overwrite an existing subproblem in :attr:`problems`. kwargs Additional keyword arguments. Returns ------- Self and updates the following fields: - :attr:`problems` """ if TYPE_CHECKING: assert isinstance(self._problem_manager, ProblemManager) self._problem_manager.add_problem(key, problem, overwrite=overwrite, **kwargs) return self
[docs] @require_prepare def remove_problem(self, key: Tuple[K, K]) -> "BaseCompoundProblem[K, B]": """Remove a subproblem. .. seealso:: - See :doc:`../notebooks/examples/problems/300_adding_and_removing_problems` on how to remove subproblems. Parameters ---------- key Key of the subproblem to remove. Returns ------- Self and updates the following fields: - :attr:`problems` """ if TYPE_CHECKING: assert isinstance(self._problem_manager, ProblemManager) self._problem_manager.remove_problem(key) return self
@property def solutions(self) -> Dict[Tuple[K, K], BaseSolverOutput]: """Solutions to the :attr:`problems`.""" if self._problem_manager is None: return {} return self._problem_manager.solutions @property def adata(self) -> AnnData: """Annotated data object.""" return self._adata @property def _policy(self) -> Optional[SubsetPolicy[K]]: if self._problem_manager is None: return None return self._problem_manager.policy def _ensure_valid_policy(self, policy: Policy_t) -> None: if self._valid_policies and policy not in self._valid_policies: raise ValueError(f"Invalid policy `{policy!r}`. Valid policies are: `{self._valid_policies}`.") def __getitem__(self, item: Tuple[K, K]) -> B: return self.problems[item] def __contains__(self, key: Tuple[K, K]) -> bool: return key in self.problems def __len__(self) -> int: return len(self.problems) def __iter__(self) -> Iterator[Tuple[K, K]]: return iter(self.problems) def __repr__(self) -> str: return f"{self.__class__.__name__}{list(self.problems.keys())}" def __str__(self) -> str: return repr(self)
[docs] class CompoundProblem(BaseCompoundProblem[K, B], abc.ABC): """Base class for all biological problems. This class translates a biological problem to multiple :term:`OT` problems. Parameters ---------- adata Annotated data object. kwargs Keyword arguments for :class:`~moscot.base.problems.BaseCompoundProblem`. """ @property @abc.abstractmethod def _base_problem_type(self) -> Type[B]: pass def _create_problem(self, src: K, tgt: K, src_mask: ArrayLike, tgt_mask: ArrayLike, **kwargs: Any) -> B: return self._base_problem_type( self.adata, src_obs_mask=src_mask, tgt_obs_mask=tgt_mask, src_key=src, tgt_key=tgt, **kwargs ) def _create_policy( self, policy: Policy_t, key: Optional[str] = None, **_: Any, ) -> SubsetPolicy[K]: if isinstance(policy, str): return create_policy(policy, adata=self.adata, key=key) return ExplicitPolicy(self.adata, key=key) def _callback_handler( self, term: Literal["xy", "x", "y"], key_1: K, key_2: K, problem: B, *, callback: Optional[Union[Literal["local-pca", "cost-matrix"], Callback_t]] = None, **kwargs: Any, ) -> Optional[TaggedArray]: if callback == "cost-matrix": return self._cost_matrix_callback(term=term, key_1=key_1, key_2=key_2, **kwargs) return super()._callback_handler( term=term, key_1=key_1, key_2=key_2, problem=problem, callback=callback, **kwargs ) def _cost_matrix_callback( self, term: Literal["xy", "x", "y"], *, key: str, key_1: K, key_2: Optional[K] = None, **_: Any ) -> Optional[TaggedArray]: if TYPE_CHECKING: assert isinstance(self._policy, SubsetPolicy) try: data = self.adata.obsp[key] except KeyError: raise KeyError(f"Unable to fetch data from `adata.obsp[{key!r}]`.") from None mask = self._policy.create_mask(key_1, allow_empty=False) if term == "xy": if key_2 is None: raise ValueError("If `term` is `xy`, `key_2` cannot be `None`.") mask_2 = self._policy.create_mask(key_2, allow_empty=False) linear_cost_matrix = data[mask, :][:, mask_2] if sp.issparse(linear_cost_matrix): logger.warning("Linear cost matrix being densified.") linear_cost_matrix = linear_cost_matrix.A return TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX) if term in ("x", "y"): quad_cost_matrix = data[mask, :][:, mask] if sp.issparse(quad_cost_matrix): logger.warning("Quadratic cost matrix being densified.") quad_cost_matrix = quad_cost_matrix.A return TaggedArray(quad_cost_matrix, tag=Tag.COST_MATRIX) raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.")