Source code for stlearn.pl.trajectory.local_plot

import matplotlib.pyplot as plt
import numpy as np
from anndata import AnnData
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d


[docs] def local_plot( adata: AnnData, use_cluster: int, use_label: str = "leiden", reverse: bool = False, cluster: int = 0, data_alpha: float = 1.0, arrow_alpha: float = 1.0, branch_alpha: float = 0.2, spot_size: float | int = 1, show_color_bar: bool = True, show_axis: bool = False, show_plot: bool = True, name: str | None = None, dpi: int = 150, output: str | None = None, copy: bool = False, ) -> AnnData | None: """\ Local spatial trajectory inference plot. Parameters ---------- adata Annotated data matrix. use_label Use label result of cluster method. use_cluster Choose a specific clusters that will display in the plot. data_alpha Opacity of the spot. arrow_alpha Opacity of the arrow. branch_alpha Opacity of the branch edge. edge_alpha Opacity of edge in PAGA graph in the tissue. node_alpha Opacity of node in PAGA graph in the tissue. spot_size Size of the spot. show_color_bar Show color bar or not. show_axis Show axis or not. show_legend Show legend or not. dpi Set dpi as the resolution for the plot. copy Return a copy instead of writing to adata. Returns ------- Nothing """ tmp = adata.obs[adata.obs[use_label] == str(use_cluster)] ref_cluster = adata[list(tmp.index)] # plt.rcParams['figure.dpi'] = dpi # plt.rcParams['figure.figsize'] = 5, 5 fig = plt.figure(figsize=(5, 5)) ax = fig.gca(projection="3d") centroids_ = [] classes_ = [] order_dict = {} order = 0 for i in ref_cluster.obs["sub_cluster_labels"].unique(): if ( len(adata.obs[adata.obs["sub_cluster_labels"] == str(i)]) > adata.uns["threshold_spots"] ): classes_.append(i) centroid_dict = adata.uns["centroid_dict"] centroid_dict = {int(key): centroid_dict[key] for key in centroid_dict} centroids_.append(centroid_dict[int(i)]) order_dict[int(i)] = int(order) order += 1 stdm = adata.uns["ST_distance_matrix"] non_abs_dpt = adata.uns["nonabs_dpt_distance_matrix"] for i in range(0, len(centroids_)): if i == len(centroids_) - 1: break j = 0 while j <= (len(centroids_) - 2 - i): j = j + 1 m = stdm[order_dict[int(classes_[i])], order_dict[int(classes_[i + j])]] dpt_distance = non_abs_dpt[ order_dict[int(classes_[i])], order_dict[int(classes_[i + j])] ] y = calculate_y(np.abs(m)) x = np.linspace(centroids_[i][0], centroids_[i + j][0], 1000) z = np.linspace(centroids_[i][1], centroids_[i + j][1], 1000) if reverse: dpt_distance = -dpt_distance if dpt_distance <= 0: xyz = ([x[500], x[520]], [y[500], y[520]], [z[500], z[520]]) else: xyz = ([x[520], x[500]], [y[520], y[500]], [z[520], z[500]]) arrow = Arrow3D( xyz[0], xyz[1], xyz[2], zorder=10, mutation_scale=7, lw=1, arrowstyle="->", color="r", alpha=arrow_alpha, ) ax.add_artist(arrow) ax.text( x[500], y[500] - 0.15, z[500], np.round(np.abs(m), 3), color="black", size=5, zorder=100, ) sc = ax.scatter( ref_cluster.obs["imagecol"], 0, ref_cluster.obs["imagerow"], c=ref_cluster.obs["dpt_pseudotime"], s=spot_size, cmap="viridis", zorder=0, alpha=data_alpha, ) _ = ax.scatter( adata.obs[adata.obs[use_label] != cluster]["imagecol"], 0, adata.obs[adata.obs[use_label] != cluster]["imagerow"], c="grey", s=spot_size, zorder=0, alpha=0.1, ) if show_color_bar: cb = fig.colorbar(sc, cax=fig.add_axes([0.1, 0.3, 0.03, 0.5])) cb.outline.set_visible(False) ax.set_ylim([-1, 0]) ax.set_xlim([min(adata.obs["imagecol"]) - 10, max(adata.obs["imagecol"]) + 10]) ax.set_zlim([min(adata.obs["imagerow"]) - 10, max(adata.obs["imagerow"]) + 10]) if not show_axis: # make the grid lines transparent ax.xaxis._axinfo["grid"]["color"] = (1, 1, 1, 0) ax.yaxis._axinfo["grid"]["color"] = (1, 1, 1, 0) ax.zaxis._axinfo["grid"]["color"] = (1, 1, 1, 0) ax.get_xaxis().set_ticks([]) ax.get_yaxis().set_ticks([]) ax.get_zaxis().set_ticks([]) ax.invert_zaxis() ax.patch.set_visible(False) if show_plot: plt.show() # plt.rcParams['figure.figsize'] = 6, 4 if output is not None: if name is None: print("The file name is not defined!") name = use_label fig.savefig(output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0) return adata
def calculate_y(m): import math mu = 0 variance = 1 sigma = math.sqrt(variance) y = np.linspace(mu - 0.5 * sigma, mu + 0.5 * sigma, 1000) y = np.cos(np.absolute(y)) y = -(m) * (y - np.min(y)) / np.ptp(y) return y class Arrow3D(FancyArrowPatch): def __init__(self, xs, ys, zs, *args, **kwargs): FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) self._verts3d = xs, ys, zs def draw(self, renderer): xs3d, ys3d, zs3d = self._verts3d xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) FancyArrowPatch.draw(self, renderer)