This example shows how to solve quadratic problems, e.g., the LineageProblem, the SpatioTemporalProblem, the MappingProblem, the AlignmentProblem, and the GWProblem.

from moscot import datasets
from moscot.problems.generic import GWProblem

import numpy as np

import scanpy as sc


Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="batch", quad_term="spatial")

AnnData object with n_obs × n_vars = 40 × 60
obs: 'batch', 'celltype'
uns: 'pca'
obsm: 'spatial', 'X_pca'
varm: 'PCs'


## Basic parameters#

There are some parameters in quadratic problems which play the same role as in linear problems. Hence, we refer to Linear problems for the role of epsilon, tau_a, and tau_b. In fused quadratic problems (also referred to as Fused Gromov-Wasserstein) there is an additional parameter alpha defining the convex combination between the quadratic and the linear term, defined by joint_attr. Setting alpha = 1 only considers the pure quadratic problem, ignoring joint_attr. Setting alpha = 0 is not possible, and hence linear problems must be chosen.

gwp = GWProblem(adata)
gwp = gwp.prepare(
key="batch",
GW_x={"attr": "obsm", "key": "spatial"},
GW_y={"attr": "obsm", "key": "spatial"},
)
gwp = gwp.solve(alpha=1.0, epsilon=1e-1)

fgwp = fgwp.prepare(
key="batch",
GW_x={"attr": "obsm", "key": "spatial"},
GW_y={"attr": "obsm", "key": "spatial"},
joint_attr="X_pca",
)
fgwp = fgwp.solve(epsilon=1e-1, alpha=0.5)

max_difference = np.max(
np.abs(
gwp["0", "1"].solution.transport_matrix
- fgwp["0", "1"].solution.transport_matrix
)
)
print(f"{max_difference:.6f}")

INFO     Computing pca with n_comps=30 for xy using adata.X
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].
0.021854


## Low-rank solutions#

Whenever the dataset is very large, the computational complexity can be reduced by setting rank to a positive integer . In this case, epsilon can also be set to $$0$$, while only the balanced case ($$\text{tau}_a = \text{tau_b} = 1$$) is supported. Moreover, the data has to be provided as point clouds, i.e., no precomputed cost matrix can be passed.

gwp = gwp.solve(epsilon=1e-2, rank=3)

INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].


## Scaling the cost#

scale_cost parameter works the same way as for linear problems, see Linear problems for more information. Note that all cost terms will be scaled by the same argument.