|
import os, sys |
|
import shutil |
|
import glob |
|
import torch |
|
import numpy as np |
|
import copy |
|
from itertools import groupby |
|
from operator import itemgetter |
|
import json |
|
import re |
|
import random |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
from tqdm import tqdm |
|
import random |
|
import Bio |
|
from icecream import ic |
|
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
conversion = 'ARNDCQEGHILKMFPSTWYVX-' |
|
|
|
|
|
|
|
|
|
|
|
|
|
class Potential: |
|
|
|
def get_gradients(seq): |
|
''' |
|
EVERY POTENTIAL CLASS MUST RETURN GRADIENTS |
|
''' |
|
|
|
sys.exit('ERROR POTENTIAL HAS NOT BEEN IMPLEMENTED') |
|
|
|
|
|
class AACompositionalBias(Potential): |
|
""" |
|
T = number of timesteps to set up diffuser with |
|
|
|
schedule = type of noise schedule to use linear, cosine, gaussian |
|
|
|
noise = type of ditribution to sample from; DEFAULT - normal_gaussian |
|
|
|
""" |
|
|
|
def __init__(self, args, features, potential_scale, DEVICE): |
|
|
|
self.L = features['L'] |
|
self.DEVICE = DEVICE |
|
self.frac_seq_to_weight = args['frac_seq_to_weight'] |
|
self.add_weight_every_n = args['add_weight_every_n'] |
|
self.aa_weights_json = args['aa_weights_json'] |
|
self.one_weight_per_position = args['one_weight_per_position'] |
|
self.aa_weight = args['aa_weight'] |
|
self.aa_spec = args['aa_spec'] |
|
self.aa_composition = args['aa_composition'] |
|
self.potential_scale = potential_scale |
|
|
|
self.aa_weights_to_add = [0 for l in range(21)] |
|
self.aa_max_potential = None |
|
|
|
|
|
if self.aa_weights_json != None: |
|
with open(self.aa_weights_json, 'r') as f: |
|
aa_weights = json.load(f) |
|
else: |
|
aa_weights = {} |
|
|
|
for k,v in aa_weights.items(): |
|
aa_weights_to_add[conversion.index(k)] = v |
|
|
|
aa_weights_to_add = [0 for l in range(21)] |
|
|
|
self.aa_weights_to_add = torch.tensor(aa_weights_to_add)[None].repeat(self.L,1).to(self.DEVICE, non_blocking=True) |
|
|
|
|
|
if self.add_weight_every_n > 1 or self.frac_seq_to_weight > 0: |
|
|
|
assert (self.add_weight_every_n > 1) ^ (self.frac_seq_to_weight > 0), 'use either --add_weight_every_n or --frac_seq_to_weight but not both' |
|
weight_mask = torch.zeros_like(self.aa_weights_to_add) |
|
if add_weight_every_n > 1: |
|
idxs_to_unmask = torch.arange(0,self.L,self.add_weight_every_n) |
|
else: |
|
indexs = np.arange(0,self.L).tolist() |
|
idxs_to_unmask = random.sample(indexs,int(self.frac_seq_to_weight*self.L)) |
|
idxs_to_unmask.sort() |
|
|
|
weight_mask[idxs_to_unmask,:] = 1 |
|
self.aa_weights_to_add *= weight_mask |
|
|
|
if one_weight_per_position: |
|
for p in range(self.aa_weights_to_add.shape[0]): |
|
where_ones = torch.where(self.aa_weights_to_add[p,:] > 0)[0].tolist() |
|
if len(where_ones) > 0: |
|
w_sample = random.sample(where_ones,1)[0] |
|
self.aa_weights_to_add[p,:w_sample] = 0 |
|
self.aa_weights_to_add[p,w_sample+1:] = 0 |
|
|
|
elif self.aa_spec != None: |
|
|
|
assert self.aa_weight != None, 'please specify --aa_weight' |
|
|
|
|
|
repeat_len = len(self.aa_spec) |
|
weight_split = [float(x) for x in self.aa_weight.split(',')] |
|
|
|
aa_idxs = [] |
|
for k,c in enumerate(self.aa_spec): |
|
if c != 'X': |
|
assert c in conversion, f'the letter you have chosen is not an amino acid: {c}' |
|
aa_idxs.append((k,conversion.index(c))) |
|
|
|
if len(self.aa_weight) > 1: |
|
assert len(aa_idxs) == len(weight_split), f'need to give same number of weights as AAs in weight spec' |
|
|
|
self.aa_weights_to_add = torch.zeros(self.L,21) |
|
|
|
for p,w in zip(aa_idxs,weight_split): |
|
x,a = p |
|
self.aa_weights_to_add[x,a] = w |
|
|
|
self.aa_weights_to_add = self.aa_weights_to_add[:repeat_len,:].repeat(self.L//repeat_len+1,1)[:self.L].to(self.DEVICE, non_blocking=True) |
|
|
|
elif self.aa_composition != None: |
|
|
|
self.aa_comp = [(x[0],float(x[1:])) for x in self.aa_composition.split(',')] |
|
self.aa_max_potential = 0 |
|
assert sum([f for aa,f in self.aa_comp]) <= 1, f'total sequence fraction specified in aa_composition is > 1' |
|
|
|
else: |
|
sys.exit(f'You are missing an argument to use the aa_bias potential') |
|
|
|
def get_gradients(self, seq): |
|
''' |
|
seq = L,21 |
|
|
|
return gradients to update the sequence with for the next pass |
|
''' |
|
|
|
if self.aa_max_potential != None: |
|
soft_seq = torch.softmax(seq, dim=1) |
|
print('ADDING SOFTMAXED SEQUENCE POTENTIAL') |
|
|
|
aa_weights_to_add_list = [] |
|
for aa,f in self.aa_comp: |
|
aa_weights_to_add_copy = self.aa_weights_to_add.clone() |
|
|
|
soft_seq_tmp = soft_seq.clone().detach().requires_grad_(True) |
|
aa_idx = conversion.index(aa) |
|
|
|
|
|
where_add = torch.topk(soft_seq_tmp[:,aa_idx], int(f*self.L))[1] |
|
|
|
|
|
aa_potential = torch.zeros(21) |
|
aa_potential[conversion.index(aa)] = 1.0 |
|
aa_potential = aa_potential.repeat(self.L,1).to(self.DEVICE, non_blocking=True) |
|
|
|
|
|
aa_comp_loss = torch.sum(torch.sum((aa_potential - soft_seq_tmp)**2, dim=1)**0.5) |
|
|
|
|
|
aa_comp_loss.backward() |
|
update_grads = soft_seq_tmp.grad |
|
|
|
for k in range(self.L): |
|
if k in where_add: |
|
aa_weights_to_add_copy[k,:] = -update_grads[k,:]*self.potential_scale |
|
else: |
|
aa_weights_to_add_copy[k,:] = update_grads[k,:]*self.potential_scale |
|
aa_weights_to_add_list.append(aa_weights_to_add_copy) |
|
|
|
aa_weights_to_add_array = torch.stack((aa_weights_to_add_list)) |
|
self.aa_weights_to_add = torch.mean(aa_weights_to_add_array.float(), 0) |
|
|
|
|
|
return self.aa_weights_to_add |
|
|
|
|
|
class HydrophobicBias(Potential): |
|
""" |
|
Calculate loss with respect to soft_seq of the sequence hydropathy index (Kyte and Doolittle, 1986). |
|
|
|
T = number of timesteps to set up diffuser with |
|
|
|
schedule = type of noise schedule to use linear, cosine, gaussian |
|
|
|
noise = type of ditribution to sample from; DEFAULT - normal_gaussian |
|
|
|
""" |
|
def __init__(self, args, features, potential_scale, DEVICE): |
|
|
|
self.target_score = args['hydrophobic_score'] |
|
self.potential_scale = potential_scale |
|
self.loss_type = args['hydrophobic_loss_type'] |
|
print(f'USING {self.loss_type} LOSS TYPE...') |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.alpha_1 = list("ARNDCQEGHILKMFPSTWYVX") |
|
|
|
|
|
self.gravy_dict = {'C': 2.5, 'D': -3.5, 'S': -0.8, 'Q': -3.5, 'K': -3.9, |
|
'I': 4.5, 'P': -1.6, 'T': -0.7, 'F': 2.8, 'N': -3.5, |
|
'G': -0.4, 'H': -3.2, 'L': 3.8, 'R': -4.5, 'W': -0.9, |
|
'A': 1.8, 'V':4.2, 'E': -3.5, 'Y': -1.3, 'M': 1.9, 'X': 0, '-': 0} |
|
|
|
self.gravy_list = [self.gravy_dict[a] for a in self.alpha_1] |
|
|
|
|
|
|
|
|
|
print(f'GUIDING SEQUENCES TO HAVE TARGET GRAVY SCORE OF: {self.target_score}') |
|
return None |
|
|
|
|
|
def get_gradients(self, seq): |
|
""" |
|
Calculate gradients with respect to GRAVY index of input seq. |
|
Uses a MSE loss. |
|
|
|
Arguments |
|
--------- |
|
seq : tensor |
|
L X 21 logits after saving seq_out from xt |
|
|
|
Returns |
|
------- |
|
gradients : list of tensors |
|
gradients of soft_seq with respect to loss on partial_charge |
|
""" |
|
|
|
gravy_matrix = torch.tensor(self.gravy_list)[None].repeat(seq.shape[0],1).to(DEVICE) |
|
|
|
|
|
soft_seq = torch.softmax(seq,dim=-1).requires_grad_(requires_grad=True).to(DEVICE) |
|
|
|
|
|
if self.loss_type == 'simple': |
|
gravy_score = torch.mean(torch.sum(soft_seq*gravy_matrix,dim=-1), dim=0) |
|
loss = ((gravy_score - self.target_score)**2)**0.5 |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
self.gradients = soft_seq.grad |
|
|
|
|
|
|
|
|
|
|
|
elif self.loss_type == 'complex': |
|
loss = torch.mean((torch.sum(soft_seq*gravy_matrix, dim = -1) - self.target_score)**2) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
self.gradients = soft_seq.grad |
|
|
|
|
|
|
|
|
|
return -self.gradients*self.potential_scale |
|
|
|
|
|
class ChargeBias(Potential): |
|
""" |
|
Calculate losses and get gradients with respect to soft_seq for the sequence charge at a given pH. |
|
|
|
T = number of timesteps to set up diffuser with |
|
|
|
schedule = type of noise schedule to use linear, cosine, gaussian |
|
|
|
noise = type of ditribution to sample from; DEFAULT - normal_gaussian |
|
|
|
""" |
|
def __init__(self, args, features, potential_scale, DEVICE): |
|
|
|
self.target_charge = args['target_charge'] |
|
self.pH = args['target_pH'] |
|
self.loss_type = args['charge_loss_type'] |
|
self.potential_scale = potential_scale |
|
self.L = features['L'] |
|
self.DEVICE = DEVICE |
|
|
|
|
|
|
|
|
|
|
|
|
|
pos_pKs_list = [[0.0, 12.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.98, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] |
|
neg_pKs_list = [[0.0, 0.0, 0.0, 4.05, 9.0, 0.0, 4.45, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0]] |
|
cterm_pKs_list = [[0.0, 0.0, 0.0, 4.55, 0.0, 0.0, 4.75, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] |
|
nterm_pKs_list = [[7.59, 0.0, 0.0, 0.0, 0.0, 0.0, 7.7, 0.0, 0.0, 0.0, 0.0, 0.0, 7.0, 0.0, 8.36, 6.93, 6.82, 0.0, 0.0, 7.44, 0.0]] |
|
|
|
|
|
self.cterm_pKs = torch.tensor(cterm_pKs_list) |
|
self.nterm_pKs = torch.tensor(nterm_pKs_list) |
|
self.pos_pKs = torch.tensor(pos_pKs_list) |
|
self.neg_pKs = torch.tensor(neg_pKs_list) |
|
|
|
|
|
pos_pKs_repeat = self.pos_pKs.repeat(self.L - 2, 1) |
|
neg_pKs_repeat = self.neg_pKs.repeat(self.L - 2, 1) |
|
|
|
|
|
self.pos_pKs_matrix = torch.cat((torch.zeros_like(self.nterm_pKs), pos_pKs_repeat, self.nterm_pKs)).to(DEVICE) |
|
self.neg_pKs_matrix = torch.cat((self.cterm_pKs, neg_pKs_repeat, torch.zeros_like(self.cterm_pKs))).to(DEVICE) |
|
|
|
|
|
self.cterm_charged_idx = torch.nonzero(self.cterm_pKs) |
|
self.cterm_neutral_idx = torch.nonzero(self.cterm_pKs == 0) |
|
self.nterm_charged_idx = torch.nonzero(self.nterm_pKs) |
|
self.nterm_neutral_idx = torch.nonzero(self.nterm_pKs == 0) |
|
self.pos_pKs_idx = torch.tensor([[1, 8, 11]]) |
|
self.neg_pKs_idx = torch.tensor([[3, 4, 6, 18]]) |
|
self.neutral_pKs_idx = torch.tensor([[0, 2, 5, 7, 9, 10, 12, 13, 14, 15, 16, 17, 19, 20]]) |
|
|
|
|
|
|
|
|
|
print(f'OPTIMIZING SEQUENCE TO HAVE CHARGE = {self.target_charge}\nAT pH = {self.pH}' ) |
|
|
|
def sum_tensor_indices(self, indices, tensor): |
|
total = 0 |
|
for idx in indices: |
|
i, j = idx[0], idx[1] |
|
total += tensor[i][j] |
|
return total |
|
|
|
def sum_tensor_indices_2(self, indices, tensor): |
|
|
|
j = indices.clone().detach().long().to(self.DEVICE) |
|
|
|
row_sums = tensor[:, j].sum(dim=-1) |
|
|
|
|
|
return row_sums.reshape(-1, 1).clone().detach() |
|
|
|
|
|
def make_table(self, L): |
|
""" |
|
Make table of all (positive, neutral, negative) charges -> (i, j, k) |
|
such that: |
|
i + j + k = L |
|
(1 * i) + (0 * j) + (-1 * k) = target_charge |
|
|
|
Arguments: |
|
L: int |
|
- length of sequence, defined as seq.shape[0] |
|
target_charge : float |
|
- Target charge for the sequence to be guided towards |
|
|
|
Returns: |
|
table: N x 3 tensor |
|
- All combinations of i, j, k such that the above conditions are satisfied |
|
""" |
|
|
|
table = [] |
|
for i in range(L): |
|
for j in range(L): |
|
for k in range(L): |
|
|
|
|
|
if i+j+k == L and i-k == self.target_charge and i != 0 and j != 0 and k != 0: |
|
table.append([i,j,k]) |
|
return torch.tensor(np.array(table)) |
|
|
|
|
|
def classify_resis(self, seq): |
|
""" |
|
Classify each position in seq as either positive, neutral, or negative. |
|
Classification = max( [sum(positive residue logits), sum(neutral residue logits), sum(negative residue logits)] ) |
|
|
|
Arguments: |
|
seq: L x 21 tensor |
|
- sequence logits from the model |
|
|
|
Returns: |
|
charges: tensor |
|
- 1 x 3 tensor counting total # of each charge type in the input sequence |
|
- charges[0] = # positive residues |
|
- charges[1] = # neutral residues |
|
- charges[2] = # negative residues |
|
charge_classification: tensor |
|
- L x 1 tensor of each position's classification. 1 is positive, 0 is neutral, -1 is negative |
|
""" |
|
L = seq.shape[0] |
|
|
|
soft_seq = torch.softmax(seq.clone(),dim=-1).requires_grad_(requires_grad=True).to(self.DEVICE) |
|
|
|
|
|
|
|
sum_cterm_charged = self.sum_tensor_indices(self.cterm_charged_idx, soft_seq).item() |
|
|
|
|
|
sum_cterm_neutral = self.sum_tensor_indices(self.cterm_neutral_idx, soft_seq).item() |
|
|
|
|
|
|
|
cterm_max = max(sum_cterm_charged, sum_cterm_neutral) |
|
|
|
if cterm_max == sum_cterm_charged: |
|
cterm_class = torch.tensor([[-1]]).to(self.DEVICE) |
|
else: |
|
cterm_class = torch.tensor([[0]]).to(self.DEVICE) |
|
|
|
cterm_df = torch.tensor([[0, sum_cterm_neutral, sum_cterm_charged, cterm_max, cterm_class]]).to(self.DEVICE) |
|
|
|
|
|
sum_pos = self.sum_tensor_indices_2(self.pos_pKs_idx, soft_seq[1:L-1, ...]).to(self.DEVICE) |
|
|
|
sum_neg = self.sum_tensor_indices_2(self.neg_pKs_idx, soft_seq[1:L-1, ...]).to(self.DEVICE) |
|
|
|
sum_neutral = self.sum_tensor_indices_2(self.neutral_pKs_idx, soft_seq[1:L-1, ...]).to(self.DEVICE) |
|
|
|
|
|
|
|
middle_max, _ = torch.max(torch.stack((sum_pos, sum_neg, sum_neutral), dim=-1), dim=-1) |
|
middle_max = middle_max.to(self.DEVICE) |
|
|
|
middle_class = torch.zeros((L - 2, 1), dtype=torch.long).to(self.DEVICE) |
|
|
|
middle_class[sum_neg == middle_max] = -1 |
|
middle_class[sum_neutral == middle_max] = 0 |
|
middle_class[sum_pos == middle_max] = 1 |
|
|
|
|
|
middle_df = pd.DataFrame((torch.cat((sum_pos, sum_neutral, sum_neg, middle_max, middle_class), dim=-1)).detach().cpu().numpy()) |
|
middle_df.rename(columns={0: 'sum_pos', |
|
1: 'sum_neutral', 2: 'sum_neg', 3: 'middle_max', 4: 'middle_classified'}, |
|
inplace=True, errors='raise') |
|
|
|
|
|
sum_nterm_charged = self.sum_tensor_indices(self.nterm_charged_idx, soft_seq).to(self.DEVICE) |
|
|
|
sum_nterm_neutral = self.sum_tensor_indices(self.nterm_neutral_idx, soft_seq).to(self.DEVICE) |
|
|
|
|
|
|
|
nterm_max = max(sum_nterm_charged, sum_nterm_neutral) |
|
if nterm_max == sum_nterm_charged: |
|
nterm_class = torch.tensor([[-1]]).to(self.DEVICE) |
|
else: |
|
nterm_class = torch.tensor([[0]]).to(self.DEVICE) |
|
nterm_df = torch.tensor([[sum_nterm_charged, sum_nterm_neutral, 0, nterm_max, nterm_class]]).to(self.DEVICE) |
|
|
|
|
|
middle_df_2 = (torch.cat((sum_pos, sum_neutral, sum_neg, middle_max, middle_class), dim=-1)).to(self.DEVICE) |
|
|
|
full_tens_np = torch.cat((cterm_df, middle_df_2, nterm_df), dim = 0).detach().cpu().numpy() |
|
classification_df = pd.DataFrame(full_tens_np) |
|
classification_df.rename(columns={0: 'sum_pos', |
|
1: 'sum_neutral', 2: 'sum_neg', 3: 'max', 4: 'classification'}, |
|
inplace=True, errors='raise') |
|
|
|
charge_classification = torch.cat((cterm_class, middle_class, nterm_class), dim = 0).to(self.DEVICE) |
|
charges = [torch.sum(charge_classification == 1).item(), torch.sum(charge_classification == 0).item(), torch.sum(charge_classification == -1).item()] |
|
|
|
|
|
|
|
return torch.tensor(charges), classification_df |
|
|
|
def get_target_charge_ratios(self, table, charges): |
|
""" |
|
Find closest distance between x, y, z in table and i, j, k in charges |
|
|
|
Arguments: |
|
table: N x 3 tensor of all combinations of positive, neutral, and negative charges that obey the conditions in make_table |
|
charges: 1 x 3 tensor |
|
- 1 x 3 tensor counting total # of each charge type in the input sequence |
|
- charges[0] = # positive residues |
|
- charges[1] = # neutral residues |
|
- charges[2] = # negative residues |
|
|
|
Returns: |
|
target_charge_tensor: tensor |
|
- 1 x 3 tensor of closest row in table that matches charges of input sequence |
|
""" |
|
|
|
diff = table - charges |
|
|
|
|
|
sq_distance = torch.sum(diff ** 2, dim=-1) |
|
|
|
|
|
min_idx = torch.argmin(sq_distance) |
|
|
|
|
|
target_charge_tensor = torch.sqrt(sq_distance[min_idx]), table[min_idx] |
|
|
|
return target_charge_tensor[1] |
|
|
|
def draft_resis(self, classification_df, target_charge_tensor): |
|
""" |
|
Based on target_charge_tensor, draft the top (i, j, k) positive, neutral, and negative positions from |
|
charge_classification and return the idealized guided_charge_classification. |
|
guided_charge_classification will determine whether the gradients should be positive or negative |
|
|
|
Draft pick algorithm for determining gradient guided_charge_classification: |
|
1) Define how many positive, negative, and neutral charges are needed |
|
2) Current charge being drafted = sign of target charge, otherwise opposite charge |
|
3) From the classification_df of the currently sampled sequence, choose the position with the highest probability of being current_charge |
|
4) Make that residue +1, 0, or -1 in guided_charge_classification to dictate the sign of gradients |
|
5) Keep drafting that residue charge until it is used up, then move to the next type |
|
|
|
Arguments: |
|
classification_df: tensor |
|
- L x 1 tensor of each position's classification. 1 is positive, 0 is neutral, -1 is negative |
|
target_charge_tensor: tensor |
|
- 1 x 3 tensor of closest row in table that matches charges of input sequence |
|
|
|
Returns: |
|
guided_charge_classification: L x 1 tensor |
|
- L x 1 tensor populated with 1 = positive, 0 = neutral, -1 = negative |
|
- in get_gradients, multiply the gradients by guided_charge_classification to determine which direction |
|
the gradients should guide toward based on the current sequence distribution and the target charge |
|
""" |
|
charge_dict = {'pos': 0, 'neutral': 0, 'neg': 0} |
|
|
|
charge_dict['pos'] = target_charge_tensor[0].detach().clone() |
|
charge_dict['neutral'] = target_charge_tensor[1].detach().clone() |
|
charge_dict['neg'] = target_charge_tensor[2].detach().clone() |
|
|
|
if self.target_charge > 0: |
|
start_charge = 'pos' |
|
elif self.target_charge < 0: |
|
start_charge = 'neg' |
|
else: |
|
start_charge = 'neutral' |
|
|
|
|
|
guided_charge_classification = torch.zeros((classification_df.shape[0], 1)) |
|
|
|
|
|
draft_charge = start_charge |
|
while charge_dict[draft_charge] > 0: |
|
|
|
max_residue_idx = classification_df.loc[:, ['sum_' + draft_charge]].idxmax()[0] |
|
|
|
|
|
|
|
|
|
if draft_charge == 'pos': |
|
guided_charge_classification[max_residue_idx] = 1 |
|
elif draft_charge == 'neg': |
|
guided_charge_classification[max_residue_idx] = -1 |
|
else: |
|
guided_charge_classification[max_residue_idx] = 0 |
|
|
|
classification_df = classification_df.drop(max_residue_idx) |
|
|
|
|
|
charge_dict[draft_charge] -= 1 |
|
|
|
|
|
if charge_dict[draft_charge] == 0: |
|
if draft_charge == start_charge: |
|
draft_charge = 'neg' if start_charge == 'pos' else 'pos' |
|
elif draft_charge == 'neg': |
|
draft_charge = 'pos' |
|
elif draft_charge == 'pos': |
|
draft_charge = 'neg' |
|
else: |
|
draft_charge = 'neutral' |
|
|
|
return guided_charge_classification.requires_grad_() |
|
|
|
def get_gradients(self, seq): |
|
""" |
|
Calculate gradients with respect to SEQUENCE CHARGE at pH. |
|
Uses a MSE loss. |
|
|
|
Arguments |
|
--------- |
|
seq : tensor |
|
L X 21 logits after saving seq_out from xt |
|
|
|
Returns |
|
------- |
|
gradients : list of tensors |
|
gradients of soft_seq with respect to loss on partial_charge |
|
""" |
|
|
|
|
|
soft_seq = torch.softmax(seq,dim=-1).requires_grad_(requires_grad=True).to(DEVICE) |
|
|
|
|
|
pos_charge = torch.where(self.pos_pKs_matrix != 0, ((1) / (((10) ** ((self.pH) - self.pos_pKs_matrix)) + (1.0))), (0.0)) |
|
neg_charge = torch.where(self.neg_pKs_matrix != 0, ((1) / (((10) ** (self.neg_pKs_matrix - (self.pH))) + (1.0))), (0.0)) |
|
|
|
|
|
|
|
if self.loss_type == 'simple': |
|
|
|
partial_charge = torch.sum((soft_seq*(pos_charge - neg_charge)).requires_grad_(requires_grad=True)) |
|
|
|
print(f'CURRENT PARTIAL CHARGE: {partial_charge.item()}') |
|
|
|
loss = ((partial_charge - self.target_charge)**2)**0.5 |
|
|
|
|
|
loss.backward() |
|
|
|
self.gradients = soft_seq.grad |
|
|
|
|
|
|
|
|
|
|
|
elif self.loss_type == 'simple2': |
|
|
|
|
|
|
|
print(f'CURRENT PARTIAL CHARGE: {partial_charge.item()}') |
|
|
|
loss = (((torch.sum((soft_seq*(pos_charge - neg_charge)).requires_grad_(requires_grad=True))) |
|
- self.target_charge)**2)**0.5 |
|
|
|
|
|
loss.backward() |
|
|
|
self.gradients = soft_seq.grad |
|
|
|
|
|
|
|
|
|
|
|
elif self.loss_type == 'complex': |
|
|
|
table = self.make_table(seq.shape[0]) |
|
charges, classification_df = self.classify_resis(seq) |
|
target_charge_tensor = self.get_target_charge_ratios(table, charges) |
|
guided_charge_classification = self.draft_resis(classification_df, target_charge_tensor) |
|
|
|
|
|
soft_partial_charge = (soft_seq*(pos_charge - neg_charge)) |
|
|
|
|
|
partial_charge = torch.sum(soft_partial_charge, dim=-1).requires_grad_() |
|
|
|
|
|
print(f'CURRENT PARTIAL CHARGE: {partial_charge.sum().item()}') |
|
|
|
|
|
|
|
|
|
loss = torch.mean(((guided_charge_classification.to(self.DEVICE) - partial_charge.unsqueeze(1).to(self.DEVICE))**2)**0.5) |
|
|
|
|
|
|
|
loss.backward() |
|
|
|
self.gradients = soft_seq.grad |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return -self.gradients*self.potential_scale |
|
|
|
class PSSMbias(Potential): |
|
|
|
def __init__(self, args, features, potential_scale, DEVICE): |
|
|
|
self.features = features |
|
self.args = args |
|
self.potential_scale = potential_scale |
|
self.DEVICE = DEVICE |
|
self.PSSM = np.loadtxt(args['PSSM'], delimiter=",", dtype=float) |
|
self.PSSM = torch.from_numpy(self.PSSM).to(self.DEVICE) |
|
|
|
def get_gradients(self, seq): |
|
print(seq.shape) |
|
|
|
|
|
return self.PSSM*self.potential_scale |
|
|
|
|
|
POTENTIALS = {'aa_bias':AACompositionalBias, 'charge':ChargeBias, 'hydrophobic':HydrophobicBias, 'PSSM':PSSMbias} |
|
|