Linear problems#

This example shows how to solve linear problems, e.g., the TemporalProblem, and the SinkhornProblem.

See also

Imports and data loading#

from moscot import datasets
from moscot.problems.generic import SinkhornProblem

import numpy as np

Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="day")
adata
AnnData object with n_obs × n_vars = 40 × 60
    obs: 'day', 'celltype'

Basic parameters#

  • epsilon is the regularization parameter. The lower the epsilon, the sparser the transport map. At the same time, the algorithm takes longer to converge.

  • tau_a and tau_b denote the unbalancedness parameters in the source and the target distribution, respectively. \(\text{tau}_a = 1\) means the source marginals have to be fully satisfied while \(0 < \text{tau}_a < 1\) relaxes this condition. Analogously, tau_b affects the marginals of the target distribution. We demonstrate the effect of tau_a and tau_b with the SinkhornProblem.

Whenever the prior marginals a and b of the source and the target distribution, respectively, are not passed, they are set to be uniform.

sp = SinkhornProblem(adata)
sp = sp.prepare(key="day")
print(sp[0, 1].a[:5], sp[0, 1].b[:5])
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
[0.05 0.05 0.05 0.05 0.05] [0.05 0.05 0.05 0.05 0.05]

First, we solve the problem in a balanced manner, such that the posterior marginals of the solution (the sum over the rows and the columns for the source marginals and the target marginals, respectively) are equal to the prior marginals up to small errors (which define the convergence criterion in the balanced case).

sp = sp.solve(epsilon=1e-2, tau_a=1, tau_b=1)
sp[0, 1].solution.a[:5], sp[0, 1].solution.b[:5]
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].                                              
(Array([0.04999981, 0.05000006, 0.04999965, 0.04999992, 0.04999999],      dtype=float32),
 Array([0.05004844, 0.04996916, 0.04996588, 0.04997035, 0.04996975],      dtype=float32))

If we solve an unbalanced problem, the posterior marginals will be different.

sp = sp.solve(epsilon=1e-2, tau_a=0.9, tau_b=0.99)
sp[0, 1].solution.a[:5], sp[0, 1].solution.b[:5]
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                
(Array([0.02987743, 0.02270868, 0.03522239, 0.00888421, 0.03033637],      dtype=float32),
 Array([0.02467274, 0.02773538, 0.02489461, 0.02233962, 0.02674251],      dtype=float32))

Low-rank solutions#

Whenever the dataset is very large, the computational complexity can be reduced by setting rank to a positive integer [Scetbon et al., 2021]. In this case, epsilon can also be set to \(0\), while only the balanced case (\(\text{tau}_a = \text{tau}_b = 1\)) is supported. The rank should be significantly smaller than the number of cells in both source and target distribution.

sp = sp.solve(epsilon=0, rank=3)
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                

Scaling the cost#

scale_cost scales the cost matrix which often helps the algorithm to converge. While any number can be passed, it is also possible to scale the cost matrix by e.g. its mean, median, and maximum. We recommend using the mean as this is possible without instantiating the cost matrix and hence reduces computational complexity. Moreover, it is more stable w.r.t. the outliers than, e.g., scaling by the maximum. Note that the solution of the optimal transport is not stable across different scalings:

sp = sp.solve(epsilon=1e-2, scale_cost="mean")
tm_mean = sp[0, 1].solution.transport_matrix
tm_mean[:3, :3]
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                
Array([[2.5468854e-16, 3.0775851e-08, 3.3745863e-18],
       [5.2715894e-20, 1.0699465e-21, 4.7493223e-02],
       [1.2640489e-05, 9.2377137e-09, 2.2984659e-03]], dtype=float32)
sp = sp.solve(epsilon=1e-2, scale_cost="max_cost")
tm_max = sp[0, 1].solution.transport_matrix
tm_max[:3, :3]
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                
Array([[2.1109562e-11, 4.0229529e-06, 1.1995808e-12],
       [1.0332924e-13, 8.7991570e-15, 4.4901680e-02],
       [1.5114920e-04, 1.4853405e-06, 4.1852128e-03]], dtype=float32)

We can compute the correlation of the flattened transport matrix to get an idea of the influence of different scalings.

correlation = np.corrcoef(tm_mean.flatten(), tm_max.flatten())[0, 1]
correlation
0.9929824680375936