open-dcoder-0.5B / generation_utils.py
fredzzp's picture
Initial model upload with custom code
aac0a08 verified
# veomni/models/transformers/qwen2/generation_utils.py
import warnings
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.distributions as dists
from torch.nn import functional as F
from transformers import __version__
from transformers.generation.configuration_utils import GenerationConfig
from transformers.utils import ModelOutput, is_torchdynamo_compiling, logging
logger = logging.get_logger(__name__)
def top_p_logits(logits, top_p=None):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits, top_k=None):
if top_k is None or top_k == 0:
return logits
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits.float(), dim=-1)
if temperature > 0:
x0 = dists.Categorical(probs=probs).sample()
else:
_, x0 = probs.max(dim=-1)
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
if margin_confidence:
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
top1_probs = sorted_probs[..., 0]
top2_probs = sorted_probs[..., 1]
confidence = top1_probs - top2_probs
elif neg_entropy:
log_probs = torch.log(probs.clamp(min=1e-10))
confidence = (probs * log_probs).sum(dim=-1)
return confidence, x0
@dataclass
class MDMModelOutput(ModelOutput):
sequences: torch.LongTensor = None
history: Optional[Tuple[torch.FloatTensor]] = None
class MDMGenerationConfig(GenerationConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.temperature: float = kwargs.pop("temperature", 0.0)
self.top_p: Optional[float] = kwargs.pop("top_p", None)
self.top_k: Optional[int] = kwargs.pop("top_k", None)
self.eps: float = kwargs.pop("eps", 1e-3)
self.steps: int = kwargs.pop("steps", 512)
self.alg: str = kwargs.pop("alg", 'entropy')
self.alg_temp: Optional[float] = kwargs.pop("alg_temp", 0.0)
self.output_history: bool = kwargs.pop("output_history", False)
self.mask_token_id = kwargs.pop("mask_token_id", None)
class MDMGenerationMixin:
"""
Mixin class for Masked Diffusion Model generation, adapted from the Dream model's generation utils.
"""
@staticmethod
def _expand_inputs_for_generation(
expand_size: int = 1,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
if expand_size == 1:
return input_ids, attention_mask
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
if attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
return input_ids, attention_mask
def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs
) -> MDMGenerationConfig:
if generation_config is None:
generation_config = self.generation_config
# Use MDMGenerationConfig as the target class
if not isinstance(generation_config, MDMGenerationConfig):
generation_config = MDMGenerationConfig.from_dict(generation_config.to_dict())
# Update with kwargs
generation_config.update(**kwargs)
return generation_config
@torch.no_grad()
def diffusion_generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[MDMGenerationConfig] = None,
**kwargs,
) -> Union[MDMModelOutput, torch.LongTensor]:
# 1. Prepare generation config
generation_config = self._prepare_generation_config(generation_config, **kwargs)
# 2. Prepare inputs
input_ids = inputs
attention_mask = kwargs.get("attention_mask", None)
if input_ids is None:
raise ValueError("`inputs` must be provided for diffusion generation.")
if generation_config.max_new_tokens is not None:
generation_config.max_length = input_ids.shape[-1] + generation_config.max_new_tokens
# 3. Expand inputs for multi-sequence generation
input_ids, attention_mask = self._expand_inputs_for_generation(
expand_size=generation_config.num_return_sequences,
input_ids=input_ids,
attention_mask=attention_mask
)
# 4. Run the sampling loop
return self._sample(
input_ids,
attention_mask=attention_mask,
generation_config=generation_config
)
def _sample(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor],
generation_config: MDMGenerationConfig
) -> Union[MDMModelOutput, torch.LongTensor]:
# Extract params from config
max_length = generation_config.max_length
mask_token_id = generation_config.mask_token_id
if mask_token_id is None:
raise ValueError("`mask_token_id` must be set in the generation config.")
steps = generation_config.steps
eps = generation_config.eps
alg = generation_config.alg
alg_temp = generation_config.alg_temp
temperature = generation_config.temperature
top_p = generation_config.top_p
top_k = generation_config.top_k
histories = [] if generation_config.output_history else None
# Pad input_ids to max_length with mask tokens
x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
# The model expects a bidirectional mask, so we just use the presence of pad_token_id
# for the attention mask during generation.
gen_attention_mask = (x != self.config.pad_token_id).long() if self.config.pad_token_id is not None else None
timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
for i in range(steps):
mask_index = (x == mask_token_id)
if not mask_index.any(): # Stop if no tokens are masked
break
# is_causal=False is crucial for bidirectional attention
outputs = self(input_ids=x, attention_mask=gen_attention_mask, is_causal=False)
logits = outputs.logits
# CRITICAL: Shift logits to predict the next token, aligning with training
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
mask_logits = logits[mask_index]
t = timesteps[i]
s = timesteps[i + 1]
if alg == 'origin':
p_transfer = 1 - s / t if i < steps - 1 else 1
x0 = torch.full_like(x[mask_index], fill_value=mask_token_id, device=self.device, dtype=torch.long)
transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
_, sampled_tokens = sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
x0[transfer_index_t_s] = sampled_tokens
x[mask_index] = x0
else:
# Confidence-based sampling (maskgit, entropy, etc.)
confidence_alg_map = {'maskgit_plus': False, 'topk_margin': True, 'entropy': True}
is_margin_conf = confidence_alg_map.get(alg, False)
is_neg_entropy = alg == 'entropy'
confidence, x0 = sample_tokens(mask_logits, temperature, top_p, top_k, margin_confidence=is_margin_conf, neg_entropy=is_neg_entropy)
num_masked = mask_index.sum(dim=-1, keepdim=True)
gamma = 1 - s / t
num_to_unmask = (num_masked * gamma).long()
# Place confidence scores back into a full tensor to find top-k across the sequence
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=confidence.dtype)
full_confidence[mask_index] = confidence
if (alg_temp is not None and alg_temp > 0):
# Temperature-based sampling of which tokens to unmask
unmask_probs = F.softmax(full_confidence / alg_temp, dim=-1)
unmask_indices = torch.multinomial(unmask_probs, num_samples=num_to_unmask.max(), replacement=False)
else:
# Top-k confidence sampling
_, unmask_indices = torch.topk(full_confidence, k=num_to_unmask.max(), dim=-1)
# Create a mask for the tokens we are going to unmask
rows = torch.arange(x.size(0), device=x.device).unsqueeze(1)
unmask_selection_mask = torch.zeros_like(x, dtype=torch.bool)
unmask_selection_mask[rows, unmask_indices] = True
# Filter indices based on per-row `num_to_unmask`
unmask_selection_mask = unmask_selection_mask & (torch.cumsum(unmask_selection_mask.long(), dim=-1) <= num_to_unmask)
# Place the newly generated tokens (x0) into a full tensor
x_unmasked_proposals = torch.full_like(x, fill_value=mask_token_id)
x_unmasked_proposals[mask_index] = x0
# Update the main tensor `x` with the unmasked tokens
x[unmask_selection_mask] = x_unmasked_proposals[unmask_selection_mask]
if histories is not None:
histories.append(x.clone())
if generation_config.return_dict_in_generate:
return MDMModelOutput(sequences=x, history=histories)
else:
return x