Source code for SpaHDmap.data.data_util

import numpy as np
import pandas as pd
import scanpy as sc
import squidpy as sq
from skimage import io
from typing import Tuple, Optional, Union, List
from skimage import filters
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import anndata
import scipy
import math
import cv2
import warnings
import pickle
import os
from .sparkx import sparkx
from .bsp import bsp
from .color_normalize import color_normalize


def _compute_scaled_bbox(spot_coord: np.ndarray,
                         radius: int,
                         shape: Tuple[int, int]) -> Tuple[Tuple[int, int], Tuple[int, int]]:
    """
    Compute a clamped bounding box for scaled spot coordinates.

    Parameters
    ----------
    spot_coord
        Spot coordinates in the same scale as ``shape``.
    radius
        Spot radius in the same scale as ``shape``.
    shape
        Target image shape as ``(height, width)``.
    """
    min_coords = spot_coord.min(0).astype(int)
    max_coords = spot_coord.max(0).astype(int)
    row_range = (max(0, min_coords[0] - radius), min(shape[0], max_coords[0] + radius + 1))
    col_range = (max(0, min_coords[1] - radius), min(shape[1], max_coords[1] + radius + 1))
    return row_range, col_range


def _estimate_background_value(lowres_image: np.ndarray,
                               binary_mask: np.ndarray,
                               outer_mask: np.ndarray) -> np.ndarray:
    """
    Estimate the background color for mask generation from the outer region.

    Falls back to the global image median when the scaled bounding box fully
    consumes the low-resolution canvas.
    """
    background_class = outer_mask & binary_mask
    foreground_class = outer_mask & ~binary_mask

    if not np.any(background_class) and not np.any(foreground_class):
        warnings.warn(
            "No outer-region pixels are available for background estimation after scaling. "
            "Falling back to the global image median."
        )
        return np.median(lowres_image.reshape(-1, lowres_image.shape[-1]), axis=0)

    if not np.any(background_class):
        return np.median(lowres_image[foreground_class], axis=0)

    if not np.any(foreground_class):
        return np.median(lowres_image[background_class], axis=0)

    outer_pixels = cv2.cvtColor(lowres_image, cv2.COLOR_RGB2GRAY)[outer_mask]
    var_background = np.var(outer_pixels[binary_mask[outer_mask]])
    var_foreground = np.var(outer_pixels[~binary_mask[outer_mask]])
    selected = background_class if var_background < var_foreground else foreground_class
    return np.median(lowres_image[selected], axis=0)


def _extract_spatial_coords_from_table(spot_coord: pd.DataFrame) -> pd.DataFrame:
    """
    Extract spatial coordinates from the last two columns of a spot table.

    This keeps the loader format-agnostic. If an input table stores pixel
    coordinates in ``row, col`` order, callers should pass ``swap_coord=False``
    when preparing the section.
    """
    coords = spot_coord.iloc[:, -2:].copy()
    coords.columns = ['x_coord', 'y_coord']
    return coords


[docs] class STData: """ A class for handling and managing spatial transcriptomics data. Parameters ---------- adata AnnData object containing the spatial transcriptomics data. section_name Name of the tissue section. radius Radius of spots in original scale. scale_rate Scale rate for adjusting coordinates and image size. select_hvgs Whether to select highly variable genes (HVGs) from the data. gene_list List of genes to arrange the data by. If provided, select_hvgs will be set to False. Missing genes will be added with zero expression. swap_coord Whether to swap the x and y coordinates. create_mask Whether to create a mask for the image. image_type Type of the image ('HE' or 'Immunofluorescence'). If None, will be auto-detected. color_norm Whether to apply Reinhard color normalization. Only works for H&E images. """ def __init__(self, adata: anndata.AnnData, section_name: str, radius: float, scale_rate: float = 1., select_hvgs: bool = True, gene_list: Optional[List[str]] = None, swap_coord: bool = True, create_mask: bool = True, image_type: Optional[str] = None, color_norm: bool = False): # Process the AnnData object self.adata = preprocess_adata(adata, select_hvgs, swap_coord, gene_list) # Extract image and spot coordinates image = adata.uns['spatial'][list(adata.uns['spatial'].keys())[0]]['images']['orires'].copy() spot_coord = adata.obsm['spatial'] del self.adata.uns['spatial'][list(adata.uns['spatial'].keys())[0]]['images']['orires'] # Initialize the STData object self.section_name = section_name self.scale_rate = scale_rate self.original_radius = radius self.radius = round(radius / scale_rate) self.kernel_size = self.radius // 2 * 2 + 1 # Initialize the scores and save_paths dictionaries self.save_paths = None self.scores = {'NMF': None, 'GCN': None, 'VD': None, 'SpaHDmap': None, 'SpaHDmap_spot': None} self.clusters = {'NMF': None, 'SpaHDmap': None} self.tissue_coord = None self.X = {} # Preprocess the image and spot expression data self._preprocess(spot_coord, image, create_mask, image_type, color_norm) @property def spot_exp(self): return self.adata.X.toarray() if isinstance(self.adata.X, scipy.sparse.spmatrix) else self.adata.X @property def genes(self): return self.adata.var_names.tolist()
[docs] @staticmethod def load(path: str) -> 'STData': """ Load STData object from file. Parameters ---------- path Path to load the STData object from. Should end with '.st' Returns ------- STData Loaded STData object """ if not path.endswith('.st'): path = path + '.st' with open(path, 'rb') as f: st_data = pickle.load(f) # Restore memory efficient sparse matrix format if scipy.sparse.issparse(st_data.adata.X): st_data.adata.X = st_data.adata.X.tocsc() return st_data
[docs] def save(self, path: str): """ Save STData object to file. Parameters ---------- path Path to save the STData object. Should end with '.st' """ if not path.endswith('.st'): path = path + '.st' # Convert sparse matrix to csr format for better pickling if scipy.sparse.issparse(self.adata.X): self.adata.X = self.adata.X.tocsr() with open(path, 'wb') as f: pickle.dump(self, f)
[docs] def show(self, scale: float = 4.): """ Visualizes the spots and the tissue mask on the image. Parameters ---------- scale The scale factor for visualization. """ # Prepare image for plotting img_display = np.transpose(self.image, (1, 2, 0)) # Crop the image to the region of interest img_cropped = img_display[self.row_range[0]:self.row_range[1], self.col_range[0]:self.col_range[1]] # Scale down the image new_shape = (int(img_cropped.shape[1] / scale), int(img_cropped.shape[0] / scale)) img_scaled = cv2.resize(img_cropped, new_shape, interpolation=cv2.INTER_AREA) # Create a figure with two subplots fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) fig.suptitle(f'Section: {self.section_name}') # --- First subplot: Spots on Image --- ax1.imshow(img_scaled) ax1.set_title('Spots on Image') ax1.axis('off') # Draw circles for each spot for spot in self.spot_coord: row, col = spot # Adjust coordinates based on the crop range and scale adj_row = (row - self.row_range[0]) / scale adj_col = (col - self.col_range[0]) / scale # matplotlib uses (x, y) which corresponds to (col, row) circle = patches.Circle((adj_col, adj_row), radius=self.radius / scale, color='r', alpha=0.3) ax1.add_patch(circle) # --- Second subplot: Mask on Image --- ax2.imshow(img_scaled) ax2.set_title('Mask on Image') ax2.axis('off') # Scale down the mask and create overlay mask_scaled = cv2.resize(self.mask.astype(np.uint8), new_shape, interpolation=cv2.INTER_NEAREST).astype(bool) mask_overlay = np.zeros((mask_scaled.shape[0], mask_scaled.shape[1], 4), dtype=float) mask_overlay[mask_scaled] = [0, 1, 0, 0.4] # Green, 40% transparent where mask is True ax2.imshow(mask_overlay) plt.tight_layout(rect=[0, 0, 1, 0.96]) plt.show()
def _preprocess(self, spot_coord: np.ndarray, image: np.ndarray, create_mask: bool, image_type: str, color_norm: bool): """ Preprocess spot_coord and prepare the feasible domain and process the image. Parameters ---------- spot_coord Array of original spot coordinates. image Original image data. create_mask Whether to create a mask for the image. image_type Type of the image ('HE' or 'Immunofluorescence'). color_norm Whether to apply Reinhard color normalization. """ # Process the spot coordinates self.spot_coord = spot_coord / self.scale_rate - 1 self.num_spots = self.spot_coord.shape[0] # Process the image image = (image / np.max(image, axis=(0, 1), keepdims=True)).astype(np.float32) hires_shape = (math.ceil(image.shape[0] / self.scale_rate), math.ceil(image.shape[1] / self.scale_rate)) bg_lowres_shape = (math.ceil(image.shape[0] / 16), math.ceil(image.shape[1] / 16)) tmp_row_range, tmp_col_range = _compute_scaled_bbox(self.spot_coord, self.radius, hires_shape) original_radius = getattr(self, 'original_radius', max(1, round(self.radius * self.scale_rate))) original_spot_coord = spot_coord - 1 bg_row_range, bg_col_range = _compute_scaled_bbox(original_spot_coord, original_radius, image.shape[:2]) hires_image = cv2.resize(image, (hires_shape[1], hires_shape[0]), interpolation=cv2.INTER_AREA).astype(np.float32) if self.scale_rate != 1 else image bg_lowres_image = cv2.resize(image, (bg_lowres_shape[1], bg_lowres_shape[0]), interpolation=cv2.INTER_AREA).astype(np.float32) self.image_type = _classify_image_type(bg_lowres_image) if image_type is None else image_type print(f"Processing image, seems to be {self.image_type} image.") self.image = np.transpose(hires_image, (2, 0, 1)) if create_mask: # Create masks for outer regions gray = cv2.cvtColor(bg_lowres_image, cv2.COLOR_RGB2GRAY) ## Apply Otsu's thresholding thresh = filters.threshold_otsu(gray) binary_mask = gray > thresh outer_mask = np.ones(bg_lowres_shape, dtype=np.bool_) outer_mask[bg_row_range[0]//16:bg_row_range[1]//16, bg_col_range[0]//16:bg_col_range[1]//16] = 0 ## Determine background value from the remaining outer region background_value = _estimate_background_value(bg_lowres_image, binary_mask, outer_mask) # Create mask of image mask, tmp_mask = np.zeros(hires_shape, dtype=np.bool_), np.zeros(hires_shape, dtype=np.bool_) mask[np.where(np.mean(np.abs(hires_image - background_value[None, None, :]), axis=2) > 0.075)] = 1 ## Overlap mask with spot coordinates tmp_mask[tmp_row_range[0]:tmp_row_range[1], tmp_col_range[0]:tmp_col_range[1]] = 1 mask = np.logical_and(mask, tmp_mask) ## Close and open the mask if self.image_type == 'Immunofluorescence': mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((self.radius*4, self.radius*4), np.uint8)).astype(np.bool_) else: mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((self.radius//4, self.radius//4), np.uint8)).astype(np.bool_) mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((self.radius//2, self.radius//2), np.uint8)).astype(np.bool_) ## Get row and column ranges and final mask mask_idx = np.where(mask == 1) if mask_idx[0].size == 0 or mask_idx[1].size == 0: warnings.warn( "Auto-generated tissue mask is empty after scaling. " "Falling back to the clamped spot bounding box; please verify image_path and scale_rate." ) mask = tmp_mask.copy() mask_idx = np.where(mask == 1) self.row_range = (np.min(mask_idx[0]), np.max(mask_idx[0])) self.col_range = (np.min(mask_idx[1]), np.max(mask_idx[1])) self.mask = mask[self.row_range[0]:self.row_range[1], self.col_range[0]:self.col_range[1]] else: self.mask = mask = np.ones(hires_shape, dtype=np.bool_) self.row_range = (0, hires_shape[0]) self.col_range = (0, hires_shape[1]) # Apply color normalization after mask creation if color_norm and self.image_type == 'HE': print("Applying Reinhard color normalization...") # Convert image back to (H, W, C) format for color normalization rgb_image = np.transpose(self.image, (1, 2, 0)) * 255.0 rgb_image = rgb_image.astype(np.uint8) # Apply Reinhard normalization using the mask cnorm_image = color_normalize(rgb_image, mask) # Convert back to (C, H, W) format self.image = np.transpose(cnorm_image, (2, 0, 1)) elif color_norm and self.image_type != 'HE': print(f"Color normalization is only supported for H&E images, skipping for {self.image_type} image.") # Create feasible domain self.feasible_domain = mask.copy() for (row, col) in self.spot_coord: row, col = round(row), round(col) row_range = np.arange(max(row - self.radius, 0), min(row + self.radius + 1, hires_shape[0])) col_range = np.arange(max(col - self.radius, 0), min(col + self.radius + 1, hires_shape[1])) self.feasible_domain[np.ix_(row_range, col_range)] = 0 def __repr__(self): """ Return a string representation of the STData object. """ return (f"STData object for section: {self.section_name}\n" f"Number of spots: {self.num_spots}\n" f"Number of genes: {len(self.genes)}\n" f"Image shape: {self.image.shape}\n" f"Scale rate: {self.scale_rate}\n" f"Spot radius: {self.radius}\n" f"Image type: {self.image_type}\n" f"Available scores: {', '.join(score for score, value in self.scores.items() if value is not None)}") def __str__(self): """ Return a string with a summary of the STData object. """ return (f"STData object for section: {self.section_name}\n" f"Number of spots: {self.num_spots}\n" f"Number of genes: {len(self.genes)}\n" f"Image shape: {self.image.shape}\n" f"Scale rate: {self.scale_rate}\n" f"Spot radius: {self.radius}\n" f"Image type: {self.image_type}\n" f"Available scores: {', '.join(score for score, value in self.scores.items() if value is not None)}") def __getstate__(self): """Custom pickling to handle unpicklable objects.""" state = self.__dict__.copy() # Remove unpicklable objects if hasattr(self, 'save_paths'): state['save_paths'] = None return state def __setstate__(self, state): """Custom unpickling to restore object state.""" self.__dict__.update(state) # Initialize empty save paths if needed if not self.save_paths: self.save_paths = {'NMF': None, 'GCN': None, 'VD': None, 'SpaHDmap': None, 'SpaHDmap_spot': None}
def _classify_image_type(image): """ Classify an image as either H&E stained or high dynamic range Immunofluorescence. Parameters ---------- image The input image. Can be high bit depth. Returns ------- str 'HE' for H&E stained images, 'Immunofluorescence' for Immunofluorescence images. """ # Calculate histogram hist, bin_edges = np.histogram(image.flatten(), bins=1000, range=(0, 1)) # Calculate metrics low_intensity_ratio = np.sum(hist[:100]) / np.sum(hist) high_intensity_ratio = np.sum(hist[-100:]) / np.sum(hist) # Check for characteristics of Immunofluorescence images if low_intensity_ratio > 0.5 and high_intensity_ratio < 0.05: return 'Immunofluorescence' return 'HE' def read_10x_data(data_path: str) -> anndata.AnnData: """ Read 10x Visium spatial transcriptomics data. Parameters ---------- data_path Path to the 10x Visium data directory. Returns ------- anndata.AnnData AnnData object containing the spatial transcriptomics data. """ adata = sc.read_visium(data_path) return adata def read_from_image_and_coord(image_path: str, coord_path: str, exp_path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Read data from separate image, coordinate, and expression files. Parameters ---------- image_path Path to the H&E image file. coord_path Path to the spot coordinates file. exp_path Path to the gene expression file. Returns ------- Tuple Tuple containing: - image (np.ndarray): H&E image data. - spot_coord (np.ndarray): Spot coordinates. - spot_exp (np.ndarray): Gene expression data. """ # Read image image = io.imread(image_path) # Read spot coordinates spot_coord = pd.read_csv(coord_path, index_col=0).values # Read gene expression data spot_exp = pd.read_csv(exp_path, index_col=0).values return image, spot_coord, spot_exp def preprocess_adata(adata: anndata.AnnData, select_hvgs: bool = True, swap_coord: bool = True, gene_list: Optional[List[str]] = None) -> anndata.AnnData: """ Preprocess the spatial transcriptomics data, including normalization and SVG selection using squidpy. Parameters ---------- adata AnnData object containing the spatial transcriptomics data. select_hvgs Whether to select highly variable genes (HVGs). swap_coord Whether to swap the x and y coordinates. gene_list List of genes to arrange the data by. If provided, select_hvgs will be set to False. Missing genes will be added with zero expression. Returns ------- anndata.AnnData Preprocessed AnnData object. """ print(f"Pre-processing gene expression data for {adata.shape[0]} spots and {adata.shape[1]} genes.") # Swap x and y coordinates if swap_coord: adata.obsm['spatial'] = adata.obsm['spatial'][:, ::-1] print("Swapping x and y coordinates.") else: warnings.warn("Coordinates are not swapped. Make sure the coordinates are in the correct order.") # Normalize data if adata.X.max() < 20: warnings.warn("Data seems to be already normalized, skipping pre-processing.") else: adata.var_names_make_unique() sc.pp.filter_cells(adata, min_genes=3) sc.pp.filter_genes(adata, min_cells=3) sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) # Handle gene_list arrangement before select_hvgs if gene_list is not None: print(f"Arranging genes according to provided gene list with {len(gene_list)} genes.") select_hvgs = False # Override select_hvgs when gene_list is provided # Save spatial information and other uns data before processing saved_uns = adata.uns.copy() saved_obsm = adata.obsm.copy() saved_obs = adata.obs.copy() # Get current genes current_genes = adata.var_names.tolist() missing_genes = [gene for gene in gene_list if gene not in current_genes] if missing_genes: print(f"Adding {len(missing_genes)} missing genes with zero expression.") # Create a new AnnData object with missing genes (all zeros) # Match the data format (sparse or dense) of the original adata n_obs = adata.n_obs if scipy.sparse.issparse(adata.X): # If original data is sparse, create sparse matrix missing_data = scipy.sparse.csr_matrix((n_obs, len(missing_genes))) else: # If original data is dense, create dense matrix missing_data = np.zeros((n_obs, len(missing_genes)), dtype=adata.X.dtype) missing_adata = anndata.AnnData(X=missing_data) missing_adata.var_names = missing_genes missing_adata.obs_names = adata.obs_names # Concatenate the original data with the missing genes adata = anndata.concat([adata, missing_adata], axis=1, merge='same') # Reorder genes according to gene_list adata = adata[:, gene_list] # Restore the saved information adata.uns = saved_uns adata.obsm = saved_obsm adata.obs = saved_obs print(f"Data rearranged to match gene list order with {adata.shape[1]} genes.") if select_hvgs: sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=10000, subset=True) return adata
[docs] def prepare_stdata(section_name: str = None, st_path: str = None, image_path: str = None, adata: sc.AnnData = None, select_hvgs: bool = True, scale_rate: float = 1, radius: float = None, swap_coord: bool = True, create_mask: bool = True, image_type: str = None, color_norm: bool = False, gene_list: List[str] = None, **kwargs): """ Prepare an STData object from various data sources, with a specific loading priority. This function orchestrates the loading and preprocessing of spatial transcriptomics data to create a unified STData object. It can handle several input formats, including a pre-saved STData object, an AnnData object, 10x Visium data directories, or separate files for expression, coordinates, and imaging. The function follows a specific priority for loading the gene expression data: - **st_path**: If provided, it will first attempt to load a serialized STData object. - **adata**: If `st_path` is not given or fails, it will use a provided AnnData object. - **visium_path**: If `adata` is not provided, it will look for a 10x Visium data directory. - **spot_coord_path & spot_exp_path**: If none of the above are available, it will load the data from separate coordinate and expression files. Internal processing steps include: - **Data Reading**: Loads data based on the priority scheme. - **Gene Expression Processing**: Normalizes and log-transforms the expression data. Optionally, it selects spatially variable genes (SVGs). - **Image Processing**: Reads the high-resolution image, creates a tissue mask, and can apply color normalization for H&E images. - **Coordinate Handling**: Adjusts spot coordinates based on the scale rate and can swap row/column coordinates if needed, usually it has to be performed for the 10X Visium data. Parameters ---------- section_name The name for the tissue section. This is a required parameter. st_path Path to a saved `.st` file to load a pre-existing STData object. image_path Path to the high-resolution tissue image file. Required unless loading from `st_path`. adata An AnnData object containing expression data and spatial coordinates. select_hvgs Whether to select highly variable genes (HVGs). scale_rate The factor by which to scale the input image and coordinates. This is always interpreted relative to ``image_path``. For example, if you want a target resolution of ``0.5 um/px``, ``image_path`` should point to the original full-resolution image and ``scale_rate`` should be computed against that image's native microns-per-pixel value. radius The radius of the spots in the original, unscaled image. This is required when loading data from `spot_coord_path` and `spot_exp_path`. swap_coord Whether to swap the row and column coordinates. create_mask Whether to create a binary mask of the tissue from the image. image_type The type of imaging data, either 'HE' or 'Immunofluorescence'. If None, it will be auto-detected. color_norm Whether to apply Reinhard color normalization. This is only applicable to H&E images. gene_list A specific list of genes to use. If provided, `select_hvgs` is ignored. **kwargs : Additional keyword arguments for different loading schemes. - `visium_path` (str): Path to a 10x Visium data directory. - `spot_coord_path` (str): Path to the spot coordinates file (e.g., `.csv`). - `spot_exp_path` (str): Path to the gene expression file (e.g., `.h5`). Returns ------- STData A fully prepared STData object ready for analysis. """ # Try loading from st_path first if provided if st_path is not None: print(f"*** Loading saved STData from {st_path} ***") try: st_data = STData.load(st_path) if section_name and section_name != st_data.section_name: st_data.section_name = section_name print(f"Updated section name to {section_name}") return st_data except Exception as e: print(f"Failed to load .st file: {e}") print("Falling back to other data sources...") # Check if image_path is provided when needed if image_path is None: raise ValueError("image_path is required when st_path is not provided or loading fails") # Check for AnnData if adata is not None: print(f"*** Reading and preparing AnnData for section {section_name} ***") # Check for 10 Visium data if AnnData is not available elif 'visium_path' in kwargs and kwargs['visium_path'] is not None: print(f"*** Reading and preparing Visium data for section {section_name} ***") count_file = [f for f in os.listdir(kwargs['visium_path']) if f.endswith('.h5')] if not count_file: raise ValueError("No count file found in the Visium directory.") else: count_file = count_file[0] adata = sc.read_visium(kwargs['visium_path'], count_file=count_file) # Read from scratch if neither AnnData nor 10 Visium is available else: print(f"*** Reading and preparing data from scratch for section {section_name} ***") spot_coord_path = kwargs.get('spot_coord_path') spot_exp_path = kwargs.get('spot_exp_path') if not all([spot_coord_path, spot_exp_path]): raise ValueError("Missing required paths for reading from scratch.") if '.h5' in spot_exp_path: try: adata = sc.read_h5ad(spot_exp_path) except TypeError: try: adata = sc.read_10x_h5(spot_exp_path) except Exception as e: raise ValueError(f"Unsupported file format for spot_exp_path: {e}") elif '.csv' in spot_exp_path: adata = sc.read_csv(spot_exp_path) else: try: adata = sc.read(spot_exp_path) except Exception as e: raise ValueError(f"Unsupported file format for spot_exp_path: {e}") if spot_coord_path.endswith('.csv'): spot_coord = pd.read_csv(spot_coord_path, index_col=0) elif spot_coord_path.endswith('.parquet'): import pyarrow.parquet as pq spot_coord = pq.read_table(spot_coord_path).to_pandas() spot_coord.set_index(spot_coord.columns[0], inplace=True) spot_coord.drop(columns=spot_coord.columns[0], inplace=True) else: raise ValueError("Unsupported file format for spot_coord_path. We suggest transforming the spot coordinates into a .csv file with spot names as index and x/y coordinates as the last two columns.") spot_coord = _extract_spatial_coords_from_table(spot_coord) # get the common index between adata and spot_coord common_index = adata.obs_names.intersection(spot_coord.index) adata = adata[common_index, :] spot_coord = spot_coord.loc[common_index] # Add spot coordinates and image to adata adata.obsm['spatial'] = spot_coord.loc[adata.obs_names].values image = io.imread(image_path) if 'spatial' not in adata.uns: adata.uns['spatial'] = {section_name: {'images': {'orires': image}}} else: section_id = list(adata.uns['spatial'].keys())[0] try: radius = round(adata.uns['spatial'][section_id]['scalefactors']['spot_diameter_fullres'] / 2) print(f"Spot radius found in AnnData: {radius}") except KeyError: if radius is not None: warnings.warn("Radius is specified but not found in AnnData. Using the specified radius instead.") adata.uns['spatial'][section_id]['images']['orires'] = image if radius is None: warnings.warn("Radius is not found, using default radius of 65.") radius = 65 # Create STData object st_data = STData(adata, select_hvgs=select_hvgs, section_name=section_name, scale_rate=scale_rate, radius=radius, swap_coord=swap_coord, create_mask=create_mask, image_type=image_type, color_norm=color_norm, gene_list=gene_list) return st_data
[docs] def select_svgs(section: Union[STData, List[STData]], n_top_genes: int = 3000, method: str = 'moran'): """ Select the top SVGs based on Moran's I or SPARK-X or BSP for a given section or list of sections. Update each section's AnnData object to only include the selected SVGs. Parameters ---------- section STData object or list of STData objects. n_top_genes Number of top SVGs to select. method Method to use for selecting SVGs. Either 'moran', 'sparkx' or 'bsp'. """ sections = section if isinstance(section, list) else [section] # Find the overlap of genes across all sections overlap_genes = set(sections[0].genes) for section in sections: overlap_genes = overlap_genes.intersection(section.genes) # Compute spatial neighbors sq.gr.spatial_neighbors(section.adata) overlap_genes = list(overlap_genes) if len(sections) > 1: print(f"Find {len(overlap_genes)} overlapping genes across {len(sections)} sections.") # If the number of overlapping genes is less than or equal to n_top_genes, select all of them if len(overlap_genes) <= n_top_genes: warnings.warn( "Number of genes is less than the specified number of top genes, using all genes.") selected_genes = overlap_genes else: if method == 'moran': # Compute Moran's I for overlapping genes across all sections moran_i_values = [] for section in sections: sq.gr.spatial_autocorr(section.adata, mode="moran", genes=overlap_genes) moran_i_values.append(section.adata.uns['moranI']['I']) # Combine Moran's Index results and select top n_top_genes combined_moran_i = pd.concat(moran_i_values, axis=1, keys=[s.section_name for s in sections]) combined_moran_i['mean_rank'] = combined_moran_i.mean(axis=1).rank(method='dense', ascending=False) selected_genes = combined_moran_i.sort_values('mean_rank').head(n_top_genes).index.tolist() elif method == 'sparkx': # Compute SPARK-X p-values for overlapping genes across all sections sparkx_pvals = [] for section in sections: counts = section.adata[:, overlap_genes].X location = section.adata.obsm['spatial'] pvals = sparkx(counts, location) sparkx_pvals.append(pd.Series(pvals, index=overlap_genes)) # Combine SPARK-X p-values and select top n_top_genes combined_sparkx = pd.concat(sparkx_pvals, axis=1, keys=[s.section_name for s in sections]) combined_sparkx['mean_rank'] = combined_sparkx.mean(axis=1).rank(method='dense', ascending=True) selected_genes = combined_sparkx.sort_values('mean_rank').head(n_top_genes).index.tolist() elif method == 'bsp': # Compute BSP p-values for overlapping genes across all sections bsp_pvals = [] for section in sections: counts = section.adata[:, overlap_genes].X location = section.adata.obsm['spatial'] pvals = bsp(location, counts) bsp_pvals.append(pd.Series(pvals, index=overlap_genes)) # Combine BSP p-values and select top n_top_genes (lower p-values are better) combined_bsp = pd.concat(bsp_pvals, axis=1, keys=[s.section_name for s in sections]) combined_bsp['mean_rank'] = combined_bsp.mean(axis=1).rank(method='dense', ascending=True) selected_genes = combined_bsp.sort_values('mean_rank').head(n_top_genes).index.tolist() else: raise ValueError("Invalid method. Choose either 'moran', 'sparkx' or 'bsp'.") # Update each section's AnnData object with the selected SVGs for section in sections: section.adata = section.adata[:, selected_genes] print(f"Selected {len(selected_genes)} SVGs using {method} method.")