Handling marginals#

This example shows how to explicitly pass the marginals.

Marginals define the weight of each single cell within a distribution of cells. In many cases, marginals are chosen to be uniform as all cells are equally important. In some cases, we have prior knowledge to adapt the marginals.

For example, score_genes_for_marginals() computes the marginals such that cells expressing proliferation marker genes get a higher weight as they are assumed to have multiple descendants or that certain cells are outliers in space, so they should not influence the mapping too much. We demonstrate how to pass marginals with the AlignmentProblem.

Imports and data loading#

from moscot import datasets
from moscot.problems.space import AlignmentProblem

import numpy as np

Simulate data using sim_align().

adata = datasets.sim_align()
adata
AnnData object with n_obs × n_vars = 1200 × 500
    obs: 'batch'
    uns: 'batch_colors'
    obsm: 'spatial'

Uniform marginals#

ap = AlignmentProblem(adata)
ap = ap.prepare(batch_key="batch", policy="sequential")
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  

If marginals are not specified, they are assumed to be uniform.

ap["0", "1"].a[:5], ap["0", "1"].b[:5]
(array([0.0025, 0.0025, 0.0025, 0.0025, 0.0025]),
 array([0.0025, 0.0025, 0.0025, 0.0025, 0.0025]))

User-defined marginals#

If we want to specify the marginals, they should be passed via obs. Let’s assume, we want to assign less weight to the cell 400 in our source distribution.

adata.obs["source_marginals"] = np.ones(adata.n_obs)
adata.obs.loc["400", "source_marginals"] = 0.5
adata.obs.head()
batch source_marginals
400 0 0.5
401 0 1.0
402 0 1.0
403 0 1.0
404 0 1.0

Similarly, we want to assign less weight to cell 397-1 in the target distribution.

adata.obs["target_marginals"] = np.ones(adata.n_obs)
adata.obs.loc["397-1", "target_marginals"] = 0.5
adata.obs.tail()
batch source_marginals target_marginals
395-1 2 1.0 1.0
396-1 2 1.0 1.0
397-1 2 1.0 0.5
398-1 2 1.0 1.0
399-1 2 1.0 1.0
ap2 = AlignmentProblem(adata)
ap2 = ap2.prepare(batch_key="batch", a="source_marginals", b="target_marginals")
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
INFO     Computing pca with `n_comps=30` for `xy` using `adata.X`                                                  
ap2["0", "1"].a[:5], ap2["1", "2"].b[-5:]
(array([0.5, 1. , 1. , 1. , 1. ]), array([1. , 1. , 0.5, 1. , 1. ]))

Note that cell 397-1 belongs to batch 2, hence it never appears in a source distribution as we have chosen the SequentialPolicy. Similarly, the cells belonging to batch 0 are never part of a target distribution. Also note that the scale of the marginals influences the convergence criterion. Hence, we recommend normalizing the marginals to sum to \(1\).