Quadratic problems#

This example shows how to solve quadratic problems, e.g., the LineageProblem, the SpatioTemporalProblem, the MappingProblem, the AlignmentProblem, and the GWProblem.

See also

Imports and data loading#

import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import GWProblem

import numpy as np

import scanpy as sc

Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="batch", quad_term="spatial")
sc.pp.pca(adata)
adata
AnnData object with n_obs × n_vars = 40 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'spatial', 'X_pca'
    varm: 'PCs'

Basic parameters#

There are some parameters in quadratic problems which play the same role as in linear problems. Hence, we refer to Linear problems for the role of epsilon, tau_a, and tau_b. In fused quadratic problems (also referred to as Fused Gromov-Wasserstein) there is an additional parameter alpha defining the convex combination between the quadratic and the linear term, defined by joint_attr. Setting alpha = 1 only considers the pure quadratic problem, ignoring joint_attr. Setting alpha = 0 is not possible, and hence linear problems must be chosen.

gwp = GWProblem(adata)
gwp = gwp.prepare(
    key="batch",
    x_attr={"attr": "obsm", "key": "spatial"},
    y_attr={"attr": "obsm", "key": "spatial"},
)
gwp = gwp.solve(alpha=1.0, epsilon=1e-1)

fgwp = GWProblem(adata)
fgwp = fgwp.prepare(
    key="batch",
    x_attr={"attr": "obsm", "key": "spatial"},
    y_attr={"attr": "obsm", "key": "spatial"},
    joint_attr="X_pca",
)
fgwp = fgwp.solve(epsilon=1e-1, alpha=0.5)

max_difference = np.max(
    np.abs(
        gwp["0", "1"].solution.transport_matrix
        - fgwp["0", "1"].solution.transport_matrix
    )
)
print(f"max difference: {max_difference:.6f}")
INFO     Ordering Index(['0-0', '1-0', '2-0', '3-0', '4-0', '5-0', '6-0', '7-0', '8-0', '9-0',                     
                '10-0', '11-0', '12-0', '13-0', '14-0', '15-0', '16-0', '17-0', '18-0',                            
                '19-0', '0-1', '1-1', '2-1', '3-1', '4-1', '5-1', '6-1', '7-1', '8-1',                             
                '9-1', '10-1', '11-1', '12-1', '13-1', '14-1', '15-1', '16-1', '17-1',                             
                '18-1', '19-1'],                                                                                   
               dtype='object') in ascending order.                                                                 
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].                                              
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
INFO     Ordering Index(['0-0', '1-0', '2-0', '3-0', '4-0', '5-0', '6-0', '7-0', '8-0', '9-0',                     
                '10-0', '11-0', '12-0', '13-0', '14-0', '15-0', '16-0', '17-0', '18-0',                            
                '19-0', '0-1', '1-1', '2-1', '3-1', '4-1', '5-1', '6-1', '7-1', '8-1',                             
                '9-1', '10-1', '11-1', '12-1', '13-1', '14-1', '15-1', '16-1', '17-1',                             
                '18-1', '19-1'],                                                                                   
               dtype='object') in ascending order.                                                                 
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].                                              
max difference: 0.021854

Low-rank solutions#

Whenever the dataset is very large, the computational complexity can be reduced by setting rank to a positive integer [Scetbon et al., 2021]. In this case, epsilon can also be set to \(0\), while only the balanced case (\(\text{tau}_a = \text{tau}_b = 1\)) is supported. Moreover, the data has to be provided as point clouds, i.e., no precomputed cost matrix can be passed.

gwp = gwp.solve(epsilon=1e-2, rank=3)
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                

Scaling the cost#

scale_cost parameter works the same way as for linear problems, see Linear problems for more information. Note that all cost terms will be scaled by the same argument.