moscot.problems.generic.SinkhornProblem.solve¶
- SinkhornProblem.solve(epsilon=0.001, tau_a=1.0, tau_b=1.0, rank=-1, scale_cost='mean', batch_size=None, stage=('prepared', 'solved'), initializer=None, initializer_kwargs=mappingproxy({}), jit=True, threshold=0.001, lse_mode=True, inner_iterations=10, min_iterations=None, max_iterations=None, device=None, **kwargs)[source]¶
Solve the individual linear subproblems using the Sinkhorn algorithm [Cuturi, 2013].
- Parameters:
epsilon (
float
) – Entropic regularization.tau_a (
float
) – Parameter in \((0, 1]\) that defines how much unbalanced is the problem on the source marginals. If \(1\), the problem is balanced.tau_b (
float
) – Parameter in \((0, 1]\) that defines how much unbalanced is the problem on the target marginals. If \(1\), the problem is balanced.rank (
int
) – Rank of the low-rank OT solver [Scetbon et al., 2021]. If \(-1\), full-rank solver [Cuturi, 2013] is used.scale_cost (
Union
[float
,Literal
['mean'
,'max_cost'
,'max_bound'
,'max_norm'
,'median'
]]) – How to re-scale the cost matrix. If afloat
, the cost matrix will be re-scaled as \(\frac{\text{cost}}{\text{scale_cost}}\).batch_size (
Optional
[int
]) – Number of rows/columns of the cost matrix to materialize during the Sinkhorn iterations. Larger value will require more memory.stage (
Union
[Literal
['prepared'
,'solved'
],Tuple
[Literal
['prepared'
,'solved'
],...
]]) – Stage by which to filter theproblems
to be solved.initializer (
Union
[SinkhornInitializer
,Literal
['default'
,'gaussian'
,'sorting'
],None
]) – How to initialize the solution. IfNone
,'default'
will be used for a full-rank solver and'rank2'
for a low-rank solver.initializer_kwargs (
Mapping
[str
,Any
]) – Keyword arguments for theinitializer
.threshold (
float
) – Convergence threshold of the Sinkhorn algorithm. In the balanced case, this is typically the deviation between the target marginals and the marginals of the current transport matrix. In the unbalanced case, the relative change between the successive solutions is checked.lse_mode (
bool
) – Whether to use log-sum-exp (LSE) computations for numerical stability.inner_iterations (
int
) – Compute the convergence criterion everyinner_iterations
.min_iterations (
Optional
[int
]) – Minimum number of Sinkhorn iterations.max_iterations (
Optional
[int
]) – Maximum number of Sinkhorn iterations.device (
Optional
[Literal
['cpu'
,'gpu'
,'tpu'
]]) – Transfer the solution to a different device, seeto()
. IfNone
, keep the output on the original device.**kwargs (
Any
) – The description is missing.
- Return type:
SinkhornProblem
[TypeVar
(K
, bound=Hashable
),TypeVar
(B
, bound=OTProblem
)]- Returns:
: Returns self and updates the following fields: