Tagged arrays#
This example shows how to use the TaggedArray
.
TaggedArray
stores the data passed by the users in a unified way before it is passed to the backend.
See also
See Lineage tree on how to use lineage tree to compute leaf distance.
See Custom cost matrices on how to pass precomputed cost matrices.
See Quadratic problems for an introduction on how to solve quadratic problems.
See Quadratic problems (advanced) for an advanced example on how to solve quadratic problems.
Imports and data loading#
from moscot import datasets
from moscot.problems.generic import GWProblem
import numpy as np
import pandas as pd
import scanpy as sc
np.set_printoptions(threshold=1, precision=3)
Simulate data using simulate_data()
.
adata = datasets.simulate_data(n_distributions=2, key="batch")
sc.pp.pca(adata)
adata
AnnData object with n_obs × n_vars = 40 × 60
obs: 'batch', 'celltype'
uns: 'pca'
obsm: 'X_pca'
varm: 'PCs'
Prepare the problem#
We instantiate and prepare a GWProblem
to demonstrate the role of the TaggedArray
.
gwp = GWProblem(adata)
gwp = gwp.prepare(key="batch", joint_attr="X_pca", GW_x="X_pca", GW_y="X_pca")
The OTProblem
has attributes xy
, x
, and y
, storing the data for the linear and quadratic term, respectively. These attributes are all TaggedArrays
.
gwp["0", "1"].xy
TaggedArray(data_src=ArrayView([[ 4.718, -0.162, -1.028, ..., 0.521, -0.039, -0.043],
[-1.683, 0.754, -2.114, ..., -0.348, -0.498, -0.051],
[ 0.119, -1.908, 2.526, ..., -0.536, -0.61 , -0.274],
...,
[-1.764, 2.958, -0.065, ..., -0.659, -0.011, 0.067],
[-1.088, 3.231, -0.984, ..., -0.545, -0.262, -0.015],
[-1.471, 0.102, -0.997, ..., 0.234, 0.048, -0.021]],
dtype=float32), data_tgt=ArrayView([[-1.076, -3.572, -0.315, ..., 0.492, -0.226, 0.194],
[ 0.852, -1.033, 3.834, ..., 0.252, -0.115, 0.243],
[-0.411, -2.689, -1.863, ..., -0.347, 0.601, 0.005],
...,
[ 2.024, -1.597, -0.591, ..., -0.114, 0.29 , -0.332],
[-0.938, 2.426, -0.128, ..., 0.194, 0.829, -0.438],
[ 2.709, -2.885, -0.925, ..., 0.315, 0.334, 0.078]],
dtype=float32), tag='point_cloud', cost=<ott.geometry.costs.SqEuclidean object at 0x29e7bb310>)
Attributes#
Each TaggedArray
has attributes data_src
, data_tgt
, cost
, and tag
.
gwp["0", "1"].xy.tag, gwp["0", "1"].xy.cost
('point_cloud', <ott.geometry.costs.SqEuclidean at 0x29e7bb310>)
The tag
attribute is of type Tag
and defines what kind of data is stored in the TaggedArray
.
Possible tags are cost_matrix
, kernel
, and point_cloud
. Whenever tag='point_cloud'
,
the backend is expected to compute the cost on the fly. The cost
attribute should then specify which cost to compute from the point clouds.
If the TaggedArray
corresponds to a linear term,
data_src
and data_tgt
contain the point clouds of the source and the target distribution, respectively.
gwp["0", "1"].xy.data_src, gwp["0", "1"].xy.data_tgt
(ArrayView([[ 4.718, -0.162, -1.028, ..., 0.521, -0.039, -0.043],
[-1.683, 0.754, -2.114, ..., -0.348, -0.498, -0.051],
[ 0.119, -1.908, 2.526, ..., -0.536, -0.61 , -0.274],
...,
[-1.764, 2.958, -0.065, ..., -0.659, -0.011, 0.067],
[-1.088, 3.231, -0.984, ..., -0.545, -0.262, -0.015],
[-1.471, 0.102, -0.997, ..., 0.234, 0.048, -0.021]],
dtype=float32),
ArrayView([[-1.076, -3.572, -0.315, ..., 0.492, -0.226, 0.194],
[ 0.852, -1.033, 3.834, ..., 0.252, -0.115, 0.243],
[-0.411, -2.689, -1.863, ..., -0.347, 0.601, 0.005],
...,
[ 2.024, -1.597, -0.591, ..., -0.114, 0.29 , -0.332],
[-0.938, 2.426, -0.128, ..., 0.194, 0.829, -0.438],
[ 2.709, -2.885, -0.925, ..., 0.315, 0.334, 0.078]],
dtype=float32))
If the TaggedArray
corresponds to a quadratic term, the cost
will be computed pairwise between points of the same distribution. Hence,
data_tgt is None
.
gwp["0", "1"].x.data_src, gwp["0", "1"].x.data_tgt
(ArrayView([[ 4.718, -0.162, -1.028, ..., 0.521, -0.039, -0.043],
[-1.683, 0.754, -2.114, ..., -0.348, -0.498, -0.051],
[ 0.119, -1.908, 2.526, ..., -0.536, -0.61 , -0.274],
...,
[-1.764, 2.958, -0.065, ..., -0.659, -0.011, 0.067],
[-1.088, 3.231, -0.984, ..., -0.545, -0.262, -0.015],
[-1.471, 0.102, -0.997, ..., 0.234, 0.048, -0.021]],
dtype=float32),
None)