Tagged arrays#

This example shows how to use the TaggedArray.

TaggedArray stores the data passed by the users in a unified way before it is passed to the backend.

See also

Imports and data loading#

import warnings

warnings.simplefilter("ignore", FutureWarning)

from moscot import datasets
from moscot.problems.generic import FGWProblem

import numpy as np
import pandas as pd

import scanpy as sc

np.set_printoptions(threshold=1, precision=3)

Simulate data using simulate_data().

adata = datasets.simulate_data(n_distributions=2, key="batch")
sc.pp.pca(adata)
adata
AnnData object with n_obs × n_vars = 40 × 60
    obs: 'batch', 'celltype'
    uns: 'pca'
    obsm: 'X_pca'
    varm: 'PCs'

Prepare the problem#

We instantiate and prepare a FGWProblem to demonstrate the role of the TaggedArray.

fgw = FGWProblem(adata)
fgw = fgw.prepare(key="batch", x_attr="X_pca", y_attr="X_pca", joint_attr="X_pca")

The OTProblem has attributes xy, x, and y, storing the data for the linear and quadratic term, respectively. These attributes are all TaggedArrays.

fgw[("0", "1")].xy
TaggedArray(data_src=ArrayView([[ 4.718, -0.162, -1.028, ...,  0.521, -0.039, -0.043],
           [-1.683,  0.754, -2.114, ..., -0.348, -0.498, -0.051],
           [ 0.119, -1.908,  2.526, ..., -0.536, -0.61 , -0.274],
           ...,
           [-1.764,  2.958, -0.065, ..., -0.659, -0.011,  0.067],
           [-1.088,  3.231, -0.984, ..., -0.545, -0.262, -0.015],
           [-1.471,  0.102, -0.997, ...,  0.234,  0.048, -0.021]],
          dtype=float32), data_tgt=ArrayView([[-1.076, -3.572, -0.315, ...,  0.492, -0.226,  0.194],
           [ 0.852, -1.033,  3.834, ...,  0.252, -0.115,  0.243],
           [-0.411, -2.689, -1.863, ..., -0.347,  0.601,  0.005],
           ...,
           [ 2.024, -1.597, -0.591, ..., -0.114,  0.29 , -0.332],
           [-0.938,  2.426, -0.128, ...,  0.194,  0.829, -0.438],
           [ 2.709, -2.885, -0.925, ...,  0.315,  0.334,  0.078]],
          dtype=float32), tag=<Tag.POINT_CLOUD: 'point_cloud'>, cost=<ott.geometry.costs.SqEuclidean object at 0x000002095006ED50>)

Attributes#

Each TaggedArray has attributes data_src, data_tgt, cost, and tag.

fgw["0", "1"].xy.tag, fgw["0", "1"].xy.cost
(<Tag.POINT_CLOUD: 'point_cloud'>,
 <ott.geometry.costs.SqEuclidean at 0x2095006ed50>)

The tag attribute is of type Tag and defines what kind of data is stored in the TaggedArray. Possible tags are cost_matrix, kernel, and point_cloud. Whenever tag='point_cloud', the backend is expected to compute the cost on the fly. The cost attribute should then specify which cost to compute from the point clouds.

If the TaggedArray corresponds to a linear term, data_src and data_tgt contain the point clouds of the source and the target distribution, respectively.

fgw["0", "1"].xy.data_src, fgw["0", "1"].xy.data_tgt
(ArrayView([[ 4.718, -0.162, -1.028, ...,  0.521, -0.039, -0.043],
            [-1.683,  0.754, -2.114, ..., -0.348, -0.498, -0.051],
            [ 0.119, -1.908,  2.526, ..., -0.536, -0.61 , -0.274],
            ...,
            [-1.764,  2.958, -0.065, ..., -0.659, -0.011,  0.067],
            [-1.088,  3.231, -0.984, ..., -0.545, -0.262, -0.015],
            [-1.471,  0.102, -0.997, ...,  0.234,  0.048, -0.021]],
           dtype=float32),
 ArrayView([[-1.076, -3.572, -0.315, ...,  0.492, -0.226,  0.194],
            [ 0.852, -1.033,  3.834, ...,  0.252, -0.115,  0.243],
            [-0.411, -2.689, -1.863, ..., -0.347,  0.601,  0.005],
            ...,
            [ 2.024, -1.597, -0.591, ..., -0.114,  0.29 , -0.332],
            [-0.938,  2.426, -0.128, ...,  0.194,  0.829, -0.438],
            [ 2.709, -2.885, -0.925, ...,  0.315,  0.334,  0.078]],
           dtype=float32))

If the TaggedArray corresponds to a quadratic term, the cost will be computed pairwise between points of the same distribution. Hence, data_tgt is None.

fgw["0", "1"].x.data_src, fgw["0", "1"].x.data_tgt
(ArrayView([[ 4.718, -0.162, -1.028, ...,  0.521, -0.039, -0.043],
            [-1.683,  0.754, -2.114, ..., -0.348, -0.498, -0.051],
            [ 0.119, -1.908,  2.526, ..., -0.536, -0.61 , -0.274],
            ...,
            [-1.764,  2.958, -0.065, ..., -0.659, -0.011,  0.067],
            [-1.088,  3.231, -0.984, ..., -0.545, -0.262, -0.015],
            [-1.471,  0.102, -0.997, ...,  0.234,  0.048, -0.021]],
           dtype=float32),
 None)

Modifying the tags#

Whenever the tag='cost_matrix', the backend expects an instantiated cost matrix. There are two different cases to distinguish. First, the user might directly want to pass custom cost matrices, see Custom cost matrices for more information. In this case, cost='custom' must be set.

When setting custom cost matrices, e.g., via set_xy(), the TaggedArray will change its tag. Before setting the custom cost matrix we still have tag='point_cloud' and data_tgt is not None, as it contains the point cloud of the target distribution.

fgw["0", "1"].xy.tag, fgw["0", "1"].xy.data_tgt
(<Tag.POINT_CLOUD: 'point_cloud'>,
 ArrayView([[-1.076, -3.572, -0.315, ...,  0.492, -0.226,  0.194],
            [ 0.852, -1.033,  3.834, ...,  0.252, -0.115,  0.243],
            [-0.411, -2.689, -1.863, ..., -0.347,  0.601,  0.005],
            ...,
            [ 2.024, -1.597, -0.591, ..., -0.114,  0.29 , -0.332],
            [-0.938,  2.426, -0.128, ...,  0.194,  0.829, -0.438],
            [ 2.709, -2.885, -0.925, ...,  0.315,  0.334,  0.078]],
           dtype=float32))

We now construct a random custom cost matrix for the linear term.

rng = np.random.default_rng(seed=42)
obs_names_0 = fgw["0", "1"].adata_src.obs_names
obs_names_1 = fgw["0", "1"].adata_tgt.obs_names

cost_linear_01 = np.abs(rng.normal(size=(len(obs_names_0), len(obs_names_1))))
cm_linear = pd.DataFrame(data=cost_linear_01, index=obs_names_0, columns=obs_names_1)

fgw["0", "1"].set_xy(cm_linear, tag="cost_matrix")
fgw["0", "1"].xy.tag, fgw["0", "1"].xy.data_tgt
(<Tag.COST_MATRIX: 'cost_matrix'>, None)