Source code for moscot.base.problems._mixins

from __future__ import annotations

import types
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    Iterable,
    Literal,
    Mapping,
    Optional,
    Protocol,
    Sequence,
    Union,
)

import numpy as np
import pandas as pd
from scipy.sparse.linalg import LinearOperator

import scanpy as sc
from anndata import AnnData

from moscot import _constants
from moscot._types import ArrayLike, Numeric_t, Str_Dict_t
from moscot.base.output import BaseSolverOutput
from moscot.base.problems._utils import (
    _check_argument_compatibility_cell_transition,
    _correlation_test,
    _get_df_cell_transition,
    _order_transition_matrix,
    _validate_annotations,
    _validate_args_cell_transition,
)
from moscot.base.problems.compound_problem import ApplyOutput_t, B, K
from moscot.plotting._utils import set_plotting_vars
from moscot.utils.data import transcription_factors
from moscot.utils.subset_policy import SubsetPolicy

__all__ = ["AnalysisMixin"]


class AnalysisMixinProtocol(Protocol[K, B]):
    """Protocol class."""

    adata: AnnData
    _policy: SubsetPolicy[K]
    solutions: dict[tuple[K, K], BaseSolverOutput]
    problems: dict[tuple[K, K], B]

    def _apply(
        self,
        data: Optional[Union[str, ArrayLike]] = None,
        source: Optional[K] = None,
        target: Optional[K] = None,
        forward: bool = True,
        return_all: bool = False,
        scale_by_marginals: bool = False,
        **kwargs: Any,
    ) -> ApplyOutput_t[K]: ...

    def _interpolate_transport(
        self: AnalysisMixinProtocol[K, B],
        path: Sequence[tuple[K, K]],
        scale_by_marginals: bool = True,
    ) -> LinearOperator: ...

    def _flatten(
        self: AnalysisMixinProtocol[K, B],
        data: dict[K, ArrayLike],
        *,
        key: Optional[str],
    ) -> ArrayLike: ...

    def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]:
        """Push distribution."""
        ...

    def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]:
        """Pull distribution."""
        ...

    def _cell_transition(
        self: AnalysisMixinProtocol[K, B],
        source: K,
        target: K,
        source_groups: Str_Dict_t,
        target_groups: Str_Dict_t,
        aggregation_mode: Literal["annotation", "cell"] = "annotation",
        key_added: Optional[str] = _constants.CELL_TRANSITION,
        **kwargs: Any,
    ) -> pd.DataFrame: ...

    def _cell_transition_online(
        self: AnalysisMixinProtocol[K, B],
        key: Optional[str],
        source: K,
        target: K,
        source_groups: Str_Dict_t,
        target_groups: Str_Dict_t,
        forward: bool = False,  # return value will be row-stochastic if forward=True, else column-stochastic
        aggregation_mode: Literal["annotation", "cell"] = "annotation",
        other_key: Optional[str] = None,
        other_adata: Optional[str] = None,
        batch_size: Optional[int] = None,
        normalize: bool = True,
    ) -> pd.DataFrame: ...

    def _annotation_mapping(
        self: AnalysisMixinProtocol[K, B],
        mapping_mode: Literal["sum", "max"],
        annotation_label: str,
        forward: bool,
        source: K,
        target: K,
        key: str | None = None,
        other_adata: Optional[str] = None,
        scale_by_marginals: bool = True,
        cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
    ) -> pd.DataFrame: ...


[docs] class AnalysisMixin(Generic[K, B]): """Base Analysis Mixin.""" def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) def _cell_transition( self: AnalysisMixinProtocol[K, B], source: K, target: K, source_groups: Str_Dict_t, target_groups: Str_Dict_t, aggregation_mode: Literal["annotation", "cell"] = "annotation", key_added: Optional[str] = _constants.CELL_TRANSITION, **kwargs: Any, ) -> pd.DataFrame: if aggregation_mode == "annotation" and (source_groups is None or target_groups is None): raise ValueError( "If `aggregation_mode='annotation'`, `source_groups` and `target_groups` cannot be `None`." ) if aggregation_mode == "cell" and source_groups is None and target_groups is None: raise ValueError("At least one of `source_groups` and `target_group` must be specified.") _check_argument_compatibility_cell_transition( source_annotation=source_groups, target_annotation=target_groups, aggregation_mode=aggregation_mode, **kwargs, ) tm = self._cell_transition_online( source=source, target=target, source_groups=source_groups, target_groups=target_groups, aggregation_mode=aggregation_mode, **kwargs, ) if key_added is not None: forward = kwargs.pop("forward") if aggregation_mode == "cell" and "cell" in self.adata.obs: raise KeyError(f"Aggregation is already present in `adata.obs[{aggregation_mode!r}]`.") plot_vars = { "source": source, "target": target, "source_groups": source_groups if (not forward or aggregation_mode == "annotation") else "cell", "target_groups": target_groups if (forward or aggregation_mode == "annotation") else "cell", "transition_matrix": tm, } set_plotting_vars( self.adata, _constants.CELL_TRANSITION, key=key_added, value=plot_vars, ) return tm def _cell_transition_online( self: AnalysisMixinProtocol[K, B], key: Optional[str], source: K, target: K, source_groups: Str_Dict_t, target_groups: Str_Dict_t, forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic aggregation_mode: Literal["annotation", "cell"] = "annotation", other_key: Optional[str] = None, other_adata: Optional[str] = None, batch_size: Optional[int] = None, normalize: bool = True, **_: Any, ) -> pd.DataFrame: source_annotation_key, source_annotations, source_annotations_ordered = _validate_args_cell_transition( self.adata, source_groups ) target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition( self.adata if other_adata is None else other_adata, target_groups ) df_source = _get_df_cell_transition( self.adata, [source_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key], key, source, ) df_target = _get_df_cell_transition( self.adata if other_adata is None else other_adata, [target_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key], key if other_adata is None else other_key, target, ) source_annotations_verified, target_annotations_verified = _validate_annotations( df_source=df_source, df_target=df_target, source_annotation_key=source_annotation_key, target_annotation_key=target_annotation_key, source_annotations=source_annotations, target_annotations=target_annotations, aggregation_mode=aggregation_mode, forward=forward, ) if aggregation_mode == "annotation": df_target["distribution"] = 0 df_source["distribution"] = 0 tm = pd.DataFrame( np.zeros((len(source_annotations_verified), len(target_annotations_verified))), index=source_annotations_verified, columns=target_annotations_verified, ) if forward: tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] source=source, target=target, annotation_key=source_annotation_key, annotations_1=source_annotations_verified, annotations_2=target_annotations_verified, df=df_target, tm=tm, forward=True, ) else: tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] source=source, target=target, annotation_key=target_annotation_key, annotations_1=target_annotations_verified, annotations_2=source_annotations_verified, df=df_source, tm=tm, forward=False, ) elif aggregation_mode == "cell": tm = pd.DataFrame(columns=target_annotations_verified if forward else source_annotations_verified) if forward: tm = self._cell_aggregation_transition( # type: ignore[attr-defined] source=source, target=target, annotation_key=target_annotation_key, annotations_1=source_annotations_verified, annotations_2=target_annotations_verified, df_1=df_target, df_2=df_source, tm=tm, batch_size=batch_size, forward=True, ) else: tm = self._cell_aggregation_transition( # type: ignore[attr-defined] source=source, target=target, annotation_key=source_annotation_key, annotations_1=target_annotations_verified, annotations_2=source_annotations_verified, df_1=df_source, df_2=df_target, tm=tm, batch_size=batch_size, forward=False, ) else: raise NotImplementedError(f"Aggregation mode `{aggregation_mode!r}` is not yet implemented.") if normalize: tm = tm.div(tm.sum(axis=1), axis=0) return _order_transition_matrix( tm=tm, source_annotations_verified=source_annotations_verified, target_annotations_verified=target_annotations_verified, source_annotations_ordered=source_annotations_ordered, target_annotations_ordered=target_annotations_ordered, forward=forward, ) def _annotation_mapping( self: AnalysisMixinProtocol[K, B], mapping_mode: Literal["sum", "max"], annotation_label: str, source: K, target: K, key: str | None = None, forward: bool = True, other_adata: str | None = None, scale_by_marginals: bool = True, batch_size: int | None = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), ) -> pd.DataFrame: if mapping_mode == "sum": cell_transition_kwargs = dict(cell_transition_kwargs) cell_transition_kwargs.setdefault("aggregation_mode", "cell") # aggregation mode should be set to cell cell_transition_kwargs.setdefault("key", key) cell_transition_kwargs.setdefault("source", source) cell_transition_kwargs.setdefault("target", target) cell_transition_kwargs.setdefault("other_adata", other_adata) cell_transition_kwargs.setdefault("forward", not forward) cell_transition_kwargs.setdefault("batch_size", batch_size) if forward: cell_transition_kwargs.setdefault("source_groups", annotation_label) cell_transition_kwargs.setdefault("target_groups", None) axis = 0 # rows else: cell_transition_kwargs.setdefault("source_groups", None) cell_transition_kwargs.setdefault("target_groups", annotation_label) axis = 1 # columns out: pd.DataFrame = self._cell_transition(**cell_transition_kwargs) return out.idxmax(axis=axis).to_frame(name=annotation_label) if mapping_mode == "max": out = [] if forward: source_df = _get_df_cell_transition( self.adata, annotation_keys=[annotation_label], filter_key=key, filter_value=source, ) out_len = self.solutions[(source, target)].shape[1] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): tm_batch: ArrayLike = self.pull( source=source, target=target, data=None, subset=(batch, batch_size), normalize=True, return_all=False, scale_by_marginals=scale_by_marginals, split_mass=True, key_added=None, ) v = np.array(tm_batch.argmax(0)) out.extend(source_df[annotation_label][v[i]] for i in range(len(v))) else: target_df = _get_df_cell_transition( self.adata if other_adata is None else other_adata, annotation_keys=[annotation_label], filter_key=key, filter_value=target, ) out_len = self.solutions[(source, target)].shape[0] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): tm_batch: ArrayLike = self.push( # type: ignore[no-redef] source=source, target=target, data=None, subset=(batch, batch_size), normalize=True, return_all=False, scale_by_marginals=scale_by_marginals, split_mass=True, key_added=None, ) v = np.array(tm_batch.argmax(0)) out.extend(target_df[annotation_label][v[i]] for i in range(len(v))) categories = pd.Categorical(out) return pd.DataFrame(categories, columns=[annotation_label]) raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") def _sample_from_tmap( self: AnalysisMixinProtocol[K, B], source: K, target: K, n_samples: int, source_dim: int, target_dim: int, batch_size: int = 256, account_for_unbalancedness: bool = False, interpolation_parameter: Optional[Numeric_t] = None, seed: Optional[int] = None, ) -> tuple[list[Any], list[ArrayLike]]: rng = np.random.RandomState(seed) if account_for_unbalancedness and interpolation_parameter is None: raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.") if interpolation_parameter is not None and not (0 < interpolation_parameter < 1): raise ValueError( f"Expected interpolation parameter to be in interval `(0, 1)`, found `{interpolation_parameter}`." ) mass = np.ones(target_dim) if account_for_unbalancedness and interpolation_parameter is not None: col_sums = self._apply( source=source, target=target, normalize=True, forward=True, scale_by_marginals=False, explicit_steps=[(source, target)], ) if TYPE_CHECKING: assert isinstance(col_sums, np.ndarray) col_sums = np.asarray(col_sums).squeeze() + 1e-12 mass = mass / np.power(col_sums, 1 - interpolation_parameter) row_probability = np.asarray( self._apply( source=source, target=target, data=mass, normalize=True, forward=False, scale_by_marginals=False, explicit_steps=[(source, target)], ) ).squeeze() rows_sampled = rng.choice(source_dim, p=row_probability / row_probability.sum(), size=n_samples) rows, counts = np.unique(rows_sampled, return_counts=True) all_cols_sampled: list[str] = [] for batch in range(0, len(rows), batch_size): rows_batch = rows[batch : batch + batch_size] counts_batch = counts[batch : batch + batch_size] data = np.zeros((source_dim, len(rows_batch))) data[rows_batch, range(len(rows_batch))] = 1 col_p_given_row = np.asarray( self._apply( source=source, target=target, data=data, normalize=True, forward=True, scale_by_marginals=False, explicit_steps=[(source, target)], ) ).squeeze() if account_for_unbalancedness: if TYPE_CHECKING: assert isinstance(col_sums, np.ndarray) col_p_given_row = col_p_given_row / col_sums[:, None] cols_sampled = [ rng.choice(a=target_dim, size=counts_batch[i], p=col_p_given_row[:, i] / col_p_given_row[:, i].sum()) for i in range(len(rows_batch)) ] all_cols_sampled.extend(cols_sampled) return rows, all_cols_sampled # type: ignore[return-value] def _interpolate_transport( self: AnalysisMixinProtocol[K, B], # TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key) path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, **_: Any, ) -> LinearOperator: """Interpolate transport matrix.""" if TYPE_CHECKING: assert isinstance(self._policy, SubsetPolicy) # TODO(@MUCDK, @giovp, discuss what exactly this function should do, seems like it could be more generic) fst, *rest = path return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals) def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: tmp = np.full(len(self.adata), np.nan) for k, v in data.items(): mask = self.adata.obs[key] == k tmp[mask] = np.squeeze(v) return tmp def _annotation_aggregation_transition( self: AnalysisMixinProtocol[K, B], source: K, target: K, annotation_key: str, annotations_1: list[Any], annotations_2: list[Any], df: pd.DataFrame, tm: pd.DataFrame, forward: bool, ) -> pd.DataFrame: if not forward: tm = tm.T func = self.push if forward else self.pull for subset in annotations_1: result = func( # TODO(@MUCDK) check how to make compatible with all policies source=source, target=target, data=annotation_key, subset=subset, normalize=True, return_all=False, scale_by_marginals=False, split_mass=False, key_added=None, ) df["distribution"] = result cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True) cell_dist /= cell_dist.sum() tm.loc[subset, :] = [ cell_dist.loc[annotation, "distribution"] if annotation in cell_dist.distribution.index else 0 for annotation in annotations_2 ] return tm def _cell_aggregation_transition( self: AnalysisMixinProtocol[K, B], source: str, target: str, annotation_key: str, # TODO(MUCDK): unused variables, del below annotations_1: list[Any], annotations_2: list[Any], df_1: pd.DataFrame, df_2: pd.DataFrame, tm: pd.DataFrame, batch_size: Optional[int], forward: bool, ) -> pd.DataFrame: func = self.push if forward else self.pull if batch_size is None: batch_size = len(df_2) for batch in range(0, len(df_2), batch_size): result = func( # TODO(@MUCDK) check how to make compatible with all policies source=source, target=target, data=None, subset=(batch, batch_size), normalize=True, return_all=False, scale_by_marginals=False, split_mass=True, key_added=None, ) current_cells = df_2.iloc[range(batch, min(batch + batch_size, len(df_2)))].index.tolist() df_1.loc[:, current_cells] = result to_app = df_1[df_1[annotation_key].isin(annotations_2)].groupby(annotation_key).sum().transpose() tm = pd.concat([tm, to_app], verify_integrity=True, axis=0) df_1 = df_1.drop(current_cells, axis=1) return tm # adapted from: # https://github.com/theislab/cellrank/blob/master/cellrank/_utils/_utils.py#L392
[docs] def compute_feature_correlation( self: AnalysisMixinProtocol[K, B], obs_key: str, corr_method: Literal["pearson", "spearman"] = "pearson", significance_method: Literal["fisher", "perm_test"] = "fisher", annotation: Optional[dict[str, Iterable[str]]] = None, layer: Optional[str] = None, features: Optional[Union[list[str], Literal["human", "mouse", "drosophila"]]] = None, confidence_level: float = 0.95, n_perms: int = 1000, seed: Optional[int] = None, **kwargs: Any, ) -> pd.DataFrame: """Compute correlation of push-forward or pull-back distribution with features. Correlates a feature, e.g., counts of a gene, with probabilities of cells mapped to a set of cells such as the push-forward or pull-back distributions. .. seealso:: - TODO: create and link an example Parameters ---------- obs_key Key in :attr:`~anndata.AnnData.obs` containing the push-forward or pull-back distribution. corr_method Which type of correlation to compute, either ``'pearson'`` or ``'spearman'``. significance_method Mode to use when calculating p-values and confidence intervals. Valid options are: - ``'fisher'`` - Fisher transformation :cite:`fisher:21`. - ``'perm_test'`` - `permutation test <https://en.wikipedia.org/wiki/Permutation_test>`_. annotation How to subset the data when computing the correlation: - :obj:`None` - do not subset the data. - :class:`str` - key in :attr:`~anndata.AnnData.obs` where categorical data is stored. - :class:`dict` - a dictionary with one key corresponding to a categorical column in :attr:`~anndata.AnnData.obs` and values to a subset of categories. layer Key in :attr:`~anndata.AnnData.layers` from which to get the expression. If :obj:`None`, use :attr:`~anndata.AnnData.X`. features Features in :class:`~anndata.AnnData` to correlate with :attr:`obs['{obs_key}'] <anndata.AnnData.obs>`: - :obj:`None` - all features from :attr:`~anndata.AnnData.var` will be taken into account. - :obj:`list` - subset of :attr:`~anndata.AnnData.var_names` or :attr:`~anndata.AnnData.obs_names`. - ``'human'``, ``'mouse'``, or ``'drosophila'`` - the features are subsetted to the transcription factors from :func:`~moscot.utils.data.transcription_factors`. confidence_level Confidence level for the confidence interval calculation. Must be in interval :math:`[0, 1]`. n_perms Number of permutations to use when ``method = 'perm_test'``. seed Random seed when ``method = 'perm_test'``. kwargs Keyword arguments for parallelization, e.g., ``n_jobs``. Returns ------- Dataframe of shape ``(n_features, 5)`` containing the following columns, one for each feature: - ``'corr'`` - correlation between the count data and push/pull distributions. - ``'pval'`` - calculated p-values for double-sided test. - ``'qval'`` - corrected p-values using the `Benjamini-Hochberg <https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini%E2%80%93Hochberg_procedure>`_ method at :math:`0.05` level. - ``'ci_low'`` - lower bound of the ``confidence_level`` correlation confidence interval. - ``'ci_high'`` - upper bound of the ``confidence_level`` correlation confidence interval. """ if obs_key not in self.adata.obs: raise KeyError(f"Unable to access data in `adata.obs[{obs_key!r}]`.") if annotation is not None: annotation_key, annotation_vals = next(iter(annotation.items())) if annotation_key not in self.adata.obs: raise KeyError(f"Unable to access data in `adata.obs[{annotation_key!r}]`.") if not isinstance(annotation_vals, Iterable): raise TypeError("`annotation` expected to be dictionary of length 1 with value being an iterable.") adata = self.adata[self.adata.obs[annotation_key].isin(annotation_vals)] else: adata = self.adata adata = adata[~adata.obs[obs_key].isnull()] if not adata.n_obs: raise ValueError(f"`adata.obs[{obs_key!r}]` only contains NaN values.") distribution: pd.DataFrame = adata.obs[[obs_key]] if isinstance(features, str): tfs = transcription_factors(organism=features) features = list(set(tfs).intersection(adata.var_names)) if not features: raise KeyError("No common transcription factors found in the data base.") elif features is None: features = list(self.adata.var_names) return _correlation_test( X=sc.get.obs_df(adata, keys=features, layer=layer).values, Y=distribution, feature_names=features, corr_method=corr_method, significance_method=significance_method, confidence_level=confidence_level, n_perms=n_perms, seed=seed, **kwargs, )
[docs] def compute_entropy( self: AnalysisMixinProtocol[K, B], source: K, target: K, forward: bool = True, key_added: Optional[str] = "conditional_entropy", batch_size: Optional[int] = None, c: float = 1e-10, **kwargs: Any, ) -> Optional[pd.DataFrame]: """Compute the conditional entropy per cell. The conditional entropy reflects the uncertainty of the mapping of a single cell. Parameters ---------- source Source key. target Target key. forward If `True`, computes the conditional entropy of a cell in the source distribution, else the conditional entropy of a cell in the target distribution. key_added Key in :attr:`~anndata.AnnData.obs` where the entropy is stored. batch_size Batch size for the computation of the entropy. If :obj:`None`, the entire dataset is used. c Constant added to each row of the transport matrix to avoid numerical instability. kwargs Kwargs for :func:`~scipy.stats.entropy`. Returns ------- :obj:`None` if ``key_added`` is not None. Otherwise, returns a data frame of shape ``(n_cells, 1)`` containing the conditional entropy per cell. """ from scipy import stats filter_value = source if forward else target df = pd.DataFrame( index=self.adata[self.adata.obs[self._policy.key] == filter_value, :].obs_names, columns=[key_added] if key_added is not None else ["entropy"], ) batch_size = batch_size if batch_size is not None else len(df) func = self.push if forward else self.pull for batch in range(0, len(df), batch_size): cond_dists = func( source=source, target=target, data=None, subset=(batch, batch_size), normalize=True, return_all=False, scale_by_marginals=False, split_mass=True, key_added=None, ) df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs) # type: ignore[operator] if key_added is not None: self.adata.obs[key_added] = df return df if key_added is None else None