Adding and removing problems#

This example shows how to add or remove single problems.

Adding a single problem can be useful for fine tuning, and it is sometimes needed for certain downstream functions, e.g., for compute_interpolated_distance().

See also

  • TODO: link to other relevant examples

Imports and data loading#

from moscot import datasets
from moscot.problems.time import TemporalProblem

Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=4, key="day")
adata
AnnData object with n_obs × n_vars = 80 × 60
    obs: 'day', 'celltype'

Prepare and solve the problem#

Let’s prepare and solve the problem.

tp = TemporalProblem(adata).prepare(time_key="day").solve(epsilon=1e-2)

for key, subprob in tp.problems.items():
    print(f"key: {key}, solution: {subprob.solution}")
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(20, 20)].                                      
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(20, 20)].                                      
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(20, 20)].                                      
key: (0, 1), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (1, 2), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (2, 3), solution: OTTOutput[shape=(20, 20), cost=0.7871, converged=True]

Re-solving a subproblem#

We might want to solve one of the problems again, for example because the solver did not converge, or we simply want to try different parameters. Let’s experiment with unbalancedness in the solution between days 2 and 3. Hence, we extract the subproblem and solve it again.

extracted_problem = tp.problems[2, 3]
extracted_problem = extracted_problem.solve(epsilon=1e-2, tau_a=0.95, tau_b=0.95)
extracted_problem.solution
OTTOutput[shape=(20, 20), cost=0.39, converged=True]

After re-solving the subproblem, we add it back to the TemporalProblem.

tp = tp.add_problem((2, 3), extracted_problem, overwrite=True)
for key, subprob in tp.problems.items():
    print(f"key: {key}, solution: {subprob.solution}")
key: (0, 1), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (1, 2), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (2, 3), solution: OTTOutput[shape=(20, 20), cost=0.39, converged=True]