This example shows how to add or remove single problems.

Adding a single problem can be useful for fine tuning, and it is sometimes needed for certain downstream functions, e.g., for compute_interpolated_distance().

• TODO: link to other relevant examples

from moscot import datasets
from moscot.problems.time import TemporalProblem


Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=4, key="day")

AnnData object with n_obs × n_vars = 80 × 60
obs: 'day', 'celltype'


## Prepare and solve the problem#

Let’s prepare and solve the problem.

tp = TemporalProblem(adata).prepare(time_key="day").solve(epsilon=1e-2)

for key, subprob in tp.problems.items():
print(f"key: {key}, solution: {subprob.solution}")

INFO     Computing pca with n_comps=30 for xy using adata.X
INFO     Computing pca with n_comps=30 for xy using adata.X
INFO     Computing pca with n_comps=30 for xy using adata.X
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(20, 20)].
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(20, 20)].
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(20, 20)].
key: (0, 1), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (1, 2), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (2, 3), solution: OTTOutput[shape=(20, 20), cost=0.7871, converged=True]


## Re-solving a subproblem#

We might want to solve one of the problems again, for example because the solver did not converge, or we simply want to try different parameters. Let’s experiment with unbalancedness in the solution between days 2 and 3. Hence, we extract the subproblem and solve it again.

extracted_problem = tp.problems[2, 3]
extracted_problem = extracted_problem.solve(epsilon=1e-2, tau_a=0.95, tau_b=0.95)
extracted_problem.solution

OTTOutput[shape=(20, 20), cost=0.39, converged=True]


After re-solving the subproblem, we add it back to the TemporalProblem.

tp = tp.add_problem((2, 3), extracted_problem, overwrite=True)
for key, subprob in tp.problems.items():
print(f"key: {key}, solution: {subprob.solution}")

key: (0, 1), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (1, 2), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (2, 3), solution: OTTOutput[shape=(20, 20), cost=0.39, converged=True]