RND1-Base-0910 / sampling.py
athms's picture
Upload folder using huggingface_hub
3e7a3bf verified
raw
history blame
10.9 kB
"""
RND1 sampling module for masked diffusion generation.
This module implements entropy-based token selection for iterative denoising
in diffusion language models. Supports both greedy and stochastic sampling
with optional prefix/suffix constraints and infilling.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union
def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
"""
Apply top-k filtering to logits: with non-top-k values set to -inf
"""
top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
filtered_logits = torch.full_like(logits, float('-inf'))
filtered_logits.scatter_(-1, top_k_indices, top_k_values)
return filtered_logits
def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
"""
Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
"""
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 0] = False # Keep at least one token
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
return logits.masked_fill(indices_to_remove, float('-inf'))
@torch.no_grad()
def diffusion_sample(
model: nn.Module,
seq_len: int = 256,
num_steps: int = 256,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: float = 1.0,
greedy: bool = True,
mask_token_id: int = 151669,
prefix_ids: Optional[torch.LongTensor] = None,
suffix_ids: Optional[torch.LongTensor] = None,
infill_length: Optional[int] = None,
eos_token_id: int = 151645,
pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
generator: Optional[torch.Generator] = None,
visualizer: Optional['TerminalVisualizer'] = None,
) -> torch.LongTensor:
"""
Perform masked diffusion sampling with entropy-based token selection.
Args:
model: The RND1 language model
seq_len: Target sequence length
num_steps: Number of denoising steps
top_k: Optional top-k filtering for sampling (None = no filtering)
top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering)
When both top_k and top_p are set, top_k is applied first, then top_p
temperature: Temperature for sampling (higher = more random, lower = more deterministic)
Values close to 0 are clamped to 1e-8 to avoid division by zero
greedy: Whether to use greedy sampling (True) or stochastic (False)
mask_token_id: Token ID for masked positions (default: 151669)
prefix_ids: Optional prefix token IDs to preserve
suffix_ids: Optional suffix token IDs to preserve
infill_length: Length of infill region between prefix/suffix
eos_token_id: End of sequence token ID (default: 151645)
pad_token_id: Padding token ID (default: None, uses 0 if needed)
bos_token_id: Beginning of sequence token ID (default: None)
device: Device for computation (None = infer from model)
generator: Optional torch generator for reproducible sampling
visualizer: Optional TerminalVisualizer for live visualization
Returns:
Generated token IDs as LongTensor
"""
model.eval()
if device is None:
device = next(model.parameters()).device
else:
device = torch.device(device)
dtype = next(model.parameters()).dtype
if pad_token_id is None:
pad_token_id = 0
# Build initial masked sequence
# When prefix_ids is provided, we create a sequence of length seq_len where:
# - The prefix occupies the first pre_len positions
# - The remaining (seq_len - pre_len) positions are filled with mask tokens to be generated
if prefix_ids is not None or suffix_ids is not None:
if prefix_ids is not None:
prefix_ids = prefix_ids.to(device) if isinstance(prefix_ids, torch.Tensor) else torch.tensor(prefix_ids, device=device)
pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0
else:
pre_len = 0
if suffix_ids is not None:
suffix_ids = suffix_ids.to(device) if isinstance(suffix_ids, torch.Tensor) else torch.tensor(suffix_ids, device=device)
suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0
else:
suf_len = 0
reserved = (1 if bos_token_id is not None else 0) + (1 if eos_token_id is not None else 0)
used = pre_len + suf_len + reserved
if used > seq_len:
raise ValueError(
f"Combined length of prefix ({pre_len}), suffix ({suf_len}), "
f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). "
f"Please increase seq_len or reduce input lengths."
)
elif used == seq_len:
raise ValueError(
f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) "
f"+ special tokens ({reserved}) = seq_len ({seq_len}). "
f"Need at least 1 position for generation."
)
infill_length = min(infill_length or (seq_len - used), seq_len - used)
x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device)
pos = 0
if bos_token_id is not None:
x[0, pos] = bos_token_id; pos += 1
if pre_len > 0:
x[0, pos:pos+pre_len] = prefix_ids.flatten()[:pre_len]; pos += pre_len
fill_start, fill_end = pos, pos + infill_length
x[0, fill_start:fill_end] = mask_token_id
pos = fill_end
if suf_len > 0:
x[0, pos:pos+suf_len] = suffix_ids.flatten()[:suf_len]; pos += suf_len
if eos_token_id is not None and pos < seq_len:
if isinstance(eos_token_id, (list, tuple)):
x[0, pos] = eos_token_id[0]
else:
x[0, pos] = eos_token_id
init_maskable = torch.zeros_like(x, dtype=torch.bool)
init_maskable[0, fill_start:fill_end] = True
else:
x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device)
if bos_token_id is not None:
x[0, 0] = bos_token_id
if eos_token_id is not None:
# If eos_token_id is a list, use the first one
if isinstance(eos_token_id, (list, tuple)):
x[0, -1] = eos_token_id[0]
else:
x[0, -1] = eos_token_id
init_maskable = x.eq(mask_token_id)
if bos_token_id is not None:
init_maskable[:, 0] = False
if eos_token_id is not None:
# Handle both single token and list of tokens
if isinstance(eos_token_id, (list, tuple)):
for eos_id in eos_token_id:
init_maskable &= x.ne(eos_id)
else:
init_maskable &= x.ne(eos_token_id)
init_maskable &= x.ne(pad_token_id)
maskable = init_maskable.clone()
xt = x.clone()
if visualizer:
visualizer.start_visualization(xt, maskable, num_steps)
def forward_scores(tokens):
"""Compute predictions and entropy scores for next tokens."""
# Try with input_ids parameter first (standard HF models)
try:
model_output = model(input_ids=tokens)
except TypeError:
# Fall back to positional argument
model_output = model(tokens)
safe_temperature = max(temperature, 1e-8) # Prevent division by zero
logits = model_output.logits / safe_temperature
# Note: When both top_k and top_p are provided, they are applied sequentially:
# First top_k filters to k tokens, then top_p filters from those k tokens
if top_k is not None and top_k > 0:
logits = apply_top_k_filtering(logits, top_k)
if top_p is not None and 0 < top_p < 1.0:
logits = apply_top_p_filtering(logits, top_p)
logp = torch.log_softmax(logits, dim=-1)
if greedy:
pred_next = logp.argmax(-1)
else:
# Sample from categorical distribution with proper RNG handling
if generator is not None:
# Use multinomial with generator for reproducible sampling
probs = logp.exp()
pred_next = torch.multinomial(probs.view(-1, probs.size(-1)), 1, generator=generator).squeeze(-1).view(probs.shape[:-1])
else:
pred_next = torch.distributions.Categorical(logits=logp).sample()
conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
p = logp.exp()
ent_next = -(p * logp).sum(-1)
# Shift predictions: pos i predicts token i+1
pred_i = tokens.clone()
conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min)
ent_i = torch.zeros_like(ent_next)
pred_i[:, 1:] = pred_next[:, :-1]
conf_i[:, 1:] = conf_next[:, :-1]
ent_i[:, 1:] = ent_next[:, :-1]
return pred_i, conf_i, ent_i
pred_i, conf_i, ent_i = forward_scores(xt)
total_masked = init_maskable.sum(1, keepdim=True)
finf = torch.finfo(conf_i.dtype)
for step in range(num_steps - 1, 0, -1):
rate = step / num_steps
cutoff_len = (total_masked * rate).long().clamp(min=0)
# Choose HIGH-entropy tokens to keep masked
sel_scores = ent_i.masked_fill(~maskable, -finf.max)
B, L = sel_scores.shape
k_max = cutoff_len.max().item()
if k_max > 0:
sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True)
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
for b in range(B):
k_b = int(cutoff_len[b].item())
if k_b > 0:
keep_mask[b, idx[b, :k_b]] = True
else:
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
to_unmask = maskable & ~keep_mask
if to_unmask.any():
xt[to_unmask] = pred_i[to_unmask]
maskable[to_unmask] = False
if visualizer:
visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i)
if maskable.any():
pred_i, conf_i, ent_i = forward_scores(xt)
if maskable.any():
xt[maskable] = pred_i[maskable]
if visualizer:
visualizer.stop_visualization()
return xt