from __future__ import annotations
import types
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Iterable,
Literal,
Mapping,
Optional,
Sequence,
Union,
)
import numpy as np
import pandas as pd
from scipy.sparse.linalg import LinearOperator
import scanpy as sc
from moscot import _constants
from moscot._types import ArrayLike, Numeric_t, Str_Dict_t
from moscot.base.problems._utils import (
_check_argument_compatibility_cell_transition,
_correlation_test,
_get_df_cell_transition,
_order_transition_matrix,
_validate_annotations,
_validate_args_cell_transition,
)
from moscot.base.problems.compound_problem import B, K
from moscot.base.problems.problem import (
AbstractPushPullAdata,
AbstractSolutionsProblems,
)
from moscot.plotting._utils import set_plotting_vars
from moscot.utils.data import transcription_factors
from moscot.utils.subset_policy import SubsetPolicy
__all__ = ["AnalysisMixin"]
[docs]
class AnalysisMixin(Generic[K, B], AbstractPushPullAdata, AbstractSolutionsProblems):
"""Base Analysis Mixin."""
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
def _cell_transition(
self,
source: K,
target: Optional[K],
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
aggregation_mode: Literal["annotation", "cell"] = "annotation",
key_added: Optional[str] = _constants.CELL_TRANSITION,
**kwargs: Any,
) -> pd.DataFrame:
if aggregation_mode == "annotation" and (source_groups is None or target_groups is None):
raise ValueError(
"If `aggregation_mode='annotation'`, `source_groups` and `target_groups` cannot be `None`."
)
if aggregation_mode == "cell" and source_groups is None and target_groups is None:
raise ValueError("At least one of `source_groups` and `target_group` must be specified.")
_check_argument_compatibility_cell_transition(
source_annotation=source_groups,
target_annotation=target_groups,
aggregation_mode=aggregation_mode,
**kwargs,
)
tm = self._cell_transition_online(
source=source,
target=target,
source_groups=source_groups,
target_groups=target_groups,
aggregation_mode=aggregation_mode,
**kwargs,
)
if key_added is not None:
forward = kwargs.pop("forward")
if aggregation_mode == "cell" and "cell" in self.adata.obs:
raise KeyError(f"Aggregation is already present in `adata.obs[{aggregation_mode!r}]`.")
plot_vars = {
"source": source,
"target": target,
"source_groups": source_groups if (not forward or aggregation_mode == "annotation") else "cell",
"target_groups": target_groups if (forward or aggregation_mode == "annotation") else "cell",
"transition_matrix": tm,
}
set_plotting_vars(
self.adata,
_constants.CELL_TRANSITION,
key=key_added,
value=plot_vars,
)
return tm
def _annotation_aggregation_transition(
self,
annotations_1: list[Any],
annotations_2: list[Any],
df: pd.DataFrame,
func: Callable[..., ArrayLike],
) -> pd.DataFrame:
n1 = len(annotations_1)
n2 = len(annotations_2)
tm_arr = np.zeros((n1, n2))
# Factorize annotations in df_res_annotation
codes, uniques = pd.factorize(df.values)
# Map annotations in 'annotations_2' to indices in 'uniques'
annotations_in_df_to_idx = {annotation: idx for idx, annotation in enumerate(uniques)}
annotations_2_codes = [annotations_in_df_to_idx.get(annotation, -1) for annotation in annotations_2]
for i, subset in enumerate(annotations_1):
result = func(
subset=subset,
)
# Compute sums over 'codes' weighted by 'result'
sums = np.bincount(codes, weights=result.squeeze(), minlength=len(uniques))
dist = [sums[code] if code != -1 else 0 for code in annotations_2_codes]
tm_arr[i, :] = dist
return pd.DataFrame(
tm_arr,
index=annotations_1,
columns=annotations_2,
)
def _cell_transition_online(
self,
key: Optional[str],
source: K,
target: Optional[K],
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic
aggregation_mode: Literal["annotation", "cell"] = "annotation",
other_key: Optional[str] = None,
other_adata: Optional[str] = None,
batch_size: Optional[int] = None,
normalize: bool = True,
**_: Any,
) -> pd.DataFrame:
source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition(
self.adata, source_groups
)
target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition(
self.adata if other_adata is None else other_adata, target_groups
)
new_annotation_key = "new_annotation"
df_source = _get_df_cell_transition(
self.adata,
[source_annotation_key],
key,
source,
).rename(columns={source_annotation_key: new_annotation_key})
df_target = _get_df_cell_transition(
self.adata if other_adata is None else other_adata,
[target_annotation_key],
key if other_adata is None else other_key,
target,
).rename(columns={target_annotation_key: new_annotation_key})
source_annotations_verified, target_annotations_verified = _validate_annotations(
df_source=df_source,
df_target=df_target,
source_annotation_key=new_annotation_key,
target_annotation_key=new_annotation_key,
source_annotations=source_annotations,
target_annotations=target_annotations,
aggregation_mode=aggregation_mode,
forward=forward,
)
df_to, df_from = (df_target, df_source) if forward else (df_source, df_target)
df_to = df_to[new_annotation_key]
move_op = self.push if forward else self.pull
move_op_const_kwargs = {
"source": source,
"target": target,
"normalize": True,
"return_all": False,
"scale_by_marginals": False,
"key_added": None,
}
if aggregation_mode == "annotation":
func = partial(
move_op,
data=source_annotation_key if forward else target_annotation_key,
split_mass=False,
**move_op_const_kwargs,
)
tm = self._annotation_aggregation_transition(
annotations_1=source_annotations_verified if forward else target_annotations_verified,
annotations_2=target_annotations_verified if forward else source_annotations_verified,
df=df_to,
func=func,
)
elif aggregation_mode == "cell":
func = partial(
move_op,
data=None,
split_mass=True,
**move_op_const_kwargs,
)
tm = self._cell_aggregation_transition(
df_from=df_from,
df_to=df_to,
annotations=target_annotations_verified if forward else source_annotations_verified,
batch_size=batch_size,
func=func,
)
else:
raise NotImplementedError(f"Aggregation mode `{aggregation_mode!r}` is not yet implemented.")
if normalize:
tm = tm.div(tm.sum(axis=1), axis=0)
return _order_transition_matrix(
tm=tm,
source_annotations_verified=source_annotations_verified,
target_annotations_verified=target_annotations_verified,
source_annotations_ordered=source_annotations_ordered,
target_annotations_ordered=target_annotations_ordered,
forward=forward,
)
def _annotation_mapping(
self,
mapping_mode: Literal["sum", "max"],
annotation_label: str,
source: K,
target: K,
key: str | None = None,
forward: bool = True,
other_adata: str | None = None,
scale_by_marginals: bool = True,
batch_size: int | None = None,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> pd.DataFrame:
if mapping_mode == "sum":
cell_transition_kwargs = dict(cell_transition_kwargs)
cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell
cell_transition_kwargs.setdefault("key", key)
cell_transition_kwargs.setdefault("source", source)
cell_transition_kwargs.setdefault("target", target)
cell_transition_kwargs.setdefault("other_adata", other_adata)
cell_transition_kwargs.setdefault("forward", not forward)
cell_transition_kwargs.setdefault("batch_size", batch_size)
if forward:
cell_transition_kwargs.setdefault("source_groups", annotation_label)
cell_transition_kwargs.setdefault("target_groups", None)
axis = 0 # rows
else:
cell_transition_kwargs.setdefault("source_groups", None)
cell_transition_kwargs.setdefault("target_groups", annotation_label)
axis = 1 # columns
out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs)
return out.idxmax(axis=axis).to_frame(name=annotation_label)
if mapping_mode == "max":
out = []
if forward:
source_df = _get_df_cell_transition(
self.adata,
annotation_keys=[annotation_label],
filter_key=key,
filter_value=source,
)
out_len = self.solutions[(source, target)].shape[1]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch: ArrayLike = self.pull(
source=source,
target=target,
data=None,
subset=(batch, batch_size),
normalize=True,
return_all=False,
scale_by_marginals=scale_by_marginals,
split_mass=True,
key_added=None,
)
v = np.array(tm_batch.argmax(0))
out.extend(source_df[annotation_label][v[i]] for i in range(len(v)))
else:
target_df = _get_df_cell_transition(
self.adata if other_adata is None else other_adata,
annotation_keys=[annotation_label],
filter_key=key,
filter_value=target,
)
out_len = self.solutions[(source, target)].shape[0]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch: ArrayLike = self.push( # type: ignore[no-redef]
source=source,
target=target,
data=None,
subset=(batch, batch_size),
normalize=True,
return_all=False,
scale_by_marginals=scale_by_marginals,
split_mass=True,
key_added=None,
)
v = np.array(tm_batch.argmax(0))
out.extend(target_df[annotation_label][v[i]] for i in range(len(v)))
categories = pd.Categorical(out)
return pd.DataFrame(categories, columns=[annotation_label])
raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.")
def _sample_from_tmap(
self,
source: K,
target: K,
n_samples: int,
source_dim: int,
target_dim: int,
batch_size: int = 256,
account_for_unbalancedness: bool = False,
interpolation_parameter: Optional[Numeric_t] = None,
seed: Optional[int] = None,
) -> tuple[list[Any], list[ArrayLike]]:
rng = np.random.RandomState(seed)
if account_for_unbalancedness and interpolation_parameter is None:
raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.")
if interpolation_parameter is not None and not (0 < interpolation_parameter < 1):
raise ValueError(
f"Expected interpolation parameter to be in interval `(0, 1)`, found `{interpolation_parameter}`."
)
mass = np.ones(target_dim)
if account_for_unbalancedness and interpolation_parameter is not None:
col_sums = self._apply(
source=source,
target=target,
normalize=True,
forward=True,
scale_by_marginals=False,
explicit_steps=[(source, target)],
)
if TYPE_CHECKING:
assert isinstance(col_sums, np.ndarray)
col_sums = np.asarray(col_sums).squeeze() + 1e-12
mass = mass / np.power(col_sums, 1 - interpolation_parameter)
row_probability = np.asarray(
self._apply(
source=source,
target=target,
data=mass,
normalize=True,
forward=False,
scale_by_marginals=False,
explicit_steps=[(source, target)],
)
).squeeze()
rows_sampled = rng.choice(source_dim, p=row_probability / row_probability.sum(), size=n_samples)
rows, counts = np.unique(rows_sampled, return_counts=True)
all_cols_sampled: list[str] = []
for batch in range(0, len(rows), batch_size):
rows_batch = rows[batch : batch + batch_size]
counts_batch = counts[batch : batch + batch_size]
data = np.zeros((source_dim, len(rows_batch)))
data[rows_batch, range(len(rows_batch))] = 1
col_p_given_row = np.asarray(
self._apply(
source=source,
target=target,
data=data,
normalize=True,
forward=True,
scale_by_marginals=False,
explicit_steps=[(source, target)],
)
).squeeze()
if account_for_unbalancedness:
if TYPE_CHECKING:
assert isinstance(col_sums, np.ndarray)
col_p_given_row = col_p_given_row / col_sums[:, None]
cols_sampled = [
rng.choice(a=target_dim, size=counts_batch[i], p=col_p_given_row[:, i] / col_p_given_row[:, i].sum())
for i in range(len(rows_batch))
]
all_cols_sampled.extend(cols_sampled)
return rows, all_cols_sampled # type: ignore[return-value]
def _interpolate_transport(
self,
# TODO(@giovp): rename this to 'explicit_steps'; pass to policy.plan() and reintroduce (source_key, target_key)
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
**_: Any,
) -> LinearOperator:
"""Interpolate transport matrix."""
if TYPE_CHECKING:
assert isinstance(self._policy, SubsetPolicy)
# TODO(@MUCDK, @giovp, discuss what exactly this function should do, seems like it could be more generic)
fst, *rest = path
return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals)
def _flatten(self, data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
tmp = np.full(len(self.adata), np.nan)
for k, v in data.items():
mask = self.adata.obs[key] == k
tmp[mask] = np.squeeze(v)
return tmp
def _cell_aggregation_transition(
self,
df_from: pd.DataFrame,
df_to: pd.DataFrame,
annotations: list[Any],
batch_size: Optional[int],
func: Callable[..., ArrayLike],
) -> pd.DataFrame:
# Factorize annotations in df_to
annotations_in_df_to = df_to.values
codes_to, uniques_to = pd.factorize(annotations_in_df_to)
# Map annotations in 'annotations' to codes
annotations_to_code = {annotation: idx for idx, annotation in enumerate(uniques_to)}
annotations_codes = [annotations_to_code.get(annotation, -1) for annotation in annotations]
n_annotations = len(annotations)
n_from_cells = len(df_from)
if batch_size is None:
batch_size = n_from_cells
tm_arr = np.zeros((n_from_cells, n_annotations))
index = df_from.index
# Process in batches
for batch_start in range(0, n_from_cells, batch_size):
batch_end = min(batch_start + batch_size, n_from_cells)
subset = (batch_start, batch_end - batch_start)
result = func(subset=subset)
# Result shape: (n_to_cells, batch_size)
# For each cell in the batch, we compute the sum over annotations
for i in range(batch_end - batch_start):
cell_distribution = result[:, i]
# Aggregate over annotations using bincount
sums = np.bincount(
codes_to,
weights=cell_distribution,
minlength=len(uniques_to),
)
# Map sums to annotations_verified_codes
dist = [sums[code] if code != -1 else 0 for code in annotations_codes]
tm_arr[batch_start + i, :] = dist
return pd.DataFrame(tm_arr, index=index, columns=annotations)
# adapted from:
# https://github.com/theislab/cellrank/blob/master/cellrank/_utils/_utils.py#L392
[docs]
def compute_feature_correlation(
self,
obs_key: str,
corr_method: Literal["pearson", "spearman"] = "pearson",
significance_method: Literal["fisher", "perm_test"] = "fisher",
annotation: Optional[dict[str, Iterable[str]]] = None,
layer: Optional[str] = None,
features: Optional[Union[list[str], Literal["human", "mouse", "drosophila"]]] = None,
confidence_level: float = 0.95,
n_perms: int = 1000,
seed: Optional[int] = None,
**kwargs: Any,
) -> pd.DataFrame:
"""Compute correlation of push-forward or pull-back distribution with features.
Correlates a feature, e.g., counts of a gene, with probabilities of cells mapped to a set of cells such as
the push-forward or pull-back distributions.
.. seealso::
- TODO: create and link an example
Parameters
----------
obs_key
Key in :attr:`~anndata.AnnData.obs` containing the push-forward or pull-back distribution.
corr_method
Which type of correlation to compute, either ``'pearson'`` or ``'spearman'``.
significance_method
Mode to use when calculating p-values and confidence intervals. Valid options are:
- ``'fisher'`` - Fisher transformation :cite:`fisher:21`.
- ``'perm_test'`` - `permutation test <https://en.wikipedia.org/wiki/Permutation_test>`_.
annotation
How to subset the data when computing the correlation:
- :obj:`None` - do not subset the data.
- :class:`str` - key in :attr:`~anndata.AnnData.obs` where categorical data is stored.
- :class:`dict` - a dictionary with one key corresponding to a categorical column in
:attr:`~anndata.AnnData.obs` and values to a subset of categories.
layer
Key in :attr:`~anndata.AnnData.layers` from which to get the expression.
If :obj:`None`, use :attr:`~anndata.AnnData.X`.
features
Features in :class:`~anndata.AnnData` to correlate with
:attr:`obs['{obs_key}'] <anndata.AnnData.obs>`:
- :obj:`None` - all features from :attr:`~anndata.AnnData.var` will be taken into account.
- :obj:`list` - subset of :attr:`~anndata.AnnData.var_names` or :attr:`~anndata.AnnData.obs_names`.
- ``'human'``, ``'mouse'``, or ``'drosophila'`` - the features are subsetted to the transcription factors
from :func:`~moscot.utils.data.transcription_factors`.
confidence_level
Confidence level for the confidence interval calculation. Must be in interval :math:`[0, 1]`.
n_perms
Number of permutations to use when ``method = 'perm_test'``.
seed
Random seed when ``method = 'perm_test'``.
kwargs
Keyword arguments for parallelization, e.g., ``n_jobs``.
Returns
-------
Dataframe of shape ``(n_features, 5)`` containing the following columns, one for each feature:
- ``'corr'`` - correlation between the count data and push/pull distributions.
- ``'pval'`` - calculated p-values for double-sided test.
- ``'qval'`` - corrected p-values using the `Benjamini-Hochberg
<https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini%E2%80%93Hochberg_procedure>`_ method
at :math:`0.05` level.
- ``'ci_low'`` - lower bound of the ``confidence_level`` correlation confidence interval.
- ``'ci_high'`` - upper bound of the ``confidence_level`` correlation confidence interval.
"""
if obs_key not in self.adata.obs:
raise KeyError(f"Unable to access data in `adata.obs[{obs_key!r}]`.")
if annotation is not None:
annotation_key, annotation_vals = next(iter(annotation.items()))
if annotation_key not in self.adata.obs:
raise KeyError(f"Unable to access data in `adata.obs[{annotation_key!r}]`.")
if not isinstance(annotation_vals, Iterable):
raise TypeError("`annotation` expected to be dictionary of length 1 with value being an iterable.")
adata = self.adata[self.adata.obs[annotation_key].isin(annotation_vals)]
else:
adata = self.adata
adata = adata[~adata.obs[obs_key].isnull()]
if not adata.n_obs:
raise ValueError(f"`adata.obs[{obs_key!r}]` only contains NaN values.")
distribution: pd.DataFrame = adata.obs[[obs_key]]
if isinstance(features, str):
tfs = transcription_factors(organism=features)
features = list(set(tfs).intersection(adata.var_names))
if not features:
raise KeyError("No common transcription factors found in the data base.")
elif features is None:
features = list(self.adata.var_names)
return _correlation_test(
X=sc.get.obs_df(adata, keys=features, layer=layer).values,
Y=distribution,
feature_names=features,
corr_method=corr_method,
significance_method=significance_method,
confidence_level=confidence_level,
n_perms=n_perms,
seed=seed,
**kwargs,
)
[docs]
def compute_entropy(
self,
source: K,
target: K,
forward: bool = True,
key_added: Optional[str] = "conditional_entropy",
batch_size: Optional[int] = None,
c: float = 1e-10,
**kwargs: Any,
) -> Optional[pd.DataFrame]:
"""Compute the conditional entropy per cell.
The conditional entropy reflects the uncertainty of the mapping of a single cell.
Parameters
----------
source
Source key.
target
Target key.
forward
If `True`, computes the conditional entropy of a cell in the source distribution, else the
conditional entropy of a cell in the target distribution.
key_added
Key in :attr:`~anndata.AnnData.obs` where the entropy is stored.
batch_size
Batch size for the computation of the entropy. If :obj:`None`, the entire dataset is used.
c
Constant added to each row of the transport matrix to avoid numerical instability.
kwargs
Kwargs for :func:`~scipy.stats.entropy`.
Returns
-------
:obj:`None` if ``key_added`` is not None. Otherwise, returns a data frame of shape ``(n_cells, 1)`` containing
the conditional entropy per cell.
"""
from scipy import stats
filter_value = source if forward else target
df = pd.DataFrame(
index=self.adata[self.adata.obs[self._policy.key] == filter_value, :].obs_names,
columns=[key_added] if key_added is not None else ["entropy"],
)
batch_size = batch_size if batch_size is not None else len(df)
func = self.push if forward else self.pull
for batch in range(0, len(df), batch_size):
cond_dists = func(
source=source,
target=target,
data=None,
subset=(batch, batch_size),
normalize=True,
return_all=False,
scale_by_marginals=False,
split_mass=True,
key_added=None,
)
df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs)
if key_added is not None:
self.adata.obs[key_added] = df
return df if key_added is None else None