|
|
""" |
|
|
RND1 sampling module for masked diffusion generation. |
|
|
|
|
|
This module implements entropy-based token selection for iterative denoising |
|
|
in diffusion language models. Supports both greedy and stochastic sampling |
|
|
with optional prefix/suffix constraints and infilling. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
|
|
|
def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor: |
|
|
""" |
|
|
Apply top-k filtering to logits: with non-top-k values set to -inf |
|
|
""" |
|
|
top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1) |
|
|
filtered_logits = torch.full_like(logits, float('-inf')) |
|
|
filtered_logits.scatter_(-1, top_k_indices, top_k_values) |
|
|
return filtered_logits |
|
|
|
|
|
|
|
|
def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor: |
|
|
""" |
|
|
Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf |
|
|
""" |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > p |
|
|
sorted_indices_to_remove[..., 0] = False |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) |
|
|
return logits.masked_fill(indices_to_remove, float('-inf')) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def diffusion_sample( |
|
|
model: nn.Module, |
|
|
seq_len: int = 256, |
|
|
num_steps: int = 256, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
temperature: float = 1.0, |
|
|
greedy: bool = True, |
|
|
mask_token_id: int = 151669, |
|
|
prefix_ids: Optional[torch.LongTensor] = None, |
|
|
suffix_ids: Optional[torch.LongTensor] = None, |
|
|
infill_length: Optional[int] = None, |
|
|
eos_token_id: int = 151645, |
|
|
pad_token_id: Optional[int] = None, |
|
|
bos_token_id: Optional[int] = None, |
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
visualizer: Optional['TerminalVisualizer'] = None, |
|
|
) -> torch.LongTensor: |
|
|
""" |
|
|
Perform masked diffusion sampling with entropy-based token selection. |
|
|
|
|
|
Args: |
|
|
model: The RND1 language model |
|
|
seq_len: Target sequence length |
|
|
num_steps: Number of denoising steps |
|
|
top_k: Optional top-k filtering for sampling (None = no filtering) |
|
|
top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering) |
|
|
When both top_k and top_p are set, top_k is applied first, then top_p |
|
|
temperature: Temperature for sampling (higher = more random, lower = more deterministic) |
|
|
Values close to 0 are clamped to 1e-8 to avoid division by zero |
|
|
greedy: Whether to use greedy sampling (True) or stochastic (False) |
|
|
mask_token_id: Token ID for masked positions (default: 151669) |
|
|
prefix_ids: Optional prefix token IDs to preserve |
|
|
suffix_ids: Optional suffix token IDs to preserve |
|
|
infill_length: Length of infill region between prefix/suffix |
|
|
eos_token_id: End of sequence token ID (default: 151645) |
|
|
pad_token_id: Padding token ID (default: None, uses 0 if needed) |
|
|
bos_token_id: Beginning of sequence token ID (default: None) |
|
|
device: Device for computation (None = infer from model) |
|
|
generator: Optional torch generator for reproducible sampling |
|
|
visualizer: Optional TerminalVisualizer for live visualization |
|
|
|
|
|
Returns: |
|
|
Generated token IDs as LongTensor |
|
|
""" |
|
|
model.eval() |
|
|
|
|
|
if device is None: |
|
|
device = next(model.parameters()).device |
|
|
else: |
|
|
device = torch.device(device) |
|
|
dtype = next(model.parameters()).dtype |
|
|
|
|
|
if pad_token_id is None: |
|
|
pad_token_id = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prefix_ids is not None or suffix_ids is not None: |
|
|
if prefix_ids is not None: |
|
|
prefix_ids = prefix_ids.to(device) if isinstance(prefix_ids, torch.Tensor) else torch.tensor(prefix_ids, device=device) |
|
|
pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0 |
|
|
else: |
|
|
pre_len = 0 |
|
|
|
|
|
if suffix_ids is not None: |
|
|
suffix_ids = suffix_ids.to(device) if isinstance(suffix_ids, torch.Tensor) else torch.tensor(suffix_ids, device=device) |
|
|
suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0 |
|
|
else: |
|
|
suf_len = 0 |
|
|
|
|
|
reserved = (1 if bos_token_id is not None else 0) + (1 if eos_token_id is not None else 0) |
|
|
used = pre_len + suf_len + reserved |
|
|
|
|
|
if used > seq_len: |
|
|
raise ValueError( |
|
|
f"Combined length of prefix ({pre_len}), suffix ({suf_len}), " |
|
|
f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). " |
|
|
f"Please increase seq_len or reduce input lengths." |
|
|
) |
|
|
elif used == seq_len: |
|
|
raise ValueError( |
|
|
f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) " |
|
|
f"+ special tokens ({reserved}) = seq_len ({seq_len}). " |
|
|
f"Need at least 1 position for generation." |
|
|
) |
|
|
|
|
|
infill_length = min(infill_length or (seq_len - used), seq_len - used) |
|
|
|
|
|
x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device) |
|
|
pos = 0 |
|
|
if bos_token_id is not None: |
|
|
x[0, pos] = bos_token_id; pos += 1 |
|
|
if pre_len > 0: |
|
|
x[0, pos:pos+pre_len] = prefix_ids.flatten()[:pre_len]; pos += pre_len |
|
|
fill_start, fill_end = pos, pos + infill_length |
|
|
x[0, fill_start:fill_end] = mask_token_id |
|
|
pos = fill_end |
|
|
if suf_len > 0: |
|
|
x[0, pos:pos+suf_len] = suffix_ids.flatten()[:suf_len]; pos += suf_len |
|
|
|
|
|
if eos_token_id is not None and pos < seq_len: |
|
|
if isinstance(eos_token_id, (list, tuple)): |
|
|
x[0, pos] = eos_token_id[0] |
|
|
else: |
|
|
x[0, pos] = eos_token_id |
|
|
|
|
|
init_maskable = torch.zeros_like(x, dtype=torch.bool) |
|
|
init_maskable[0, fill_start:fill_end] = True |
|
|
else: |
|
|
x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device) |
|
|
if bos_token_id is not None: |
|
|
x[0, 0] = bos_token_id |
|
|
if eos_token_id is not None: |
|
|
|
|
|
if isinstance(eos_token_id, (list, tuple)): |
|
|
x[0, -1] = eos_token_id[0] |
|
|
else: |
|
|
x[0, -1] = eos_token_id |
|
|
init_maskable = x.eq(mask_token_id) |
|
|
|
|
|
if bos_token_id is not None: |
|
|
init_maskable[:, 0] = False |
|
|
if eos_token_id is not None: |
|
|
|
|
|
if isinstance(eos_token_id, (list, tuple)): |
|
|
for eos_id in eos_token_id: |
|
|
init_maskable &= x.ne(eos_id) |
|
|
else: |
|
|
init_maskable &= x.ne(eos_token_id) |
|
|
init_maskable &= x.ne(pad_token_id) |
|
|
|
|
|
maskable = init_maskable.clone() |
|
|
xt = x.clone() |
|
|
|
|
|
if visualizer: |
|
|
visualizer.start_visualization(xt, maskable, num_steps) |
|
|
|
|
|
def forward_scores(tokens): |
|
|
"""Compute predictions and entropy scores for next tokens.""" |
|
|
|
|
|
try: |
|
|
model_output = model(input_ids=tokens) |
|
|
except TypeError: |
|
|
|
|
|
model_output = model(tokens) |
|
|
|
|
|
safe_temperature = max(temperature, 1e-8) |
|
|
logits = model_output.logits / safe_temperature |
|
|
|
|
|
|
|
|
|
|
|
if top_k is not None and top_k > 0: |
|
|
logits = apply_top_k_filtering(logits, top_k) |
|
|
|
|
|
if top_p is not None and 0 < top_p < 1.0: |
|
|
logits = apply_top_p_filtering(logits, top_p) |
|
|
|
|
|
logp = torch.log_softmax(logits, dim=-1) |
|
|
|
|
|
if greedy: |
|
|
pred_next = logp.argmax(-1) |
|
|
else: |
|
|
|
|
|
if generator is not None: |
|
|
|
|
|
probs = logp.exp() |
|
|
pred_next = torch.multinomial(probs.view(-1, probs.size(-1)), 1, generator=generator).squeeze(-1).view(probs.shape[:-1]) |
|
|
else: |
|
|
pred_next = torch.distributions.Categorical(logits=logp).sample() |
|
|
|
|
|
conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
p = logp.exp() |
|
|
ent_next = -(p * logp).sum(-1) |
|
|
|
|
|
|
|
|
pred_i = tokens.clone() |
|
|
conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min) |
|
|
ent_i = torch.zeros_like(ent_next) |
|
|
|
|
|
pred_i[:, 1:] = pred_next[:, :-1] |
|
|
conf_i[:, 1:] = conf_next[:, :-1] |
|
|
ent_i[:, 1:] = ent_next[:, :-1] |
|
|
|
|
|
return pred_i, conf_i, ent_i |
|
|
|
|
|
pred_i, conf_i, ent_i = forward_scores(xt) |
|
|
total_masked = init_maskable.sum(1, keepdim=True) |
|
|
finf = torch.finfo(conf_i.dtype) |
|
|
|
|
|
for step in range(num_steps - 1, 0, -1): |
|
|
rate = step / num_steps |
|
|
cutoff_len = (total_masked * rate).long().clamp(min=0) |
|
|
|
|
|
|
|
|
sel_scores = ent_i.masked_fill(~maskable, -finf.max) |
|
|
B, L = sel_scores.shape |
|
|
k_max = cutoff_len.max().item() |
|
|
if k_max > 0: |
|
|
sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True) |
|
|
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool) |
|
|
for b in range(B): |
|
|
k_b = int(cutoff_len[b].item()) |
|
|
if k_b > 0: |
|
|
keep_mask[b, idx[b, :k_b]] = True |
|
|
else: |
|
|
keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool) |
|
|
|
|
|
to_unmask = maskable & ~keep_mask |
|
|
if to_unmask.any(): |
|
|
xt[to_unmask] = pred_i[to_unmask] |
|
|
maskable[to_unmask] = False |
|
|
|
|
|
if visualizer: |
|
|
visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i) |
|
|
|
|
|
if maskable.any(): |
|
|
pred_i, conf_i, ent_i = forward_scores(xt) |
|
|
|
|
|
if maskable.any(): |
|
|
xt[maskable] = pred_i[maskable] |
|
|
|
|
|
if visualizer: |
|
|
visualizer.stop_visualization() |
|
|
|
|
|
return xt |