File size: 6,525 Bytes
6788772
 
 
 
 
552cf9a
 
6788772
 
 
 
 
 
 
 
 
 
 
 
 
 
4981657
6788772
 
 
 
 
4981657
6788772
 
 
 
 
 
 
2f63b5b
c23f8b5
6788772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be4e6f5
6788772
 
 
 
 
 
 
 
 
 
 
c23f8b5
 
 
6788772
 
9565da7
6788772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be4e6f5
6788772
be4e6f5
6788772
 
 
 
 
 
 
 
 
 
 
 
 
 
c23f8b5
 
6788772
 
c23f8b5
6788772
be4e6f5
6788772
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import torch.nn as nn
from torch_geometric.nn import radius_graph
import scanpy as sc
from huggingface_hub import PyTorchModelHubMixin
from models_cifm.mlp_and_gnn import MLPBiasFree
from models_cifm.egnn_void_invariant import VIEGNNModel


class CIFM(
    nn.Module,
    PyTorchModelHubMixin, 
    # optionally, you can add metadata which gets pushed to the model card
    repo_url='ynyou/CIFM',
    pipeline_tag='mask-generation',
    license='mit',
):
    def __init__(self, args):
        super().__init__()
        self.gene_encoder = MLPBiasFree(in_dim=args.in_dim, out_dim=args.hidden_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
        self.model = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module,
                emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=False)
        self.mask_cell_decoder = VIEGNNModel(num_layers=args.num_layer, num_mlp_layers_in_module=args.num_mlp_layers_in_module,
                emb_dim=args.hidden_dim, in_dim=args.hidden_dim, out_dim=args.hidden_dim, residual=False)
        self.mask_cell_expression = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
        self.mask_cell_dropout = MLPBiasFree(in_dim=args.hidden_dim, out_dim=args.in_dim, hidden_dim=args.hidden_dim, num_layer=args.num_mlp_layers_in_module)
        self.mask_embedding = nn.Embedding(1, args.hidden_dim)
        self.proj = MLPBiasFree(in_dim=args.hidden_dim, out_dim=1, hidden_dim=args.hidden_dim, num_layer=4)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        self.hidden_dim = args.hidden_dim
        self.radius_spatial_graph = args.radius_spatial_graph
    
    def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
        device = next(self.parameters()).device

        linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
        linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
        linear_out2 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
        if zero_init_for_unmatched_genes:
            linear_in.weight.data.zero_()
            linear_out1.weight.data.zero_()
            linear_out2.weight.data.zero_()

        num_matching = 0
        unmatched_channels = []
        for idx_target, ensembls in enumerate(channel2ensembl_ids_target):
            if len(ensembls) == 0:
                continue

            embs_in = []
            embs_out1 = []
            embs_out2 = []
            for ensembl in ensembls:
                for idx_source, ensembles2 in enumerate(channel2ensembl_ids_source):
                    if ensembl in ensembles2:
                        embs_in.append(self.gene_encoder.layers[0].weight.data[:, idx_source])
                        embs_out1.append(self.mask_cell_expression.layers[-1].weight.data[idx_source])
                        embs_out2.append(self.mask_cell_dropout.layers[-1].weight.data[idx_source])
            
            if len(embs_in) == 0:
                unmatched_channels += ensembls
                continue
            
            embs_in = torch.stack(embs_in).mean(dim=0)
            embs_out1 = torch.stack(embs_out1).mean(dim=0)
            embs_out2 = torch.stack(embs_out2).mean(dim=0)
            linear_in.weight.data[:, idx_target] = embs_in
            linear_out1.weight.data[idx_target] = embs_out1
            linear_out2.weight.data[idx_target] = embs_out2

            num_matching += 1
        
        self.gene_encoder.layers[0] = linear_in.to(device)
        self.mask_cell_expression.layers[-1] = linear_out1.to(device)
        self.mask_cell_dropout.layers[-1] = linear_out2.to(device)

        unmatched_channels = list(set(unmatched_channels))
        print('matching', num_matching, 'gene channels out of', len(channel2ensembl_ids_target), '; unmatched channels:', unmatched_channels)

    def forward(self):
        pass
    
    def encode(self, expressions, coordinates, edge_index):
        embeddings = self.gene_encoder(expressions)
        embeddings, _ = self.model(embeddings, coordinates, edge_index)
        return embeddings
    
    def encode_decode(self, expressions, coordinates, edge_index, mapping):
        device = expressions.device

        embeddings = self.encode(expressions, coordinates, edge_index)
        embeddings[mapping] = self.mask_embedding(torch.zeros(1, dtype=torch.int64).to(device))
        embeddings_dec = self.mask_cell_decoder(embeddings, coordinates, edge_index)[0][mapping]

        expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
        dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))

        expressions_dec[dropouts_dec<=0.5] = 0
        return expressions_dec
    
    def embed(self, adata):
        device = next(self.parameters()).device

        expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
        coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
        coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
        edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True)

        embeddings = self.encode(expressions, coordinates, edge_index)
        return embeddings

    def predict_cells_at_locations(self, adata, locations):
        device = next(self.parameters()).device
    
        locations = torch.tensor(locations, dtype=torch.float32)

        expressions = torch.tensor(adata.X.toarray(), dtype=torch.float32).to(device)
        expressions = torch.cat([expressions, torch.zeros(locations.shape[0], expressions.shape[1]).to(device)], dim=0)
        coordinates = torch.tensor(adata.obsm['spatial'], dtype=torch.float32)
        coordinates = torch.cat([coordinates, locations], dim=0)
        coordinates = torch.cat([coordinates, torch.zeros(coordinates.shape[0], 1)], dim=1).to(device)
        edge_index = radius_graph(coordinates, r=self.radius_spatial_graph, max_num_neighbors=10000, loop=True)
        idx_cells_to_predict = torch.arange(expressions.shape[0]-locations.shape[0], expressions.shape[0]).to(device)

        expressions_pred = self.encode_decode(expressions, coordinates, edge_index, idx_cells_to_predict)
        return expressions_pred