Sequential alignment

In this example we focus on how to use different policies in the AlignmentProblem to produce specific alignments of spatial samples.

See also

Preliminaries

import warnings

warnings.simplefilter(action="ignore", category=UserWarning)


from moscot import datasets
from moscot.problems.space import AlignmentProblem

import squidpy as sq

Simulate data using sim_align().

adata = datasets.sim_align()
adata
AnnData object with n_obs × n_vars = 1200 × 500
    obs: 'batch'
    uns: 'batch_colors'
    obsm: 'spatial'

The adata consists of 3 different slides (batches), each having 400 cells.

sq.pl.spatial_scatter(adata, shape=None, library_id="batch", color="batch")
../../../../_images/779a69cc6c1cd30301de314af4f9af794dc5b2035cb648685e68c5e6221e35bc.png

In the AlignmentProblem, there are two available policies: SequentialPolicy and StarPolicy.

Star policy

StarPolicy means that transport maps will be computed from every batch to the reference one. Hence, reference needs to be specified in prepare().

ap = AlignmentProblem(adata=adata)
ap = ap.prepare(batch_key="batch", policy="star", reference="2")
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Normalizing spatial coordinates of `x`.                                                                   
INFO     Normalizing spatial coordinates of `y`.                                                                   
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Normalizing spatial coordinates of `x`.                                                                   
INFO     Normalizing spatial coordinates of `y`.                                                                   
ap = ap.solve()
ap.solutions
INFO     Solving `2` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(400, 400)].                                            
INFO     Solving problem OTProblem[stage='prepared', shape=(400, 400)].                                            
{('1', '2'): OTTOutput[shape=(400, 400), cost=1.1172, converged=True],
 ('0', '2'): OTTOutput[shape=(400, 400), cost=0.2524, converged=True]}

##Aligning the slides to one reference

When we use align(), on a problem prepared with policy="star", all of the batches are aligned to the same reference as the policy:

ap.align(key_added="star_warp")
sq.pl.spatial_scatter(
    adata,
    shape=None,
    spatial_key="star_warp",
    library_id="batch",
    color="batch",
    title="Alignment 0 -> 2, 1 -> 2",
)
../../../../_images/660c22b5992dadc3f408229093a7d2f1a53cfb4aea430de372085fd5f2a4b74c.png

Sequential policy

ap = AlignmentProblem(adata=adata)
ap = ap.prepare(batch_key="batch", policy="sequential")
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Normalizing spatial coordinates of `x`.                                                                   
INFO     Normalizing spatial coordinates of `y`.                                                                   
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Normalizing spatial coordinates of `x`.                                                                   
INFO     Normalizing spatial coordinates of `y`.                                                                   
ap = ap.solve()
ap.solutions
INFO     Solving `2` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(400, 400)].                                            
INFO     Solving problem OTProblem[stage='prepared', shape=(400, 400)].                                            
{('1', '2'): OTTOutput[shape=(400, 400), cost=1.1172, converged=True],
 ('0', '1'): OTTOutput[shape=(400, 400), cost=1.0933, converged=True]}

##Aligning the slides sequentially

When we use align() as is, all of the slides are aligned to the reference:

ap.align(reference="2", key_added="seq_warp")
sq.pl.spatial_scatter(
    adata,
    shape=None,
    spatial_key="seq_warp",
    library_id="batch",
    color="batch",
    title="Alignment 0 -> 1 -> 2",
)
../../../../_images/8c0dabf6116a84ff9dc0e1c6ef502729540ff4dd8bbaa9722025cf5e029fbe5b.png

As, according to the sequential policy, no transport map was computed from batch 0 to batch 2 directly, the alignment 0 -> 2 is achieved by chaining two transport maps 0 -> 1 -> 2.

Now, if we want to see the mapping of only the directly sequential batches, there are two options.

First, we can use align() multiple times, changing the reference batch and subsetting adata for plotting if needed.

ap.align(reference="1", key_added="align_1")
sq.pl.spatial_scatter(
    adata[adata.obs["batch"].isin(("0", "1"))],
    shape=None,
    spatial_key="align_1",
    library_id="batch",
    color="batch",
    title="Alignment 0 -> 1",
)
../../../../_images/a6942ad2a8d52c6fae8c6c36ccaf0bfc3887c575a370e9d37c4b39f2917d4ddc.png
ap.align(reference="2", key_added="align_2")
sq.pl.spatial_scatter(
    adata[adata.obs["batch"].isin(("1", "2"))],
    shape=None,
    spatial_key="align_2",
    library_id="batch",
    color="batch",
    title="Alignment 1 -> 2",
)
../../../../_images/89e1201b92c9c8dd6f400a69c248027b77605323530b2c013ad3839cb1e7c647.png

The second option would be instantiating different AlignmentProblems only with the relevant slides, and thus do pairwise alignments.

ap01 = AlignmentProblem(adata=adata[adata.obs["batch"].isin(("0", "1"))])
ap01 = ap01.prepare(batch_key="batch")
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Normalizing spatial coordinates of `x`.                                                                   
INFO     Normalizing spatial coordinates of `y`.                                                                   
ap01 = ap01.solve()
ap01.solutions
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(400, 400)].                                            
{('0', '1'): OTTOutput[shape=(400, 400), cost=1.0933, converged=True]}
ap12 = AlignmentProblem(adata=adata[adata.obs["batch"].isin(("1", "2"))])
ap12 = ap12.prepare(batch_key="batch")
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Normalizing spatial coordinates of `x`.                                                                   
INFO     Normalizing spatial coordinates of `y`.                                                                   
ap12 = ap12.solve()
ap12.solutions
INFO     Solving `1` problems                                                                                      
INFO     Solving problem OTProblem[stage='prepared', shape=(400, 400)].                                            
{('1', '2'): OTTOutput[shape=(400, 400), cost=1.1172, converged=True]}
ap01.align(reference="1", key_added="align")
sq.pl.spatial_scatter(
    ap01.adata,
    shape=None,
    spatial_key="align",
    library_id="batch",
    color="batch",
    title="Alignment 0 -> 1",
)
../../../../_images/1042d932accc6b89c9b9cbe1017af11304da6f724bfa5d941d757aae1e8ffe48.png
ap12.align(reference="2", key_added="align")
sq.pl.spatial_scatter(
    ap12.adata,
    shape=None,
    spatial_key="align",
    library_id="batch",
    color="batch",
    title="Alignment 1 -> 2",
)
../../../../_images/86d55867acd1607f686b3ebef4ffb1c1f7beb82dfc8ae900459ddd4cf9e39526.png