Source code for stlearn.plotting.cci_plot

from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import matplotlib
import pandas as pd
import numpy as np
import networkx as nx
import math
import matplotlib.patches as patches
from numba.typed import List
import seaborn as sns
import sys
from anndata import AnnData
from typing import Optional, Union

from typing import Optional, Union, Mapping  # Special
from typing import Sequence, Iterable  # ABCs
from typing import Tuple  # Classes

import warnings

from .classes import CciPlot, LrResultPlot
from .classes_bokeh import BokehSpatialCciPlot, BokehLRPlot
from ._docs import doc_spatial_base_plot, doc_het_plot, doc_lr_plot
from ..utils import Empty, _empty, _AxesSubplot, _docs_params
from .utils import get_cmap, check_cmap, get_colors
from .cluster_plot import cluster_plot
from .deconvolution_plot import deconvolution_plot
from .gene_plot import gene_plot
from stlearn.plotting.utils import get_colors
import stlearn.plotting.cci_plot_helpers as cci_hs
from .cci_plot_helpers import (
    get_int_df,
    add_arrows,
    create_flat_df,
    _box_map,
    chordDiagram,
)
from scipy.stats import gaussian_kde

import importlib

importlib.reload(cci_hs)

from bokeh.io import push_notebook, output_notebook
from bokeh.plotting import show

#### Functions for visualising the overall LR results and diagnostics.


[docs]def lr_diagnostics( adata, highlight_lrs: list = None, n_top: int = None, color0: str = "turquoise", color1: str = "plum", figsize: tuple = (10, 4), lr_text_fp: dict = None, show: bool = True, ): """Diagnostic plot looking at relationship between technical features of lrs and lr rank. Two plots generated: left is the average of the median for nonzero expressing spots for both the ligand and the receptor on the y-axis, & LR-rank by no. of significant spots on the x-axis. Right is the average of the proportion of zeros for the ligand and receptor gene on teh y-axis. Parameters ---------- adata: AnnData The data object on which st.tl.cci.run has been applied. highlight_lrs: list List of LRs to highlight, will add text and change point color for these LR pairs. n_top: int The number of LRs to display. If None shows all. color0: str The color of the nonzero-median scatter plot. lr_text_fp: dict Font dict for the LR text if highlight_lrs not None. axis_text_fp: dict Font dict for the axis text labels. Returns ------- Figure, Axes Figure and axes of the plot, if show=False. """ if type(n_top) == type(None): n_top = adata.uns["lr_summary"].shape[0] fig, axes = plt.subplots(ncols=2, figsize=figsize) cci_hs.lr_scatter( adata, "nonzero-median", highlight_lrs=highlight_lrs, n_top=n_top, color=color0, ax=axes[0], lr_text_fp=lr_text_fp, show=False, ) cci_hs.lr_scatter( adata, "zero-prop", highlight_lrs=highlight_lrs, n_top=n_top, color=color1, ax=axes[1], lr_text_fp=lr_text_fp, show=False, ) if show: plt.show() else: return fig, axes
[docs]def lr_summary( adata, n_top: int = 50, highlight_lrs: list = None, y: str = "n_spots_sig", color: str = "gold", figsize: tuple = None, highlight_color: str = "red", max_text: int = 50, lr_text_fp: dict = None, ax: Axes = None, show: bool = True, ): """Plotting the top LRs ranked by number of significant spots. Parameters ---------- adata: AnnData The data object on which st.tl.cci.run has been applied. n_top: int The no. of LRs to plot. highlight_lrs: list A list of LRs to highlight on the plot, will added text and change color of points for these LRs. Useful for highlighting LRs of interest. y: str The way to rank the LRs, default is by the no. of signifcant spots, but can be any column in adata.uns['lr_summary']. color: str The color of the points. figsize: tuple Size of the figure; (width, height). highlight_color: str Only relevant if highlight_lrs specified; controls colour of LRs to highlight. max_text: int If the no. of n_top is above this limit, stop showing text to indicate the LR names. Allows to see global shape without crowding with LR name text. lr_text_fp: dict Matplotlib font dictionary specifying text details, eg fontsize. ax: Axes Axes on which to draw the scatter plot; if not inputted constructs own. show: bool Whether to show the plot, if False will return the ax. Returns _______ Axes If show=False, returns the ax for additional modification. """ allowed = ["n_spots", "n_spots_sig", "n_spots_sig_pval"] if y not in allowed: raise Exception(f"Got {y} for y; must be one of {allowed}") return cci_hs.lr_scatter( adata, y, n_top=n_top, color=color, show_all=n_top <= max_text, figsize=figsize, highlight_lrs=highlight_lrs, ax=ax, lr_text_fp=lr_text_fp, highlight_color=highlight_color, show=show, )
[docs]def lr_n_spots( adata, n_top: int = 100, font_dict: dict = None, xtick_dict: dict = None, bar_width: float = 1, max_text: int = 50, non_sig_color: str = "dodgerblue", sig_color: str = "springgreen", figsize: tuple = (6, 4), show_title: bool = True, show: bool = True, ): """Bar plot showing for each LR no. of sig versus non-sig spots. Parameters ---------- adata: AnnData Data on which st.tl.cci.run has been applied. n_top: int The no. of LRs to display, ranked by adata.uns['lr_summary'] font_dict: dict dictionary specifying matplotlib font parameters e.g. weight. xtick_dict: dict dictionary specifying matplotlib font parameters for x labels e.g. weight. bar_width: int Width of each bar in the bar plot. max_text: int If n_top exceeds this number, stop showing the LR text, since can cause crowding. non_sig_color: str Specifies the color for bar plot proportion indicating no. of non-sig spots. sig_color: str Specifies color for bars indicating the sig. spots counts. figsize: tuple Specifies figure dimensions. show_title: bool Whether to show title on outputted plot. show: bool Whether to show the plot; if false returns the figure & axes for further modification. Returns ------- Fig, Axes Figure & axes with the plot draw on; only if show=False. Else None. """ if type(font_dict) == type(None): font_dict = {"weight": "bold", "size": 12} if type(xtick_dict) == type(None): xtick_dict = {"fontweight": "bold", "rotation": 90, "size": 6} lrs = adata.uns["lr_summary"].index.values[0:n_top] n_sig = adata.uns["lr_summary"].loc[:, "n_spots_sig"].values n_non_sig = adata.uns["lr_summary"].loc[:, "n_spots"].values - n_sig rank = list(range(len(n_sig))) fig, ax = plt.subplots(figsize=figsize) ax.bar(rank[0:n_top], n_non_sig[0:n_top], bar_width, color=non_sig_color) ax.bar( rank[0:n_top], n_sig[0:n_top], bar_width, bottom=n_non_sig[0:n_top], color=sig_color, ) ax.set_ylabel("n_spots", font_dict) ax.set_xlabel("LRs Ranked (n_spots_sig)", font_dict) if show_title: ax.set_title("Signficant and non-signficant spots per LR", font_dict) ax.legend(labels=["non-sig", "sig"], loc="upper right") if n_top <= max_text: ax.set_xticks(rank[0:n_top]) ax.set_xticklabels(lrs, fontdict=xtick_dict) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) if show: plt.show() else: return fig, ax
[docs]def lr_go( adata, n_top: int = 20, highlight_go: list = None, figsize=(6, 4), rot: float = 50, lr_text_fp: dict = None, highlight_color: str = "yellow", max_text: int = 50, show: bool = True, ): """Plots the results from the LR GO analysis. Parameters ---------- adata: AnnData Data object on which st.tl.cci.lr & st.tl.cci.run_lr_go has been called. n_top: int Specifies the no. of GO terms to show. highlight_go: list Names of GO terms to highlight in different color & text. figsize: tuple Size of the figure. rot: int Rotation of the text labels for the GO terms. lr_text_fp: dict Dictionary specifying matplotlib text params that control rendering of label on the points. highlight_color: str The color to plot the GO terms specified in highlight_go. max_text: int If n_top exceeds this number, stop showing the GO text, since can cause crowding. show: bool Whether to show the plot. """ # Making sure LR GO has been run # if "lr_go" not in adata.uns: raise Exception("Need to run st.tl.cci.run_lr_go() first!") go_results = adata.uns["lr_go"] gos = go_results.loc[:, "Description"].values.astype(str) y = -np.log10(go_results.loc[:, "p.adjust"].values) sizes = go_results.loc[:, "Count"].values cci_hs.rank_scatter( gos[0:n_top], y[0:n_top], point_sizes=sizes[0:n_top], highlight_items=highlight_go, lr_text_fp=lr_text_fp, highlight_color=highlight_color, figsize=figsize, y_label="-log10(padjs)", x_label="GO Rank", height=6, color="deepskyblue", rot=rot, width_ratio=0.4, show=show, point_size_name="n-genes", show_all=n_top <= max_text, )
[docs]def cci_check( adata: AnnData, use_label: str, figsize=(16, 10), cell_label_size=20, axis_text_size=18, tick_size=14, show=True, ): """Checks relationship between no. of significant CCI-LR interactions and cell type frequency. Parameters ---------- adata: AnnData Data on which st.tl.cci.run & st.tl.cci.run_cci has been performed. use_label: str The cell type label information used when running st.tl.cci.run_cci figsize: tuple Size of outputted figure. cell_label_size: int Size of the cell labels put on top of the bar chart. axis_text_size: int Size of the axis text. tick_size: int Size of the ticks displayed at bottom of chart. show: bool Whether to show the plot or not; if false returns figure & axes. Returns ------- Figure, Ax1, Ax2 The figure, axes for the barchart, and twin axes for the lineplot. """ labels = adata.obs[use_label].values.astype(str) label_set = np.array(list(adata.obs[use_label].cat.categories)) colors = get_colors(adata, use_label) xs = np.array(list(range(len(label_set)))) int_dfs = adata.uns[f"per_lr_cci_{use_label}"] # Counting!!! # cell_counts = [] # Cell type frequencies cell_sigs = [] # Cell type significant interactions for j, label in enumerate(label_set): counts = sum(labels == label) cell_counts.append(counts) int_count = 0 for lr in int_dfs: int_df = int_dfs[lr] label_index = np.where(int_df.index.values == label)[0][0] int_bool = int_df.values > 0 int_count += sum(int_bool[label_index, :]) int_count += sum(int_bool[:, label_index]) # prevent double counts int_count -= int_bool[label_index, label_index] cell_sigs.append(int_count) cell_counts = np.array(cell_counts) cell_sigs = np.array(cell_sigs) order = np.argsort(cell_counts) cell_counts = cell_counts[order] cell_sigs = cell_sigs[order] colors = np.array(colors)[order] label_set = label_set[order] # Plotting bar plot # fig, ax = plt.subplots(figsize=figsize) ax.bar(xs, cell_counts, color=colors) text_dist = max(cell_counts) * 0.015 fontdict = {"fontweight": "bold", "fontsize": cell_label_size} for j in range(len(xs)): ax.text( xs[j], cell_counts[j] + text_dist, label_set[j], rotation=90, fontdict=fontdict, ) axis_text_fp = {"fontweight": "bold", "fontsize": axis_text_size} ax.set_ylabel("Cell type frequency", color="black", **axis_text_fp) ax.spines["top"].set_visible(False) ax.tick_params(labelsize=tick_size) ax.set_xlabel("Cell type rank", **axis_text_fp) # Line-plot of the interaction counts # ax2 = ax.twinx() ax2.set_ylabel("CCI-LR interactions", color="blue", **axis_text_fp) ax2.plot(xs, cell_sigs, color="blue", linewidth=2) ax2.tick_params(axis="y", labelcolor="blue", labelsize=tick_size) ax2.spines["top"].set_visible(False) ax2.tick_params(labelsize=tick_size) fig.tight_layout() if show: plt.show() else: return fig, ax, ax2
# Functions for visualisation the LR results per spot.
[docs]def lr_result_plot( adata: AnnData, use_lr: Optional["str"] = None, use_result: Optional["str"] = "lr_sig_scores", # plotting param title: Optional["str"] = None, figsize: Optional[Tuple[float, float]] = None, cmap: Optional[str] = "Spectral_r", ax: Optional[matplotlib.axes._subplots.Axes] = None, fig: Optional[matplotlib.figure.Figure] = None, show_plot: Optional[bool] = True, show_axis: Optional[bool] = False, show_image: Optional[bool] = True, show_color_bar: Optional[bool] = True, zoom_coord: Optional[float] = None, crop: Optional[bool] = True, margin: Optional[float] = 100, size: Optional[float] = 7, image_alpha: Optional[float] = 1.0, cell_alpha: Optional[float] = 1.0, use_raw: Optional[bool] = False, fname: Optional[str] = None, dpi: Optional[int] = 120, contour: bool = False, step_size: Optional[int] = None, vmin: float = None, vmax: float = None, ): """Plots the per spot statistics for given LR. Parameters ---------- adata: AnnData Data on which st.tl.cci.run has been performed use_lr: str LR to show results for. use_result: str LR matrix in data.obsm; 'lr_scores', 'lr_sig_scores', 'p_vals', 'p_adjs' '-log10(p_adjs)'. title: str Plot title. figsize: tuple Figure size. cmap: str Color of points. ax: Axes Axes on which to plot. fig: Figure Figure associated with axes. show_plot: bool Whether to show plot or not. show_axis: bool Whether to show axis or not. show_image: bool Whether to plot the image. show_color_bar: bool Whether to show the color bar. crop: bool Whether to crop the image down to match spatial coordinates of gene spot present. margin: float Margin around the points for the image. size: float Size of the points. image_alpha: float Transparency of image. cell_alpha: float Transparency of points. use_raw: bool Whether to use adata.raw or not. fname: str Name of file to save plot to. dpi: int Plot saving quality. contour: bool Whether to plot as contour. step_size: int Step size of contour==True vmin: float Minimum value of scale bar. vmax: float Maximum value of scale bar. """ LrResultPlot( adata, use_lr, use_result, # plotting param title, figsize, cmap, None, ax, fig, show_plot, show_axis, show_image, show_color_bar, crop, zoom_coord, margin, size, image_alpha, cell_alpha, use_raw, fname, dpi, # cci_rank param contour, step_size, vmin, vmax, )
# @_docs_params(het_plot=doc_lr_plot)
[docs]def lr_plot( adata: AnnData, lr: str, min_expr: float = 0, sig_spots=True, use_label: str = None, outer_mode: str = "continuous", l_cmap=None, r_cmap=None, lr_cmap=None, inner_cmap=None, inner_size_prop: float = 0.25, middle_size_prop: float = 0.5, outer_size_prop: float = 1, pt_scale: int = 100, title="", show_image: bool = True, show_arrows: bool = False, fig: Figure = None, ax: Axes = None, arrow_head_width: float = 4, arrow_width: float = 0.001, arrow_cmap: str = None, arrow_vmax: float = None, sig_cci: bool = False, lr_colors: dict = None, figsize: tuple = (6.4, 4.8), use_mix: bool = None, # plotting params **kwargs, ) -> Optional[AnnData]: """Creates different kinds of spatial visualisations for the LR analysis results. To see combinations of parameters refer to stLearn CCI tutorial. Parameters ---------- adata: Anndata Data on which st.tl.cci.run has been performed; extra options unlocked below when have performed st.tl.cci.run_cci as well. lr: str The LR to display results for. min_expr: float The minimum expr above which LR considered expressed when plotting binary LR expression. sig_spots: bool Whether to subset to significant spots or not. use_label: str The cell type labels to use if plotting cell types. outer_mode: str The mode for the larger points when displaying LR expression; can either be 'binary' or 'continuous' or None. 'Binary' discretizes each spot as expressing L, R, both, or neither. 'Continuous' shows color gradient for levels of LR expression by plotting two points for each spot, the 'inner' point is the receptor expression levels, and the 'outer' point is the ligand expression level. None plots no ligand/receptor expression. l_cmap: str Cmap for coloring the ligand expression, only if outer_mode=='continuous'. r_cmap: str Cmap for coloring the receptor expression, only if outer_mode=='continuous'. lr_cmap: str Cmap for coloring coexpression. inner_cmap: str Cmap for the inner point if outer_mode is 'binary'. inner_size_prop: float Proportion of the inner point size when plotting to points for one spot. Scale of 0 to 1. middle_size_prop: float Controls size of middle point if specifying parameters that plot 3 points per spot to display multiple information. Scale 0 to 1. outer_size_prop: float Point size of the outer point. pt_scale: float Overall size of point. title: str Title of figure. show_image: bool Whether to show the background image. show_arrows: bool Whether to plot arrows indicating interactions between spots. fig: Figure Figure to draw on. ax: Axes Axes to draw on. arrow_head_width: float Width of arrow head; only if show_arrows is true. arrow_width: float Width the the arrow body; only if show_arrows is true. arrow_cmap: float Cmap to color arrows; default is black arrows, but if specified will color the arrow by the average expression of the ligand and receptor of the spots connected by the arrow. arrow_vmax: float Maximum value of the arrow colour bar. sig_cci: bool Whether to only show results which involve signficant celltype-celltype interactions; particularly relevant when plotting the arrows. lr_colors: dict Specifies the colors of the LRs when plotting with outer_mode='binary'; structures is {'l': color, 'r': color, 'lr': color, '': color}; the last key-value indicates colour for spots not expressing the ligand or receptor. figsize: tuple (width, height) of figure if not inputted. kwargs: Extra arguments parsed to plotting functions used internally. """ # Input checking # l, r = lr.split("_") ran_lr = "lr_summary" in adata.uns ran_sig = False if not ran_lr else "n_spots_sig" in adata.uns["lr_summary"].columns if ran_lr and lr in adata.uns["lr_summary"].index: if ran_sig: lr_sig = adata.uns["lr_summary"].loc[lr, :].values[1] > 0 else: lr_sig = True else: lr_sig = False if sig_spots and not ran_lr: raise Exception( "No LR results testing results found, " "please run st.tl.cci_rank.run first, or set sig_spots=False." ) elif sig_spots and not lr_sig: raise Exception( "LR has no significant spots, to visualise anyhow set" "sig_spots=False" ) # Making sure have run_cci first with respective labelling # if ( show_arrows and sig_cci and use_label and f"per_lr_cci_{use_label}" not in adata.uns ): raise Exception( "Cannot subset arrow interactions to significant ccis " "without performing st.tl.run_cci with " f"use_label={use_label} first." ) # Getting which are the allowed stats for the lr to plot # if not ran_sig: lr_use_labels = ["lr_scores"] else: lr_use_labels = [ "lr_scores", "p_val", "p_adj", "-log10(p_adj)", "lr_sig_scores", ] if type(use_mix) != type(None) and use_mix not in adata.uns: raise Exception( f"Specified use_mix, but no deconvolution results added " "to adata.uns matching the use_mix ({use_mix}) key." ) elif ( type(use_label) != type(None) and use_label in lr_use_labels and ran_sig and not lr_sig ): raise Exception( f"Since use_label refers to lr stats & ran permutation testing, " f"LR needs to be significant to view stats." ) elif ( type(use_label) != type(None) and use_label not in adata.obs.keys() and use_label not in lr_use_labels ): raise Exception( f"use_label must be in adata.obs or " f"one of lr stats: {lr_use_labels}." ) out_options = ["binary", "continuous", None] if outer_mode not in out_options: raise Exception(f"{outer_mode} should be one of {out_options}") if l not in adata.var_names or r not in adata.var_names: raise Exception("L or R not found in adata.var_names.") # Whether to show just the significant spots or all spots lr_index = np.where(adata.uns["lr_summary"].index.values == lr)[0][0] sig_bool = adata.obsm["lr_sig_scores"][:, lr_index] > 0 if sig_spots: adata_full = adata adata = adata[sig_bool, :] else: adata_full = adata # Dealing with the axis # if type(fig) == type(None) or type(ax) == type(None): fig, ax = plt.subplots(figsize=figsize) expr = adata.to_df() l_expr = expr.loc[:, l].values r_expr = expr.loc[:, r].values # Adding binary points of the ligand/receptor pair # if outer_mode == "binary": l_bool, r_bool = l_expr > min_expr, r_expr > min_expr lr_binary_labels = [] for i in range(len(l_bool)): if l_bool[i] and not r_bool[i]: lr_binary_labels.append(l) elif not l_bool[i] and r_bool[i]: lr_binary_labels.append(r) elif l_bool[i] and r_bool[i]: lr_binary_labels.append(lr) elif not l_bool[i] and not r_bool[i]: lr_binary_labels.append("") lr_binary_labels = pd.Series( np.array(lr_binary_labels), index=adata.obs_names ).astype("category") adata.obs[f"{lr}_binary_labels"] = lr_binary_labels if type(lr_cmap) == type(None): lr_cmap = "default" # This gets ignored due to setting colours below if type(lr_colors) == type(None): lr_colors = { l: matplotlib.colors.to_hex("r"), r: matplotlib.colors.to_hex("limegreen"), lr: matplotlib.colors.to_hex("b"), "": "#836BC6", # Neutral color in H&E images. } label_set = adata.obs[f"{lr}_binary_labels"].cat.categories adata.uns[f"{lr}_binary_labels_colors"] = [ lr_colors[label] for label in label_set ] else: lr_cmap = check_cmap(lr_cmap) cluster_plot( adata, use_label=f"{lr}_binary_labels", cmap=lr_cmap, size=outer_size_prop * pt_scale, crop=False, ax=ax, fig=fig, show_image=show_image, show_plot=False, **kwargs, ) # Showing continuous gene expression of the LR pair # elif outer_mode == "continuous": if type(l_cmap) == type(None): l_cmap = matplotlib.colors.LinearSegmentedColormap.from_list( "lcmap", [(0, 0, 0), (0.5, 0, 0), (0.75, 0, 0), (1, 0, 0)] ) else: l_cmap = check_cmap(l_cmap) if type(r_cmap) == type(None): r_cmap = matplotlib.colors.LinearSegmentedColormap.from_list( "rcmap", [(0, 0, 0), (0, 0.5, 0), (0, 0.75, 0), (0, 1, 0)] ) else: r_cmap = check_cmap(r_cmap) gene_plot( adata, gene_symbols=l, size=outer_size_prop * pt_scale, cmap=l_cmap, color_bar_label=l, ax=ax, fig=fig, crop=False, show_image=show_image, **kwargs, ) gene_plot( adata, gene_symbols=r, size=middle_size_prop * pt_scale, cmap=r_cmap, color_bar_label=r, ax=ax, fig=fig, crop=False, show_image=show_image, **kwargs, ) # Adding the cell type labels # if type(use_label) != type(None): if use_label in lr_use_labels: inner_cmap = inner_cmap if type(inner_cmap) != type(None) else "copper" # adata.obsm[f'{lr}_{use_label}'] = adata.uns['per_lr_results'][ # lr].loc[adata.obs_names,use_label].values lr_result_plot( adata, use_lr=lr, show_image=show_image, cmap=inner_cmap, crop=False, ax=ax, fig=fig, size=inner_size_prop * pt_scale, **kwargs, ) else: inner_cmap = inner_cmap if type(inner_cmap) != type(None) else "default" cluster_plot( adata, use_label=use_label, cmap=inner_cmap, size=inner_size_prop * pt_scale, crop=False, ax=ax, fig=fig, show_image=show_image, show_plot=False, **kwargs, ) # Adding in labels which show the interactions between signicant spots & # neighbours if show_arrows: l_expr = adata_full[:, l].X.toarray()[:, 0] r_expr = adata_full[:, r].X.toarray()[:, 0] if sig_cci: int_df = adata.uns[f"per_lr_cci_{use_label}"][lr] else: int_df = None cci_hs.add_arrows( adata_full, l_expr, r_expr, min_expr, sig_bool, fig, ax, use_label, int_df, arrow_head_width, arrow_width, arrow_cmap, arrow_vmax, ) # Cropping # # if crop: # x0, x1 = ax.get_xlim() # y0, y1 = ax.get_ylim() # x_margin, y_margin = (x1-x0)*margin_ratio, (y1-y0)*margin_ratio # print(x_margin, y_margin) # print(x0, x1, y0, y1) # ax.set_xlim(x0 - x_margin, x1 + x_margin) # ax.set_ylim(y0 - y_margin, y1 + y_margin) # #ax.set_ylim(ax.get_ylim()[::-1]) fig.suptitle(title)
#### het_plot currently out of date; #### from old data structure when only test individual LRs. @_docs_params(spatial_base_plot=doc_spatial_base_plot, het_plot=doc_het_plot) def het_plot( adata: AnnData, # plotting param title: Optional["str"] = None, figsize: Optional[Tuple[float, float]] = None, cmap: Optional[str] = "Spectral_r", use_label: Optional[str] = None, list_clusters: Optional[list] = None, ax: Optional[matplotlib.axes._subplots.Axes] = None, fig: Optional[matplotlib.figure.Figure] = None, show_plot: Optional[bool] = True, show_axis: Optional[bool] = False, show_image: Optional[bool] = True, show_color_bar: Optional[bool] = True, zoom_coord: Optional[float] = None, crop: Optional[bool] = True, margin: Optional[bool] = 100, size: Optional[float] = 7, image_alpha: Optional[float] = 1.0, cell_alpha: Optional[float] = 1.0, use_raw: Optional[bool] = False, fname: Optional[str] = None, dpi: Optional[int] = 120, # cci_rank param use_het: Optional[str] = "het", contour: bool = False, step_size: Optional[int] = None, vmin: float = None, vmax: float = None, ) -> Optional[AnnData]: """\ Allows the visualization of significant cell-cell interaction as the values of dot points or contour in the Spatial transcriptomics array. Parameters ------------------------------------- {spatial_base_plot} {het_plot} Examples ------------------------------------- >>> import stlearn as st >>> adata = st.datasets.example_bcba() >>> pvalues = "lr_pvalues" >>> st.pl.gene_plot(adata, use_het = pvalues) """ CciPlot( adata, title=title, figsize=figsize, cmap=cmap, use_label=use_label, list_clusters=list_clusters, ax=ax, fig=fig, show_plot=show_plot, show_axis=show_axis, show_image=show_image, show_color_bar=show_color_bar, zoom_coord=zoom_coord, crop=crop, margin=margin, size=size, image_alpha=image_alpha, cell_alpha=cell_alpha, use_raw=use_raw, fname=fname, dpi=dpi, use_het=use_het, contour=contour, step_size=step_size, vmin=vmin, vmax=vmax, ) # Functions relating to visualising celltype-celltype interactions after # calling: st.tl.cci.run_cci
[docs]def ccinet_plot( adata: AnnData, use_label: str, lr: str = None, pos: dict = None, return_pos: bool = False, cmap: str = "default", font_size: int = 12, node_size_exp: int = 1, node_size_scaler: int = 1, min_counts: int = 0, sig_interactions: bool = True, fig: matplotlib.figure.Figure = None, ax: matplotlib.axes.Axes = None, pad=0.25, title: str = None, figsize: tuple = (10, 10), ): """Circular celltype-celltype interaction network based on LR-CCI analysis. The size of the nodes drawn for each cell type indicates the total no. of spot interactions that cell type is involved in; while the color of the arrows between nodes is coloured by the total no. of interactions between those particular cell types. Parameters ---------- adata: AnnData Data on which st.tl.cci.run & st.tl.cci.run_cci has been applied. use_label: str Indicates the cell type labels or deconvolution results used for cell-cell interaction counting by LR pairs. lr: str The LR pair to visualise the cci network for. If None, will use spot cci counts across all LR pairs from adata.uns[f'lr_cci_{use_label}']. pos: dict Positions to draw each cell type, format as outputted from running networkx.circular_layout(graph). If not inputted will be generated. return_pos: bool Whether to return the positions of the cell types drawn or not; useful for input back into this function via the 'pos' parameter to get consistent positioning of the cell types when plotting for different LR pairs. cmap: str Cmap to use when generating the cell colors, if not already specified by adata.uns[f'{use_label}_colors']. font_size: int Size of the cell type labels. node_size_scaler: float Scaler to multiply by node sizes to increase/decrease size. node_size_exp: int Increases difference between node sizes by this exponent. min_counts: int Minimum no. of LR interactions for connection to be drawn. Returns ------- pos: dict Dictionary of positions where the nodes are draw if return_pos is True, useful for consistent layouts. """ cmap, cmap_n = get_cmap(cmap) # Making sure adata in correct state that this function should run # if f"lr_cci_{use_label}" not in adata.uns: raise Exception( "Need to first call st.tl.run_cci with the equivalnt " "use_label to visualise cell-cell interactions." ) elif type(lr) != type(None) and lr not in adata.uns[f"per_lr_cci_{use_label}"]: raise Exception( f"{lr} not found in {f'per_lr_cci_{use_label}'}, " "suggesting no significant interactions." ) # Either plotting overall interactions, or just for a particular LR # int_df, title = get_int_df(adata, lr, use_label, sig_interactions, title) # Creating the interaction graph # all_set = int_df.index.values int_matrix = int_df.values graph = nx.MultiDiGraph() int_bool = int_matrix > min_counts int_matrix = int_matrix * int_bool for i, cell_A in enumerate(all_set): if cell_A not in graph: graph.add_node(cell_A) for j, cell_B in enumerate(all_set): if int_bool[i, j]: count = int_matrix[i, j] graph.add_edge(cell_A, cell_B, weight=count) # Determining graph layout, node sizes, & edge colours # if type(pos) == type(None): pos = nx.circular_layout(graph) # position the nodes using the layout total = sum(sum(int_matrix)) node_names = list(graph.nodes.keys()) node_indices = [np.where(all_set == node_name)[0][0] for node_name in node_names] node_sizes = np.array( [ ( ((sum(int_matrix[i, :] + int_matrix[:, i]) - int_matrix[i, i]) / total) * 10000 * node_size_scaler ) ** (node_size_exp) for i in node_indices ] ) node_sizes[node_sizes == 0] = 0.1 # pseudocount edges = list(graph.edges.items()) e_totals = [] for i, edge in enumerate(edges): trans_i = np.where(all_set == edge[0][0])[0][0] receive_i = np.where(all_set == edge[0][1])[0][0] e_total = ( sum(list(int_matrix[trans_i, :]) + list(int_matrix[:, receive_i])) - int_matrix[trans_i, receive_i] ) # so don't double count e_totals.append(e_total) edge_weights = [edge[1]["weight"] / e_totals[i] for i, edge in enumerate(edges)] # Determining node colors # nodes = np.unique(list(graph.nodes.keys())) node_colors = get_colors(adata, use_label, cmap, label_set=nodes) if not np.all(np.array(node_names) == nodes): nodes_indices = [np.where(nodes == node)[0][0] for node in node_names] node_colors = np.array(node_colors)[nodes_indices] #### Drawing the graph ##### if type(fig) == type(None) or type(ax) == type(None): fig, ax = plt.subplots(figsize=figsize, facecolor=[0.7, 0.7, 0.7, 0.4]) # Adding in the self-loops # z = 55 for i, edge in enumerate(edges): cell_type = edge[0][0] if cell_type != edge[0][1]: continue x, y = pos[cell_type] angle = math.degrees(math.atan(y / x)) if x > 0: angle = angle + 180 arc = patches.Arc( xy=(x, y), width=0.3, height=0.025, lw=5, ec=plt.cm.get_cmap("Blues")(edge_weights[i]), angle=angle, theta1=z, theta2=360 - z, ) ax.add_patch(arc) # Drawing the main components of the graph # edges = nx.draw_networkx( graph, pos, node_size=node_sizes, node_color=node_colors, arrowstyle="->", arrowsize=50, width=5, font_size=font_size, font_weight="bold", edge_color=edge_weights, edge_cmap=plt.cm.Blues, ax=ax, ) fig.suptitle(title, fontsize=30) plt.tight_layout() # Adding padding # xlims = ax.get_xlim() ax.set_xlim(xlims[0] - pad, xlims[1] + pad) ylims = ax.get_ylim() ax.set_ylim(ylims[0] - pad, ylims[1] + pad) if return_pos: return pos
[docs]def cci_map( adata: AnnData, use_label: str, lr: str = None, ax: matplotlib.figure.Axes = None, show: bool = False, figsize: tuple = None, cmap: str = "Spectral_r", sig_interactions: bool = True, title=None, ): """Heatmap visualising sender->receivers of cell type interactions. Parameters ---------- adata: AnnData Data on which st.tl.cci.run & st.tl.cci.run_cci has been applied. use_label: str Indicates the cell type labels or deconvolution results used for cell-cell interaction counting by LR pairs. lr: str The LR pair to visualise the sender->receiver interactions for. If None, will use all pairs via adata.uns[f'lr_cci_{use_label}']. ax: Axes Axes on which to plot the heatmap, if None then generates own. show: bool Whether to show the plot or not; if not, then returns ax. figsize: tuple (width, height), specifies the dimensions of the figure. Only relevant if ax=None. cmap: str Cmap used to color the number of LR interactions. sig_interactions: bool Whether to only show significant CCIs, or all observed interactions. title: None Title to display over the heatmap. If not provided, will be determined based on the run parameters. Returns ------- ax: matplotlib.figure.Axes Axes where the heatmap was drawn if show=False. """ # Either plotting overall interactions, or just for a particular LR # int_df, title = get_int_df(adata, lr, use_label, sig_interactions, title) if type(figsize) == type(None): # Adjust size depending on no. cell types add = np.array([int_df.shape[0] * 0.1, int_df.shape[0] * 0.05]) figsize = tuple(np.array([6.4, 4.8]) + add) # Rank by total interactions # int_vals = int_df.values total_ints = int_vals.sum(axis=1) + int_vals.sum(axis=0) - int_vals.diagonal() order = np.argsort(-total_ints) int_df = int_df.iloc[order, order[::-1]] # Reformat the interaction df # flat_df = create_flat_df(int_df) ax = _box_map( flat_df["x"], flat_df["y"], flat_df["value"].astype(int), ax=ax, figsize=figsize, cmap=cmap, ) ax.set_ylabel("Sender") ax.set_xlabel("Receiver") plt.suptitle(title) if show: plt.show() else: return ax
[docs]def lr_cci_map( adata: AnnData, use_label: str, lrs: list or np.array = None, n_top_lrs: int = 5, n_top_ccis: int = 15, min_total: int = 0, ax: matplotlib.figure.Axes = None, figsize: tuple = (6.48, 4.8), show: bool = False, cmap: str = "Spectral_r", square_scaler: int = 700, sig_interactions: bool = True, ): """Heatmap of interaction counts. Rows are lrs and columns are celltype->celltype interactions. Parameters ---------- adata: AnnData Data on which st.tl.cci.run & st.tl.cci.run_cci has been applied. use_label: str Indicates the cell type labels or deconvolution results used for the cell-cell interaction counting by LR pairs. lrs: list-like LR pairs to show in the heatmap, if None then top 5 lrs with highest no. of interactions used from adata.uns['lr_summary']. n_top_lrs: int Indicates how many top lrs to show; is ignored if lrs is not None. n_top_ccis: int Indicates maximum no. of CCIs to show. min_total: int Minimum no. of totals interaction celltypes must have to be shown. ax: Axes Axes on which to draw the heatmap, is generated internally if None. figsize: tuple (width, height), only relevant if ax=None. show: bool Whether to show the plot or not, if not returns ax. cmap: str Cmap used to color the number of LR interactions. square_scaler: int Scaler to size the squares displayed. sig_interactions: bool Whether to only show significant CCIs, or all observed interactions. Returns ------- ax: matplotlib.figure.Axes Axes where the heatmap was drawn on if show=False. """ if sig_interactions: lr_int_dfs = adata.uns[f"per_lr_cci_{use_label}"] else: lr_int_dfs = adata.uns[f"per_lr_cci_raw_{use_label}"] if type(lrs) == type(None): lrs = np.array(list(lr_int_dfs.keys())) else: lrs = np.array(lrs) n_top_lrs = len(lrs) # Creating a new int_df with lrs as rows & cell-cell as column # cell_types = list(lr_int_dfs.values())[0].index.values.astype(str) n_ints = len(cell_types) ** 2 new_ints = np.zeros((len(lrs), n_ints)) for lr_i, lr in enumerate(lrs): col_i = 0 int_df = lr_int_dfs[lr] ccis = [] for c_i, cell_i in enumerate(cell_types): for c_j, cell_j in enumerate(cell_types): new_ints[lr_i, col_i] = int_df.values[c_i, c_j] ccis.append("->".join([cell_i, cell_j])) col_i += 1 new_int_df = pd.DataFrame(new_ints, index=lrs, columns=ccis) # Filtering out ccis which have few LR interactions # total_ints = new_int_df.values.sum(axis=0) order = np.argsort(-total_ints) new_int_df = new_int_df.iloc[:, order[0:n_top_ccis]] # Getting the top_lrs to display by top loadings in PCA # if n_top_lrs < len(lrs): top_lrs = adata.uns["lr_summary"].index.values[0:n_top_lrs] new_int_df = new_int_df.loc[top_lrs, :] # Ordering by the no. of interactions # cci_ints = new_int_df.values.sum(axis=0) cci_order = np.argsort(-cci_ints) lr_ints = new_int_df.values.sum(axis=1) lr_order = np.argsort(-lr_ints) new_int_df = new_int_df.iloc[lr_order, cci_order] # Getting a flat version of the array for plotting # flat_df = create_flat_df(new_int_df.transpose()) if flat_df.shape[0] == 0 or flat_df.shape[1] == 0: raise Exception(f"No interactions greater than min: {min_total}") ax = _box_map( flat_df["x"], flat_df["y"], flat_df["value"].astype(int), ax=ax, cmap=cmap, figsize=figsize, square_scaler=square_scaler, ) ax.set_ylabel("LR-pair") ax.set_xlabel("Cell-cell interaction") if show: plt.show() else: return ax
[docs]def lr_chord_plot( adata: AnnData, use_label: str, lr: str = None, min_ints: int = 2, n_top_ccis: int = 10, cmap: str = "default", sig_interactions: bool = True, label_size: int = 10, label_rotation: float = 0, title: str = None, figsize: tuple = (8, 8), show: bool = True, ): """Chord diagram of interactions between cell types. Note that interaction is measured as the total no. of edges connecting two cell types expressing the ligand and/or receptor in significant neighbourhoods for given LR pair. The chord diagram is read as follows: Each cell type has a labelled edge taking up a proportion of the outter circle. Chords connecting cell type edges are coloured by the dominant sending cell. Each chord linking cell types has an assymetric shape. For two cell types, A and B, the side of the chord attached to edge A is sized by the total interactions from B->A, where B is expressing the ligand & A is expressing the receptor. Hence, the proportion of a cell type's edge in the chordplot circle represents the total input signals to that cell type; while the area of the chordplot circle taken up by the outputted chords from a given cell type represents the total output signals from that cell type. Parameters ---------- adata: AnnData Data on which st.tl.cci.run & st.tl.cci.run_cci has been applied. use_label: str Indicates the cell type labels or deconvolution results used for cell-cell interaction counting by LR pairs. lr: str The LR pair to visualise the CCIs for. If None, will use all pairs via adata.uns[f'lr_cci_{use_label}']. min_ints: int Minimum no. of interactions celltypes must have to be shown. n_top_ccis: int Maximum no. of CCIs to show, will take the top number of these to display. cmap: str Cmap to use to get colors if colors not already in adata.uns[f'{use_label}_colors'] sig_interactions: bool Whether to show only significant CCIs or all interaction counts. label_size: str The size of the cell type labels to render. label_rotation: float Rotation of the cell type label text. title: str The title above the plot; informative default is determined based on input. figsize: tuple Figure dimensions. show: bool Show or not; if not return figure & axes. Returns ------- fig: matplotlib.figure.Figure Figure on which the heatmap was drawn if show=False. ax: matplotlib.figure.Axes Axes where the heatmap was drawn on if show=False. """ # Either plotting overall interactions, or just for a particular LR # int_df, title = get_int_df(adata, lr, use_label, sig_interactions, title) int_df = int_df.transpose() fig = plt.figure(figsize=figsize) flux = int_df.values total_ints = flux.sum(axis=1) + flux.sum(axis=0) - flux.diagonal() keep = np.where(total_ints > min_ints)[0] # Limit of 10 for good display # if len(keep) > n_top_ccis: keep = np.argsort(-total_ints)[0:n_top_ccis] # Filter any with all zeros after filtering # all_zero = np.array( [np.all(np.logical_and(flux[i, keep] == 0, flux[keep, i] == 0)) for i in keep] ) keep = keep[all_zero == False] if len(keep) == 0: # If we don't keep anything, warn the user print( f"Warning: for {lr} at the current min_ints ({min_ints}), there " f"are no interaction to display. Adjust min_ints to a lower value" f" to visualise chordplot for this LR." ) return flux = flux[:, keep] flux = flux[keep, :].astype(float) # Add pseudocount to row/column which has all zeros for the incoming # so can make the connection between the two for i in range(flux.shape[0]): if np.all(flux[i, :] == 0): flux[i, flux[:, i] > 0] += sys.float_info.min elif np.all(flux[:, i] == 0): flux[flux[i, :] > 0, i] += sys.float_info.min cell_names = int_df.index.values.astype(str)[keep] nodes = cell_names # Retrieving colors of cell types # colors = get_colors(adata, use_label, cmap=cmap, label_set=cell_names) ax = plt.axes([0, 0, 1, 1]) nodePos = chordDiagram(flux, ax, lim=1.25, colors=colors) ax.axis("off") prop = dict(fontsize=label_size, ha="center", va="center") label_rotation_ = label_rotation for i in range(len(cell_names)): x, y = nodePos[i][0:2] rotation = nodePos[i][2] # Prevent text going upside down at certain rotations if (rotation < 90 and rotation > 18 and label_rotation != 0) or ( rotation < 120 and rotation > 90 ): label_rotation_ = -label_rotation else: label_rotation_ = label_rotation ax.text( x, y, nodes[i], rotation=nodePos[i][2] + label_rotation_, **prop ) # size=10, fig.suptitle(title, fontsize=12, fontweight="bold") if show: plt.show() else: return fig, ax
def grid_plot( adata, use_label: str = None, n_row: int = 10, n_col: int = 10, size: int = 1, figsize=(4.5, 4.5), show: bool = False, ): """Plots grid over the top of spatial data to show how cells will be grouped if gridded. Parameters ---------- adata: AnnData Data on which st.tl.cci.run & st.tl.cci.run_cci has been applied. use_label: str Indicates the cell type labels or deconvolution results used for cell-cell interaction counting by LR pairs. n_row: str The number of rows in the grid. n_col: int The number of columns in the grid. Returns ------- fig: matplotlib.figure.Figure Figure on which the heatmap is draw if show=False. ax: matplotlib.figure.Axes Axes where the heatmap was drawn on if show=False. """ xs, ys = adata.obs["imagecol"].values, adata.obs["imagerow"].values grid_counts, xedges, yedges = np.histogram2d(xs, ys, bins=[n_col, n_row]) xmin, xmax = min(xedges), max(xedges) ymin, ymax = min(yedges), max(yedges) fig, ax = plt.subplots(figsize=figsize) # Plotting the points # if type(use_label) != type(None): if f"{use_label}_colors" in adata.uns: color_map = {} for i, ct in enumerate(adata.obs[use_label].cat.categories): color_map[ct] = adata.uns[f"{use_label}_colors"][i] cell_colors = [color_map[ct] for ct in adata.obs[use_label]] else: # Otherwise plot by cell density # stack = np.vstack([xs, ys]) cell_colors = gaussian_kde(stack)(stack) ax.scatter(xs, -ys, s=size, c=cell_colors) ax.vlines(xedges, -ymin, -ymax, color="#36454F") ax.hlines(-yedges, xmin, xmax, color="#36454F") if show: plt.show() else: return fig, ax ####################### Bokeh Interactive Plots ################################
[docs]def lr_plot_interactive(adata: AnnData): """Plots the LR scores for significant spots interatively using Bokeh. Parameters ---------- adata: AnnData Data on which st.tl.cci.run has been applied. """ bokeh_object = BokehLRPlot(adata) output_notebook() show(bokeh_object.app, notebook_handle=True)
[docs]def spatialcci_plot_interactive(adata: AnnData): """Plots the significant CCI in the spatial context interactively using Bokeh. Parameters ---------- adata: AnnData Data on which st.tl.cci.run & st.tl.cci.run_cci has been applied. """ bokeh_object = BokehSpatialCciPlot(adata) output_notebook() show(bokeh_object.app, notebook_handle=True)
# def het_plot_interactive(adata: AnnData): # bokeh_object = BokehCciPlot(adata) # output_notebook() # show(bokeh_object.app, notebook_handle=True) # Bokeh & old grid plots; # has not been tested since multi-LR testing implimentation. # def het_plot_interactive(adata: AnnData): # bokeh_object = BokehCciPlot(adata) # output_notebook() # show(bokeh_object.app, notebook_handle=True) # def grid_plot( # adata: AnnData, # use_het: str = None, # num_row: int = 10, # num_col: int = 10, # vmin: float = None, # vmax: float = None, # cropped: bool = True, # margin: int = 100, # dpi: int = 100, # name: str = None, # output: str = None, # copy: bool = False, # ) -> Optional[AnnData]: # # """ # Cell diversity plot for sptial transcriptomics data. # # Parameters # ---------- # adata: Annotated data matrix. # use_het: Cluster heterogeneity count results from tl.cci_rank.het # num_row: int Number of grids on height # num_col: int Number of grids on width # cropped crop image or not. # margin margin used in cropping. # dpi: Set dpi as the resolution for the plot. # name: Name of the output figure file. # output: Save the figure as file or not. # copy: Return a copy instead of writing to adata. # # Returns # ------- # Nothing # """ # # try: # import seaborn as sns # except: # raise ImportError("Please run `pip install seaborn`") # plt.subplots() # # sns.heatmap( # pd.DataFrame(np.array(adata.obsm[use_het]).reshape(num_col, num_row)).T, # vmin=vmin, # vmax=vmax, # ) # plt.axis("equal") # # if output is not None: # plt.savefig( # output + "/" + name + "_heatmap.pdf", # dpi=dpi, # bbox_inches="tight", # pad_inches=0, # ) # # plt.show()