# Custom cost matrices¶

This example shows how to use custom cost matrices.

Even it is recommended to pass the data as opposed to pre-computed cost matrices due to computational complexity, we demonstrate in the following how to pass custom cost matrices.

There are two ways to pass custom cost matrices, either by setting them on a single problem level after preparing the problem or by passing it in the obsp layer.

import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import FGWProblem

import numpy as np
import pandas as pd

import scanpy as sc

Simulate data using simulate_data().

AnnData object with n_obs × n_vars = 60 × 60
obs: 'batch', 'celltype'
uns: 'pca'
obsm: 'X_pca'
varm: 'PCs'

## Prepare the problem¶

The first option is to prepare the problem in an arbitrary way and override the cost terms of the single OT problems after.

fgw = fgw.prepare(key="batch", x_attr="X_pca", y_attr="X_pca")
fgw
INFO     Computing pca with n_comps=30 for xy using adata.X
INFO     Computing pca with n_comps=30 for xy using adata.X
FGWProblem[('1', '2'), ('0', '1')]

## Passing cost matrices one by one¶

We can pass the custom cost matrices by accessing the OTProblem. The method set_xy() allows to pass a custom cost matrix for the linear term. set_x() allows to set a custom cost matrix for the quadratic term corresponding to the source distribution, while set_y() works analogously for the quadratic term in the target distribution.

When using the above-mentioned methods, we need to pass a DataFrame to ensure that the order of the rows and columns of the cost matrix is correct. In the following we retrieve the cell names to construct the DataFrame containing custom cost matrices.

rng = np.random.default_rng(seed=42)

cost_linear_01 = np.abs(rng.normal(size=(len(obs_names_0), len(obs_names_1))))

cm_linear = pd.DataFrame(data=cost_linear_01, index=obs_names_0, columns=obs_names_1)

Now we can set the custom cost matrices:

fgw["0", "1"].set_xy(cm_linear, tag="cost_matrix")

When solving the problem, the custom cost matrices will be used for the problem mapping from batch '0' to batch '1', while the problem mapping from batch '1' to batch '2' is still using the information passed in prepare().

## Cost matrices in the obsp layer¶

A second way to pass custom cost matrices is using obsp. This is especially useful when saving and loading a model. On the other hand, it might be more difficult to store the cost matrix in the correct place in obsp. In the following, we construct the obsp layer. When doing this, be sure that the order of the cost matrix entries are correct. In the following, we construct an obsp layer containing custom cost matrices for both linear and quadratic terms for both OT problems, mapping from batch '0' to batch '1', and from batch '1' to batch '2'.

cost_linear_12 = np.abs(rng.normal(size=(len(obs_names_1), len(obs_names_2))))

print(cost_linear_01.shape)
print(cost_linear_12.shape)

blocks = [
[
np.zeros((len(obs_names_2), len(obs_names_0))),
np.zeros((len(obs_names_2), len(obs_names_1))),
],
]

(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(60, 60)

We need to specify where to fetch the custom cost matrices in the prepare() methods. If we want to only use the linear custom cost matrix, we need to modify the joint_attr as follows:

joint_attr = {"key": "cost_matrices", "tag": "cost_matrix"}
fgw = fgw.prepare(key="batch", joint_attr=joint_attr, x_attr="X_pca", y_attr="X_pca")

If we want to use only quadratic custom cost matrices, we need to modify x_attr and y_attr.

x_attr = {
"attr": "obsp",
"key": "cost_matrices",
"tag": "cost_matrix",
"cost": "custom",
}
y_attr = {
"attr": "obsp",
"key": "cost_matrices",
"tag": "cost_matrix",
"cost": "custom",
}
fgw = fgw.prepare(key="batch", joint_attr="X_pca", x_attr=x_attr, y_attr=y_attr)
fgw[("0", "1")]
OTProblem[stage='prepared', shape=(20, 20)]

If we want to use custom cost matrices for all terms, we can do this the following way:

fgw = fgw.prepare(key="batch", joint_attr=joint_attr, x_attr=x_attr, y_attr=y_attr)
fgw[("1", "2")]
OTProblem[stage='prepared', shape=(20, 20)]