RND1-Base-0910 / generation_config.py
fedebotu's picture
[Chore] propagate changes
2011e8a
"""
RND1 Generation Configuration.
This module defines the generation configuration for RND1 models,
controlling the diffusion-based generation process.
"""
from typing import Optional
from transformers.generation.configuration_utils import GenerationConfig
class RND1GenerationConfig(GenerationConfig):
"""
Configuration class for RND1 generation parameters.
This class extends the base GenerationConfig to include parameters
specific to diffusion-based language generation.
Args:
max_length: Maximum sequence length
num_diffusion_steps: Number of denoising steps in the diffusion process
mask_token_id: Token ID used for masking during diffusion
temperature: Temperature for sampling (higher = more random)
top_k: Optional top-k filtering
top_p: Optional nucleus (top-p) filtering
greedy: Whether to use greedy decoding (True) or stochastic sampling (False)
**kwargs: Additional arguments passed to GenerationConfig
"""
def __init__(
self,
max_length: int = 256,
num_diffusion_steps: int = 256,
mask_token_id: int = 151669,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
greedy: bool = True,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
use_cache: bool = False,
**kwargs,
):
# Force no caching for RND1 generation - remove from kwargs if present
kwargs.pop('use_cache', None)
super().__init__(
max_length=max_length,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=not greedy,
use_cache=False, # Always False for RND1
**kwargs,
)
# RND1-specific parameters
self.num_diffusion_steps = num_diffusion_steps
self.mask_token_id = mask_token_id
self.greedy = greedy
self.temperature = float(temperature) # Ensure it's a float
def to_dict(self):
"""Convert configuration to dictionary."""
output = super().to_dict()
output["num_diffusion_steps"] = self.num_diffusion_steps
output["mask_token_id"] = self.mask_token_id
output["greedy"] = self.greedy
return output