"""
Reference paper:
HINet: Half Instance Normalization Network for Image Restoration (https://arxiv.org/abs/2105.06086)
Code:
https://github.com/megvii-model/HINet/blob/main/basicsr/models/archs/hinet_arch.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
def conv_down(in_chn, out_chn, bias=False):
"""
Creates a convolutional layer with a kernel size of 4, stride of 2, and padding of 1.
Parameters
----------
in_chn
Number of input channels.
out_chn
Number of output channels.
bias
Whether to include a bias term. Default to False.
Returns
-------
nn.Conv2d
Convolutional layer.
"""
layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
return layer
class UNetConvBlock(nn.Module):
"""
A convolutional block used in the U-Net architecture.
Parameters
----------
in_size
Number of input channels.
out_size
Number of output channels.
downsample
Whether to include a downsampling layer.
relu_slope
Slope for the LeakyReLU activation.
use_HIN
Whether to use Half Instance Normalization. Default to False.
"""
def __init__(self, in_size, out_size, downsample, relu_slope, use_HIN=False):
"""
Initialize the `UNetConvBlock` model.
"""
super(UNetConvBlock, self).__init__()
self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0)
self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True)
self.relu_1 = nn.LeakyReLU(relu_slope, inplace=False)
self.conv_2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True)
self.relu_2 = nn.LeakyReLU(relu_slope, inplace=False)
if use_HIN: self.norm = nn.InstanceNorm2d(out_size // 2, affine=True)
self.use_HIN = use_HIN
self.downsample = downsample
self.downsample_layer = conv_down(out_size, out_size, bias=False) if downsample else None
def forward(self, x, enc=None, dec=None):
"""
Forward pass for the UNetConvBlock.
Parameters
----------
x
Input tensor.
enc
Tensor from the corresponding downsampling block.
dec
Tensor from the corresponding upsampling block.
Returns
-------
torch.Tensor
Output tensor after applying the convolutional block.
"""
out = self.conv_1(x)
if self.use_HIN:
out_1, out_2 = torch.chunk(out, 2, dim=1)
out = torch.cat([self.norm(out_1), out_2], dim=1)
out = self.relu_1(out)
out = self.relu_2(self.conv_2(out))
out += self.identity(x)
if enc is not None and dec is not None:
out = out + self.csff_enc(enc) + self.csff_dec(dec)
if self.downsample:
out_down = self.downsample_layer(out)
return out_down, out
else:
return out
class UNetUpBlock(nn.Module):
"""
An upsampling block used in the U-Net architecture.
Parameters
----------
in_size
Number of input channels.
out_size
Number of output channels.
relu_slope
Slope for the LeakyReLU activation.
"""
def __init__(self, in_size, out_size, relu_slope):
"""
Initialize the `UNetUpBlock` model.
"""
super(UNetUpBlock, self).__init__()
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True)
self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope)
def forward(self, x, bridge):
"""
Forward pass for the UNetUpBlock.
Parameters
----------
x
Input tensor.
bridge
Tensor from the corresponding downsampling block.
Returns
-------
torch.Tensor
Output tensor after applying the upsampling block.
"""
up = self.up(x)
out = torch.cat([up, bridge], 1)
out = self.conv_block(out)
return out
[docs]
class SpaHDmapUnet(nn.Module):
"""
A deep learning architecture for image and spot expression prediction.
It integrates Non-negative Matrix Factorization (NMF) and low-rank representation, enabling efficient
prediction and high-definition pixel-wise embedding output.
Parameters
----------
rank
The rank of the low-rank representation. Defaults to 20.
num_genes
The number of genes in the dataset. Defaults to 2000.
num_channels
The number of channels in the input image. Defaults to 3.
reference
Dictionary of query and reference pairs, e.g., {'query1': 'reference1', 'query2': 'reference2'}. Only used for multi-section analysis. Defaults to None.
Example
-------
>>> model = SpaHDmapUnet(rank=20, num_genes=1000, num_channels=3)
>>> image = torch.rand(1, 3, 256, 256)
>>> feasible_coord = {}
>>> vd_score = torch.rand(1)
>>> model(image, feasible_coord, vd_score)
"""
def __init__(self,
rank: int = 20,
num_genes: int = 2000,
num_channels: int = 3,
reference: dict = None):
"""
Initialize the `SpaHDmapUnet` model.
"""
super(SpaHDmapUnet, self).__init__()
self.num_genes = num_genes
self.rank = rank
self.num_channels = num_channels
# Basic U-Net architecture
wf = prev_channels = 32
self.depth = 4
self.down_path = nn.ModuleList()
self.conv_init = nn.Conv2d(num_channels, wf, 3, 1, 1)
for i in range(self.depth):
use_HIN = True if i <= 4 else False
downsample = True if (i + 1) < self.depth else False
self.down_path.append(
UNetConvBlock(prev_channels, (2 ** i) * wf, downsample, 0.2, use_HIN=use_HIN))
prev_channels = (2 ** i) * wf
self.up_path = nn.ModuleList()
self.skip_conv = nn.ModuleList()
for i in reversed(range(self.depth - 1)):
self.up_path.append(UNetUpBlock(prev_channels, (2 ** i) * wf, 0.2))
self.skip_conv.append(nn.Conv2d((2 ** i) * wf, (2 ** i) * wf, 3, 1, 1))
prev_channels = (2 ** i) * wf
self.output = nn.Sequential(
UNetConvBlock(prev_channels, self.num_channels, False, 0.2),
nn.Sigmoid()
)
# Low-rank representation
self.low_rank = UNetConvBlock(prev_channels, self.rank, False, 0.2)
self.image_pred = nn.Sequential(
UNetConvBlock(self.rank, self.num_channels, False, 0.2),
nn.Sigmoid()
)
# Decoder for Non-negative Matrix Factorization (NMF)
self.nmf_decoder = nn.Parameter(torch.randn(self.num_genes, self.rank), requires_grad=True)
self.apply(__initial_weights__)
self.training_mode = False
# Remove batch effect
self.gamma = None
if reference is not None:
gamma = torch.zeros(len(reference), self.num_genes)
self.gamma = nn.Parameter(gamma, requires_grad=True)
self.query2idx = {query: torch.tensor(idx) for idx, query in enumerate(reference)}
[docs]
def forward(self, image, section_name=None, feasible_coord=None, vd_score=None, encode_only=False):
"""
Forward pass for the SpaHDmapUnet model.
Parameters
----------
image
Input image tensor.
section_name
Section name for batch effect removal. Default to None.
feasible_coord
Dictionary of feasible coordinates. Default to None.
vd_score
Input tensor representing the sequenced spot embeddings. Default to None.
encode_only
Whether to only perform encoding. Default to False.
Returns
-------
image_pred
Predicted image.
spot_exp_pred
Predicted spot expression (if feasible coordinates are provided).
HR_score
High-resolution pixel-wise embedding output.
"""
x1 = self.conv_init(image)
encs = []
for i, down in enumerate(self.down_path):
if (i + 1) < self.depth:
x1, x1_up = down(x1)
encs.append(x1_up)
else:
x1 = down(x1)
for i, up in enumerate(self.up_path):
x1 = up(x1, self.skip_conv[i](encs[-i - 1]))
# For pretraining stage, only return the predicted image and patch features
if not self.training_mode:
image_pred = self.output(x1)
if self.training:
return image_pred
return image_pred, encs[-1]
# Low-rank representation
low_rank_score = self.low_rank(x1)
vd_score_logit = torch.logit(vd_score, eps=1.388794e-11)
HR_score = torch.sigmoid(vd_score_logit + low_rank_score)
# Return high-definition pixel-wise embedding output if only performing encoding
if encode_only: return HR_score
# Image prediction
image_pred = self.image_pred(HR_score)
# If no feasible coordinates are provided, return the image prediction and high-definition pixel-wise embedding
if len(feasible_coord) == 0: return image_pred, None, HR_score
# Get spot scores through averaging the high-definition pixel-wise embedding output
spot_score = [torch.mean(HR_score[0, :, coord[0], coord[1]], dim=1) for _, coord in feasible_coord.items()]
spot_score = torch.stack(spot_score, dim=0)
# Predict spot expression based on multiplying the spot scores with the NMF decoder (all are non-negative)
nmf_decoder_limited = torch.relu(self.nmf_decoder)
spot_exp_pred = F.linear(spot_score, nmf_decoder_limited)
# Remove batch effect
if self.gamma is not None and section_name in self.query2idx:
query_idx = self.query2idx[section_name]
spot_exp_pred = torch.relu(spot_exp_pred + self.gamma[query_idx, :])
return image_pred, spot_exp_pred, HR_score
class GraphConv(nn.Module):
"""
A graph convolutional layer for graph neural networks.
Parameters
----------
input_dim
The input dimension of the graph convolutional layer.
output_dim
The output dimension of the graph convolutional layer.
Returns
-------
torch.Tensor
The output tensor of the graph convolutional layer.
"""
def __init__(self, input_dim: int, output_dim: int):
"""
Attribute:
Initialize the `GraphConv` model.
"""
super(GraphConv, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
x = self.linear(x)
return torch.sparse.mm(adj, x)
[docs]
class GraphAutoEncoder(nn.Module):
"""
A graph autoencoder for predicting spot embeddings.
Parameters
----------
adj_matrix
The adjacency matrix of the graph.
num_spots
The number of spots in the dataset.
rank
The rank of the graph autoencoder. Defaults to 20.
Example
-------
>>> adj_matrix = torch.rand(10, 10)
>>> model = GraphAutoEncoder(adj_matrix, num_spots=5, rank=20)
>>> score = torch.rand(5, 20)
>>> model(score)
"""
def __init__(self,
adj_matrix: torch.sparse.Tensor,
num_spots: int,
rank: int = 20):
"""
Attribute:
Initialize the `GraphAutoEncoder` model.
"""
super(GraphAutoEncoder, self).__init__()
self.rank = rank
self.adj_matrix = adj_matrix
# Parameters of pseudo spots' initial embeddings
self.pseudo_score = nn.Parameter(torch.randn((adj_matrix.shape[0] - num_spots, rank)), requires_grad=True)
# Define graph convolutional layers
self.gc1 = GraphConv(input_dim=rank, output_dim=64)
self.gc2 = GraphConv(input_dim=64, output_dim=256)
self.gc3 = GraphConv(input_dim=256, output_dim=64)
self.gc4 = GraphConv(input_dim=64, output_dim=rank)
self.apply(__initial_weights__)
[docs]
def forward(self, score):
"""
Forward pass for the GraphAutoEncoder.
Parameters
----------
score
Input tensor representing the sequenced spot embeddings.
Returns
-------
y
Reconstructed spot embedding whose values are limited to [0, 1].
"""
# Apply sigmoid to latent strengths to limit their values
pseudo_score = torch.sigmoid(self.pseudo_score)
# Concatenate the sequenced and pseudo spot embeddings
x = torch.cat([score, pseudo_score], dim=0)
# Graph Convolutional Layers
x = F.relu(self.gc1(x, self.adj_matrix))
x = F.relu(self.gc2(x, self.adj_matrix))
x = F.relu(self.gc3(x, self.adj_matrix))
# Reconstructed spot embedding whose values are limited to [0, 1]
y = F.sigmoid(self.gc4(x, self.adj_matrix))
return y
def __initial_weights__(module):
# Initialize the weights of the model
if isinstance(module, nn.Conv2d):
init.kaiming_normal_(module.weight, mode='fan_out', a=0.2, nonlinearity='leaky_relu')
if module.bias is not None:
init.constant_(module.bias, 0)
elif isinstance(module, nn.BatchNorm2d):
init.constant_(module.weight, 1)
init.constant_(module.bias, 0)
elif isinstance(module, nn.Linear):
init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
init.constant_(module.bias, 0)