TemporalProblem.solve(epsilon=0.001, tau_a=1.0, tau_b=1.0, rank=-1, scale_cost='mean', batch_size=None, stage=('prepared', 'solved'), initializer=None, initializer_kwargs=mappingproxy({}), jit=True, threshold=0.001, lse_mode=True, inner_iterations=10, min_iterations=0, max_iterations=2000, device=None, **kwargs)[source]#

Solve the temporal problem.

  • epsilon (float) – Entropic regularization.

  • tau_a (float) – Parameter in \((0, 1]\) that defines how much unbalanced is the problem on the source marginals. If \(1\), the problem is balanced.

  • tau_b (float) – Parameter in \((0, 1]\) that defines how much unbalanced is the problem on the target marginals. If \(1\), the problem is balanced.

  • rank (int) – Rank of the low-rank OT solver [Scetbon et al., 2021]. If \(-1\), full-rank solver [Cuturi, 2013] is used.

  • scale_cost (Union[float, Literal['mean', 'max_cost', 'max_bound', 'max_norm', 'median']]) – How to re-scale the cost matrix. If a float, the cost matrix will be re-scaled as \(\frac{\text{cost}}{\text{scale_cost}}\).

  • batch_size (Optional[int]) – Number of rows/columns of the cost matrix to materialize during the Sinkhorn iterations. Larger value will require more memory.

  • stage (Union[Literal['prepared', 'solved'], Tuple[Literal['prepared', 'solved'], ...]]) – Stage by which to filter the problems to be solved.

  • initializer (Union[Literal['default', 'gaussian', 'sorting'], Literal['random', 'rank2', 'k-means', 'generalized-k-means'], None]) – How to initialize the solution. If None, 'default' will be used for a full-rank solver and 'rank2' for a low-rank solver.

  • initializer_kwargs (Mapping[str, Any]) – Keyword arguments for the initializer.

  • jit (bool) – Whether to jit() the underlying ott solver.

  • threshold (float) – Convergence threshold of the Sinkhorn algorithm. In the balanced case, this is typically the deviation between the target marginals and the marginals of the current transport matrix. In the unbalanced case, the relative change between the successive solutions is checked.

  • lse_mode (bool) – Whether to use log-sum-exp (LSE) computations for numerical stability.

  • inner_iterations (int) – Compute the convergence criterion every inner_iterations.

  • min_iterations (int) – Minimum number of Sinkhorn iterations.

  • max_iterations (int) – Maximum number of Sinkhorn iterations.

  • device (Optional[Literal['cpu', 'gpu', 'tpu']]) – Transfer the solution to a different device, see to(). If None, keep the output on the original device.

  • kwargs (Any) – Keyword arguments for solve().

Return type:



: Returns self and updates the following fields: