Source code for stlearn.plotting.mask_plot

import matplotlib
from matplotlib import pyplot as plt

from typing import Optional, Union
from anndata import AnnData


[docs]def plot_mask( adata: AnnData, library_id: str = None, show_spot: bool = True, spot_alpha: float = 1.0, cmap: str = "vega_20_scanpy", tissue_alpha: float = 1.0, mask_alpha: float = 0.5, spot_size: Union[float, int] = 6.5, show_legend: bool = True, name: str = "mask_plot", dpi: int = 150, output: str = None, show_axis: bool = False, show_plot: bool = True, ) -> Optional[AnnData]: """\ mask plot for sptial transcriptomics data. Parameters ---------- adata Annotated data matrix. library_id Library id stored in AnnData. show_spot Show spot or not spot_alpha Opacity of the spot. cmap Color map to use. tissue_alpha Opacity of the tissue. mask_alpha Opacity of the mask. spot_size Size of the spot. show_axis Show axis or not. show_legend Show legend or not. name Name of the output figure file. dpi DPI of the output figure. output Save the figure as file or not. show_plot Show plot or not Returns ------- Nothing """ from scanpy.plotting import palettes from stlearn.plotting import palettes_st if cmap == "vega_10_scanpy": cmap = palettes.vega_10_scanpy elif cmap == "vega_20_scanpy": cmap = palettes.vega_20_scanpy elif cmap == "default_102": cmap = palettes.default_102 elif cmap == "default_28": cmap = palettes.default_28 elif cmap == "jana_40": cmap = palettes_st.jana_40 elif cmap == "default": cmap = palettes_st.default else: raise ValueError( "We only support vega_10_scanpy, vega_20_scanpy, default_28, default_102" ) cmaps = matplotlib.colors.LinearSegmentedColormap.from_list("", cmap) cmap_ = plt.cm.get_cmap(cmaps) plt.rcParams["figure.dpi"] = dpi # Option for turning off showing figure plt.ioff() # Initialize matplotlib fig, a = plt.subplots() x_min, x_max = adata.obs["imagecol"].min(), adata.obs["imagecol"].max() y_min, y_max = adata.obs["imagerow"].min(), adata.obs["imagerow"].max() add_x = (x_max - x_min) // 20 add_y = (y_max - y_min) // 20 plt.xlim([x_min - add_x, x_max + add_x]) plt.ylim([y_max + add_y, y_min - add_y]) key = "mask_annotation" if show_spot: for i, cluster in enumerate(adata.obs.groupby(key)): if cluster[1][key + "_code"][0] == -1: # Plot scatter plot based on pixel of spots _ = a.scatter( cluster[1]["imagecol"], cluster[1]["imagerow"], edgecolor="none", alpha=spot_alpha, s=spot_size, marker="o", c=[(1.0, 1.0, 1.0, 1.0)], ) else: # Plot scatter plot based on pixel of spots _ = a.scatter( cluster[1]["imagecol"], cluster[1]["imagerow"], edgecolor="none", alpha=spot_alpha, s=spot_size, marker="o", c=[cmap_(int(cluster[1][key + "_code"][0]) / (len(cmap) - 1))], ) if show_legend: from matplotlib.patches import Patch legend_elements = [] for index, row in adata.obs.groupby(key).first().reset_index().iterrows(): if row[key + "_code"] == -1: legend_elements.append( Patch(color=(1.0, 1.0, 1.0, 1.0), label=row[key]) ) else: legend_elements.append( Patch( color=cmap_(int(row[key + "_code"]) / (len(cmap) - 1)), label=row[key], ) ) a.legend( handles=legend_elements, loc="center left", bbox_to_anchor=(1, 0.5), prop=dict(size=8), ) if library_id is None: library_id = list(adata.uns["spatial"].keys())[0] image = adata.uns["spatial"][library_id]["images"][ adata.uns["spatial"][library_id]["use_quality"] ] if not show_axis: a.axis("off") # Overlay the tissue image a.imshow( image, alpha=tissue_alpha, zorder=-1, ) a.imshow( adata.uns[key], alpha=mask_alpha, zorder=-1, ) if output is not None: fig.savefig(output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0) if show_plot == True: plt.show()