from __future__ import annotations
import abc
import copy
import functools
from abc import abstractmethod
from typing import Any, Callable, Iterable, Literal, Optional, Union
import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import LinearOperator
from moscot._logging import logger
from moscot._types import ArrayLike, Device_t, DTypeLike
__all__ = ["BaseDiscreteSolverOutput", "MatrixSolverOutput", "BaseNeuralOutput"]
class BaseSolverOutput(abc.ABC):
"""Base class for all solver outputs."""
@abc.abstractmethod
def push(self, x: ArrayLike, **kwargs) -> ArrayLike:
"""Push the solution based on a condition."""
@abc.abstractmethod
def _apply_forward(self, x: ArrayLike) -> ArrayLike:
"""Apply the transport matrix in the forward direction."""
@property
@abc.abstractmethod
def shape(self) -> tuple[int, int]:
"""Shape of the problem."""
@property
@abc.abstractmethod
def converged(self) -> bool:
"""Whether the solver converged."""
@abc.abstractmethod
def to(self: BaseSolverOutput, device: Optional[Device_t] = None) -> BaseSolverOutput:
"""Transfer self to another compute device.
Parameters
----------
device
Device where to transfer the solver output. If :obj:`None`, use the default device.
Returns
-------
Self transferred to the ``device``.
"""
def _format_params(self, fmt: Callable[[Any], str]) -> str:
params = {"shape": self.shape}
return ", ".join(f"{name}={fmt(val)}" for name, val in params.items())
def __repr__(self) -> str:
return f"{self.__class__.__name__}[{self._format_params(repr)}]"
def __str__(self) -> str:
return f"{self.__class__.__name__}[{self._format_params(str)}]"
[docs]
class BaseDiscreteSolverOutput(BaseSolverOutput, abc.ABC):
"""Base class for all discrete solver outputs."""
@abc.abstractmethod
def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
"""Apply :attr:`transport_matrix` to an array of shape ``[n, d]`` or ``[m, d]``."""
@property
@abc.abstractmethod
def transport_matrix(self) -> ArrayLike:
"""Transport matrix of shape ``[n, m]``."""
@property
@abc.abstractmethod
def cost(self) -> float:
"""Regularized :term:`OT` cost."""
@property
@abc.abstractmethod
def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]:
""":term:`Dual potentials` :math:`f` and :math:`g`.
Only valid for the :term:`Sinkhorn` algorithm.
"""
@property
@abc.abstractmethod
def shape(self) -> tuple[int, int]:
"""Shape of the :attr:`transport_matrix`."""
@property
@abc.abstractmethod
def is_linear(self) -> bool:
"""Whether the output is a solution to a :term:`linear problem`."""
@property
def rank(self) -> int:
"""Rank of the :attr:`transport_matrix`."""
return -1
@property
def is_low_rank(self) -> bool:
"""Whether the :attr:`transport_matrix` is :term:`low-rank`."""
return self.rank > -1
@abc.abstractmethod
def _ones(self, n: int) -> ArrayLike:
"""Generate vector of 1s of shape ``[n,]``."""
[docs]
def push(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike:
"""Push mass through the :attr:`transport_matrix`.
It is equivalent to :math:`T^T x` but without instantiating the transport matrix :math:`T`, if possible.
Parameters
----------
x
Array of shape ``[n,]`` or ``[n, d]`` to push.
scale_by_marginals
Whether to scale by the source marginals :attr:`a`.
Returns
-------
Array of shape ``[m,]`` or ``[m, d]``, depending on the shape of ``x``.
"""
if x.ndim not in (1, 2):
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
if x.shape[0] != self.shape[0]:
raise ValueError(f"Expected array to have shape `({self.shape[0]}, ...)`, found `{x.shape}`.")
if scale_by_marginals:
x = self._scale_by_marginals(x, forward=True)
return self._apply(x, forward=True)
[docs]
def pull(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike:
"""Pull mass through the :attr:`transport_matrix`.
It is equivalent to :math:`T x` but without instantiating the transport matrix :math:`T`, if possible.
Parameters
----------
x
Array of shape ``[m,]`` or ``[m, d]`` to pull.
scale_by_marginals
Whether to scale by the target marginals :attr:`b`.
Returns
-------
Array of shape ``[n,]`` or ``[n, d]``, depending on the shape of ``x``.
"""
if x.ndim not in (1, 2):
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
if x.shape[0] != self.shape[1]:
raise ValueError(f"Expected array to have shape `({self.shape[1]}, ...)`, found `{x.shape}`.")
if scale_by_marginals:
x = self._scale_by_marginals(x, forward=False)
return self._apply(x, forward=False)
[docs]
def as_linear_operator(self, scale_by_marginals: bool = False) -> LinearOperator:
"""Transform :attr:`transport_matrix` into a linear operator.
Parameters
----------
scale_by_marginals
Whether to scale by :term:`marginals`.
Returns
-------
The :attr:`transport_matrix` as a linear operator.
"""
push = functools.partial(self.push, scale_by_marginals=scale_by_marginals)
pull = functools.partial(self.pull, scale_by_marginals=scale_by_marginals)
# push: a @ X (rmatvec)
# pull: X @ a (matvec)
return LinearOperator(shape=self.shape, dtype=self.dtype, matvec=pull, rmatvec=push)
[docs]
def chain(self, outputs: Iterable[BaseDiscreteSolverOutput], scale_by_marginals: bool = False) -> LinearOperator:
"""Chain subsequent applications of :attr:`transport_matrix`.
Parameters
----------
outputs
Sequence of transport matrices to chain.
scale_by_marginals
Whether to scale by :term:`marginals`.
Returns
-------
The chained transport matrices as a linear operator.
"""
op = self.as_linear_operator(scale_by_marginals)
for out in outputs:
op *= out.as_linear_operator(scale_by_marginals)
return op
[docs]
def sparsify(
self,
mode: Literal["threshold", "percentile", "min_row"],
value: Optional[float] = None,
batch_size: int = 1024,
n_samples: Optional[int] = None,
seed: Optional[int] = None,
) -> MatrixSolverOutput:
"""Sparsify the :attr:`transport_matrix`.
This function sets all entries of the transport matrix below a certain threshold to :math:`0` and
returns a :class:`~moscot.base.output.MatrixSolverOutput` with sparsified transport matrix stored
as a :class:`~scipy.sparse.csr_matrix`.
.. warning::
This function only serves for interfacing software which has to instantiate the transport matrix,
:mod:`moscot` never uses the sparsified transport matrix.
Parameters
----------
mode
How to determine the value below which entries are set to :math:`0`. Valid options are:
- `'threshold'` - ``value`` is the threshold below which entries are set to :math:`0`.
- `'percentile'` - ``value`` is the percentile in :math:`[0, 100]` of the :attr:`transport_matrix`.
below which entries are set to :math:`0`.
- `'min_row'` - ``value`` is not used, it is chosen such that each row has at least 1 non-zero entry.
value
Value to use for sparsification.
batch_size
How many rows to materialize when sparsifying the :attr:`transport_matrix`.
n_samples
If ``mode = 'percentile'``, determine the number of samples based on which the percentile is computed
stochastically. Note this means that a matrix of shape `[n_samples, min(transport_matrix.shape)]`
has to be instantiated. If `None`, ``n_samples`` is set to ``batch_size``.
seed
Random seed needed for sampling if ``mode = 'percentile'``.
Returns
-------
Output with sparsified transport matrix.
"""
n, m = self.shape
if mode == "threshold":
if value is None:
raise ValueError("If `mode = 'threshold'`, `threshold` cannot be `None`.")
thr = value
elif mode == "percentile":
if value is None:
raise ValueError("If `mode = 'percentile'`, `threshold` cannot be `None`.")
rng = np.random.RandomState(seed=seed)
n_samples = n_samples if n_samples is not None else batch_size
k = min(n_samples, n)
x = np.zeros((m, k))
rows = rng.choice(m, size=k)
x[rows, np.arange(k)] = 1.0
res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors
thr = np.percentile(res, value)
elif mode == "min_row":
thr = np.inf
for batch in range(0, m, batch_size):
x = np.eye(m, min(batch_size, m - batch), -(min(batch, m)))
res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors
thr = min(thr, float(res.max(axis=1).min()))
else:
raise NotImplementedError(f"Mode `{mode}` is not yet implemented.")
k, func, fn_stack = (n, self.push, sp.vstack) if n < m else (m, self.pull, sp.hstack)
tmaps_sparse: list[sp.csr_matrix] = []
for batch in range(0, k, batch_size):
x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float)
res = np.array(func(x, scale_by_marginals=False))
res[res < thr] = 0.0
tmaps_sparse.append(sp.csr_matrix(res.T if n < m else res))
return MatrixSolverOutput(
transport_matrix=fn_stack(tmaps_sparse), cost=self.cost, converged=self.converged, is_linear=self.is_linear
)
@property
def a(self) -> ArrayLike:
""":term:`Marginals` of the source distribution.
If the output of an :term:`unbalanced OT problem`, these are the posterior marginals.
"""
return self.pull(self._ones(self.shape[1]))
@property
def b(self) -> ArrayLike:
""":term:`Marginals` of the target distribution.
If the output of an :term:`unbalanced OT problem`, these are the posterior marginals.
"""
return self.push(self._ones(self.shape[0]))
@property
def dtype(self) -> DTypeLike:
"""Underlying data type."""
return self.a.dtype
def _format_params(self, fmt: Callable[[Any], str]) -> str:
params = {"shape": self.shape, "cost": round(self.cost, 4), "converged": self.converged}
return ", ".join(f"{name}={fmt(val)}" for name, val in params.items())
def _scale_by_marginals(self, x: ArrayLike, *, forward: bool, eps: float = 1e-12) -> ArrayLike:
# alt. we could use the public push/pull
marginals = self.a if forward else self.b
if x.ndim == 2:
marginals = marginals[:, None]
return x / (marginals + eps)
def __bool__(self) -> bool:
return self.converged
[docs]
class MatrixSolverOutput(BaseDiscreteSolverOutput):
""":term:`OT` solution with a materialized transport matrix.
Parameters
----------
transport_matrix
Transport matrix of shape ``[n, m]``.
cost
Cost of an :term:`OT` problem.
converged
Whether the solution converged.
is_linear
Whether this is a solution to a :term:`linear problem`.
"""
# TODO(michalk8): don't provide defaults?
def __init__(
self,
transport_matrix: Union[ArrayLike, sp.spmatrix],
*,
cost: float = np.nan,
converged: bool = True,
is_linear: bool = True,
):
super().__init__()
self._transport_matrix = transport_matrix
self._cost = cost
self._converged = converged
self._is_linear = is_linear
def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
if forward:
return self.transport_matrix.T @ x
return self.transport_matrix @ x
def _apply_forward(self, x: ArrayLike) -> ArrayLike:
return self._apply(x, forward=True)
@property
def transport_matrix(self) -> ArrayLike: # noqa: D102
return self._transport_matrix
@property
def shape(self) -> tuple[int, ...]: # noqa: D102
return self.transport_matrix.shape
[docs]
def to( # noqa: D102
self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None
) -> BaseDiscreteSolverOutput:
if device is not None:
logger.warning(f"`{self!r}` does not support the `device` argument, ignoring.")
if dtype is None:
return self
obj = copy.copy(self)
obj._transport_matrix = obj.transport_matrix.astype(dtype)
return obj
@property
def cost(self) -> float: # noqa: D102
return self._cost
@property
def converged(self) -> bool: # noqa: D102
return self._converged
@property
def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]: # noqa: D102
return None
@property
def is_linear(self) -> bool: # noqa: D102
return self._is_linear
def _ones(self, n: int) -> ArrayLike:
if isinstance(self.transport_matrix, np.ndarray):
return np.ones((n,), dtype=self.transport_matrix.dtype)
import jax.numpy as jnp
return jnp.ones((n,), dtype=self.transport_matrix.dtype)
class BaseNeuralOutput(BaseSolverOutput, abc.ABC):
"""Base class for output of."""
@abstractmethod
def project_to_transport_matrix(
self,
source: Optional[ArrayLike] = None,
target: Optional[ArrayLike] = None,
condition: Optional[ArrayLike] = None,
save_transport_matrix: bool = False,
batch_size: int = 1024,
k: int = 30,
length_scale: Optional[float] = None,
seed: int = 42,
) -> sp.csr_matrix:
"""Project transport matrix."""