Quadratic problems (advanced)#

This example shows an advanced quadratic problems usage, e.g., the LineageProblem, the SpatioTemporalProblem, the MappingProblem, the AlignmentProblem, the GWProblem, and the FGWProblem.

See also

Imports and data loading#

import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import GWProblem

import scanpy as sc

Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="batch", quad_term="spatial")
sc.pp.pca(adata)
adata
AnnData object with n_obs × n_vars = 40 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'spatial', 'X_pca'
    varm: 'PCs'
gwp = GWProblem(adata)
gwp = gwp.prepare(
    key="batch",
    x_attr={"attr": "obsm", "key": "spatial"},
    y_attr={"attr": "obsm", "key": "spatial"},
)
gwp
GWProblem[('0', '1')]

Threshold#

The threshold parameter defines the convergence criterion. In the balanced setting the threshold denotes the deviation between prior and posterior marginals, while in the unbalanced setting the threshold corresponds to a Cauchy sequence stopping criterion.

Initializers#

Different Initializers can help to improve convergence. For the full-rank case only the default initializer exists, hence the initializer argument must be set to None.

For low-rank problems the same initializers as for the linear low-rank solvers are available, and initializer_kwargs can be passed the same way, see Linear problems (advanced) for more information.

Number of iterations#

To solve a quadratic optimal transport problem, a consecutively-updated linearized problem is solved. Here, min_iterations denotes a lower bound and max_iterations an upper bound on the number of outer iterations. If max_iterations is too low, the solver might not converge.

gwp = gwp.solve(epsilon=1e-1, min_iterations=0, max_iterations=1)
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].                                              
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING  Solver did not converge                                                                                   

Linear solver keyword arguments#

As mentioned above, each outer loop step of the Gromov-Wasserstein algorithm consists of solving a linear problem. Arguments for the linear solver can be specified via linear_solver_kwargs, keyword arguments for Sinkhorn in the full-rank case or keyword arguments for LRSinkhorn, respectively. This way, we can also set the minimum and maximum number of iterations for the linear solver:

ls_kwargs = {"min_iterations": 10, "max_iterations": 1000, "threshold": 0.01}
gwp = gwp.solve(
    epsilon=1e-1,
    threshold=0.1,
    min_iterations=2,
    max_iterations=20,
    linear_solver_kwargs=ls_kwargs,
)
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                

Low-rank hyperparameters#

The parameters gamma and gamma_rescale are the same as in the linear case, see example Linear problems (advanced).

Keyword arguments and implementation details#

Whenever the solve method of a quadratic problem is called, a backend-specific quadratic solver is instantiated. Currently, ott is supported, its corresponding quadratic solvers is GromovWasserstein, handling both the full-rank and the low-rank case. moscot wraps this class in GWSolver, handling both the purely quadratic and the fused quadratic problem.