Source code for stlearn.pl.trajectory.pseudotime_plot

import matplotlib
import networkx as nx
import numpy as np
from anndata import AnnData
from matplotlib import pyplot as plt
from numpy._typing import NDArray

from stlearn.pl.utils import get_cluster, get_node
from stlearn.utils import _read_graph


[docs] def pseudotime_plot( adata: AnnData, library_id: str | None = None, use_label: str = "leiden", pseudotime_key: str = "dpt_pseudotime", list_clusters: str | list[str] | None = None, cell_alpha: float = 1.0, image_alpha: float = 1.0, edge_alpha: float = 0.8, node_alpha: float = 1.0, spot_size: float | int = 6.5, node_size: float = 5, show_color_bar: bool = True, show_axis: bool = False, show_graph: bool = True, show_trajectories: bool = False, reverse: bool = False, show_node: bool = True, show_plot: bool = True, cropped: bool = True, margin: int = 100, dpi: int = 150, output: str | None = None, name: str | None = None, ax=None, ) -> AnnData | None: """\ Global trajectory inference plot (Only DPT). Parameters ---------- adata Annotated data matrix. library_id Library id stored in AnnData. use_label Use label result of cluster method. list_clusters Choose set of clusters that will display in the plot. cell_alpha Opacity of the spot. image_alpha Opacity of the tissue. edge_alpha Opacity of edge in PAGA graph in the tissue. node_alpha Opacity of node in PAGA graph in the tissue. cmap Color map to use. spot_size Size of the spot. node_size Size of node in PAGA graph in the tissue. show_color_bar Show color bar or not. show_axis Show axis or not. show_graph Show PAGA graph or not. show_legend Show legend or not. show_plot Show plot or not dpi DPI of the output figure. output The output folder of the plot. name The filename of the plot. Returns ------- Nothing """ checked_list_clusters: list[str] if list_clusters is None: checked_list_clusters = adata.obs[use_label].cat.categories elif isinstance(list_clusters, str): checked_list_clusters = [list_clusters] else: checked_list_clusters = list_clusters imagecol = adata.obs["imagecol"] imagerow = adata.obs["imagerow"] tmp = adata.obs G = _read_graph(adata, "global_graph") labels = nx.get_edge_attributes(G, "weight") result = [] query_node = list( map(int, get_node(checked_list_clusters, adata.uns["split_node"])) ) for edge in G.edges(query_node): if (edge[0] in query_node) and (edge[1] in query_node): result.append(edge) result2 = [] for edge in result: try: result2.append(labels[edge] + 0.5) except: result2.append(labels[edge[::-1]] + 0.5) fig, a = plt.subplots() if ax is not None: a = ax centroid_dict: dict[int, NDArray[np.float64]] = adata.uns["centroid_dict"] centroid_dict = {int(key): centroid_dict[key] for key in centroid_dict} dpt = adata.obs[pseudotime_key] vmin = min(dpt) vmax = max(dpt) # Plot scatter plot based on pixel of spots from sklearn.preprocessing import MinMaxScaler scaler = MinMaxScaler() scale = scaler.fit_transform(tmp[pseudotime_key].values.reshape(-1, 1)).reshape( -1, 1 ) plot = a.scatter( tmp["imagecol"], tmp["imagerow"], edgecolor="none", alpha=cell_alpha, s=spot_size, marker="o", vmin=vmin, vmax=vmax, cmap=plt.get_cmap("viridis"), c=scale.reshape(1, -1)[0], ) used_colors = adata.uns[use_label + "_colors"] cmaps = matplotlib.colors.LinearSegmentedColormap.from_list("", used_colors) cmap = plt.get_cmap(cmaps) if show_graph: nx.draw_networkx( G, pos=centroid_dict, node_size=1, font_size=0, linewidths=2, edgelist=result, width=result2, alpha=edge_alpha, edge_color="#333333", ) for node, pos in centroid_dict.items(): if str(node) in get_node(checked_list_clusters, adata.uns["split_node"]): a.text( pos[0], pos[1], get_cluster(str(node), adata.uns["split_node"]), color="white", fontsize=node_size, zorder=100, bbox=dict( facecolor=cmap( int(get_cluster(str(node), adata.uns["split_node"])) / (len(used_colors) - 1) ), boxstyle="circle", alpha=node_alpha, ), ) if show_trajectories: used_colors = adata.uns[use_label + "_colors"] cmaps = matplotlib.colors.LinearSegmentedColormap.from_list("", used_colors) cmap = plt.get_cmap(cmaps) if "PTS_graph" not in adata.uns: raise ValueError("Please run stlearn.spatial.trajectory.pseudotimespace!") tmp = _read_graph(adata, "PTS_graph") G = tmp.copy() remove = [edge for edge in G.edges if 9999 in edge] G.remove_edges_from(remove) G.remove_node(9999) centroid_dict = adata.uns["centroid_dict"] centroid_dict = {int(key): centroid_dict[key] for key in centroid_dict} if reverse: nx.draw_networkx_edges( G, pos=centroid_dict, node_size=10, alpha=1.0, width=2.5, edge_color="#f4efd3", arrowsize=17, arrowstyle="<|-", connectionstyle="arc3,rad=0.2", ) else: nx.draw_networkx_edges( G, pos=centroid_dict, node_size=10, alpha=1.0, width=2.5, edge_color="#f4efd3", arrowsize=17, arrowstyle="-|>", connectionstyle="arc3,rad=0.2", ) if show_node: for node, pos in centroid_dict.items(): if str(node) in get_node( checked_list_clusters, adata.uns["split_node"] ): a.text( pos[0], pos[1], str(get_cluster(str(node), adata.uns["split_node"])), color="black", fontsize=8, zorder=100, bbox=dict( facecolor=cmap( get_cluster(str(node), adata.uns["split_node"]) / (len(used_colors) - 1) ), boxstyle="circle", alpha=1, ), ) if show_color_bar: cb = plt.colorbar(plot, cmap="viridis") cb.outline.set_visible(False) if library_id is None: library_id = list(adata.uns["spatial"].keys())[0] image = adata.uns["spatial"][library_id]["images"][ adata.uns["spatial"][library_id]["use_quality"] ] a.imshow( image, alpha=image_alpha, zorder=-1, ) if not show_axis: a.axis("off") if cropped: a.set_xlim(imagecol.min() - margin, imagecol.max() + margin) a.set_ylim(imagerow.min() - margin, imagerow.max() + margin) a.set_ylim(a.get_ylim()[::-1]) # plt.gca().invert_yaxis() if output is not None and name is not None: fig.savefig(output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0) if show_plot: plt.show() return adata