File size: 10,750 Bytes
aac0a08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
# veomni/models/transformers/qwen2/generation_utils.py
import warnings
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.distributions as dists
from torch.nn import functional as F
from transformers import __version__
from transformers.generation.configuration_utils import GenerationConfig
from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging
logger = logging.get_logger(__name__)
def top_p_logits(logits, top_p=None):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits, top_k=None):
if top_k is None or top_k == 0:
return logits
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits.float(), dim=-1)
if temperature > 0:
x0 = dists.Categorical(probs=probs).sample()
else:
_, x0 = probs.max(dim=-1)
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
if margin_confidence:
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
top1_probs = sorted_probs[..., 0]
top2_probs = sorted_probs[..., 1]
confidence = top1_probs - top2_probs
elif neg_entropy:
log_probs = torch.log(probs.clamp(min=1e-10))
confidence = (probs * log_probs).sum(dim=-1)
return confidence, x0
@dataclass
class MDMModelOutput(ModelOutput):
sequences: torch.LongTensor = None
history: Optional[Tuple[torch.FloatTensor]] = None
class MDMGenerationConfig(GenerationConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.temperature: float = kwargs.pop("temperature", 0.0)
self.top_p: Optional[float] = kwargs.pop("top_p", None)
self.top_k: Optional[int] = kwargs.pop("top_k", None)
self.eps: float = kwargs.pop("eps", 1e-3)
self.steps: int = kwargs.pop("steps", 512)
self.alg: str = kwargs.pop("alg", 'entropy')
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", 0.0)
self.output_history: bool = kwargs.pop("output_history", False)
self.mask_token_id = kwargs.pop("mask_token_id", None)
class MDMGenerationMixin:
"""
Mixin class for Masked Diffusion Model generation, adapted from the Dream model's generation utils.
"""
@staticmethod
def _expand_inputs_for_generation(
expand_size: int = 1,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
if expand_size == 1:
return input_ids, attention_mask
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
if attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
return input_ids, attention_mask
def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs
) -> MDMGenerationConfig:
if generation_config is None:
generation_config = self.generation_config
# Use MDMGenerationConfig as the target class
if not isinstance(generation_config, MDMGenerationConfig):
generation_config = MDMGenerationConfig.from_dict(generation_config.to_dict())
# Update with kwargs
generation_config.update(**kwargs)
return generation_config
@torch.no_grad()
def diffusion_generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[MDMGenerationConfig] = None,
**kwargs,
) -> Union[MDMModelOutput, torch.LongTensor]:
# 1. Prepare generation config
generation_config = self._prepare_generation_config(generation_config, **kwargs)
# 2. Prepare inputs
input_ids = inputs
attention_mask = kwargs.get("attention_mask", None)
if input_ids is None:
raise ValueError("`inputs` must be provided for diffusion generation.")
if generation_config.max_new_tokens is not None:
generation_config.max_length = input_ids.shape[-1] + generation_config.max_new_tokens
# 3. Expand inputs for multi-sequence generation
input_ids, attention_mask = self._expand_inputs_for_generation(
expand_size=generation_config.num_return_sequences,
input_ids=input_ids,
attention_mask=attention_mask
)
# 4. Run the sampling loop
return self._sample(
input_ids,
attention_mask=attention_mask,
generation_config=generation_config
)
def _sample(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor],
generation_config: MDMGenerationConfig
) -> Union[MDMModelOutput, torch.LongTensor]:
# Extract params from config
max_length = generation_config.max_length
mask_token_id = generation_config.mask_token_id
if mask_token_id is None:
raise ValueError("`mask_token_id` must be set in the generation config.")
steps = generation_config.steps
eps = generation_config.eps
alg = generation_config.alg
alg_temp = generation_config.alg_temp
temperature = generation_config.temperature
top_p = generation_config.top_p
top_k = generation_config.top_k
histories = [] if generation_config.output_history else None
# Pad input_ids to max_length with mask tokens
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
# The model expects a bidirectional mask, so we just use the presence of pad_token_id
# for the attention mask during generation.
gen_attention_mask = (x != self.config.pad_token_id).long() if self.config.pad_token_id is not None else None
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
for i in range(steps):
mask_index = (x == mask_token_id)
if not mask_index.any(): # Stop if no tokens are masked
break
# is_causal=False is crucial for bidirectional attention
outputs = self(input_ids=x, attention_mask=gen_attention_mask, is_causal=False)
logits = outputs.logits
# CRITICAL: Shift logits to predict the next token, aligning with training
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
mask_logits = logits[mask_index]
t = timesteps[i]
s = timesteps[i + 1]
if alg == 'origin':
p_transfer = 1 - s / t if i < steps - 1 else 1
x0 = torch.full_like(x[mask_index], fill_value=mask_token_id, device=self.device, dtype=torch.long)
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
_, sampled_tokens = sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
x0[transfer_index_t_s] = sampled_tokens
x[mask_index] = x0
else:
# Confidence-based sampling (maskgit, entropy, etc.)
confidence_alg_map = {'maskgit_plus': False, 'topk_margin': True, 'entropy': True}
is_margin_conf = confidence_alg_map.get(alg, False)
is_neg_entropy = alg == 'entropy'
confidence, x0 = sample_tokens(mask_logits, temperature, top_p, top_k, margin_confidence=is_margin_conf, neg_entropy=is_neg_entropy)
num_masked = mask_index.sum(dim=-1, keepdim=True)
gamma = 1 - s / t
num_to_unmask = (num_masked * gamma).long()
# Place confidence scores back into a full tensor to find top-k across the sequence
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=confidence.dtype)
full_confidence[mask_index] = confidence
if (alg_temp is not None and alg_temp > 0):
# Temperature-based sampling of which tokens to unmask
unmask_probs = F.softmax(full_confidence / alg_temp, dim=-1)
unmask_indices = torch.multinomial(unmask_probs, num_samples=num_to_unmask.max(), replacement=False)
else:
# Top-k confidence sampling
_, unmask_indices = torch.topk(full_confidence, k=num_to_unmask.max(), dim=-1)
# Create a mask for the tokens we are going to unmask
rows = torch.arange(x.size(0), device=x.device).unsqueeze(1)
unmask_selection_mask = torch.zeros_like(x, dtype=torch.bool)
unmask_selection_mask[rows, unmask_indices] = True
# Filter indices based on per-row `num_to_unmask`
unmask_selection_mask = unmask_selection_mask & (torch.cumsum(unmask_selection_mask.long(), dim=-1) <= num_to_unmask)
# Place the newly generated tokens (x0) into a full tensor
x_unmasked_proposals = torch.full_like(x, fill_value=mask_token_id)
x_unmasked_proposals[mask_index] = x0
# Update the main tensor `x` with the unmasked tokens
x[unmask_selection_mask] = x_unmasked_proposals[unmask_selection_mask]
if histories is not None:
histories.append(x.clone())
if generation_config.return_dict_in_generate:
return MDMModelOutput(sequences=x, history=histories)
else:
return x |