Adding and removing problems#
This example shows how to add or remove single problems.
Adding a single problem can be useful for fine tuning, and it is sometimes needed
for certain downstream functions, e.g., for compute_interpolated_distance()
.
See also
TODO: link to other relevant examples
Imports and data loading#
from moscot import datasets
from moscot.problems.time import TemporalProblem
Simulate data using simulate_data()
.
adata = datasets.simulate_data(n_distributions=4, key="day")
adata
AnnData object with n_obs × n_vars = 80 × 60
obs: 'day', 'celltype'
Prepare and solve the problem#
Let’s prepare and solve the problem.
tp = TemporalProblem(adata).prepare(time_key="day").solve(epsilon=1e-2)
for key, subprob in tp.problems.items():
print(f"key: {key}, solution: {subprob.solution}")
INFO Computing pca with `n_comps=30` for `xy` using `adata.X`
INFO Computing pca with `n_comps=30` for `xy` using `adata.X`
INFO Computing pca with `n_comps=30` for `xy` using `adata.X`
INFO Solving problem BirthDeathProblem[stage='prepared', shape=(20, 20)].
INFO Solving problem BirthDeathProblem[stage='prepared', shape=(20, 20)].
INFO Solving problem BirthDeathProblem[stage='prepared', shape=(20, 20)].
key: (0, 1), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (1, 2), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (2, 3), solution: OTTOutput[shape=(20, 20), cost=0.7871, converged=True]
Re-solving a subproblem#
We might want to solve one of the problems again, for example because the solver did not converge, or we simply want to try different parameters. Let’s experiment with unbalancedness in the solution between days 2
and 3
. Hence, we extract the subproblem and solve it again.
extracted_problem = tp.problems[2, 3]
extracted_problem = extracted_problem.solve(epsilon=1e-2, tau_a=0.95, tau_b=0.95)
extracted_problem.solution
OTTOutput[shape=(20, 20), cost=0.39, converged=True]
After re-solving the subproblem, we add it back to the TemporalProblem
.
tp = tp.add_problem((2, 3), extracted_problem, overwrite=True)
for key, subprob in tp.problems.items():
print(f"key: {key}, solution: {subprob.solution}")
key: (0, 1), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (1, 2), solution: OTTOutput[shape=(20, 20), cost=0.7858, converged=True]
key: (2, 3), solution: OTTOutput[shape=(20, 20), cost=0.39, converged=True]