from __future__ import annotations
import abc
import copy
import functools
from typing import Any, Callable, Iterable, Literal, Optional, Union
import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import LinearOperator
from moscot._logging import logger
from moscot._types import ArrayLike, Device_t, DTypeLike
__all__ = ["BaseDiscreteSolverOutput", "MatrixSolverOutput"]
def _mass_select_block(rows: np.ndarray, *, value: float, max_k: Optional[int]) -> sp.csr_matrix:
"""Keep, per row, the smallest set of entries capturing ``value`` of the row mass.
Parameters
----------
rows
Dense block of shape ``[b, m]``; each row is a row of the transport matrix.
value
Target fraction of each row's mass to retain, in ``(0, 1]``.
max_k
Optional cap on the number of entries kept per row.
Returns
-------
Sparsified block as a :class:`~scipy.sparse.csr_matrix` of shape ``[b, m]``.
"""
_, m = rows.shape
order = np.argsort(rows, axis=1)[:, ::-1] # descending
sorted_vals = np.take_along_axis(rows, order, axis=1)
totals = sorted_vals.sum(axis=1, keepdims=True)
cum = np.cumsum(sorted_vals, axis=1)
# `value >= 1` keeps everything (avoids float cumsum vs. sum mismatch -> exact reconstruction).
target = np.full_like(totals, np.inf) if value >= 1.0 else value * totals
# smallest prefix whose cumulative mass reaches the target (crossing entry included).
k_per_row = (cum < target).sum(axis=1) + 1
k_per_row = np.where(totals.ravel() <= 0.0, 0, k_per_row) # all-zero rows keep nothing
if max_k is not None:
k_per_row = np.minimum(k_per_row, max_k)
k_per_row = np.minimum(k_per_row, m)
keep_sorted = np.arange(m)[None, :] < k_per_row[:, None]
sel = np.zeros(rows.shape, dtype=bool)
np.put_along_axis(sel, order, keep_sorted, axis=1)
return sp.csr_matrix(np.where(sel, rows, 0.0))
def _sparsify_block(
rows: np.ndarray, *, mode: str, thr: Optional[float], value: Optional[float], max_k: Optional[int]
) -> sp.csr_matrix:
"""Apply the sparsification criterion to a block of transport-matrix rows."""
if mode == "mass":
assert value is not None # validated in `sparsify`
return _mass_select_block(rows, value=value, max_k=max_k)
rows = np.array(rows) # writable copy
rows[rows < thr] = 0.0
return sp.csr_matrix(rows)
class BaseSolverOutput(abc.ABC):
"""Base class for all solver outputs."""
@abc.abstractmethod
def push(self, x: ArrayLike, **kwargs) -> ArrayLike:
"""Push the solution based on a condition."""
@abc.abstractmethod
def _apply_forward(self, x: ArrayLike) -> ArrayLike:
"""Apply the transport matrix in the forward direction."""
@property
@abc.abstractmethod
def shape(self) -> tuple[int, int]:
"""Shape of the problem."""
@property
@abc.abstractmethod
def converged(self) -> bool:
"""Whether the solver converged."""
@abc.abstractmethod
def to(self: BaseSolverOutput, device: Optional[Device_t] = None) -> BaseSolverOutput:
"""Transfer self to another compute device.
Parameters
----------
device
Device where to transfer the solver output. If :obj:`None`, use the default device.
Returns
-------
Self transferred to the ``device``.
"""
def _format_params(self, fmt: Callable[[Any], str]) -> str:
params = {"shape": self.shape}
return ", ".join(f"{name}={fmt(val)}" for name, val in params.items())
def __repr__(self) -> str:
return f"{self.__class__.__name__}[{self._format_params(repr)}]"
def __str__(self) -> str:
return f"{self.__class__.__name__}[{self._format_params(str)}]"
[docs]
class BaseDiscreteSolverOutput(BaseSolverOutput, abc.ABC):
"""Base class for all discrete solver outputs."""
@abc.abstractmethod
def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
"""Apply :attr:`transport_matrix` to an array of shape ``[n, d]`` or ``[m, d]``."""
@property
@abc.abstractmethod
def transport_matrix(self) -> ArrayLike:
"""Transport matrix of shape ``[n, m]``."""
@property
@abc.abstractmethod
def cost(self) -> float:
"""Regularized :term:`OT` cost."""
@property
@abc.abstractmethod
def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]:
""":term:`Dual potentials` :math:`f` and :math:`g`.
Only valid for the :term:`Sinkhorn` algorithm.
"""
@property
@abc.abstractmethod
def shape(self) -> tuple[int, int]:
"""Shape of the :attr:`transport_matrix`."""
@property
@abc.abstractmethod
def is_linear(self) -> bool:
"""Whether the output is a solution to a :term:`linear problem`."""
@property
def rank(self) -> int:
"""Rank of the :attr:`transport_matrix`."""
return -1
@property
def is_low_rank(self) -> bool:
"""Whether the :attr:`transport_matrix` is :term:`low-rank`."""
return self.rank > -1
@abc.abstractmethod
def _ones(self, n: int) -> ArrayLike:
"""Generate vector of 1s of shape ``[n,]``."""
[docs]
def push(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike:
"""Push mass through the :attr:`transport_matrix`.
It is equivalent to :math:`T^T x` but without instantiating the transport matrix :math:`T`, if possible.
Parameters
----------
x
Array of shape ``[n,]`` or ``[n, d]`` to push.
scale_by_marginals
Whether to scale by the source marginals :attr:`a`.
Returns
-------
Array of shape ``[m,]`` or ``[m, d]``, depending on the shape of ``x``.
"""
if x.ndim not in (1, 2):
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
if x.shape[0] != self.shape[0]:
raise ValueError(f"Expected array to have shape `({self.shape[0]}, ...)`, found `{x.shape}`.")
if scale_by_marginals:
x = self._scale_by_marginals(x, forward=True)
return self._apply(x, forward=True)
[docs]
def pull(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike:
"""Pull mass through the :attr:`transport_matrix`.
It is equivalent to :math:`T x` but without instantiating the transport matrix :math:`T`, if possible.
Parameters
----------
x
Array of shape ``[m,]`` or ``[m, d]`` to pull.
scale_by_marginals
Whether to scale by the target marginals :attr:`b`.
Returns
-------
Array of shape ``[n,]`` or ``[n, d]``, depending on the shape of ``x``.
"""
if x.ndim not in (1, 2):
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
if x.shape[0] != self.shape[1]:
raise ValueError(f"Expected array to have shape `({self.shape[1]}, ...)`, found `{x.shape}`.")
if scale_by_marginals:
x = self._scale_by_marginals(x, forward=False)
return self._apply(x, forward=False)
[docs]
def as_linear_operator(self, scale_by_marginals: bool = False) -> LinearOperator:
"""Transform :attr:`transport_matrix` into a linear operator.
Parameters
----------
scale_by_marginals
Whether to scale by :term:`marginals`.
Returns
-------
The :attr:`transport_matrix` as a linear operator.
"""
push = functools.partial(self.push, scale_by_marginals=scale_by_marginals)
pull = functools.partial(self.pull, scale_by_marginals=scale_by_marginals)
# push: a @ X (rmatvec)
# pull: X @ a (matvec)
return LinearOperator(shape=self.shape, dtype=self.dtype, matvec=pull, rmatvec=push)
[docs]
def chain(self, outputs: Iterable[BaseDiscreteSolverOutput], scale_by_marginals: bool = False) -> LinearOperator:
"""Chain subsequent applications of :attr:`transport_matrix`.
Parameters
----------
outputs
Sequence of transport matrices to chain.
scale_by_marginals
Whether to scale by :term:`marginals`.
Returns
-------
The chained transport matrices as a linear operator.
"""
op = self.as_linear_operator(scale_by_marginals)
for out in outputs:
op *= out.as_linear_operator(scale_by_marginals)
return op
[docs]
def sparsify(
self,
mode: Literal["threshold", "percentile", "min_row", "mass"],
value: Optional[float] = None,
batch_size: int = 1024,
n_samples: Optional[int] = None,
seed: Optional[int] = None,
max_k: Optional[int] = None,
) -> MatrixSolverOutput:
"""Sparsify the :attr:`transport_matrix`.
This function sets entries of the transport matrix to :math:`0` according to ``mode`` and
returns a :class:`~moscot.base.output.MatrixSolverOutput` with sparsified transport matrix stored
as a :class:`~scipy.sparse.csr_matrix`. The transport matrix is materialized in row blocks of
``batch_size`` rows, so peak memory is bounded by ``batch_size``.
.. warning::
This function only serves for interfacing software which has to instantiate the transport matrix,
:mod:`moscot` never uses the sparsified transport matrix.
Parameters
----------
mode
How to determine the entries that are set to :math:`0`. Valid options are:
- `'threshold'` - ``value`` is the threshold below which entries are set to :math:`0`.
- `'percentile'` - ``value`` is the percentile in :math:`[0, 100]` of the :attr:`transport_matrix`.
below which entries are set to :math:`0`.
- `'min_row'` - ``value`` is not used, it is chosen such that each row has at least 1 non-zero entry.
- `'mass'` - per row, keep the largest entries capturing a fraction ``value`` of the row's mass
(at most ``max_k`` entries per row); ``value`` must be in :math:`(0, 1]`.
value
Value to use for sparsification. Its meaning depends on ``mode`` (see above).
batch_size
How many rows to materialize at a time when sparsifying the :attr:`transport_matrix`.
n_samples
If ``mode = 'percentile'``, determine the number of samples based on which the percentile is computed
stochastically. Note this means that a matrix of shape `[n_samples, min(transport_matrix.shape)]`
has to be instantiated. If `None`, ``n_samples`` is set to ``batch_size``.
seed
Random seed needed for sampling if ``mode = 'percentile'``.
max_k
Maximum number of entries to keep per row. Only valid when ``mode = 'mass'``.
Returns
-------
Output with sparsified transport matrix.
"""
n, m = self.shape
thr: Optional[float] = None
if mode == "mass":
if value is None or not 0.0 < value <= 1.0:
raise ValueError("If `mode = 'mass'`, `value` must be in `(0, 1]`.")
if max_k is not None and max_k <= 0:
raise ValueError(f"`max_k` must be a positive integer, found `{max_k}`.")
elif mode == "threshold":
if value is None:
raise ValueError("If `mode = 'threshold'`, `value` cannot be `None`.")
thr = value
elif mode == "percentile":
if value is None:
raise ValueError("If `mode = 'percentile'`, `value` cannot be `None`.")
rng = np.random.RandomState(seed=seed)
n_samples = n_samples if n_samples is not None else batch_size
k = min(n_samples, n)
x = np.zeros((m, k))
rows = rng.choice(m, size=k)
x[rows, np.arange(k)] = 1.0
res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors
thr = np.percentile(res, value)
elif mode == "min_row":
thr = np.inf
for batch in range(0, m, batch_size):
x = np.eye(m, min(batch_size, m - batch), -(min(batch, m)))
res = self.pull(x, scale_by_marginals=False) # tmap @ indicator_vectors
thr = min(thr, float(res.max(axis=1).min()))
else:
raise NotImplementedError(f"Mode `{mode}` is not yet implemented.")
if mode != "mass" and max_k is not None:
raise ValueError("`max_k` is only supported with `mode = 'mass'`.")
# Always iterate over source rows so that each block holds rows of the transport matrix.
# This keeps the per-row criteria (e.g. `mass`) well-defined and peak memory at `[batch_size, m]`.
tmaps_sparse: list[sp.csr_matrix] = []
for batch in range(0, n, batch_size):
cols = min(batch_size, n - batch)
x = np.eye(n, cols, -batch, dtype=float)
rows = np.asarray(self.push(x, scale_by_marginals=False)).T # [cols, m] = rows of `T`
tmaps_sparse.append(_sparsify_block(rows, mode=mode, thr=thr, value=value, max_k=max_k))
transport_matrix = sp.vstack(tmaps_sparse).tocsr() if tmaps_sparse else sp.csr_matrix((n, m))
return MatrixSolverOutput(
transport_matrix=transport_matrix,
cost=self.cost,
converged=self.converged,
is_linear=self.is_linear,
)
@property
def a(self) -> ArrayLike:
""":term:`Marginals` of the source distribution.
If the output of an :term:`unbalanced OT problem`, these are the posterior marginals.
"""
return self.pull(self._ones(self.shape[1]))
@property
def b(self) -> ArrayLike:
""":term:`Marginals` of the target distribution.
If the output of an :term:`unbalanced OT problem`, these are the posterior marginals.
"""
return self.push(self._ones(self.shape[0]))
@property
def dtype(self) -> DTypeLike:
"""Underlying data type."""
return self.a.dtype
def _format_params(self, fmt: Callable[[Any], str]) -> str:
params = {"shape": self.shape, "cost": round(self.cost, 4), "converged": self.converged}
return ", ".join(f"{name}={fmt(val)}" for name, val in params.items())
def _scale_by_marginals(self, x: ArrayLike, *, forward: bool, eps: float = 1e-12) -> ArrayLike:
# alt. we could use the public push/pull
marginals = self.a if forward else self.b
if x.ndim == 2:
marginals = marginals[:, None]
return x / (marginals + eps)
def __bool__(self) -> bool:
return self.converged
[docs]
class MatrixSolverOutput(BaseDiscreteSolverOutput):
""":term:`OT` solution with a materialized transport matrix.
Parameters
----------
transport_matrix
Transport matrix of shape ``[n, m]``.
cost
Cost of an :term:`OT` problem.
converged
Whether the solution converged.
is_linear
Whether this is a solution to a :term:`linear problem`.
"""
# TODO(michalk8): don't provide defaults?
def __init__(
self,
transport_matrix: Union[ArrayLike, sp.spmatrix],
*,
cost: float = np.nan,
converged: bool = True,
is_linear: bool = True,
):
super().__init__()
self._transport_matrix = transport_matrix
self._cost = cost
self._converged = converged
self._is_linear = is_linear
def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike:
if forward:
return self.transport_matrix.T @ x
return self.transport_matrix @ x
def _apply_forward(self, x: ArrayLike) -> ArrayLike:
return self._apply(x, forward=True)
@property
def transport_matrix(self) -> ArrayLike: # noqa: D102
return self._transport_matrix
@property
def shape(self) -> tuple[int, ...]: # noqa: D102
return self.transport_matrix.shape
[docs]
def to( # noqa: D102
self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None
) -> BaseDiscreteSolverOutput:
if device is not None:
logger.warning(f"`{self!r}` does not support the `device` argument, ignoring.")
if dtype is None:
return self
obj = copy.copy(self)
obj._transport_matrix = obj.transport_matrix.astype(dtype)
return obj
@property
def cost(self) -> float: # noqa: D102
return self._cost
@property
def converged(self) -> bool: # noqa: D102
return self._converged
@property
def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]: # noqa: D102
return None
@property
def is_linear(self) -> bool: # noqa: D102
return self._is_linear
def _ones(self, n: int) -> ArrayLike:
if isinstance(self.transport_matrix, np.ndarray):
return np.ones((n,), dtype=self.transport_matrix.dtype)
import jax.numpy as jnp
return jnp.ones((n,), dtype=self.transport_matrix.dtype)