Passing callbacks in prepare()

In this example, we show how to use different callbacks.

The callback argument states which computation should be run on X to get the joint cost when preparing the problem. Callbacks can be set for different terms - linear (xy_callback) and quadratic (x_callback, y_callback).

See also

Imports and data loading

import warnings

warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)

from moscot import datasets
from moscot.problems.space import MappingProblem
from moscot.utils.tagged_array import TaggedArray

import numpy as np
import pandas as pd
from sklearn.decomposition import SparsePCA

import anndata
import scanpy as sc
adata_sc = datasets.drosophila(spatial=False)
adata_sp = datasets.drosophila(spatial=True)
adata_sc, adata_sp
(AnnData object with n_obs × n_vars = 1297 × 2000
     obs: 'n_counts'
     var: 'n_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
     uns: 'hvg', 'log1p', 'pca'
     obsm: 'X_pca'
     varm: 'PCs'
     layers: 'counts',
 AnnData object with n_obs × n_vars = 3039 × 82
     obs: 'n_counts'
     var: 'n_counts'
     uns: 'log1p', 'pca'
     obsm: 'X_pca', 'spatial'
     varm: 'PCs'
     layers: 'counts')

Spatial normalization

When normalize_spatial=True is passed, as it is by default, the spatial coordinates are normalized by standardizing them.

mp = MappingProblem(adata_sc=adata_sc, adata_sp=adata_sp)
mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, normalize_spatial=True)
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Normalizing spatial coordinates of `x`.                                                                   
mp[("src", "tgt")].x.data_src.std()
1.0000000000000002
mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, normalize_spatial=False)
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
mp[("src", "tgt")].x.data_src.std()
66.97163996056013

The normalize_spatial argument effectively uses the "spatial-norm" callback.

Passing callbacks

PCA computation in gene space

To create a joint PCA embedding between two sets of genes, we compute the PCA embedding for pairs of distributions by passing xy_callback="local-pca" to run on X.

mp = mp.prepare(
    sc_attr={"attr": "obsm", "key": "X_pca"},
    xy_callback="local-pca",
)
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Normalizing spatial coordinates of `x`.                                                                   

The callback creates a point cloud that contains PCA projections of the data.

mp[("src", "tgt")].xy.tag
<Tag.POINT_CLOUD: 'point_cloud'>
mp.solve()
mp.solutions
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(3039, 1297)].                                          
{('src', 'tgt'): OTTOutput[shape=(3039, 1297), cost=1.6444, converged=True]}

Using geodesic costs

To use geodesic costs defined on a graph, we can create the underlying graph (here in gene expression space) using xy_callback="graph-construction". Note that the cost has to be set explicitly.

mp = mp.prepare(
    sc_attr={"attr": "obsm", "key": "X_pca"},
    normalize_spatial=False,
    xy_callback="graph-construction",
    cost={"xy": "geodesic", "x": "sq_euclidean", "y": "sq_euclidean"},
)
INFO     Computing graph construction for `xy` using `X_pca`                                                       

and verify a graph has been constructed:

mp[("src", "tgt")].xy.tag
<Tag.GRAPH: 'graph'>
mp.solve()
mp.solutions
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(3039, 1297)].                                          
{('src', 'tgt'): OTTOutput[shape=(3039, 1297), cost=1.3147, converged=True]}

Or use set_graph_xy() with a custom graph:

adata_concat = anndata.concat([adata_sp, adata_sc])
sc.pp.neighbors(adata_concat, use_rep="X_pca")
df_graph = pd.DataFrame(
    index=adata_concat.obs_names,
    columns=adata_concat.obs_names,
    data=adata_concat.obsp["connectivities"].toarray().astype("float64"),
)

First, the problem is prepared with the default ("sq_euclidean") cost, and it is then overwritten by set_graph_xy():

mp = mp.prepare(
    sc_attr={"attr": "obsm", "key": "X_pca"},
    normalize_spatial=False,
)
mp[("src", "tgt")].set_graph_xy(df_graph, cost="geodesic")
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
mp[("src", "tgt")].xy.tag
<Tag.GRAPH: 'graph'>
mp.solve()
mp.solutions
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(3039, 1297)].                                          
{('src', 'tgt'): OTTOutput[shape=(3039, 1297), cost=1.3147, converged=True]}

Custom callback function

A callable can also be passed to be used as a custom callback. In this example we will use the scikit-learn SparsePCA() function.

The callback function receives term: Literal["xy", "x", "y"], problem.adata_src, problem.adata_tgt as arguments, as well as any keyword arguments passed in xy_callback_kwargs. It should return a moscot.utils.tagged_array.TaggedArray.

mp = mp.prepare(
    sc_attr={"attr": "obsm", "key": "X_pca"},
    normalize_spatial=False,
    xy_callback=lambda term, src, tgt: TaggedArray(
        *np.split(
            SparsePCA().fit_transform(np.vstack([src.X.toarray(), tgt.X.toarray()])),
            [src.shape[0]],
        )
    ),
)
mp[("src", "tgt")].xy.tag
<Tag.POINT_CLOUD: 'point_cloud'>
mp.solve()
mp.solutions
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(3039, 1297)].                                          
{('src', 'tgt'): OTTOutput[shape=(3039, 1297), cost=1.6591, converged=True]}