import math import torch from torch.nn.parameter import Parameter import torch.nn as nn import torch.nn.functional as F from networks import graph # import pdb class GraphConvolution(nn.Module): def __init__(self,in_features,out_features,bias=False): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.FloatTensor(in_features,out_features)) if bias: self.bias = Parameter(torch.FloatTensor(out_features)) else: self.register_parameter('bias',None) self.reset_parameters() def reset_parameters(self): # stdv = 1./math.sqrt(self.weight(1)) # self.weight.data.uniform_(-stdv,stdv) torch.nn.init.xavier_uniform_(self.weight) # if self.bias is not None: # self.bias.data.uniform_(-stdv,stdv) def forward(self, input,adj=None,relu=False): support = torch.matmul(input, self.weight) # print(support.size(),adj.size()) if adj is not None: output = torch.matmul(adj, support) else: output = support # print(output.size()) if self.bias is not None: return output + self.bias else: if relu: return F.relu(output) else: return output def __repr__(self): return self.__class__.__name__ + ' (' \ + str(self.in_features) + ' -> ' \ + str(self.out_features) + ')' class Featuremaps_to_Graph(nn.Module): def __init__(self,input_channels,hidden_layers,nodes=7): super(Featuremaps_to_Graph, self).__init__() self.pre_fea = Parameter(torch.FloatTensor(input_channels,nodes)) self.weight = Parameter(torch.FloatTensor(input_channels,hidden_layers)) self.reset_parameters() def forward(self, input): n,c,h,w = input.size() # print('fea input',input.size()) input1 = input.view(n,c,h*w) input1 = input1.transpose(1,2) # n x hw x c # print('fea input1', input1.size()) ############## Feature maps to node ################ fea_node = torch.matmul(input1,self.pre_fea) # n x hw x n_classes weight_node = torch.matmul(input1,self.weight) # n x hw x hidden_layer # softmax fea_node fea_node = F.softmax(fea_node,dim=-1) # print(fea_node.size(),weight_node.size()) graph_node = F.relu(torch.matmul(fea_node.transpose(1,2),weight_node)) return graph_node # n x n_class x hidden_layer def reset_parameters(self): for ww in self.parameters(): torch.nn.init.xavier_uniform_(ww) # if self.bias is not None: # self.bias.data.uniform_(-stdv,stdv) class Featuremaps_to_Graph_transfer(nn.Module): def __init__(self,input_channels,hidden_layers,nodes=7, source_nodes=20): super(Featuremaps_to_Graph_transfer, self).__init__() self.pre_fea = Parameter(torch.FloatTensor(input_channels,nodes)) self.weight = Parameter(torch.FloatTensor(input_channels,hidden_layers)) self.pre_fea_transfer = nn.Sequential(*[nn.Linear(source_nodes, source_nodes),nn.LeakyReLU(True), nn.Linear(source_nodes, nodes), nn.LeakyReLU(True)]) self.reset_parameters() def forward(self, input, source_pre_fea): self.pre_fea.data = self.pre_fea_learn(source_pre_fea) n,c,h,w = input.size() # print('fea input',input.size()) input1 = input.view(n,c,h*w) input1 = input1.transpose(1,2) # n x hw x c # print('fea input1', input1.size()) ############## Feature maps to node ################ fea_node = torch.matmul(input1,self.pre_fea) # n x hw x n_classes weight_node = torch.matmul(input1,self.weight) # n x hw x hidden_layer # softmax fea_node fea_node = F.softmax(fea_node,dim=1) # print(fea_node.size(),weight_node.size()) graph_node = F.relu(torch.matmul(fea_node.transpose(1,2),weight_node)) return graph_node # n x n_class x hidden_layer def pre_fea_learn(self, input): pre_fea = self.pre_fea_transfer.forward(input.unsqueeze(0)).squeeze(0) return self.pre_fea.data + pre_fea class Graph_to_Featuremaps(nn.Module): # this is a special version def __init__(self,input_channels,output_channels,hidden_layers,nodes=7): super(Graph_to_Featuremaps, self).__init__() self.node_fea = Parameter(torch.FloatTensor(input_channels+hidden_layers,1)) self.weight = Parameter(torch.FloatTensor(hidden_layers,output_channels)) self.reset_parameters() def reset_parameters(self): for ww in self.parameters(): torch.nn.init.xavier_uniform_(ww) def forward(self, input, res_feature): ''' :param input: 1 x batch x nodes x hidden_layer :param res_feature: batch x channels x h x w :return: ''' batchi,channeli,hi,wi = res_feature.size() # print(res_feature.size()) # print(input.size()) try: _,batch,nodes,hidden = input.size() except: # print(input.size()) input = input.unsqueeze(0) _,batch, nodes, hidden = input.size() assert batch == batchi input1 = input.transpose(0,1).expand(batch,hi*wi,nodes,hidden) res_feature_after_view = res_feature.view(batch,channeli,hi*wi).transpose(1,2) res_feature_after_view1 = res_feature_after_view.unsqueeze(2).expand(batch,hi*wi,nodes,channeli) new_fea = torch.cat((res_feature_after_view1,input1),dim=3) # print(self.node_fea.size(),new_fea.size()) new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1 new_weight = torch.matmul(input, self.weight) # batch x node x channel new_node = new_node.view(batch, hi*wi, nodes) # 0721 new_node = F.softmax(new_node, dim=-1) # feature_out = torch.matmul(new_node,new_weight) # print(feature_out.size()) feature_out = feature_out.transpose(2,3).contiguous().view(res_feature.size()) return F.relu(feature_out) class Graph_to_Featuremaps_savemem(nn.Module): # this is a special version for saving gpu memory. The process is same as Graph_to_Featuremaps. def __init__(self, input_channels, output_channels, hidden_layers, nodes=7): super(Graph_to_Featuremaps_savemem, self).__init__() self.node_fea_for_res = Parameter(torch.FloatTensor(input_channels, 1)) self.node_fea_for_hidden = Parameter(torch.FloatTensor(hidden_layers, 1)) self.weight = Parameter(torch.FloatTensor(hidden_layers,output_channels)) self.reset_parameters() def reset_parameters(self): for ww in self.parameters(): torch.nn.init.xavier_uniform_(ww) def forward(self, input, res_feature): ''' :param input: 1 x batch x nodes x hidden_layer :param res_feature: batch x channels x h x w :return: ''' batchi,channeli,hi,wi = res_feature.size() # print(res_feature.size()) # print(input.size()) try: _,batch,nodes,hidden = input.size() except: # print(input.size()) input = input.unsqueeze(0) _,batch, nodes, hidden = input.size() assert batch == batchi input1 = input.transpose(0,1).expand(batch,hi*wi,nodes,hidden) res_feature_after_view = res_feature.view(batch,channeli,hi*wi).transpose(1,2) res_feature_after_view1 = res_feature_after_view.unsqueeze(2).expand(batch,hi*wi,nodes,channeli) # new_fea = torch.cat((res_feature_after_view1,input1),dim=3) ## sim new_node1 = torch.matmul(res_feature_after_view1, self.node_fea_for_res) new_node2 = torch.matmul(input1, self.node_fea_for_hidden) new_node = new_node1 + new_node2 ## sim end # print(self.node_fea.size(),new_fea.size()) # new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1 new_weight = torch.matmul(input, self.weight) # batch x node x channel new_node = new_node.view(batch, hi*wi, nodes) # 0721 new_node = F.softmax(new_node, dim=-1) # feature_out = torch.matmul(new_node,new_weight) # print(feature_out.size()) feature_out = feature_out.transpose(2,3).contiguous().view(res_feature.size()) return F.relu(feature_out) class Graph_trans(nn.Module): def __init__(self,in_features,out_features,begin_nodes=7,end_nodes=2,bias=False,adj=None): super(Graph_trans, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.FloatTensor(in_features,out_features)) if adj is not None: h,w = adj.size() assert (h == end_nodes) and (w == begin_nodes) self.adj = torch.autograd.Variable(adj,requires_grad=False) else: self.adj = Parameter(torch.FloatTensor(end_nodes,begin_nodes)) if bias: self.bias = Parameter(torch.FloatTensor(out_features)) else: self.register_parameter('bias',None) # self.reset_parameters() def reset_parameters(self): # stdv = 1./math.sqrt(self.weight(1)) # self.weight.data.uniform_(-stdv,stdv) torch.nn.init.xavier_uniform_(self.weight) # if self.bias is not None: # self.bias.data.uniform_(-stdv,stdv) def forward(self, input, relu=False, adj_return=False, adj=None): support = torch.matmul(input,self.weight) # print(support.size(),self.adj.size()) if adj is None: adj = self.adj adj1 = self.norm_trans_adj(adj) output = torch.matmul(adj1,support) if adj_return: output1 = F.normalize(output,p=2,dim=-1) self.adj_mat = torch.matmul(output1,output1.transpose(-2,-1)) if self.bias is not None: return output + self.bias else: if relu: return F.relu(output) else: return output def get_adj_mat(self): adj = graph.normalize_adj_torch(F.relu(self.adj_mat)) return adj def get_encode_adj(self): return self.adj def norm_trans_adj(self,adj): # maybe can use softmax adj = F.relu(adj) r = F.softmax(adj,dim=-1) # print(adj.size()) # row_sum = adj.sum(-1).unsqueeze(-1) # d_mat = row_sum.expand(adj.size()) # r = torch.div(row_sum,d_mat) # r[torch.isnan(r)] = 0 return r if __name__ == '__main__': graph = torch.randn((7,128)) en = GraphConvolution(128,128) a = en.forward(graph) print(a) # a = en.forward(graph,pred) # print(a.size())