[docs]defget_solver(problem_kind:ProblemKind_t,*,backend:str="ott",return_class:bool=False,**kwargs:Any)->Any:"""TODO."""ifbackendnotin_REGISTRY:raiseValueError(f"Backend `{backend!r}` is not available.")solver_class=_REGISTRY[backend](problem_kind)returnsolver_classifreturn_classelsesolver_class(**kwargs)
defregister_solver(backend:str,)->Callable[[Literal["linear","quadratic"]],Union[Type["ott.SinkhornSolver"],Type["ott.GWSolver"]]]:"""Register a solver for a specific backend. Parameters ---------- backend Name of the backend. Returns ------- The decorated function which returns the type of the solver. """return_REGISTRY.register(backend)# type: ignore[return-value]@register_solver("ott")# type: ignore[arg-type]def_(problem_kind:Literal["linear","quadratic"])->Union[Type["ott.SinkhornSolver"],Type["ott.GWSolver"]]:frommoscot.backendsimportottifproblem_kind=="linear":returnott.SinkhornSolverifproblem_kind=="quadratic":returnott.GWSolverraiseNotImplementedError(f"Unable to create solver for `{problem_kind!r}` problem.")
[docs]defget_available_backends()->Tuple[str,...]:"""Return all available backends."""returntuple(backendforbackendin_REGISTRY)