File size: 6,525 Bytes
6788772 552cf9a 6788772 4981657 6788772 4981657 6788772 2f63b5b c23f8b5 6788772 be4e6f5 6788772 c23f8b5 6788772 9565da7 6788772 be4e6f5 6788772 be4e6f5 6788772 c23f8b5 6788772 c23f8b5 6788772 be4e6f5 6788772 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import torch.nn as nn
from torch_geometric.nn import radius_graph
import scanpy as sc
from huggingface_hub import PyTorchModelHubMixin
from models_cifm.mlp_and_gnn import MLPBiasFree
from models_cifm.egnn_void_invariant import VIEGNNModel
class CIFM(
nn.Module,
PyTorchModelHubMixin,
# optionally, you can add metadata which gets pushed to the model card
repo_url='ynyou/CIFM',
pipeline_tag='mask-generation',
license='mit',
):
def __init__(self, args):
super().__init__()
self.gene_encoder = MLPBiasFree(in_dim=args.in_dim, out_dim=args.hidden_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
self.model = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module,
emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=False)
self.mask_cell_decoder = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module,
emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=False)
self.mask_cell_expression = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
self.mask_cell_dropout = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
self.mask_embedding = nn.Embedding(1, args.hidden_dim)
self.proj = MLPBiasFree(in_dim=args.hidden_dim, out_dim=1, hidden_dim=args.hidden_dim, num_layer=4)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.hidden_dim = args.hidden_dim
self.radius_spatial_graph = args.radius_spatial_graph
def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
device = next(self.parameters()).device
linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
linear_out2 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
if zero_init_for_unmatched_genes:
linear_in.weight.data.zero_()
linear_out1.weight.data.zero_()
linear_out2.weight.data.zero_()
num_matching = 0
unmatched_channels = []
for idx_target, ensembls in enumerate(channel2ensembl_ids_target):
if len(ensembls) == 0:
continue
embs_in = []
embs_out1 = []
embs_out2 = []
for ensembl in ensembls:
for idx_source, ensembles2 in enumerate(channel2ensembl_ids_source):
if ensembl in ensembles2:
embs_in.append(self.gene_encoder.layers[0].weight.data[:, idx_source])
embs_out1.append(self.mask_cell_expression.layers[-1].weight.data[idx_source])
embs_out2.append(self.mask_cell_dropout.layers[-1].weight.data[idx_source])
if len(embs_in) == 0:
unmatched_channels += ensembls
continue
embs_in = torch.stack(embs_in).mean(dim=0)
embs_out1 = torch.stack(embs_out1).mean(dim=0)
embs_out2 = torch.stack(embs_out2).mean(dim=0)
linear_in.weight.data[:, idx_target] = embs_in
linear_out1.weight.data[idx_target] = embs_out1
linear_out2.weight.data[idx_target] = embs_out2
num_matching += 1
self.gene_encoder.layers[0] = linear_in.to(device)
self.mask_cell_expression.layers[-1] = linear_out1.to(device)
self.mask_cell_dropout.layers[-1] = linear_out2.to(device)
unmatched_channels = list(set(unmatched_channels))
print('matching', num_matching, 'gene channels out of', len(channel2ensembl_ids_target), '; unmatched channels:', unmatched_channels)
def forward(self):
pass
def encode(self, expressions, coordinates, edge_index):
embeddings = self.gene_encoder(expressions)
embeddings, _ = self.model(embeddings, coordinates, edge_index)
return embeddings
def encode_decode(self, expressions, coordinates, edge_index, mapping):
device = expressions.device
embeddings = self.encode(expressions, coordinates, edge_index)
embeddings[mapping] = self.mask_embedding(torch.zeros(1, dtype=torch.int64).to(device))
embeddings_dec = self.mask_cell_decoder(embeddings, coordinates, edge_index)[0][mapping]
expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
expressions_dec[dropouts_dec<=0.5] = 0
return expressions_dec
def embed(self, adata):
device = next(self.parameters()).device
expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True)
embeddings = self.encode(expressions, coordinates, edge_index)
return embeddings
def predict_cells_at_locations(self, adata, locations):
device = next(self.parameters()).device
locations = torch.tensor(locations, dtype=torch.float32)
expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1]).to(device)], dim=0)
coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
coordinates = torch.cat([coordinates, locations], dim=0)
coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True)
idx_cells_to_predict = torch.arange(expressions.shape[0]-locations.shape[0], expressions.shape[0]).to(device)
expressions_pred = self.encode_decode(expressions, coordinates, edge_index, idx_cells_to_predict)
return expressions_pred
|