Custom cost matrices

This example shows how to use custom cost matrices.

Even it is recommended to pass the data as opposed to pre-computed cost matrices due to computational complexity, we demonstrate in the following how to pass custom cost matrices.

There are two ways to pass custom cost matrices, either by setting them on a single problem level after preparing the problem or by passing it in the obsp layer.

See also

Imports and data loading

import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import FGWProblem

import numpy as np
import pandas as pd

import scanpy as sc

Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=3, key="batch")
AnnData object with n_obs × n_vars = 60 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'X_pca'
    varm: 'PCs'

Prepare the problem

The first option is to prepare the problem in an arbitrary way and override the cost terms of the single OT problems after.

fgw = FGWProblem(adata)
fgw = fgw.prepare(key="batch", x_attr="X_pca", y_attr="X_pca")
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
FGWProblem[('1', '2'), ('0', '1')]

Passing cost matrices one by one

We can pass the custom cost matrices by accessing the OTProblem. The method set_xy() allows to pass a custom cost matrix for the linear term. set_x() allows to set a custom cost matrix for the quadratic term corresponding to the source distribution, while set_y() works analogously for the quadratic term in the target distribution.

When using the above-mentioned methods, we need to pass a DataFrame to ensure that the order of the rows and columns of the cost matrix is correct. In the following we retrieve the cell names to construct the DataFrame containing custom cost matrices.

rng = np.random.default_rng(seed=42)
obs_names_0 = fgw["0", "1"].adata_src.obs_names
obs_names_1 = fgw["0", "1"].adata_tgt.obs_names

cost_linear_01 = np.abs(rng.normal(size=(len(obs_names_0), len(obs_names_1))))
cost_quad_0 = np.abs(rng.normal(size=(len(obs_names_0), len(obs_names_0))))
np.fill_diagonal(cost_quad_0, 0)
cost_quad_1 = np.abs(rng.normal(size=(len(obs_names_1), len(obs_names_1))))
np.fill_diagonal(cost_quad_1, 0)

cm_linear = pd.DataFrame(data=cost_linear_01, index=obs_names_0, columns=obs_names_1)
cm_quad_0 = pd.DataFrame(data=cost_quad_0, index=obs_names_0, columns=obs_names_0)
cm_quad_1 = pd.DataFrame(data=cost_quad_1, index=obs_names_1, columns=obs_names_1)

Now we can set the custom cost matrices:

fgw["0", "1"].set_xy(cm_linear, tag="cost_matrix")
fgw["0", "1"].set_x(cm_quad_0, tag="cost_matrix")
fgw["0", "1"].set_y(cm_quad_1, tag="cost_matrix")

When solving the problem, the custom cost matrices will be used for the problem mapping from batch '0' to batch '1', while the problem mapping from batch '1' to batch '2' is still using the information passed in prepare().

Cost matrices in the obsp layer

A second way to pass custom cost matrices is using obsp. This is especially useful when saving and loading a model. On the other hand, it might be more difficult to store the cost matrix in the correct place in obsp. In the following, we construct the obsp layer. When doing this, be sure that the order of the cost matrix entries are correct. In the following, we construct an obsp layer containing custom cost matrices for both linear and quadratic terms for both OT problems, mapping from batch '0' to batch '1', and from batch '1' to batch '2'.

obs_names_2 = fgw["1", "2"].adata_tgt.obs_names

cost_linear_12 = np.abs(rng.normal(size=(len(obs_names_1), len(obs_names_2))))
cost_quad_2 = np.abs(rng.normal(size=(len(obs_names_2), len(obs_names_2))))
np.fill_diagonal(cost_quad_2, 0)


blocks = [
    [cost_quad_0, cost_linear_01, np.zeros((len(obs_names_0), len(obs_names_2)))],
    [np.zeros((len(obs_names_1), len(obs_names_0))), cost_quad_1, cost_linear_12],
        np.zeros((len(obs_names_2), len(obs_names_0))),
        np.zeros((len(obs_names_2), len(obs_names_1))),

adata.obsp["cost_matrices"] = np.block(blocks)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(60, 60)

We need to specify where to fetch the custom cost matrices in the prepare() methods. If we want to only use the linear custom cost matrix, we need to modify the joint_attr as follows:

joint_attr = {"key": "cost_matrices", "tag": "cost_matrix"}
fgw = fgw.prepare(key="batch", joint_attr=joint_attr, x_attr="X_pca", y_attr="X_pca")

If we want to use only quadratic custom cost matrices, we need to modify x_attr and y_attr.

x_attr = {
    "attr": "obsp",
    "key": "cost_matrices",
    "tag": "cost_matrix",
    "cost": "custom",
y_attr = {
    "attr": "obsp",
    "key": "cost_matrices",
    "tag": "cost_matrix",
    "cost": "custom",
fgw = fgw.prepare(key="batch", joint_attr="X_pca", x_attr=x_attr, y_attr=y_attr)
fgw[("0", "1")]
OTProblem[stage='prepared', shape=(20, 20)]

If we want to use custom cost matrices for all terms, we can do this the following way:

fgw = fgw.prepare(key="batch", joint_attr=joint_attr, x_attr=x_attr, y_attr=y_attr)
fgw[("1", "2")]
OTProblem[stage='prepared', shape=(20, 20)]