import networkx as nx
import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from sklearn.neighbors import NearestCentroid
from stlearn.pp import neighbors
from stlearn.spatial.clustering import localization
from stlearn.spatial.morphology import adjust
from stlearn.types import _METHOD
[docs]
def pseudotime(
adata: AnnData,
use_label: str = "leiden",
eps: float = 20,
n_neighbors: int = 25,
use_rep: str = "X_pca",
threshold: float = 0.01,
radius: int = 50,
method: _METHOD = "mean",
threshold_spots: int = 5,
use_sme: bool = False,
reverse: bool = False,
pseudotime_key: str = "dpt_pseudotime",
max_nodes: int = 4,
run_knn: bool = False,
copy: bool = False,
) -> AnnData | None:
"""\
Perform pseudotime analysis.
Parameters
----------
adata:
Annotated data matrix.
use_label:
Use label result of cluster method.
eps:
The maximum distance between two samples for one to be considered as
in the neighborhood of the other. This is not a maximum bound on the
distances of points within a cluster. This is the most important DBSCAN
parameter to choose appropriately for your data set and distance function.
threshold:
Threshold to find the significant connection for PAGA graph.
radius:
radius to adjust data for diffusion map
method:
method to adjust the data.
use_sme:
Use adjusted feature by SME normalization or not
reverse:
Reverse the pseudotime score
pseudotime_key:
Key to store pseudotime
max_nodes:
Maximum number of node in available paths
copy:
Return a copy instead of writing to adata.
Notes
-----
Each run clears any previously computed values for: X_diffmap,
X_draw_graph_fr, X_diffmap_morphology, split_node, global_graph,
centroid_dict, available_paths, threshold_spots, and sub_cluster_labels.
Returns
-------
Anndata
"""
keys_obsm = ["X_diffmap", "X_draw_graph_fr", "X_diffmap_morphology"]
keys_uns = [
"split_node",
"global_graph",
"centroid_dict",
"available_paths",
"threshold_spots",
]
keys_obs = ["sub_cluster_labels"]
for key in keys_obsm:
adata.obsm.pop(key, None)
for key in keys_uns:
adata.uns.pop(key, None)
for key in keys_obs:
if key in adata.obs.columns:
del adata.obs[key]
localization(adata, use_label=use_label, eps=eps)
# Running knn
if run_knn:
neighbors(adata, n_neighbors=n_neighbors, use_rep=use_rep, random_state=0)
# Running paga
sc.tl.paga(adata, groups=use_label)
# Denoising the graph
sc.tl.diffmap(adata)
if use_sme:
adjust(adata, use_data="X_diffmap", radius=radius, method=method)
adata.obsm["X_diffmap"] = adata.obsm["X_diffmap_morphology"]
# Get connection matrix
cnt_matrix = adata.uns["paga"]["connectivities"].toarray()
# Filter by threshold
cnt_matrix[cnt_matrix < threshold] = 0.0
cnt_matrix = pd.DataFrame(cnt_matrix)
# Mapping leiden label to subcluster
cat_inds = adata.uns[use_label + "_index_dict"]
split_node = {}
for label in adata.obs[use_label].cat.categories:
meaningful_sub = []
for i in adata.obs[adata.obs[use_label] == label][
"sub_cluster_labels"
].unique():
if (
len(adata.obs[adata.obs["sub_cluster_labels"] == str(i)])
> threshold_spots
):
meaningful_sub.append(i)
label = cat_inds[int(label)]
split_node[label] = meaningful_sub
adata.uns["threshold_spots"] = threshold_spots
# split_node has string keys for rest of code/plotting (names a strings)
adata.uns["split_node"] = {str(k): v for k, v in split_node.items()}
# Replicate leiden label row to prepare for subcluster connection
# matrix construction
replicate_list = np.array([])
for i in range(0, len(cnt_matrix)):
replicate_list = np.concatenate(
[replicate_list, np.array([i] * len(split_node[i]))]
)
# Connection matrix for subcluster
cnt_matrix = cnt_matrix.loc[replicate_list.astype(int), replicate_list.astype(int)]
# Replace column and index
cnt_matrix.columns = replace_with_dict(cnt_matrix.columns, split_node)
cnt_matrix.index = replace_with_dict(cnt_matrix.index, split_node)
# Sort column and index
cnt_matrix = cnt_matrix.loc[
selection_sort(np.array(cnt_matrix.columns)),
selection_sort(np.array(cnt_matrix.index)),
]
# Create a connection graph of subclusters
G = nx.from_pandas_adjacency(cnt_matrix)
G_nodes = list(range(len(G.nodes)))
node_convert = {}
for pair in zip(list(G.nodes), G_nodes):
node_convert[pair[1]] = pair[0]
adata.uns["global_graph"] = {}
adata.uns["global_graph"]["graph"] = nx.to_scipy_sparse_array(G)
adata.uns["global_graph"]["node_dict"] = node_convert
# Create centroid dict for subclusters
clf = NearestCentroid()
clf.fit(adata.obs[["imagecol", "imagerow"]].values, adata.obs["sub_cluster_labels"])
centroid_dict = dict(zip(clf.classes_.astype(int), clf.centroids_))
def closest_node(node, nodes):
nodes = np.asarray(nodes)
dist_2 = np.sum((nodes - node) ** 2, axis=1)
return np.argmin(dist_2)
for cl in adata.obs["sub_cluster_labels"].unique():
cl_points = adata.obs[adata.obs["sub_cluster_labels"] == cl][
["imagecol", "imagerow"]
].values
new_centroid = cl_points[closest_node(centroid_dict[int(cl)], cl_points)]
centroid_dict[int(cl)] = new_centroid
adata.uns["centroid_dict"] = centroid_dict
# Running diffusion pseudo-time
sc.tl.dpt(adata)
if reverse:
adata.obs[pseudotime_key] = 1 - adata.obs[pseudotime_key]
store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key)
return adata if copy else None
# Utils
def replace_with_dict(ar, dic):
# Extract out keys and values
k = np.array(list(dic.keys()), dtype=object)
v = np.array(list(dic.values()), dtype=object)
out = np.zeros_like(ar)
for key, val in zip(k, v):
out[ar == key] = val
return out
def selection_sort(x):
for i in range(len(x)):
swap = i + np.argmin(x[i:])
x[i], x[swap] = (x[swap], x[i])
return x
def store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key):
# Read original PAGA graph
G = nx.from_numpy_array(adata.uns["paga"]["connectivities"].toarray())
edge_weights = nx.get_edge_attributes(G, "weight")
G.remove_edges_from((e for e, w in edge_weights.items() if w < threshold))
H = G.to_directed()
# Calculate pseudotime for each node
node_pseudotime = {}
for node in H.nodes:
node_pseudotime[node] = adata.obs.query(use_label + " == '" + str(node) + "'")[
pseudotime_key
].max()
# Force original PAGA to directed PAGA based on pseudotime
edge_to_remove = []
for edge in H.edges:
if node_pseudotime[edge[0]] - node_pseudotime[edge[1]] > 0:
edge_to_remove.append(edge)
H.remove_edges_from(edge_to_remove)
# Extract all available paths
all_paths = {}
for source in H.nodes:
for target in H.nodes:
paths = nx.all_simple_paths(H, source=source, target=target)
for i, path in enumerate(paths):
if len(path) < max_nodes:
all_paths[str(i) + "_" + str(source) + "_" + str(target)] = path
adata.uns["available_paths"] = all_paths
print(
"All available trajectory paths are stored in adata.uns['available_paths'] "
+ "with length < "
+ str(max_nodes)
+ " nodes"
)