Source code for moscot.backends.ott.solver

import abc
import inspect
import types
from typing import Any, Literal, Mapping, Optional, Set, Tuple, Union

import jax
import jax.numpy as jnp
from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t
from moscot.backends.ott._utils import (
    _instantiate_geodesic_cost,
    alpha_to_fused_penalty,
    check_shapes,
    ensure_2d,
)
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
from moscot.base.problems._utils import TimeScalesHeatKernel
from moscot.base.solver import OTSolver
from moscot.costs import get_cost
from moscot.utils.tagged_array import TaggedArray

__all__ = ["SinkhornSolver", "GWSolver"]

OTTSolver_t = Union[
    sinkhorn.Sinkhorn,
    sinkhorn_lr.LRSinkhorn,
    gromov_wasserstein.GromovWasserstein,
    gromov_wasserstein_lr.LRGromovWasserstein,
]
OTTProblem_t = Union[linear_problem.LinearProblem, quadratic_problem.QuadraticProblem]
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]


class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
    """Base class for :mod:`ott` solvers :cite:`cuturi2022optimal`.

    Parameters
    ----------
    jit
        Whether to :func:`~jax.jit` the :attr:`solver`.
    """

    def __init__(self, jit: bool = True):
        super().__init__()
        self._solver: Optional[OTTSolver_t] = None
        self._problem: Optional[OTTProblem_t] = None
        self._jit = jit
        self._a: Optional[jnp.ndarray] = None
        self._b: Optional[jnp.ndarray] = None

    def _create_geometry(
        self,
        x: TaggedArray,
        *,
        is_linear_term: bool,
        epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
        relative_epsilon: Optional[bool] = None,
        scale_cost: Scale_t = 1.0,
        batch_size: Optional[int] = None,
        problem_shape: Optional[Tuple[int, int]] = None,
        t: Optional[float] = None,
        directed: bool = True,
        **kwargs: Any,
    ) -> geometry.Geometry:
        if x.is_point_cloud:
            cost_fn = x.cost
            if cost_fn is None:
                cost_fn = costs.SqEuclidean()
            elif isinstance(cost_fn, str):
                cost_fn = get_cost(cost_fn, backend="ott", **kwargs)
            if not isinstance(cost_fn, costs.CostFn):
                raise TypeError(f"Expected `cost_fn` to be `ott.geometry.costs.CostFn`, found `{type(cost_fn)}`.")

            y = None if x.data_tgt is None else ensure_2d(x.data_tgt, reshape=True)
            x = ensure_2d(x.data_src, reshape=True)
            if y is not None and x.shape[1] != y.shape[1]:
                raise ValueError(
                    f"Expected `x/y` to have the same number of dimensions, found `{x.shape[1]}/{y.shape[1]}`."
                )

            return pointcloud.PointCloud(
                x,
                y=y,
                cost_fn=cost_fn,
                epsilon=epsilon,
                relative_epsilon=relative_epsilon,
                scale_cost=scale_cost,
                batch_size=batch_size,
            )

        arr = ensure_2d(x.data_src, reshape=False)
        if x.is_cost_matrix:
            return geometry.Geometry(
                cost_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
            )
        if x.is_kernel:
            return geometry.Geometry(
                kernel_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
            )
        if x.is_graph:  # we currently only support this for the linear term.
            return self._create_graph_geometry(
                is_linear_term=is_linear_term,
                x=x,
                arr=arr,
                problem_shape=problem_shape,
                t=t,
                epsilon=epsilon,
                relative_epsilon=relative_epsilon,
                scale_cost=scale_cost,
                directed=directed,
                **kwargs,
            )
        raise NotImplementedError(f"Creating geometry from `tag={x.tag!r}` is not yet implemented.")

    def _solve(  # type: ignore[override]
        self,
        prob: OTTProblem_t,
        **kwargs: Any,
    ) -> Union[OTTOutput, GraphOTTOutput]:
        solver = jax.jit(self.solver) if self._jit else self.solver
        out = solver(prob, **kwargs)
        if isinstance(prob, linear_problem.LinearProblem) and isinstance(prob.geom, geodesic.Geodesic):
            return GraphOTTOutput(out, shape=(len(self._a), len(self._b)))  # type: ignore[arg-type]
        return OTTOutput(out)

    def _create_graph_geometry(
        self,
        is_linear_term: bool,
        x: TaggedArray,
        arr: jax.Array,
        problem_shape: Optional[Tuple[int, int]],
        t: Optional[float],
        epsilon: Union[float, epsilon_scheduler.Epsilon] = None,
        relative_epsilon: Optional[bool] = None,
        scale_cost: Scale_t = 1.0,
        directed: bool = True,
        **kwargs: Any,
    ) -> geometry.Geometry:
        if x.cost == "geodesic":
            if self.problem_kind == "linear":
                if t is None:
                    if epsilon is None:
                        raise ValueError("`epsilon` cannot be `None`.")
                    return geodesic.Geodesic.from_graph(arr, t=epsilon / 4.0, directed=directed, **kwargs)

                return _instantiate_geodesic_cost(
                    arr=arr,
                    problem_shape=problem_shape,  # type: ignore[arg-type]
                    t=t,
                    is_linear_term=True,
                    epsilon=epsilon,
                    relative_epsilon=relative_epsilon,
                    scale_cost=scale_cost,
                    directed=directed,
                    **kwargs,
                )
            if self.problem_kind == "quadratic":
                problem_shape = x.shape if problem_shape is None else problem_shape
                return _instantiate_geodesic_cost(
                    arr=arr,
                    problem_shape=problem_shape,
                    t=t,
                    is_linear_term=is_linear_term,
                    epsilon=epsilon,
                    relative_epsilon=relative_epsilon,
                    scale_cost=scale_cost,
                    directed=directed,
                    **kwargs,
                )

            raise NotImplementedError(f"Invalid problem kind `{self.problem_kind}`.")
        raise NotImplementedError(f"If the geometry is a graph, `cost` must be `geodesic`, found `{x.cost}`.")

    @property
    def solver(self) -> OTTSolver_t:
        """:mod:`ott` solver."""
        return self._solver

    @property
    def rank(self) -> int:
        """Rank of the :attr:`solver`."""
        return getattr(self.solver, "rank", -1)

    @property
    def is_low_rank(self) -> bool:
        """Whether the :attr:`solver` is low-rank."""
        return self.rank > -1


[docs] class SinkhornSolver(OTTJaxSolver): """Solver for the :term:`linear problem`. The (Kantorovich relaxed) :term:`OT` problem is defined by two distributions in the same space. The aim is to obtain a probabilistic map from the source distribution to the target distribution such that the (weighted) sum of the distances between coupled data point in the source and the target distribution is minimized. Parameters ---------- jit Whether to :func:`~jax.jit` the :attr:`solver`. rank Rank of the solver. If `-1`, use :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` :cite:`cuturi:2013`, otherwise, use :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn` :cite:`scetbon:21a`. epsilon Additional epsilon regularization for the low-rank approach. initializer Initializer for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`, depending on the ``rank``. initializer_kwargs Keyword arguments for the initializer. kwargs Keyword arguments for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`, depending on the ``rank``. """ def __init__( self, jit: bool = True, rank: int = -1, epsilon: float = 0.0, initializer: SinkhornInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Any, ): super().__init__(jit=jit) if rank > -1: kwargs.setdefault("gamma", 10) kwargs.setdefault("gamma_rescale", True) initializer = "rank2" if initializer is None else initializer self._solver = sinkhorn_lr.LRSinkhorn( rank=rank, epsilon=epsilon, initializer=initializer, kwargs_init=initializer_kwargs, **kwargs ) else: initializer = "default" if initializer is None else initializer self._solver = sinkhorn.Sinkhorn(initializer=initializer, kwargs_init=initializer_kwargs, **kwargs) def _prepare( self, a: jnp.ndarray, b: jnp.ndarray, xy: Optional[TaggedArray] = None, x: Optional[TaggedArray] = None, y: Optional[TaggedArray] = None, # geometry epsilon: Union[float, epsilon_scheduler.Epsilon] = None, relative_epsilon: Optional[bool] = None, batch_size: Optional[int] = None, scale_cost: Scale_t = 1.0, cost_kwargs: Mapping[str, Any] = types.MappingProxyType({}), cost_matrix_rank: Optional[int] = None, time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None, # problem **kwargs: Any, ) -> linear_problem.LinearProblem: del x, y time_scales_heat_kernel = ( TimeScalesHeatKernel(None, None, None) if time_scales_heat_kernel is None else time_scales_heat_kernel ) if xy is None: raise ValueError(f"Unable to create geometry from `xy={xy}`.") self._a = a self._b = b geom = self._create_geometry( xy, is_linear_term=True, epsilon=epsilon, relative_epsilon=relative_epsilon, batch_size=batch_size, problem_shape=(len(self._a), len(self._b)), scale_cost=scale_cost, t=time_scales_heat_kernel.xy, **cost_kwargs, ) if cost_matrix_rank is not None: geom = geom.to_LRCGeometry(rank=cost_matrix_rank) if isinstance(geom, geodesic.Geodesic): a = jnp.concatenate((a, jnp.zeros_like(self._b)), axis=0) b = jnp.concatenate((jnp.zeros_like(self._a), b), axis=0) self._problem = linear_problem.LinearProblem(geom, a=a, b=b, **kwargs) return self._problem @property def xy(self) -> Optional[geometry.Geometry]: """Geometry defining the linear term.""" return None if self._problem is None else self._problem.geom @property def problem_kind(self) -> ProblemKind_t: # noqa: D102 return "linear" @classmethod def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: geom_kwargs = { "epsilon", "relative_epsilon", "batch_size", "scale_cost", "cost_kwargs", "cost_matrix_rank", "t", } problem_kwargs = set(inspect.signature(linear_problem.LinearProblem).parameters.keys()) problem_kwargs -= {"geom"} return geom_kwargs | problem_kwargs, {"epsilon"}
[docs] class GWSolver(OTTJaxSolver): """Solver for the :term:`quadratic problem` :cite:`memoli:2011`. The :term:`Gromov-Wasserstein (GW) <Gromov-Wasserstein>` problem involves two distribution in possibly two different spaces. Points in the source distribution are matched to points in the target distribution by comparing the relative location of the points within each distribution. Parameters ---------- jit Whether to :func:`~jax.jit` the :attr:`solver`. rank Rank of the solver. If `-1` use the full-rank :term:`GW <Gromov-Wasserstein>` :cite:`peyre:2016`, otherwise, use the low-rank approach :cite:`scetbon:21b`. initializer Initializer for :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`. initializer_kwargs Keyword arguments for the ``initializer``. linear_solver_kwargs Keyword arguments for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhorn`, depending on the ``rank``. kwargs Keyword arguments for :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` . """ def __init__( self, jit: bool = True, rank: int = -1, initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Any, ): super().__init__(jit=jit) if rank > -1: kwargs.setdefault("gamma", 10) kwargs.setdefault("gamma_rescale", True) initializer = "rank2" if initializer is None else initializer self._solver = gromov_wasserstein_lr.LRGromovWasserstein( rank=rank, initializer=initializer, kwargs_init=initializer_kwargs, **kwargs, ) else: linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs) initializer = None self._solver = gromov_wasserstein.GromovWasserstein( rank=rank, linear_ot_solver=linear_ot_solver, quad_initializer=initializer, kwargs_init=initializer_kwargs, **kwargs, ) def _prepare( self, a: jnp.ndarray, b: jnp.ndarray, xy: Optional[TaggedArray] = None, x: Optional[TaggedArray] = None, y: Optional[TaggedArray] = None, # geometry epsilon: Union[float, epsilon_scheduler.Epsilon] = None, relative_epsilon: Optional[bool] = None, batch_size: Optional[int] = None, scale_cost: Scale_t = 1.0, cost_kwargs: Mapping[str, Any] = types.MappingProxyType({}), cost_matrix_rank: Optional[int] = None, time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None, # problem alpha: float = 0.5, **kwargs: Any, ) -> quadratic_problem.QuadraticProblem: self._a = a self._b = b time_scales_heat_kernel = ( TimeScalesHeatKernel(None, None, None) if time_scales_heat_kernel is None else time_scales_heat_kernel ) if x is None or y is None: raise ValueError(f"Unable to create geometry from `x={x}`, `y={y}`.") geom_kwargs: dict[str, Any] = { "epsilon": epsilon, "relative_epsilon": relative_epsilon, "batch_size": batch_size, "scale_cost": scale_cost, **cost_kwargs, } if cost_matrix_rank is not None: geom_kwargs["cost_matrix_rank"] = cost_matrix_rank geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs) geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs) if alpha == 1.0 or xy is None: # GW # arbitrary fused penalty; must be positive geom_xy, fused_penalty = None, 1.0 else: # FGW fused_penalty = alpha_to_fused_penalty(alpha) geom_xy = self._create_geometry( xy, t=time_scales_heat_kernel.xy, problem_shape=(x.shape[0], y.shape[0]), is_linear_term=True, **geom_kwargs, ) check_shapes(geom_xx, geom_yy, geom_xy) self._problem = quadratic_problem.QuadraticProblem( geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty, a=self._a, b=self._b, **kwargs ) return self._problem @property def x(self) -> Optional[geometry.Geometry]: """The first geometry defining the quadratic term.""" return None if self._problem is None else self._problem.geom_xx @property def y(self) -> geometry.Geometry: """The second geometry defining the quadratic term.""" return None if self._problem is None else self._problem.geom_yy @property def xy(self) -> Optional[geometry.Geometry]: """Geometry defining the linear term in the :term:`FGW <fused Gromov-Wasserstein>`.""" return None if self._problem is None else self._problem.geom_xy @property def is_fused(self) -> Optional[bool]: """Whether the solver is fused.""" return None if self._problem is None else (self.xy is not None) @property def problem_kind(self) -> ProblemKind_t: # noqa: D102 return "quadratic" @classmethod def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: geom_kwargs = {"epsilon", "relative_epsilon", "batch_size", "scale_cost", "cost_kwargs", "cost_matrix_rank"} problem_kwargs = set(inspect.signature(quadratic_problem.QuadraticProblem).parameters.keys()) problem_kwargs -= {"geom_xx", "geom_yy", "geom_xy", "fused_penalty"} problem_kwargs |= {"alpha"} return geom_kwargs | problem_kwargs, {"epsilon"}