Linear problems#
This example shows how to solve linear problems, e.g., the TemporalProblem
, and the SinkhornProblem
.
See also
See Linear problems (advanced) for an advanced example how to solve linear problems.
See Quadratic problems for an introduction to solving quadratic problems.
See Quadratic problems (advanced) for an advanced example how to solve quadratic problems.
Imports and data loading#
from moscot import datasets
from moscot.problems.generic import SinkhornProblem
import numpy as np
Simulate data using simulate_data()
.
adata = datasets.simulate_data(n_distributions=2, key="day")
adata
AnnData object with n_obs × n_vars = 40 × 60
obs: 'day', 'celltype'
Basic parameters#
epsilon
is the regularization parameter. The lower theepsilon
, the sparser the transport map. At the same time, the algorithm takes longer to converge.tau_a
andtau_b
denote the unbalancedness parameters in the source and the target distribution, respectively. \(\text{tau}_a = 1\) means the source marginals have to be fully satisfied while \(0 < \text{tau}_a < 1\) relaxes this condition. Analogously,tau_b
affects the marginals of the target distribution. We demonstrate the effect oftau_a
andtau_b
with theSinkhornProblem
.
Whenever the prior marginals a
and b
of the source and the target distribution,
respectively, are not passed, they are set to be uniform.
sp = SinkhornProblem(adata)
sp = sp.prepare(key="day")
print(sp[0, 1].a[:5], sp[0, 1].b[:5])
INFO Computing pca with `n_comps=30` for `xy` using `adata.X`
[0.05 0.05 0.05 0.05 0.05] [0.05 0.05 0.05 0.05 0.05]
First, we solve the problem in a balanced manner, such that the posterior marginals of the solution (the sum over the rows and the columns for the source marginals and the target marginals, respectively) are equal to the prior marginals up to small errors (which define the convergence criterion in the balanced case).
sp = sp.solve(epsilon=1e-2, tau_a=1, tau_b=1)
sp[0, 1].solution.a[:5], sp[0, 1].solution.b[:5]
INFO Solving problem OTProblem[stage='prepared', shape=(20, 20)].
(Array([0.04999981, 0.05000006, 0.04999965, 0.04999992, 0.04999999], dtype=float32),
Array([0.05004844, 0.04996916, 0.04996588, 0.04997035, 0.04996975], dtype=float32))
If we solve an unbalanced problem, the posterior marginals will be different.
sp = sp.solve(epsilon=1e-2, tau_a=0.9, tau_b=0.99)
sp[0, 1].solution.a[:5], sp[0, 1].solution.b[:5]
INFO Solving problem OTProblem[stage='solved', shape=(20, 20)].
(Array([0.02987743, 0.02270868, 0.03522239, 0.00888421, 0.03033637], dtype=float32),
Array([0.02467274, 0.02773538, 0.02489461, 0.02233962, 0.02674251], dtype=float32))
Low-rank solutions#
Whenever the dataset is very large, the computational complexity can be
reduced by setting rank
to a positive integer [Scetbon et al., 2021]. In this
case, epsilon
can also be set to \(0\), while only the balanced case
(\(\text{tau}_a = \text{tau}_b = 1\)) is supported. The rank
should be significantly
smaller than the number of cells in both source and target distribution.
sp = sp.solve(epsilon=0, rank=3)
INFO Solving problem OTProblem[stage='solved', shape=(20, 20)].
Scaling the cost#
scale_cost
scales the cost matrix which often helps the algorithm to converge.
While any number can be passed, it is also possible to scale the cost matrix
by e.g. its mean
, median
, and maximum
. We recommend using the mean
as this
is possible without instantiating the cost matrix and hence reduces computational
complexity. Moreover, it is more stable w.r.t. the outliers than, e.g., scaling
by the maximum. Note that the solution of the optimal transport is not stable
across different scalings:
sp = sp.solve(epsilon=1e-2, scale_cost="mean")
tm_mean = sp[0, 1].solution.transport_matrix
tm_mean[:3, :3]
INFO Solving problem OTProblem[stage='solved', shape=(20, 20)].
Array([[2.5468854e-16, 3.0775851e-08, 3.3745863e-18],
[5.2715894e-20, 1.0699465e-21, 4.7493223e-02],
[1.2640489e-05, 9.2377137e-09, 2.2984659e-03]], dtype=float32)
sp = sp.solve(epsilon=1e-2, scale_cost="max_cost")
tm_max = sp[0, 1].solution.transport_matrix
tm_max[:3, :3]
INFO Solving problem OTProblem[stage='solved', shape=(20, 20)].
Array([[2.1109562e-11, 4.0229529e-06, 1.1995808e-12],
[1.0332924e-13, 8.7991570e-15, 4.4901680e-02],
[1.5114920e-04, 1.4853405e-06, 4.1852128e-03]], dtype=float32)
We can compute the correlation of the flattened transport matrix to get an idea of the influence of different scalings.
correlation = np.corrcoef(tm_mean.flatten(), tm_max.flatten())[0, 1]
correlation
0.9929824680375936