GWProblem.solve(epsilon=0.001, tau_a=1.0, tau_b=1.0, rank=-1, scale_cost='mean', batch_size=None, stage=('prepared', 'solved'), initializer=None, initializer_kwargs=mappingproxy({}), jit=True, min_iterations=5, max_iterations=50, threshold=0.001, linear_solver_kwargs=mappingproxy({}), device=None, **kwargs)[source]#

Solve the individual quadratic subproblems.

  • epsilon (float) – Entropic regularization.

  • tau_a (float) – Parameter in \((0, 1]\) that defines how much unbalanced is the problem on the source marginals. If \(1\), the problem is balanced.

  • tau_b (float) – Parameter in \((0, 1]\) that defines how much unbalanced is the problem on the target marginals. If \(1\), the problem is balanced.

  • rank (int) – Rank of the low-rank OT solver [Scetbon et al., 2021]. If \(-1\), full-rank solver [Peyré et al., 2016] is used.

  • scale_cost (Union[float, Literal['mean', 'max_cost', 'max_bound', 'max_norm', 'median']]) – How to re-scale the cost matrices. If a float, the cost matrices will be re-scaled as \(\frac{\text{cost}}{\text{scale_cost}}\).

  • batch_size (Optional[int]) – Number of rows/columns of the cost matrix to materialize during the solver iterations. Larger value will require more memory.

  • stage (Union[Literal['prepared', 'solved'], Tuple[Literal['prepared', 'solved'], ...]]) – Stage by which to filter the problems to be solved.

  • initializer (Optional[Literal['random', 'rank2', 'k-means', 'generalized-k-means']]) – How to initialize the solution. If None, 'default' will be used for a full-rank solver and 'rank2' for a low-rank solver.

  • initializer_kwargs (Mapping[str, Any]) – Keyword arguments for the initializer.

  • jit (bool) – Whether to jit() the underlying ott solver.

  • min_iterations (int) – Minimum number of (fused) GW iterations.

  • max_iterations (int) – Maximum number of (fused) GW iterations.

  • threshold (float) – Convergence threshold of the GW solver.

  • linear_solver_kwargs (Mapping[str, Any]) – Keyword arguments for the inner linear problem solver.

  • device (Optional[Literal['cpu', 'gpu', 'tpu']]) – Transfer the solution to a different device, see to(). If None, keep the output on the original device.

  • kwargs (Any) – Keyword arguments for solve().

Return type:

GWProblem[TypeVar(K, bound= Hashable), TypeVar(B, bound= OTProblem)]


: Returns self and updates the following fields: