Analyzing HSPCs with the TemporalProblem#

In this tutorial, we showcase using gene expression (GEX) and chromatin accessibility (ATAC) to analyze the trajectory of HSPC cells on a time-resolved dataset. The method builds upon [Schiebinger et al., 2019].

See also

Mapping cells across time points#

Given measurements of any modality across time we can reconstruct the trajectories by mapping cells across time. We assume that we have the same measurement at each time point, resulting in a W-type OT problem.

In its simplest form, we use unbalanced ([Chizat et al., 2018]) entropic OT [Cuturi, 2013] to account for unbalancedness while speeding up computations and to improve the statistical properties of the solution [Peyré et al., 2019]. We use the following definitions:

  • \(a\) and \(b\): marginal distributions over early and late cells, respectively, representing any prior knowledge, including growth and death rates.

  • \(\varepsilon\): weight given to entropic regularitation. Larger values will lead to more “blurred” couplings.

  • \(\tau_a\) representing the unbalancedness in the source distribution. \(\tau_a=1\) corresponds to the fully balanced case, while \(\tau_a < 1\) allows unbalancedness, e.g. a cell might have no descendant because it dies or a cell could have multiple descendants as it proliferates. Note that this is not the same as using prior information about the marginals obtained from the expression from proliferation and apoptosis markers. We normally combine both approaches, but if prior estimates are available, \(\tau_a\) can be chosen closer to 1.

  • \(\tau_b\): unbalancedness parameter in the target marginals. Analogously defined to \(\tau_a\).


import warnings
from typing import List, Literal, Optional, Tuple

import moscot as mt
import moscot.plotting as mtp
from moscot.problems.time import TemporalProblem
from tqdm.std import TqdmWarning

import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import StandardScaler

import matplotlib.pyplot as plt

import scanpy as sc

warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", TqdmWarning)

Dataset description#

The hspc() dataset is a subset from the Open Problems - Multimodal Single-Cell Integration* NeurIPS competition 2022 and includes, among other:

Cell type annotation is very coarse in this dataset and was done based only on gene expression resulting in the following annotations:

  • MasP = Mast Cell Progenitors

  • MkP = Megakaryocyte Progenitors

  • NeuP = Neutrophil Progenitors

  • MoP = Monocyte Progenitors

  • EryP = Erythrocyte Progenitors

  • HSC = Hematopoietic Stem Cells

  • BP = B-Cell Progenitors

adata = mt.datasets.hspc()
AnnData object with n_obs × n_vars = 4000 × 2000
    obs: 'day', 'donor', 'cell_type', 'technology', 'n_genes'
    var: 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'cell_type_colors', 'hvg', 'neighbors', 'neighbors_atac', 'pca', 'umap'
    obsm: 'X_lsi', 'X_pca', 'X_umap_ATAC', 'X_umap_GEX', 'peaks_tfidf'
    varm: 'PCs'
    obsp: 'connectivities', 'distances', 'neighbors_atac_connectivities', 'neighbors_atac_distances'

Visualise the data#

First, we consider the cell type annotations in a UMAP which was computed based on gene expression., basis="X_umap_GEX", color=["cell_type", "day"])

We can also plot a UMAP which was computed based on the ATAC data., basis="X_umap_ATAC", color=["cell_type", "day"])

Creating a latent space incorporating information from both GEX and ATAC#

As we want to leverage information from both modalities, we need to find a shared space of GEX and ATAC. There are different ways to obtain this, e.g., via scvi-tools models. Here, we simply concatenate the PCA-reduced GEX space and the LSI-reduced ATAC space and run PCA on the joint space.

As we want both modalities to have a comparable influence on the transport map, we need to make sure that the total variances are comparable.

np.var(adata.obsm["X_lsi"]), np.var(adata.obsm["X_pca"])
(1.0000002, 39.049347)

Standardization to unit variance is part of LSI, hence we only adapt the PCA-embedding of the gene expression.

adata.obsm["X_pca_scaled"] = StandardScaler().fit_transform(adata.obsm["X_pca"])
adata.obsm["X_shared"] = sc.pp.pca(
    np.concatenate((adata.obsm["X_pca_scaled"], adata.obsm["X_lsi"]), axis=1),

Prepare the TemporalProblem#

First, we instantiate the TemporalProblem

tp = TemporalProblem(adata)

We can adapt the marginals of the OT problem. By default, marginals are chosen to be uniform. This means that, even if each cell in the early time point has multiple possible descendants in the later time point, the proportions of the descendants sum up to “the mass of one cell”. In developing systems we know that some cells might proliferate more than others. We can account for this by adapting the marginals based on proliferation and apoptosis markers with score_genes_for_marginals().

See also

tp = tp.score_genes_for_marginals(
    gene_set_proliferation="human", gene_set_apoptosis="human"
WARNING: genes are not in var_names and ignored: ['ANLN', 'ANP32E', 'ATAD2', 'AURKA', 'AURKB', 'BIRC5', 'BLM', 'BRIP1', 'BUB1', 'CASP8AP2', 'CBX5', 'CCNB2', 'CCNE2', 'CDC20', 'CDC25C', 'CDC45', 'CDC6', 'CDCA2', 'CDCA3', 'CDCA7', 'CDCA8', 'CDK1', 'CENPA', 'CENPF', 'CHAF1B', 'CKAP2', 'CKAP2L', 'CKAP5', 'CKS1B', 'CKS2', 'CLSPN', 'CTCF', 'DLGAP5', 'DSCC1', 'DTL', 'E2F8', 'ECT2', 'EXO1', 'FAM64A', 'FEN1', 'G2E3', 'GAS2L3', 'GINS2', 'GMNN', 'GTSE1', 'HELLS', 'HJURP', 'HMGB2', 'HMMR', 'HN1', 'KIF11', 'KIF20B', 'KIF23', 'KIF2C', 'LBR', 'MCM2', 'MCM4', 'MCM5', 'MCM6', 'MLF1IP', 'MSH2', 'NASP', 'NCAPD2', 'NDC80', 'NEK2', 'NUF2', 'NUSAP1', 'PCNA', 'POLA1', 'POLD3', 'PRIM1', 'PSRC1', 'RAD51', 'RAD51AP1', 'RANGAP1', 'RFC2', 'RPA2', 'RRM1', 'RRM2', 'SLBP', 'SMC4', 'TACC3', 'TIPIN', 'TMPO', 'TOP2A', 'TPX2', 'TTK', 'TUBB4B', 'TYMS', 'UBR7', 'UHRF1', 'UNG', 'USP1', 'WDR76']
WARNING: genes are not in var_names and ignored: ['ADD1', 'AIFM3', 'ANKH', 'ANXA1', 'APP', 'ATF3', 'AVPR1A', 'BAX', 'BCAP31', 'BCL10', 'BCL2L1', 'BCL2L10', 'BCL2L2', 'BGN', 'BID', 'BIK', 'BIRC3', 'BMF', 'BMP2', 'BNIP3L', 'BRCA1', 'BTG2', 'BTG3', 'CASP1', 'CASP2', 'CASP3', 'CASP4', 'CASP6', 'CASP7', 'CASP8', 'CASP9', 'CAV1', 'CCNA1', 'CCND1', 'CCND2', 'CD2', 'CDC25B', 'CDK2', 'CDKN1A', 'CDKN1B', 'CFLAR', 'CREBBP', 'CTH', 'CTNNB1', 'CYLD', 'DAP', 'DAP3', 'DCN', 'DDIT3', 'DFFA', 'DIABLO', 'DNAJA1', 'DNAJC3', 'DNM1L', 'DPYD', 'EBP', 'EGR3', 'ENO2', 'ERBB2', 'ERBB3', 'ETF1', 'F2', 'FAS', 'FASLG', 'FDXR', 'FEZ1', 'GADD45A', 'GADD45B', 'GCH1', 'GNA15', 'GPX1', 'GPX3', 'GPX4', 'GSR', 'GSTM1', 'GUCY2D', 'H1-0', 'HMGB2', 'HMOX1', 'HSPB1', 'IFNB1', 'IFNGR1', 'IGFBP6', 'IL18', 'IL1A', 'IL1B', 'IL6', 'IRF1', 'ISG20', 'KRT18', 'MADD', 'MCL1', 'MGMT', 'MMP2', 'NEFH', 'PAK1', 'PDCD4', 'PDGFRB', 'PEA15', 'PLCB2', 'PLPPR4', 'PMAIP1', 'PPP2R5B', 'PPP3R1', 'PPT1', 'PRF1', 'PSEN1', 'PSEN2', 'RARA', 'RELA', 'RETSAT', 'RHOB', 'RHOT2', 'RNASEL', 'ROCK1', 'SATB1', 'SC5D', 'SLC20A1', 'SMAD7', 'SOD1', 'SOD2', 'SPTAN1', 'TAP1', 'TGFB2', 'TIMP2', 'TNF', 'TNFRSF12A', 'TNFSF10', 'TOP2A', 'TSPO', 'TXNIP', 'VDAC2', 'WEE1', 'XIAP']

Now we can investigate the proliferation and apoptosis markers on the UMAP. Proliferation markers are much stronger than apoptosis markers, which can be seen from the range of the scores. This is to be expected, as we are in a developmental setting., basis="umap_GEX", color=["proliferation", "apoptosis"])

We use the above scores to adapt the left marginals of our OT problems. By default, this is done via a birth death process ([Schiebinger et al., 2019]), but we can also do this using a more tunable and interpretable way.

Now we can prepare the problem. Therefore we need to pass some information:

  • time_key: defines the obs column for the temporal information.

  • joint_attr: defines based on which (latent) space we compute the distances for the OT problem. In this case, we use the latent space constructed above.

See also

tp = tp.prepare(time_key="day", joint_attr="X_shared")

Now we can investigate the prior growth rate estimates.

adata.obs["prior_growth_rates"] = tp.prior_growth_rates, basis="umap_GEX", color="prior_growth_rates")

Solve the TemporalProblem#

Now we can solve the problem. We set epsilon to be relatively small to get a sparse mapping between cells. Although we have a prior estimate of proliferation score of the cells, we still allow for adjustments of these prior estimates by setting tau_a=0.99. tau_b=0.999 as we don’t expect many cells to die in this developmental setting. Nevertheless, we don’t want to be in the fully balanced case to reduce the influence of possible outliers. scale_cost is set to mean, which simply means that we normalize the entries of the cost matrix by its mean to stabilize training.

tp = tp.solve(epsilon=1e-3, tau_a=0.99, tau_b=0.999, scale_cost="mean")
INFO     Solving `3` problems                                                                                      
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(766, 1235)].                                   
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(1235, 1201)].                                  
INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(1201, 798)].                                   

Interpreting the solution#

Growth rates#

We can compare the posterior growth rates with the prior growth rates to see the influence of the unbalancedness parameters tau_a and tau_b. Note that setting tau_a = tau_b = 1 would result in posterior growth rates being equal to prior growth rates.

adata.obs["prior_growth_rates"] = tp.prior_growth_rates
adata.obs["posterior_growth_rates"] = tp.posterior_growth_rates
    color=["prior_growth_rates", "posterior_growth_rates"],

To learn more about the influence of each single cell to the resulting coupling, we plot the "cell costs" in the source, and target distribution, respectively. High values indicate that a certain cell is unlikely to have a descendant or ancestor, respectively. One should consider removing them and rerunning the algorithm.

adata.obs["cell_costs_source"] = tp.cell_costs_source
adata.obs["cell_costs_target"] = tp.cell_costs_target

We can see that there are a few cells which influence the coupling more strongly than others, but the extend is still moderate, so we continue without removing them.
    adata, basis="umap_GEX", color=["cell_costs_source", "cell_costs_target"]

Identifying ancestry of cells.#

We can now investigate which ancestry population a certain cell type has. We do this by aggregating the transport matrix by cell type, using cell_transition()

ct_desc = tp.cell_transition(
    2, 3, "cell_type", "cell_type", forward=False, key_added="transitions_2_3"
ct_desc = tp.cell_transition(
    3, 4, "cell_type", "cell_type", forward=False, key_added="transitions_3_4"
ct_desc = tp.cell_transition(
    4, 7, "cell_type", "cell_type", forward=False, key_added="transitions_4_7"

# Create a 1x3 grid of subplots
fig, axes = plt.subplots(ncols=3, figsize=(16, 6))

axes[0] = mtp.cell_transition(
    figsize=(5, 5),

axes[1] = mtp.cell_transition(
    figsize=(5, 5),

axes[2] = mtp.cell_transition(
    figsize=(5, 5),


The transition matrices shows the ancestry of each cell type. For example, looking at the left transition matrix, we see that BPs at time point 3 are mainly derived from HSCs or BPs (first column). Analogously, NeuPs at time point 3 are mainly derived from HSCs and NeuPs (last column).

Visualizing ancestors and descendants#

We can also visualize ancestors and descendants, e.g. on a UMAP, by using the pull() and push(), respectively. We start with descendants of all cells at time point 4.

tp.pull(source=4, target=7, data="cell_type", subset="MasP")

fig, axes = plt.subplots(ncols=2, figsize=(20, 6))

axes[0] = mtp.pull(
    title=["MasP at time 7"],
axes[1] = mtp.pull(
    title=["MasP ancestors"],


The left plot above colorizes all MasP cells at time point 7 in red. All other cells in time point 7 are dark gray, while cells not measured at time point 7 are light gray. The left plot shows the ancestry likelihood of each cell on a palette from red (very likely ancestor) to dark gray (very unlikely ancestor). All cells not belonging to time point 4 are light gray.

We can also visualize the descendants of a certain subgroup. For example, we can visualize the descendants of HSCs at time point 2.

tp.push(source=2, target=3, data="cell_type", subset="HSC")
fig, axes = plt.subplots(ncols=2, figsize=(16, 6))

axes[0] = mtp.push(
    title=["HSCs at time 2"],
axes[1] = mtp.push(
    title=["HSCs descendants"],


To get an idea of the dynamics across all time points, we can visualize the cell type evolution with a sankey() diagram. We choose threshold=0.05, which ignores all transitions with a probability of less than 0.05.

mtp.sankey(tp, dpi=100, figsize=(9, 4))

Finding new subclusters#

So far we only considered the cell type annotations to perform trajectory inference. The TemporalProblem is helpful to detect trajectories of subpopulations. Therefore, we subcluster the HSC population.

As mentioned above, the cell type annotation is based on gene expression only, and hence does not take the ATAC modality into account, while our computed transport maps leverage both modalities. For subclustering we will now also use both modalities. Thus, we compute the neighbors based on our joint embedding.

sc.pp.neighbors(adata, use_rep="X_shared")

new_key = "HSC_subclustered"
    adata, restrict_to=("cell_type", ["HSC"]), key_added=new_key, resolution=0.5
), color=new_key, basis="umap_GEX")
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.

Now, we can consider the transitions between the subclusters. Therefore, we consider where the subclusters of HSCs at time point 3 are mapped to in time point 4.

    source_groups={new_key: ["HSC,0", "HSC,1", "HSC,2", "HSC,3", "HSC,4", "HSC,5"]},
mtp.cell_transition(tp, dpi=80, fontsize=14)

We can see that MoP cells mainly evolve from the HSC,5 subcluster while BP cells have progenitors in both HSC,0 and HSC,3.