This example shows how to solve quadratic problems, e.g., the LineageProblem, the SpatioTemporalProblem, the MappingProblem, the AlignmentProblem, the GWProblem, and the FGWProblem.

import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import FGWProblem, GWProblem

import numpy as np

import scanpy as sc


Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="batch", quad_term="spatial")

AnnData object with n_obs × n_vars = 40 × 60
obs: 'batch', 'celltype'
uns: 'pca'
obsm: 'spatial', 'X_pca'
varm: 'PCs'


## Basic parameters#

There are some parameters in quadratic problems which play the same role as in linear problems. Hence, we refer to Linear problems for the role of epsilon, tau_a, and tau_b. In fused quadratic problems (also referred to as Fused Gromov-Wasserstein) there is an additional parameter alpha defining the convex combination between the quadratic and the linear term, defined by joint_attr. Setting alpha = 1 only considers the pure quadratic problem, ignoring joint_attr. Setting alpha = 0 is not possible, and hence linear problems must be chosen.

gwp = GWProblem(adata)
gwp = gwp.prepare(
key="batch",
x_attr={"attr": "obsm", "key": "spatial"},
y_attr={"attr": "obsm", "key": "spatial"},
)
gwp = gwp.solve(epsilon=1e-1)

fgwp = fgwp.prepare(
key="batch",
x_attr={"attr": "obsm", "key": "spatial"},
y_attr={"attr": "obsm", "key": "spatial"},
joint_attr="X_pca",
)
fgwp = fgwp.solve(epsilon=1e-1, alpha=0.5)

max_difference = np.max(
np.abs(
gwp["0", "1"].solution.transport_matrix
- fgwp["0", "1"].solution.transport_matrix
)
)
print(f"max difference: {max_difference:.6f}")

INFO     Solving 1 problems
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

INFO     Solving 1 problems
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].
max difference: 0.021854


## Low-rank solutions#

Whenever the dataset is very large, the computational complexity can be reduced by setting rank to a positive integer . In this case, epsilon can also be set to $$0$$, while only the balanced case ($$\text{tau}_a = \text{tau}_b = 1$$) is supported. Moreover, the data has to be provided as point clouds, i.e., no precomputed cost matrix can be passed.

gwp = gwp.solve(epsilon=1e-2, rank=3)

INFO     Solving 1 problems
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].


## Scaling the cost#

scale_cost parameter works the same way as for linear problems, see Linear problems for more information. Note that all cost terms will be scaled by the same argument.