Spaces:
Running
on
Zero
Running
on
Zero
from transformers.generation.logits_process import LogitsProcessor | |
import torch | |
class SpeechOnlyNGramBlockingLogitsProcessor(LogitsProcessor): | |
def __init__( | |
self, | |
speech_token_num, | |
repeat_times=5, | |
special_token_repeat_times_dict=None, | |
window_size=8, | |
window_repeat=5, | |
special_token_window_dict=None | |
): | |
""" | |
speech_token_num: int, speech token 的数量(token_id in [0, speech_token_num) 视为 speech token) | |
repeat_times: int, 普通 speech token 的最大允许连续重复次数 | |
special_token_repeat_times_dict: dict, {token_id: repeat_times},为特殊 speech token 单独指定最大连续重复次数 | |
window_size: int, 默认滑动窗口大小 | |
window_repeat: int, 默认窗口内最大允许出现次数 | |
special_token_window_dict: dict, {token_id: (window_size, window_repeat)},为特殊 token 单独指定窗口参数 | |
""" | |
self.speech_token_num = speech_token_num | |
self.repeat_times = repeat_times | |
self.special_token_repeat_times_dict = special_token_repeat_times_dict or {} | |
self.speech_phase = False # 你需要在外部控制这个变量 | |
self.window_size = window_size | |
self.window_repeat = window_repeat | |
self.special_token_window_dict = special_token_window_dict or {1446: (13, 10)} | |
def set_phase(self, speech_phase: bool): | |
self.speech_phase = speech_phase | |
def __call__(self, input_ids, scores): | |
if not self.speech_phase: | |
# text 阶段,什么都不做 | |
return scores | |
batch_size, seq_len = input_ids.size() | |
for batch_idx in range(batch_size): | |
generated = input_ids[batch_idx].tolist() | |
if seq_len == 0: | |
continue | |
last_token = generated[-1] | |
if last_token >= self.speech_token_num: | |
continue # 不是 speech token | |
# 统计最近的 token 连续重复了多少次 | |
repeat_count = 1 | |
for i in range(seq_len-2, -1, -1): | |
if generated[i] == last_token: | |
repeat_count += 1 | |
else: | |
break | |
# 获取该 token 的最大允许重复次数 | |
max_repeat = self.special_token_repeat_times_dict.get(last_token, self.repeat_times) | |
if repeat_count >= max_repeat: | |
scores[batch_idx, last_token] = -float('inf') # 阻止生成 | |
# ====== 滑动窗口内频率抑制 ====== | |
# 对窗口内所有 speech token 检查 | |
window_tokens = set(generated[-max(self.window_size, max([v[0] for v in self.special_token_window_dict.values()], default=0)):]) | |
for token in window_tokens: | |
if token >= self.speech_token_num: | |
continue | |
# 获取该 token 的窗口参数 | |
window_size, window_repeat = self.special_token_window_dict.get( | |
token, (self.window_size, self.window_repeat) | |
) | |
window = generated[-window_size:] | |
if window.count(token) >= window_repeat: | |
scores[batch_idx, token] = -float('inf') | |
# ====== 滑动窗口内频率抑制结束 ====== | |
return scores | |
class OSUM_chat_LogitsProcessor(LogitsProcessor): | |
def __init__(self, allowed_tokens, sequence_to_match): | |
""" | |
初始化OSUM_chat_LogitsProcessor。 | |
参数: | |
allowed_tokens (list): 允许出现在当前时间步的token的ID列表 | |
sequence_to_match (list): 用来判断当前时间步允许token的前置序列 | |
""" | |
self.allowed_tokens = allowed_tokens | |
self.sequence_to_match = sequence_to_match | |
self.match_found = False # 添加一个标志,表示是否已经找到匹配的序列 | |
def init_match_found(self): | |
""" | |
初始化match_found标志。 | |
""" | |
self.match_found = False | |
def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: | |
""" | |
在每个时间步处理logits,对不符合条件的token设置极小的概率。 | |
参数: | |
input_ids (torch.Tensor): 当前输入的token ID序列 | |
logits (torch.Tensor): 当前时间步的logits (shape: [batch_size, vocab_size]) | |
返回: | |
torch.Tensor: 被处理过的logits | |
""" | |
# 如果已经匹配过一次,就跳过匹配检测,直接返回logits | |
# print("recent_tokens:!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") # 打印当前生成的序列 | |
if self.match_found: | |
return logits | |
# 获取当前生成的序列的最后几个token(假设生成的长度大于等于序列长度) | |
sequence_length = len(self.sequence_to_match) | |
if input_ids.shape[-1] >= sequence_length: | |
recent_tokens = input_ids[:, -sequence_length:].tolist() | |
# print("recent_tokens:", recent_tokens) # 打印当前生成的序列 | |
# 检查前面生成的token是否匹配我们需要的序列 | |
if all(recent_tokens[0][i] == self.sequence_to_match[i] for i in range(sequence_length)): | |
# Create a mask for allowed tokens while preserving original logits | |
mask = torch.zeros_like(logits, dtype=torch.bool) # Initialize mask as False | |
mask[:, self.allowed_tokens] = True # Mark allowed tokens as True | |
# Apply mask: keep original logits for allowed tokens, set others to -inf | |
logits = torch.where(mask, logits, -float('inf')) | |
# 设置标志,表示匹配已成功 | |
self.match_found = True | |
print("match found!!!!!!!!!!!!!!!!!!!!!!!") | |
return logits | |