import types
from types import MappingProxyType
from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Type, Union
from anndata import AnnData
from moscot import _constants
from moscot._types import (
CostKwargs_t,
OttCostFn_t,
OttCostFnMap_t,
Policy_t,
ProblemStage_t,
QuadInitializer_t,
ScaleCost_t,
SinkhornInitializer_t,
)
from moscot.base.problems.compound_problem import B, Callback_t, CompoundProblem, K
from moscot.base.problems.problem import CondOTProblem, OTProblem
from moscot.problems._utils import (
handle_conditional_attr,
handle_cost,
handle_cost_tmp,
handle_joint_attr,
handle_joint_attr_tmp,
)
from moscot.problems.generic._mixins import GenericAnalysisMixin
__all__ = ["SinkhornProblem", "GWProblem", "GENOTLinProblem", "FGWProblem"]
def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, str]:
if isinstance(z, str):
return {"attr": "obsm", "key": z, "tag": "point_cloud"} # cost handled by handle_cost
if isinstance(z, Mapping):
return dict(z)
raise TypeError("`x_attr` and `y_attr` must be of type `str` or `dict` if no callback is provided.")
[docs]
class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
"""Class for solving a :term:`linear problem`.
Parameters
----------
adata
Annotated data object.
kwargs
Keyword arguments for :class:`~moscot.base.problems.CompoundProblem`.
"""
def __init__(self, adata: AnnData, **kwargs: Any):
super().__init__(adata, **kwargs)
[docs]
def prepare(
self,
key: str,
joint_attr: Optional[Union[str, Mapping[str, Any]]] = None,
policy: Literal["sequential", "explicit", "star"] = "sequential",
cost: OttCostFn_t = "sq_euclidean",
cost_kwargs: CostKwargs_t = types.MappingProxyType({}),
a: Optional[Union[bool, str]] = None,
b: Optional[Union[bool, str]] = None,
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
) -> "SinkhornProblem[K, B]":
r"""Prepare the individual :term:`linear subproblems <linear problem>`.
.. seealso::
- See :doc:`../notebooks/examples/problems/200_custom_cost_matrices` on how to pass custom cost matrices.
- TODO(michalk8): add an example that shows how to pass different costs (with kwargs).
Parameters
----------
key
Key in :attr:`~anndata.AnnData.obs` for the :class:`~moscot.utils.subset_policy.SubsetPolicy`.
joint_attr
How to get the data for the :term:`linear term`:
- :obj:`None` - `PCA <https://en.wikipedia.org/wiki/Principal_component_analysis>`_
on :attr:`~anndata.AnnData.X` is computed.
- :class:`str` - key in :attr:`~anndata.AnnData.obsm` where the data is stored.
- :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and key in
:class:`~anndata.AnnData`, and optionally ``'tag'`` from the
:class:`tags <moscot.utils.tagged_array.Tag>`.
By default, :attr:`tag = 'point_cloud' <moscot.utils.tagged_array.Tag.POINT_CLOUD>` is used.
policy
Rule which defines how to construct the subproblems using :attr:`obs['{key}'] <anndata.AnnData.obs>`.
Valid options are:
- ``'sequential'`` - align subsequent categories.
- ``'explicit'`` - explicit sequence of subsets passed via ``subset = [(b3, b0), ...]``.
cost
Cost function to use. Valid options are:
- :class:`str` - name of the cost function, see :func:`~moscot.costs.get_available_costs`.
- :class:`dict` - a dictionary with the following keys and values:
- ``'xy'`` - cost function for the :term:`linear term`, same as above.
cost_kwargs
Keyword arguments for the :class:`~moscot.base.cost.BaseCost` or any backend-specific cost.
a
Source :term:`marginals`. Valid options are:
- :class:`str` - key in :attr:`~anndata.AnnData.obs` where the source marginals are stored.
- :class:`bool` - if :obj:`True`,
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
b
Target :term:`marginals`. Valid options are:
- :class:`str` - key in :attr:`~anndata.AnnData.obs` where the target marginals are stored.
- :class:`bool` - if :obj:`True`,
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
Returns
-------
Returns self and updates the following fields:
- :attr:`problems` - the prepared subproblems.
- :attr:`solutions` - set to an empty :class:`dict`.
- :attr:`stage` - set to ``'prepared'``.
- :attr:`problem_kind` - set to ``'linear'``.
"""
self.batch_key = key
xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs)
xy, _, _ = handle_cost(
xy=xy,
x={},
y={},
cost=cost,
cost_kwargs=cost_kwargs,
xy_callback=xy_callback,
)
return super().prepare( # type: ignore[return-value]
key=key,
policy=policy,
xy=xy,
a=a,
b=b,
xy_callback=xy_callback,
xy_callback_kwargs=xy_callback_kwargs,
reference=reference,
subset=subset,
)
[docs]
def solve(
self,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
scale_cost: ScaleCost_t = "mean",
batch_size: Optional[int] = None,
stage: Union[ProblemStage_t, Tuple[ProblemStage_t, ...]] = ("prepared", "solved"),
initializer: SinkhornInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
threshold: float = 1e-3,
lse_mode: bool = True,
inner_iterations: int = 10,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "SinkhornProblem[K,B]":
r"""Solve the individual :term:`linear subproblems <linear problem>` \
using the :term:`Sinkhorn` algorithm :cite:`cuturi:2013`.
.. seealso:
- See :doc:`../notebooks/examples/solvers/100_linear_problem_basic` on how to specify
the most important parameters.
- See :doc:`../notebooks/examples/solvers/200_linear_problems_advanced` on how to specify
additional parameters, such as the ``initializer``.
Parameters
----------
epsilon
:term:`Entropic regularization`.
tau_a
Parameter in :math:`(0, 1]` that defines how much :term:`unbalanced <unbalanced OT problem>` is the problem
on the source :term:`marginals`. If :math:`1`, the problem is :term:`balanced <balanced OT problem>`.
tau_b
Parameter in :math:`(0, 1]` that defines how much :term:`unbalanced <unbalanced OT problem>` is the problem
on the target :term:`marginals`. If :math:`1`, the problem is :term:`balanced <balanced OT problem>`.
rank
Rank of the :term:`low-rank OT` solver :cite:`scetbon:21a`.
If :math:`-1`, full-rank solver :cite:`cuturi:2013` is used.
scale_cost
How to re-scale the cost matrix. If a :class:`float`, the cost matrix
will be re-scaled as :math:`\frac{\text{cost}}{\text{scale_cost}}`.
batch_size
Number of rows/columns of the cost matrix to materialize during the :term:`Sinkhorn` iterations.
Larger value will require more memory.
stage
Stage by which to filter the :attr:`problems` to be solved.
initializer
How to initialize the solution. If :obj:`None`, ``'default'`` will be used for a full-rank solver and
``'rank2'`` for a low-rank solver.
initializer_kwargs
Keyword arguments for the ``initializer``.
jit
Whether to :func:`~jax.jit` the underlying :mod:`ott` solver.
threshold
Convergence threshold of the :term:`Sinkhorn` algorithm. In the :term:`balanced <balanced OT problem>` case,
this is typically the deviation between the target :term:`marginals` and the marginals of the current
:term:`transport matrix`. In the :term:`unbalanced <unbalanced OT problem>` case, the relative change
between the successive solutions is checked.
lse_mode
Whether to use `log-sum-exp (LSE)
<https://en.wikipedia.org/wiki/LogSumExp#log-sum-exp_trick_for_log-domain_calculations>`_
computations for numerical stability.
inner_iterations
Compute the convergence criterion every ``inner_iterations``.
min_iterations
Minimum number of :term:`Sinkhorn` iterations.
max_iterations
Maximum number of :term:`Sinkhorn` iterations.
device
Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`.
If :obj:`None`, keep the output on the original device.
kwargs
Keyword arguments for :meth:`~moscot.base.problems.CompoundProblem.solve`.
Returns
-------
Returns self and updates the following fields:
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
""" # noqa: D205
return super().solve( # type: ignore[return-value]
epsilon=epsilon,
tau_a=tau_a,
tau_b=tau_b,
rank=rank,
scale_cost=scale_cost,
batch_size=batch_size,
stage=stage,
initializer=initializer,
initializer_kwargs=initializer_kwargs,
jit=jit,
threshold=threshold,
lse_mode=lse_mode,
inner_iterations=inner_iterations,
min_iterations=min_iterations,
max_iterations=max_iterations,
device=device,
**kwargs,
)
@property
def _base_problem_type(self) -> Type[B]:
return OTProblem # type: ignore[return-value]
@property
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]
[docs]
class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
"""Class for solving the :term:`GW <Gromov-Wasserstein>` or :term:`FGW <fused Gromov-Wasserstein>` problems.
Parameters
----------
adata
Annotated data object.
kwargs
Keyword arguments for :class:`~moscot.base.problems.CompoundProblem`.
"""
def __init__(self, adata: AnnData, **kwargs: Any):
super().__init__(adata, **kwargs)
[docs]
def prepare(
self,
key: str,
x_attr: Optional[Union[str, Mapping[str, Any]]] = None,
y_attr: Optional[Union[str, Mapping[str, Any]]] = None,
policy: Literal["sequential", "explicit", "star"] = "sequential",
cost: OttCostFnMap_t = "sq_euclidean",
cost_kwargs: CostKwargs_t = types.MappingProxyType({}),
a: Optional[Union[bool, str]] = None,
b: Optional[Union[bool, str]] = None,
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> "GWProblem[K, B]":
"""Prepare the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso::
- TODO(michalk8): add an example how to pass `x_attr/y_attr`.
Parameters
----------
key
Key in :attr:`~anndata.AnnData.obs` for the :class:`~moscot.utils.subset_policy.SubsetPolicy`.
x_attr
How to get the data for the source :term:`quadratic term`:
- :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored.
- :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and key in
:class:`~anndata.AnnData`, and optionally ``'tag'`` from the
:class:`tags <moscot.utils.tagged_array.Tag>`.
- :obj:`None` - ``'x_callback'`` must be passed via ``kwargs``.
By default, :attr:`tag = 'point_cloud' <moscot.utils.tagged_array.Tag.POINT_CLOUD>` is used.
y_attr
How to get the data for the target :term:`quadratic term`:
- :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored.
- :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and the key
in :class:`~anndata.AnnData`, and optionally ``'tag'``, one of :class:`~moscot.utils.tagged_array.Tag`.
- :obj:`None` - ``'y_callback'`` must be passed via ``kwargs``.
By default, :attr:`tag = 'point_cloud' <moscot.utils.tagged_array.Tag.POINT_CLOUD>` is used.
policy
Rule which defines how to construct the subproblems. Valid options are:
- ``'sequential'`` - align subsequent categories in :attr:`obs['{key}'] <anndata.AnnData.obs>`.
- ``'explicit'`` - explicit sequence of subsets passed via ``subset = [(b3, b0), ...]``.
cost
Cost function to use. Valid options are:
- :class:`str` - name of the cost function for all terms, see :func:`~moscot.costs.get_available_costs`.
- :class:`dict` - a dictionary with the following keys and values:
- ``'xy'`` - cost function for the :term:`linear term`.
- ``'x'`` - cost function for the source :term:`quadratic term`.
- ``'y'`` - cost function for the target :term:`quadratic term`.
cost_kwargs
Keyword arguments for the :class:`~moscot.base.cost.BaseCost` or any backend-specific cost.
a
Source :term:`marginals`. Valid options are:
- :class:`str` - key in :attr:`~anndata.AnnData.obs` where the source marginals are stored.
- :class:`bool` - if :obj:`True`,
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
b
Target :term:`marginals`. Valid options are:
- :class:`str` - key in :attr:`~anndata.AnnData.obs` where the target marginals are stored.
- :class:`bool` - if :obj:`True`,
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
Returns
-------
Returns self and updates the following fields:
- :attr:`problems` - the prepared subproblems.
- :attr:`solutions` - set to an empty :class:`dict`.
- :attr:`batch_key` - key in :attr:`~anndata.AnnData.obs` where batches are stored.
- :attr:`stage` - set to ``'prepared'``.
- :attr:`problem_kind` - set to ``'quadratic'``.
"""
self.batch_key = key
x = set_quad_defaults(x_attr) if x_callback is None else {}
y = set_quad_defaults(y_attr) if y_callback is None else {}
xy, x, y = handle_cost(
xy={},
x=x,
y=y,
cost=cost,
cost_kwargs=cost_kwargs,
x_callback=x_callback,
y_callback=y_callback,
)
return super().prepare( # type: ignore[return-value]
key=key,
xy=xy,
x=x,
y=y,
policy=policy,
a=a,
b=b,
x_callback=x_callback,
y_callback=y_callback,
x_callback_kwargs=x_callback_kwargs,
y_callback_kwargs=y_callback_kwargs,
subset=subset,
reference=reference,
)
[docs]
def solve(
self,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
scale_cost: ScaleCost_t = "mean",
batch_size: Optional[int] = None,
stage: Union[ProblemStage_t, Tuple[ProblemStage_t, ...]] = ("prepared", "solved"),
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "GWProblem[K,B]":
r"""Solve the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso:
- See :doc:`../notebooks/examples/solvers/300_quad_problems_basic` on how to specify
the most important parameters.
- See :doc:`../notebooks/examples/solvers/400_quad_problems_advanced` on how to specify
additional parameters, such as the ``initializer``.
Parameters
----------
epsilon
:term:`Entropic regularization`.
tau_a
Parameter in :math:`(0, 1]` that defines how much :term:`unbalanced <unbalanced OT problem>` is the problem
on the source :term:`marginals`. If :math:`1`, the problem is :term:`balanced <balanced OT problem>`.
tau_b
Parameter in :math:`(0, 1]` that defines how much :term:`unbalanced <unbalanced OT problem>` is the problem
on the target :term:`marginals`. If :math:`1`, the problem is :term:`balanced <balanced OT problem>`.
rank
Rank of the :term:`low-rank OT` solver :cite:`scetbon:21b`.
If :math:`-1`, full-rank solver :cite:`peyre:2016` is used.
scale_cost
How to re-scale the cost matrices. If a :class:`float`, the cost matrices
will be re-scaled as :math:`\frac{\text{cost}}{\text{scale_cost}}`.
batch_size
Number of rows/columns of the cost matrix to materialize during the solver iterations.
Larger value will require more memory.
stage
Stage by which to filter the :attr:`problems` to be solved.
initializer
How to initialize the solution. If :obj:`None`, ``'default'`` will be used for a full-rank solver and
``'rank2'`` for a low-rank solver.
initializer_kwargs
Keyword arguments for the ``initializer``.
jit
Whether to :func:`~jax.jit` the underlying :mod:`ott` solver.
min_iterations
Minimum number of :term:`(fused) GW <Gromov-Wasserstein>` iterations.
max_iterations
Maximum number of :term:`(fused) GW <Gromov-Wasserstein>` iterations.
threshold
Convergence threshold of the :term:`GW <Gromov-Wasserstein>` solver.
linear_solver_kwargs
Keyword arguments for the inner :term:`linear problem` solver.
device
Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`.
If :obj:`None`, keep the output on the original device.
kwargs
Keyword arguments for :meth:`~moscot.base.problems.CompoundProblem.solve`.
Returns
-------
Returns self and updates the following fields:
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
return super().solve( # type: ignore[return-value]
alpha=1.0,
epsilon=epsilon,
tau_a=tau_a,
tau_b=tau_b,
rank=rank,
scale_cost=scale_cost,
batch_size=batch_size,
stage=stage,
initializer=initializer,
initializer_kwargs=initializer_kwargs,
jit=jit,
min_iterations=min_iterations,
max_iterations=max_iterations,
threshold=threshold,
linear_solver_kwargs=linear_solver_kwargs,
device=device,
**kwargs,
)
@property
def _base_problem_type(self) -> Type[B]:
return OTProblem # type: ignore[return-value]
@property
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]
[docs]
class FGWProblem(GWProblem[K, B]):
"""Class for solving the :term:`FGW <fused Gromov-Wasserstein>` problem.
Parameters
----------
adata
Annotated data object.
kwargs
Keyword arguments for :class:`~moscot.base.problems.CompoundProblem`.
"""
[docs]
def prepare(
self,
key: str,
joint_attr: Optional[Union[str, Mapping[str, Any]]] = None,
x_attr: Optional[Union[str, Mapping[str, Any]]] = None,
y_attr: Optional[Union[str, Mapping[str, Any]]] = None,
policy: Literal["sequential", "explicit", "star"] = "sequential",
cost: OttCostFnMap_t = "sq_euclidean",
cost_kwargs: CostKwargs_t = types.MappingProxyType({}),
a: Optional[Union[bool, str]] = None,
b: Optional[Union[bool, str]] = None,
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
) -> "FGWProblem[K, B]":
"""Prepare the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso::
- TODO(michalk8): add an example how to pass `x_attr/y_attr`.
Parameters
----------
key
Key in :attr:`~anndata.AnnData.obs` for the :class:`~moscot.utils.subset_policy.SubsetPolicy`.
joint_attr
How to get the data for the :term:`linear term` in the :term:`fused <fused Gromov-Wasserstein>` case:
- :obj:`None` - run `PCA <https://en.wikipedia.org/wiki/Principal_component_analysis>`_
on :attr:`~anndata.AnnData.X` is computed.
- :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored.
- :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and the key
in :class:`~anndata.AnnData`, and optionally ``'tag'``, one of :class:`~moscot.utils.tagged_array.Tag`.
By default, :attr:`tag = 'point_cloud' <moscot.utils.tagged_array.Tag.POINT_CLOUD>` is used.
x_attr
How to get the data for the source :term:`quadratic term`:
- :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored.
- :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and key in
:class:`~anndata.AnnData`, and optionally ``'tag'`` from the
:class:`tags <moscot.utils.tagged_array.Tag>`.
- :obj:`None` - ``'x_callback'`` must be passed via ``kwargs``.
By default, :attr:`tag = 'point_cloud' <moscot.utils.tagged_array.Tag.POINT_CLOUD>` is used.
y_attr
How to get the data for the target :term:`quadratic term`:
- :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored.
- :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and the key
in :class:`~anndata.AnnData`, and optionally ``'tag'``, one of :class:`~moscot.utils.tagged_array.Tag`.
- :obj:`None` - ``'y_callback'`` must be passed via ``kwargs``.
By default, :attr:`tag = 'point_cloud' <moscot.utils.tagged_array.Tag.POINT_CLOUD>` is used.
policy
Rule which defines how to construct the subproblems. Valid options are:
- ``'sequential'`` - align subsequent categories in :attr:`obs['{key}'] <anndata.AnnData.obs>`.
- ``'explicit'`` - explicit sequence of subsets passed via ``subset = [(b3, b0), ...]``.
cost
Cost function to use. Valid options are:
- :class:`str` - name of the cost function for all terms, see :func:`~moscot.costs.get_available_costs`.
- :class:`dict` - a dictionary with the following keys and values:
- ``'xy'`` - cost function for the :term:`linear term`.
- ``'x'`` - cost function for the source :term:`quadratic term`.
- ``'y'`` - cost function for the target :term:`quadratic term`.
cost_kwargs
Keyword arguments for the :class:`~moscot.base.cost.BaseCost` or any backend-specific cost.
a
Source :term:`marginals`. Valid options are:
- :class:`str` - key in :attr:`~anndata.AnnData.obs` where the source marginals are stored.
- :class:`bool` - if :obj:`True`,
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
b
Target :term:`marginals`. Valid options are:
- :class:`str` - key in :attr:`~anndata.AnnData.obs` where the target marginals are stored.
- :class:`bool` - if :obj:`True`,
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
xy
Data for the :term:`linear term`.
x
Data for the source :term:`quadratic term`.
y
Data for the target :term:`quadratic term`.
xy_callback
Callback function used to prepare the data in the :term:`linear term`.
x_callback
Callback function used to prepare the data in the source :term:`quadratic term`.
y_callback
Callback function used to prepare the data in the target :term:`quadratic term`.
xy_callback_kwargs
Keyword arguments for the ``xy_callback``.
x_callback_kwargs
Keyword arguments for the ``x_callback``.
y_callback_kwargs
Keyword arguments for the ``y_callback``.
Returns
-------
Returns self and updates the following fields:
- :attr:`problems` - the prepared subproblems.
- :attr:`solutions` - set to an empty :class:`dict`.
- :attr:`batch_key` - key in :attr:`~anndata.AnnData.obs` where batches are stored.
- :attr:`stage` - set to ``'prepared'``.
- :attr:`problem_kind` - set to ``'quadratic'``.
"""
self.batch_key = key
x = set_quad_defaults(x_attr) if x_callback is None else {}
y = set_quad_defaults(y_attr) if y_callback is None else {}
xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs)
xy, x, y = handle_cost(
xy=xy,
x=x,
y=y,
cost=cost,
x_callback=x_callback,
y_callback=y_callback,
xy_callback=xy_callback,
cost_kwargs=cost_kwargs,
)
return CompoundProblem.prepare(
self, # type: ignore[return-value, arg-type]
key=key,
xy=xy,
x=x,
y=y,
policy=policy,
a=a,
b=b,
reference=reference,
subset=subset, # type: ignore[arg-type]
x_callback=x_callback,
y_callback=y_callback,
xy_callback=xy_callback,
x_callback_kwargs=x_callback_kwargs,
y_callback_kwargs=y_callback_kwargs,
xy_callback_kwargs=xy_callback_kwargs,
)
[docs]
def solve(
self,
alpha: float = 0.5,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
rank: int = -1,
scale_cost: ScaleCost_t = "mean",
batch_size: Optional[int] = None,
stage: Union[ProblemStage_t, Tuple[ProblemStage_t, ...]] = ("prepared", "solved"),
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "FGWProblem[K,B]":
r"""Solve the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso:
- See :doc:`../notebooks/examples/solvers/300_quad_problems_basic` on how to specify
the most important parameters.
- See :doc:`../notebooks/examples/solvers/400_quad_problems_advanced` on how to specify
additional parameters, such as the ``initializer``.
Parameters
----------
alpha
Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and
the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while
:math:`\alpha \to 0` corresponds to the pure :term:`linear problem`.
epsilon
:term:`Entropic regularization`.
tau_a
Parameter in :math:`(0, 1]` that defines how much :term:`unbalanced <unbalanced OT problem>` is the problem
on the source :term:`marginals`. If :math:`1`, the problem is :term:`balanced <balanced OT problem>`.
tau_b
Parameter in :math:`(0, 1]` that defines how much :term:`unbalanced <unbalanced OT problem>` is the problem
on the target :term:`marginals`. If :math:`1`, the problem is :term:`balanced <balanced OT problem>`.
rank
Rank of the :term:`low-rank OT` solver :cite:`scetbon:21b`.
If :math:`-1`, full-rank solver :cite:`peyre:2016` is used.
scale_cost
How to re-scale the cost matrices. If a :class:`float`, the cost matrices
will be re-scaled as :math:`\frac{\text{cost}}{\text{scale_cost}}`.
batch_size
Number of rows/columns of the cost matrix to materialize during the solver iterations.
Larger value will require more memory.
stage
Stage by which to filter the :attr:`problems` to be solved.
initializer
How to initialize the solution. If :obj:`None`, ``'default'`` will be used for a full-rank solver and
``'rank2'`` for a low-rank solver.
initializer_kwargs
Keyword arguments for the ``initializer``.
jit
Whether to :func:`~jax.jit` the underlying :mod:`ott` solver.
min_iterations
Minimum number of :term:`(fused) GW <Gromov-Wasserstein>` iterations.
max_iterations
Maximum number of :term:`(fused) GW <Gromov-Wasserstein>` iterations.
threshold
Convergence threshold of the :term:`GW <Gromov-Wasserstein>` solver.
linear_solver_kwargs
Keyword arguments for the inner :term:`linear problem` solver.
device
Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`.
If :obj:`None`, keep the output on the original device.
kwargs
Keyword arguments for :meth:`~moscot.base.problems.CompoundProblem.solve`.
Returns
-------
Returns self and updates the following fields:
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
if alpha == 1.0:
raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.")
return CompoundProblem.solve(
self, # type: ignore[return-value, arg-type]
alpha=alpha,
epsilon=epsilon,
tau_a=tau_a,
tau_b=tau_b,
rank=rank,
scale_cost=scale_cost,
batch_size=batch_size,
stage=stage,
initializer=initializer,
initializer_kwargs=initializer_kwargs,
jit=jit,
min_iterations=min_iterations,
max_iterations=max_iterations,
threshold=threshold,
linear_solver_kwargs=linear_solver_kwargs,
device=device,
**kwargs,
)
@property
def _base_problem_type(self) -> Type[B]:
return OTProblem # type: ignore[return-value]
@property
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]
[docs]
class GENOTLinProblem(CondOTProblem):
"""Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems."""
[docs]
def prepare(
self,
key: str,
joint_attr: Union[str, Mapping[str, Any]],
conditional_attr: Union[str, Mapping[str, Any]],
policy: Literal["sequential", "star", "explicit"] = "sequential",
a: Optional[str] = None,
b: Optional[str] = None,
cost: OttCostFn_t = "sq_euclidean",
cost_kwargs: CostKwargs_t = types.MappingProxyType({}),
**kwargs: Any,
) -> "GENOTLinProblem":
"""Prepare the :class:`moscot.problems.generic.GENOTLinProblem`."""
self.batch_key = key
xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs)
conditions = handle_conditional_attr(conditional_attr)
xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs)
return super().prepare(
policy_key=key,
policy=policy,
xy=xy,
xx=xx,
conditions=conditions,
a=a,
b=b,
**kwargs,
)
[docs]
def solve(
self,
batch_size: int = 1024,
seed: int = 0,
iterations: int = 25000, # TODO(@MUCDK): rename to max_iterations
valid_freq: int = 50,
valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}),
train_size: float = 1.0,
**kwargs: Any,
) -> "GENOTLinProblem":
"""Solve."""
return super().solve(
batch_size=batch_size,
# tau_a=tau_a, # TODO: unbalancedness handler
# tau_b=tau_b,
seed=seed,
n_iters=iterations,
valid_freq=valid_freq,
valid_sinkhorn_kwargs=valid_sinkhorn_kwargs,
train_size=train_size,
solver_name="GENOTLinSolver",
**kwargs,
)
@property
def _base_problem_type(self) -> Type[CondOTProblem]:
return CondOTProblem
@property
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT # type: ignore[return-value]