Source code for stlearn.spatial.sme.pseudo_spot

from pathlib import Path
from typing import Literal

import numpy as np
import pandas as pd
import scipy
from anndata import AnnData
from scipy.sparse import csr_matrix

import stlearn

from ._weighting_matrix import (
    _PLATFORM,
    _WEIGHTING_MATRIX,
    impute_neighbour,
    weight_matrix_imputed,
)

_COPY = Literal["pseudo_spot_adata", "combined_adata"]


[docs] def pseudo_spot( adata: AnnData, tile_path: Path | str = Path("/tmp/tiles"), use_data: str = "raw", crop_size: str | int = "auto", platform: _PLATFORM = "Visium", weights: _WEIGHTING_MATRIX = "weights_matrix_all", copy: _COPY = "pseudo_spot_adata", ) -> AnnData | None: """\ Improve spatial resolution by imputing (creating) new spots from existing ones using spatial, morphological, and expression (SME) information. Parameters ---------- adata Annotated data matrix. use_data Input data, can be `raw` counts, log transformed data or dimension reduced space(`X_pca` and `X_umap`) tile_path Path to save spot image tiles crop_size Size of tiles if `auto`, automatically detect crop size weights Weighting matrix for imputation. if `weights_matrix_all`, matrix combined all information from spatial location (S), tissue morphological feature (M) and gene expression (E) if `weights_matrix_pd_md`, matrix combined information from spatial location (S), tissue morphological feature (M) platform `Visium` or `Old_ST` copy Return Anndata if `pseudo_spot_adata`, imputed Anndata if `combined_adata`, merged Anndata of original data imputed Anndata. Returns ------- Anndata """ import math from sklearn.linear_model import LinearRegression adata = adata.copy() if copy else adata if platform == "Visium": img_row = adata.obs["imagerow"] img_col = adata.obs["imagecol"] array_row = adata.obs["array_row"] array_col = adata.obs["array_col"] obs_df_ = adata.obs[["array_row", "array_col"]].copy() obs_df_.loc[:, "array_row"] = obs_df_["array_row"].apply(lambda x: x - 2 / 3) obs_df = adata.obs[["array_row", "array_col"]].copy() obs_df.loc[:, "array_row"] = obs_df["array_row"].apply(lambda x: x + 2 / 3) obs_df = obs_df.append(obs_df_).reset_index() obs_df.drop_duplicates(subset=["array_row", "array_col"], keep="last") elif platform == "Old_ST": img_row = adata.obs["imagerow"] img_col = adata.obs["imagecol"] array_row = adata.obs_names.map(lambda x: x.split("x")[1]) array_col = adata.obs_names.map(lambda x: x.split("x")[0]) obs_df_left = pd.DataFrame( {"array_row": array_row.to_list(), "array_col": array_col.to_list()}, dtype=np.float64, ) obs_df_left.loc[:, "array_row"] = obs_df_left["array_row"].apply( lambda x: x - 1 / 2 ) obs_df_right = pd.DataFrame( {"array_row": array_row.to_list(), "array_col": array_col.to_list()}, dtype=np.float64, ) obs_df_right.loc[:, "array_row"] = obs_df_right["array_row"].apply( lambda x: x + 1 / 2 ) obs_df_up = pd.DataFrame( {"array_row": array_row.to_list(), "array_col": array_col.to_list()}, dtype=np.float64, ) obs_df_up.loc[:, "array_col"] = obs_df_up["array_col"].apply( lambda x: x - 1 / 2 ) obs_df_down = pd.DataFrame( {"array_row": array_row.to_list(), "array_col": array_col.to_list()}, dtype=np.float64, ) obs_df_down.loc[:, "array_col"] = obs_df_down["array_col"].apply( lambda x: x + 1 / 2 ) obs_df_left_up = pd.DataFrame( {"array_row": array_row.to_list(), "array_col": array_col.to_list()}, dtype=np.float64, ) obs_df_left_up.loc[:, "array_row"] = obs_df_left_up["array_row"].apply( lambda x: x - 1 / 2 ) obs_df_left_up.loc[:, "array_col"] = obs_df_left_up["array_col"].apply( lambda x: x - 1 / 2 ) obs_df_right_up = pd.DataFrame( {"array_row": array_row.to_list(), "array_col": array_col.to_list()}, dtype=np.float64, ) obs_df_right_up.loc[:, "array_row"] = obs_df_right_up["array_row"].apply( lambda x: x + 1 / 2 ) obs_df_right_up.loc[:, "array_col"] = obs_df_right_up["array_col"].apply( lambda x: x - 1 / 2 ) obs_df_left_down = pd.DataFrame( {"array_row": array_row.to_list(), "array_col": array_col.to_list()}, dtype=np.float64, ) obs_df_left_down.loc[:, "array_row"] = obs_df_left_down["array_row"].apply( lambda x: x - 1 / 2 ) obs_df_left_down.loc[:, "array_col"] = obs_df_left_down["array_col"].apply( lambda x: x + 1 / 2 ) obs_df_right_down = pd.DataFrame( {"array_row": array_row.to_list(), "array_col": array_col.to_list()}, dtype=np.float64, ) obs_df_right_down.loc[:, "array_row"] = obs_df_right_down["array_row"].apply( lambda x: x + 1 / 2 ) obs_df_right_down.loc[:, "array_col"] = obs_df_right_down["array_col"].apply( lambda x: x + 1 / 2 ) obs_df = obs_df_left.append( [ obs_df_right, obs_df_up, obs_df_down, obs_df_left_up, obs_df_right_up, obs_df_left_down, obs_df_right_down, ] ).reset_index() obs_df.drop_duplicates(subset=["array_row", "array_col"], keep="last") else: raise ValueError(f"""\ {platform!r} does not support. """) reg_row = LinearRegression().fit(array_row.values.reshape(-1, 1), img_row) reg_col = LinearRegression().fit(array_col.values.reshape(-1, 1), img_col) obs_df.loc[:, "imagerow"] = ( obs_df.loc[:, "array_row"] * reg_row.coef_ + reg_row.intercept_ ) obs_df.loc[:, "imagecol"] = ( obs_df.loc[:, "array_col"] * reg_col.coef_ + reg_col.intercept_ ) impute_coor = obs_df[["imagecol", "imagerow"]] coor = adata.obs[["imagecol", "imagerow"]].append(impute_coor) point_tree = scipy.spatial.cKDTree(coor) n_neighbour = [] unit = math.sqrt(reg_row.coef_**2 + reg_col.coef_**2) for i in range(len(impute_coor)): current_neighbour = point_tree.query_ball_point( impute_coor.values[i], round(unit) ) current_neighbour = [x for x in current_neighbour if x < len(adata)] n_neighbour.append(len(current_neighbour)) obs_df["n_neighbour"] = n_neighbour obs_df = obs_df.loc[obs_df["n_neighbour"] > 1, :].reset_index() obs_df.index = obs_df.index.map(lambda x: "Pseudo_Spot_" + str(x)) impute_df = pd.DataFrame(0, index=obs_df.index, columns=adata.var_names) pseudo_spot_adata = AnnData(impute_df, obs=obs_df) pseudo_spot_adata.uns["spatial"] = adata.uns["spatial"] actual_crop_size: int if crop_size == "auto": actual_crop_size = round(unit / 2) elif isinstance(crop_size, int): actual_crop_size = crop_size else: raise ValueError(f"crop_size must be 'auto' or an integer, got {crop_size}") stlearn.pp.tiling(pseudo_spot_adata, tile_path, crop_size=actual_crop_size) stlearn.pp.extract_feature(pseudo_spot_adata) if use_data == "raw": if isinstance(adata.X, csr_matrix): count_embed = adata.X.toarray() elif isinstance(adata.X, np.ndarray): count_embed = adata.X elif isinstance(adata.X, pd.Dataframe): count_embed = adata.X.values else: print(f"{type(adata.X)} is not a valid type") else: count_embed = adata.obsm[use_data] weight_matrix_imputed(adata, pseudo_spot_adata, platform=platform) impute_neighbour(pseudo_spot_adata, count_embed=count_embed, weights=weights) assert pseudo_spot_adata.shape == pseudo_spot_adata.obsm["imputed_data"].shape pseudo_spot_adata.X = pseudo_spot_adata.obsm["imputed_data"] pseudo_spot_adata = pseudo_spot_adata[np.sum(pseudo_spot_adata.X, axis=1) > 0] print("Done") if copy == "pseudo_spot_adata": return pseudo_spot_adata else: return _merge(adata, pseudo_spot_adata)
def _merge( adata1: AnnData, adata2: AnnData, copy: bool = True, ) -> AnnData | None: merged_df = adata1.to_df().append(adata2.to_df()) merged_df_obs = adata1.obs.append(adata2.obs) merged_adata = AnnData(merged_df, obs=merged_df_obs) merged_adata.uns["spatial"] = adata1.uns["spatial"] return merged_adata if copy else None