moscot.problems.spatiotemporal.SpatioTemporalProblem.compute_interpolated_distance¶
- SpatioTemporalProblem.compute_interpolated_distance(source, intermediate, target, interpolation_parameter=None, n_interpolated_cells=None, account_for_unbalancedness=False, batch_size=256, posterior_marginals=True, seed=None, backend='ott', **kwargs)¶
Compute Wasserstein distance between OT-interpolated and intermediate cells.
See also
TODO(MUCDK): create an example showing the usage.
This is a validation method which interpolates cells between the
source
andtarget
distributions leveraging the OT coupling to approximate cells at theintermediate
time point.- Parameters:
source (
TypeVar
(K
, bound=Hashable
)) – Key identifying the source distribution.intermediate (
TypeVar
(K
, bound=Hashable
)) – Key identifying the intermediate distribution.target (
TypeVar
(K
, bound=Hashable
)) – Key identifying the target distribution.interpolation_parameter (
Optional
[float
]) – Interpolation parameter in \((0, 1)\) defining the weight of thesource
andtarget
distributions. IfNone
, it is linearly interpolated.n_interpolated_cells (
Optional
[int
]) – Number of cells used for interpolation. IfNone
, use the number of cells in theintermediate
distribution.account_for_unbalancedness (
bool
) – Whether to account for unbalancedness by assuming exponential cell growth and death.batch_size (
int
) – Number of rows/columns of the cost matrix to materialize duringpush()
orpull()
. Larger value will require more memory.posterior_marginals (
bool
) – Whether to useposterior_growth_rates
orprior_growth_rates
. TODO(MUCDK): needs more explanationseed (
Optional
[int
]) – Random seed used when sampling the interpolated cells.backend (
Literal
['ott'
]) – Backend used for the distance computation.kwargs (
Any
) –Keyword arguments for the distance function, depending on the
backend
:'ott'
-sinkhorn_divergence()
.
self (TemporalMixinProtocol[K, B])
- Return type:
- Returns:
: The distance between OT-interpolated cells and cells at the
intermediate
time point. It is recommended to compare this to the distances computed bycompute_time_point_distances()
andcompute_random_distance()
.