|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
class PretrainingPDeepPP:
|
|
|
def __init__(self, embedding_dim=1280, target_length=33, esm_ratio=None, device=None):
|
|
|
"""
|
|
|
初始化 PretrainingPDeepPP 类。
|
|
|
|
|
|
Args:
|
|
|
embedding_dim: 嵌入维度大小。
|
|
|
target_length: 目标序列长度。
|
|
|
esm_ratio: ESM 表征与嵌入表示的权重比例(由外部赋值)。
|
|
|
device: 设备信息。
|
|
|
"""
|
|
|
self.embedding_dim = embedding_dim
|
|
|
self.target_length = target_length
|
|
|
self.esm_ratio = esm_ratio
|
|
|
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
def extract_esm_representations(self, sequences, esm_model, batch_converter, batch_size=32):
|
|
|
"""
|
|
|
提取 ESM 表征,并直接返回形状为 (batch_size, target_length, embedding_dim) 的结果。
|
|
|
"""
|
|
|
sequence_representations = []
|
|
|
print("Sequences to process:", sequences)
|
|
|
print("Batch size:", batch_size)
|
|
|
|
|
|
|
|
|
labeled_sequences = [(None, seq) for seq in sequences]
|
|
|
|
|
|
for i in range(0, len(labeled_sequences), batch_size):
|
|
|
batch = labeled_sequences[i:i + batch_size]
|
|
|
if len(batch) == 0:
|
|
|
continue
|
|
|
|
|
|
_, batch_strs, batch_tokens = batch_converter(batch)
|
|
|
batch_tokens = batch_tokens.to(self.device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
|
|
|
|
|
|
|
|
|
for token_repr in results["representations"][33]:
|
|
|
sequence_representations.append(token_repr[:self.target_length])
|
|
|
|
|
|
if len(sequence_representations) == 0:
|
|
|
raise ValueError("No ESM representations were generated. Check your input sequences and batch processing logic.")
|
|
|
|
|
|
|
|
|
return torch.stack(sequence_representations)
|
|
|
|
|
|
def pad_sequences(self, sequences, max_len=None, pad_value=0):
|
|
|
if max_len is None:
|
|
|
max_len = max(len(seq) for seq in sequences)
|
|
|
padded_sequences = torch.zeros((len(sequences), max_len), dtype=torch.long)
|
|
|
for i, seq in enumerate(sequences):
|
|
|
padded_sequences[i, :len(seq)] = torch.tensor(seq)
|
|
|
return padded_sequences
|
|
|
|
|
|
def seq_to_indices(self, seq, vocab_dict):
|
|
|
return [vocab_dict.get(char, 0) for char in seq]
|
|
|
|
|
|
def create_embeddings(self, sequences, vocab, esm_model, esm_alphabet, batch_size=16):
|
|
|
"""
|
|
|
创建嵌入向量,使用类的 esm_ratio 属性动态控制权重分配。
|
|
|
|
|
|
Args:
|
|
|
sequences: 输入序列列表。
|
|
|
vocab: 字符词汇表。
|
|
|
esm_model: 预训练的 ESM 模型。
|
|
|
esm_alphabet: ESM 模型的字母表。
|
|
|
batch_size: 批量大小。
|
|
|
|
|
|
Returns:
|
|
|
结合 ESM 表征与嵌入表示的嵌入结果。
|
|
|
"""
|
|
|
if self.esm_ratio is None:
|
|
|
raise ValueError("esm_ratio is not set. Please assign a value before creating embeddings.")
|
|
|
|
|
|
|
|
|
vocab_dict = {char: i for i, char in enumerate(vocab)}
|
|
|
|
|
|
|
|
|
indices = [self.seq_to_indices(seq, vocab_dict) for seq in sequences]
|
|
|
indices_padded = self.pad_sequences(indices, max_len=self.target_length)
|
|
|
|
|
|
|
|
|
class EmbeddingPretrainedModel(nn.Module):
|
|
|
def __init__(self, vocab_size, embedding_dim, max_len):
|
|
|
super(EmbeddingPretrainedModel, self).__init__()
|
|
|
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
|
|
self.fc = nn.Linear(embedding_dim, embedding_dim)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.embedding(x)
|
|
|
x = self.fc(x)
|
|
|
return x
|
|
|
|
|
|
embedding_model = EmbeddingPretrainedModel(len(vocab), self.embedding_dim, self.target_length).to(self.device)
|
|
|
|
|
|
|
|
|
esm_representations = self.extract_esm_representations(
|
|
|
sequences,
|
|
|
esm_model,
|
|
|
esm_alphabet.get_batch_converter(),
|
|
|
batch_size=batch_size
|
|
|
)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
embedding_output = embedding_model(indices_padded.to(self.device))
|
|
|
|
|
|
|
|
|
combined_representations = self.esm_ratio * esm_representations + (1 - self.esm_ratio) * embedding_output
|
|
|
|
|
|
return combined_representations |