Linear problems (advanced)#

This example shows an advanced linear problems usage like TemporalProblem, and the SinkhornProblem.

See also

Imports and data loading#

from moscot import datasets
from moscot.problems.generic import SinkhornProblem

Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="day")
AnnData object with n_obs × n_vars = 40 × 60
    obs: 'day', 'celltype'


The threshold parameter defines the convergence criterion. In the balanced setting the threshold denotes the deviation between prior and posterior marginals, while in the unbalanced setting the threshold corresponds to a Cauchy sequence stopping criterion.


Different initializers can help to improve convergence. For the full-rank case we can set the initializer to the trivial initalizing method denoted by default. The gaussian initializer [Thornton and Cuturi, 2022] computes Gaussian approximations of two point clouds and leverages the closed-form solution of optimal transport problems between Gaussians, while the sorting initializer [Thornton and Cuturi, 2022] solves a simplified (sorting) optimal transport problem and uses its solution as initializer. For low-rank problems, different initializers are available: random, rank2, k-means or generalized-k-means [Scetbon and Cuturi, 2022].

Some initializers can have additional arguments that can be provided as a dictionary, e.g., min_iterations and max_iterations can be provided for the k_means() algorithm used by the k-means initializer.

For more information, see ott.initializers.linear.

sp = SinkhornProblem(adata)
sp = sp.prepare(key="day")

ik = {"min_iterations": 5, "max_iterations": 200}
sp = sp.solve(epsilon=0, rank=3, initializer="k-means", initializer_kwargs=ik)
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Solving problem OTProblem[stage='prepared', shape=(20, 20)].                                              

Number of iterations#

There are three types of iterations, which can be set. min_iterations is the minimum number of iterations of the algorithm. max_iterations is the maximum number of iterations. If the convergence criterion is not met after completing max_iterations, the model has not converged. inner_iterations is the number of iterations after which the model checks the convergence criterion.

If max_iterations is too low, the model won’t converge:

sp = sp.solve(epsilon=1e-3, inner_iterations=1, min_iterations=0, max_iterations=2)
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                
WARNING  Solver did not converge                                                                                   

Low-rank hyperparameters#

The low-rank algorithm requires more hyperparameters, i.e., gamma, the a step size of the mirror descent algorithm and gamma_rescale, a flag indicating whether to rescale gamma at every iteration. When tuning gamma, we recommend trying orders of \(10\). If gamma is too small or too large, the algorithm might not converge.

sp = sp.solve(epsilon=0, rank=3, initializer="random", max_iterations=30, gamma=1000)
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                
WARNING  Solver did not converge                                                                                   
sp = sp.solve(epsilon=0, rank=3, initializer="random", max_iterations=30, gamma=10)
INFO     Solving problem OTProblem[stage='solved', shape=(20, 20)].                                                

Keyword arguments and implementation details#

Whenever the solve() method of a linear problem is called, a backend-specific linear solver is instantiated. Currently, ott is the only supported, its corresponding linear solvers are Sinkhorn, which is used whenever rank = -1, and LRSinkhorn, its counterpart whenever rank is a positive integer. moscot wraps these classes in SinkhornSolver which handles both full and low-rank solvers.