Source code for stlearn.pl.trajectory.transition_markers_plot

from decimal import Decimal

import matplotlib.pyplot as plt
from anndata import AnnData


[docs] def transition_markers_plot( adata: AnnData, trajectory: str, top_genes: int = 10, dpi: int = 150, output: str | None = None, name: str | None = None, ) -> AnnData | None: """\ Plot transition marker. Parameters ---------- adata Annotated data matrix. trajectory Name of a clade/branch user wants to plot transition markers. top_genes Top genes users want to display in the plot. dpi The resolution of the plot. output The output folder of the plot. name The filename of the plot. Returns ------- Anndata """ if trajectory not in adata.uns: raise ValueError( "Please input the right trajectory name - not found in adata.uns!" ) pos = ( adata.uns[trajectory][adata.uns[trajectory]["score"] >= 0] .sort_values("score", ascending=False) .reset_index(drop=True)[:top_genes] ) neg = ( adata.uns[trajectory][adata.uns[trajectory]["score"] < 0] .sort_values("score") .reset_index(drop=True)[:top_genes] ) y = range(top_genes) x1 = list(neg["score"])[::-1] x2 = list(pos["score"])[::-1] pos = pos[::-1] neg = neg[::-1] if len(x1) < top_genes: for i in range(len(x1), top_genes): x1.append(0) if len(x2) < top_genes: for i in range(len(x2), top_genes): x2.append(0) fig, axes = plt.subplots(ncols=2, sharey=True) axes[0].barh(y, x1, align="center", color="#fb687a") axes[1].barh(y, x2, align="center", color="#31a2fb") fig.subplots_adjust(wspace=0) axes[0].spines["left"].set_visible(False) axes[0].spines["right"].set_visible(False) axes[0].spines["top"].set_visible(False) axes[1].spines["left"].set_visible(False) axes[1].spines["right"].set_visible(False) axes[1].spines["top"].set_visible(False) axes[0].get_yaxis().set_ticks([]) axes[1].get_yaxis().set_ticks([]) axes[0].tick_params(axis="both", which="both", length=0) axes[1].tick_params(axis="both", which="both", length=0) for i, x in enumerate([x1, x2]): if all(value == 0 for value in x): if i == 0: axes[0].get_xaxis().set_ticks([]) if i == 1: axes[1].get_xaxis().set_ticks([]) rects = axes[1].patches for i, rect in enumerate(rects): try: gene_name = list(pos["gene"])[i] p_value = "{:.2E}".format(Decimal(str(list(pos["p-value"])[i]))) except: gene_name = "" p_value = "" alignment = {"horizontalalignment": "left", "verticalalignment": "center"} axes[1].text( rect.get_x() + rect.get_width() + 0.01, rect.get_y() + rect.get_height() / 2.0, gene_name, **alignment, size=6, ) axes[1].text( rect.get_x() + 0.01, rect.get_y() + rect.get_height() / 2.0, p_value, color="w", **alignment, size=6, ) rects = axes[0].patches for i, rect in enumerate(rects): try: gene_name = list(neg["gene"])[i] p_value = "{:.2E}".format(Decimal(str(list(neg["p-value"])[i]))) except: gene_name = "" p_value = "" alignment = {"horizontalalignment": "right", "verticalalignment": "center"} axes[0].text( rect.get_x() + rect.get_width() - 0.01, rect.get_y() + rect.get_height() / 2.0, gene_name, **alignment, size=6, ) axes[0].text( rect.get_x() - 0.01, rect.get_y() + rect.get_height() / 2.0, p_value, color="w", **alignment, size=6, ) plt.figtext(0.5, 0.9, trajectory, ha="center", va="center") axes[0].set_xlabel("Spearman correlation coefficient") axes[0].xaxis.set_label_coords(1, -0.1) axes[0].grid(False) axes[1].grid(False) if name is None: name = trajectory if output is not None and name is not None: fig.savefig(output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0) plt.show() return adata