Quadratic problems#

This example shows how to solve quadratic problems, e.g., the LineageProblem, the SpatioTemporalProblem, the MappingProblem, the AlignmentProblem, the GWProblem, and the FGWProblem.

See also

Imports and data loading#

import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import FGWProblem, GWProblem

import numpy as np

import scanpy as sc

Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="batch", quad_term="spatial")
sc.pp.pca(adata)
adata
AnnData object with n_obs × n_vars = 40 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'spatial', 'X_pca'
    varm: 'PCs'

Basic parameters#

There are some parameters in quadratic problems which play the same role as in linear problems. Hence, we refer to Linear problems for the role of epsilon, tau_a, and tau_b. In fused quadratic problems (also referred to as Fused Gromov-Wasserstein) there is an additional parameter alpha defining the convex combination between the quadratic and the linear term, defined by joint_attr. Setting alpha = 1 only considers the pure quadratic problem, ignoring joint_attr. Setting alpha = 0 is not possible, and hence linear problems must be chosen.

gwp = GWProblem(adata)
gwp = gwp.prepare(
    key="batch",
    x_attr={"attr": "obsm", "key": "spatial"},
    y_attr={"attr": "obsm", "key": "spatial"},
)
gwp = gwp.solve(epsilon=1e-1)

fgwp = FGWProblem(adata)
fgwp = fgwp.prepare(
    key="batch",
    x_attr={"attr": "obsm", "key": "spatial"},
    y_attr={"attr": "obsm", "key": "spatial"},
    joint_attr="X_pca",
)
fgwp = fgwp.solve(epsilon=1e-1, alpha=0.5)

max_difference = np.max(
    np.abs(
        gwp["0", "1"].solution.transport_matrix
        - fgwp["0", "1"].solution.transport_matrix
    )
)
print(f"max difference: {max_difference:.6f}")
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].                                              
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].                                              
max difference: 0.021854

Low-rank solutions#

Whenever the dataset is very large, the computational complexity can be reduced by setting rank to a positive integer [Scetbon et al., 2021]. In this case, epsilon can also be set to \(0\), while only the balanced case (\(\text{tau}_a = \text{tau}_b = 1\)) is supported. Moreover, the data has to be provided as point clouds, i.e., no precomputed cost matrix can be passed.

gwp = gwp.solve(epsilon=1e-2, rank=3)
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                

Scaling the cost#

scale_cost parameter works the same way as for linear problems, see Linear problems for more information. Note that all cost terms will be scaled by the same argument.