Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,961 Bytes
841f290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from transformers.generation.logits_process import LogitsProcessor
import torch
class SpeechOnlyNGramBlockingLogitsProcessor(LogitsProcessor):
def __init__(
self,
speech_token_num,
repeat_times=5,
special_token_repeat_times_dict=None,
window_size=8,
window_repeat=5,
special_token_window_dict=None
):
"""
speech_token_num: int, speech token 的数量(token_id in [0, speech_token_num) 视为 speech token)
repeat_times: int, 普通 speech token 的最大允许连续重复次数
special_token_repeat_times_dict: dict, {token_id: repeat_times},为特殊 speech token 单独指定最大连续重复次数
window_size: int, 默认滑动窗口大小
window_repeat: int, 默认窗口内最大允许出现次数
special_token_window_dict: dict, {token_id: (window_size, window_repeat)},为特殊 token 单独指定窗口参数
"""
self.speech_token_num = speech_token_num
self.repeat_times = repeat_times
self.special_token_repeat_times_dict = special_token_repeat_times_dict or {}
self.speech_phase = False # 你需要在外部控制这个变量
self.window_size = window_size
self.window_repeat = window_repeat
self.special_token_window_dict = special_token_window_dict or {1446: (13, 10)}
def set_phase(self, speech_phase: bool):
self.speech_phase = speech_phase
def __call__(self, input_ids, scores):
if not self.speech_phase:
# text 阶段,什么都不做
return scores
batch_size, seq_len = input_ids.size()
for batch_idx in range(batch_size):
generated = input_ids[batch_idx].tolist()
if seq_len == 0:
continue
last_token = generated[-1]
if last_token >= self.speech_token_num:
continue # 不是 speech token
# 统计最近的 token 连续重复了多少次
repeat_count = 1
for i in range(seq_len-2, -1, -1):
if generated[i] == last_token:
repeat_count += 1
else:
break
# 获取该 token 的最大允许重复次数
max_repeat = self.special_token_repeat_times_dict.get(last_token, self.repeat_times)
if repeat_count >= max_repeat:
scores[batch_idx, last_token] = -float('inf') # 阻止生成
# ====== 滑动窗口内频率抑制 ======
# 对窗口内所有 speech token 检查
window_tokens = set(generated[-max(self.window_size, max([v[0] for v in self.special_token_window_dict.values()], default=0)):])
for token in window_tokens:
if token >= self.speech_token_num:
continue
# 获取该 token 的窗口参数
window_size, window_repeat = self.special_token_window_dict.get(
token, (self.window_size, self.window_repeat)
)
window = generated[-window_size:]
if window.count(token) >= window_repeat:
scores[batch_idx, token] = -float('inf')
# ====== 滑动窗口内频率抑制结束 ======
return scores
class OSUM_chat_LogitsProcessor(LogitsProcessor):
def __init__(self, allowed_tokens, sequence_to_match):
"""
初始化OSUM_chat_LogitsProcessor。
参数:
allowed_tokens (list): 允许出现在当前时间步的token的ID列表
sequence_to_match (list): 用来判断当前时间步允许token的前置序列
"""
self.allowed_tokens = allowed_tokens
self.sequence_to_match = sequence_to_match
self.match_found = False # 添加一个标志,表示是否已经找到匹配的序列
def init_match_found(self):
"""
初始化match_found标志。
"""
self.match_found = False
def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
"""
在每个时间步处理logits,对不符合条件的token设置极小的概率。
参数:
input_ids (torch.Tensor): 当前输入的token ID序列
logits (torch.Tensor): 当前时间步的logits (shape: [batch_size, vocab_size])
返回:
torch.Tensor: 被处理过的logits
"""
# 如果已经匹配过一次,就跳过匹配检测,直接返回logits
# print("recent_tokens:!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") # 打印当前生成的序列
if self.match_found:
return logits
# 获取当前生成的序列的最后几个token(假设生成的长度大于等于序列长度)
sequence_length = len(self.sequence_to_match)
if input_ids.shape[-1] >= sequence_length:
recent_tokens = input_ids[:, -sequence_length:].tolist()
# print("recent_tokens:", recent_tokens) # 打印当前生成的序列
# 检查前面生成的token是否匹配我们需要的序列
if all(recent_tokens[0][i] == self.sequence_to_match[i] for i in range(sequence_length)):
# Create a mask for allowed tokens while preserving original logits
mask = torch.zeros_like(logits, dtype=torch.bool) # Initialize mask as False
mask[:, self.allowed_tokens] = True # Mark allowed tokens as True
# Apply mask: keep original logits for allowed tokens, set others to -inf
logits = torch.where(mask, logits, -float('inf'))
# 设置标志,表示匹配已成功
self.match_found = True
print("match found!!!!!!!!!!!!!!!!!!!!!!!")
return logits
|