This example shows an advanced quadratic problems usage, e.g., the LineageProblem, the SpatioTemporalProblem, the MappingProblem, the AlignmentProblem, and the GWProblem.

from moscot import datasets
from moscot.problems.generic import GWProblem

import scanpy as sc


Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="batch", quad_term="spatial")

AnnData object with n_obs × n_vars = 40 × 60
obs: 'batch', 'celltype'
uns: 'pca'
obsm: 'spatial', 'X_pca'
varm: 'PCs'

gwp = GWProblem(adata)
gwp = gwp.prepare(
key="batch",
GW_x={"attr": "obsm", "key": "spatial"},
GW_y={"attr": "obsm", "key": "spatial"},
)
gwp

INFO     Computing pca with n_comps=30 for xy using adata.X

GWProblem[('0', '1')]


## Threshold#

The threshold parameter defines the convergence criterion. In the balanced setting the threshold denotes the deviation between prior and posterior marginals, while in the unbalanced setting the threshold corresponds to a Cauchy sequence stopping criterion.

## Initializers#

Different Initializers can help to improve convergence. For the full-rank case only the default initializer exists, hence the initializer argument must be set to None.

For low-rank problems the same initializers as for the linear low-rank solvers are available, and initializer_kwargs can be passed the same way, see Linear problems (advanced) for more information.

## Number of iterations#

To solve a quadratic optimal transport problem, a consecutively-updated linearized problem is solved. Here, min_iterations denotes a lower bound and max_iterations an upper bound on the number of outer iterations. If max_iterations is too low, the solver might not converge.

gwp = gwp.solve(alpha=0.5, epsilon=1e-1, min_iterations=0, max_iterations=1)

INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].
WARNING  Solver did not converge


## Linear solver keyword arguments#

As mentioned above, each outer loop step of the Gromov-Wasserstein algorithm consists of solving a linear problem. Arguments for the linear solver can be specified via linear_solver_kwargs, keyword arguments for Sinkhorn in the full-rank case or keyword arguments for LRSinkhorn, respectively. This way, we can also set the minimum and maximum number of iterations for the linear solver:

ls_kwargs = {"min_iterations": 10, "max_iterations": 1000, "threshold": 0.01}
gwp = gwp.solve(
alpha=0.5,
epsilon=1e-1,
threshold=0.1,
min_iterations=2,
max_iterations=20,
linear_solver_kwargs=ls_kwargs,
)

INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].
WARNING  Solver did not converge


## Low-rank hyperparameters#

The parameters gamma and gamma_rescale are the same as in the linear case, see example Linear problems (advanced).

## Keyword arguments and implementation details#

Whenever the solve method of a quadratic problem is called, a backend-specific quadratic solver is instantiated. Currently, ott is supported, its corresponding quadratic solvers is GromovWasserstein, handling both the full-rank and the low-rank case. moscot wraps this class in GWSolver, handling both the purely quadratic and the fused quadratic problem.