SpatioTemporalProblem.compute_interpolated_distance(source, intermediate, target, interpolation_parameter=None, n_interpolated_cells=None, account_for_unbalancedness=False, batch_size=256, posterior_marginals=True, seed=None, backend='ott', **kwargs)

Compute Wasserstein distance between OT-interpolated and intermediate cells.

See also

  • TODO(MUCDK): create an example showing the usage.

This is a validation method which interpolates cells between the source and target distributions leveraging the OT coupling to approximate cells at the intermediate time point.

  • source (TypeVar(K, bound= Hashable)) – Key identifying the source distribution.

  • intermediate (TypeVar(K, bound= Hashable)) – Key identifying the intermediate distribution.

  • target (TypeVar(K, bound= Hashable)) – Key identifying the target distribution.

  • interpolation_parameter (Optional[float]) – Interpolation parameter in \((0, 1)\) defining the weight of the source and target distributions. If None, it is linearly interpolated.

  • n_interpolated_cells (Optional[int]) – Number of cells used for interpolation. If None, use the number of cells in the intermediate distribution.

  • account_for_unbalancedness (bool) – Whether to account for unbalancedness by assuming exponential cell growth and death.

  • batch_size (int) – Number of rows/columns of the cost matrix to materialize during push() or pull(). Larger value will require more memory.

  • posterior_marginals (bool) – Whether to use posterior_growth_rates or prior_growth_rates. TODO(MUCDK): needs more explanation

  • seed (Optional[int]) – Random seed used when sampling the interpolated cells.

  • backend (Literal['ott']) – Backend used for the distance computation.

  • kwargs (Any) –

    Keyword arguments for the distance function, depending on the backend:

  • self (TemporalMixinProtocol[K, B])

Return type:



: The distance between OT-interpolated cells and cells at the intermediate time point. It is recommended to compare this to the distances computed by compute_time_point_distances() and compute_random_distance().