OSUM-EChat / patches /custom_speech_ngram_blocking.py
xlgeng's picture
开始部署
841f290
from transformers.generation.logits_process import LogitsProcessor
import torch
class SpeechOnlyNGramBlockingLogitsProcessor(LogitsProcessor):
def __init__(
self,
speech_token_num,
repeat_times=5,
special_token_repeat_times_dict=None,
window_size=8,
window_repeat=5,
special_token_window_dict=None
):
"""
speech_token_num: int, speech token 的数量(token_id in [0, speech_token_num) 视为 speech token)
repeat_times: int, 普通 speech token 的最大允许连续重复次数
special_token_repeat_times_dict: dict, {token_id: repeat_times},为特殊 speech token 单独指定最大连续重复次数
window_size: int, 默认滑动窗口大小
window_repeat: int, 默认窗口内最大允许出现次数
special_token_window_dict: dict, {token_id: (window_size, window_repeat)},为特殊 token 单独指定窗口参数
"""
self.speech_token_num = speech_token_num
self.repeat_times = repeat_times
self.special_token_repeat_times_dict = special_token_repeat_times_dict or {}
self.speech_phase = False # 你需要在外部控制这个变量
self.window_size = window_size
self.window_repeat = window_repeat
self.special_token_window_dict = special_token_window_dict or {1446: (13, 10)}
def set_phase(self, speech_phase: bool):
self.speech_phase = speech_phase
def __call__(self, input_ids, scores):
if not self.speech_phase:
# text 阶段,什么都不做
return scores
batch_size, seq_len = input_ids.size()
for batch_idx in range(batch_size):
generated = input_ids[batch_idx].tolist()
if seq_len == 0:
continue
last_token = generated[-1]
if last_token >= self.speech_token_num:
continue # 不是 speech token
# 统计最近的 token 连续重复了多少次
repeat_count = 1
for i in range(seq_len-2, -1, -1):
if generated[i] == last_token:
repeat_count += 1
else:
break
# 获取该 token 的最大允许重复次数
max_repeat = self.special_token_repeat_times_dict.get(last_token, self.repeat_times)
if repeat_count >= max_repeat:
scores[batch_idx, last_token] = -float('inf') # 阻止生成
# ====== 滑动窗口内频率抑制 ======
# 对窗口内所有 speech token 检查
window_tokens = set(generated[-max(self.window_size, max([v[0] for v in self.special_token_window_dict.values()], default=0)):])
for token in window_tokens:
if token >= self.speech_token_num:
continue
# 获取该 token 的窗口参数
window_size, window_repeat = self.special_token_window_dict.get(
token, (self.window_size, self.window_repeat)
)
window = generated[-window_size:]
if window.count(token) >= window_repeat:
scores[batch_idx, token] = -float('inf')
# ====== 滑动窗口内频率抑制结束 ======
return scores
class OSUM_chat_LogitsProcessor(LogitsProcessor):
def __init__(self, allowed_tokens, sequence_to_match):
"""
初始化OSUM_chat_LogitsProcessor。
参数:
allowed_tokens (list): 允许出现在当前时间步的token的ID列表
sequence_to_match (list): 用来判断当前时间步允许token的前置序列
"""
self.allowed_tokens = allowed_tokens
self.sequence_to_match = sequence_to_match
self.match_found = False # 添加一个标志,表示是否已经找到匹配的序列
def init_match_found(self):
"""
初始化match_found标志。
"""
self.match_found = False
def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
"""
在每个时间步处理logits,对不符合条件的token设置极小的概率。
参数:
input_ids (torch.Tensor): 当前输入的token ID序列
logits (torch.Tensor): 当前时间步的logits (shape: [batch_size, vocab_size])
返回:
torch.Tensor: 被处理过的logits
"""
# 如果已经匹配过一次,就跳过匹配检测,直接返回logits
# print("recent_tokens:!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") # 打印当前生成的序列
if self.match_found:
return logits
# 获取当前生成的序列的最后几个token(假设生成的长度大于等于序列长度)
sequence_length = len(self.sequence_to_match)
if input_ids.shape[-1] >= sequence_length:
recent_tokens = input_ids[:, -sequence_length:].tolist()
# print("recent_tokens:", recent_tokens) # 打印当前生成的序列
# 检查前面生成的token是否匹配我们需要的序列
if all(recent_tokens[0][i] == self.sequence_to_match[i] for i in range(sequence_length)):
# Create a mask for allowed tokens while preserving original logits
mask = torch.zeros_like(logits, dtype=torch.bool) # Initialize mask as False
mask[:, self.allowed_tokens] = True # Mark allowed tokens as True
# Apply mask: keep original logits for allowed tokens, set others to -inf
logits = torch.where(mask, logits, -float('inf'))
# 设置标志,表示匹配已成功
self.match_found = True
print("match found!!!!!!!!!!!!!!!!!!!!!!!")
return logits