PDeepPP_neuro / DataProcessor_pdeeppp.py
fondress's picture
Upload DataProcessor_pdeeppp.py with huggingface_hub
2b1de9f verified
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import BatchEncoding
class PDeepPPProcessor(ProcessorMixin):
def __init__(self, pad_char="X", target_length=33):
self.pad_char = pad_char
self.target_length = target_length
def pad_sequence(self, seq):
"""确保序列长度为 target_length,不足的部分用 pad_char 在两侧均匀填充"""
if len(seq) < self.target_length:
total_padding = self.target_length - len(seq)
left_padding = total_padding // 2
right_padding = total_padding - left_padding
seq = self.pad_char * left_padding + seq + self.pad_char * right_padding
return seq[:self.target_length]
def extract_ptm_sequences(self, sequences):
"""处理 PTM 数据,确保目标氨基酸(S、T、Y)位于序列中心"""
ptm_data = []
for seq in sequences:
for i in range(len(seq)):
if seq[i] in {'S', 'T', 'Y'}: # 仅提取 S、T、Y 作为中心的片段
start = max(0, i - self.target_length // 2)
end = min(len(seq), start + self.target_length)
padded_seq = self.pad_sequence(seq[start:end])
ptm_data.append(padded_seq)
return ptm_data
def extract_bps_sequences(self, sequences, overlapping=True, step_size=5):
"""处理生物活性数据(BPS),关注整个序列,可重叠"""
bioactive_data = []
for seq in sequences:
if len(seq) < self.target_length:
# 如果序列长度不足,直接填充到 target_length
padded_seq = self.pad_sequence(seq)
bioactive_data.append(padded_seq)
else:
# 如果序列长度足够,按照滑动窗口提取片段
for i in range(0, len(seq) - self.target_length + 1,
step_size if overlapping else self.target_length):
bioactive_data.append(self.pad_sequence(seq[i:i + self.target_length]))
return bioactive_data
def __call__(
self,
sequences,
mode, # 去除默认值,强制外部传入
overlapping=True,
step_size=5,
**kwargs
):
"""
预处理蛋白质序列,仅处理数据到指定长度。
Args:
sequences: 序列列表或单个序列字符串。
mode: 选择处理模式,必须从外部传入,"PTM" 或 "BPS"。
overlapping: BPS 模式下是否使用重叠窗口。
step_size: BPS 模式下的步长。
"""
# 确保 sequences 是列表
if isinstance(sequences, str):
sequences = [sequences]
# 根据模式提取序列
if mode == "PTM":
processed_sequences = self.extract_ptm_sequences(sequences)
elif mode == "BPS":
processed_sequences = self.extract_bps_sequences(
sequences,
overlapping=overlapping,
step_size=step_size
)
else:
raise ValueError("Invalid mode. Please choose 'PTM' or 'BPS'.")
if len(processed_sequences) == 0:
raise ValueError("No sequences processed. Check input data and processing logic.")
# 创建返回字典,仅包含预处理后的序列
model_inputs = {
"raw_sequences": processed_sequences, # 预处理后的序列
}
return BatchEncoding(data=model_inputs) # 返回处理后的数据