""" RND1 Generation Configuration. This module defines the generation configuration for RND1 models, controlling the diffusion-based generation process. """ from typing import Optional from transformers.generation.configuration_utils import GenerationConfig class RND1GenerationConfig(GenerationConfig): """ Configuration class for RND1 generation parameters. This class extends the base GenerationConfig to include parameters specific to diffusion-based language generation. Args: max_length: Maximum sequence length num_diffusion_steps: Number of denoising steps in the diffusion process mask_token_id: Token ID used for masking during diffusion temperature: Temperature for sampling (higher = more random) top_k: Optional top-k filtering top_p: Optional nucleus (top-p) filtering greedy: Whether to use greedy decoding (True) or stochastic sampling (False) **kwargs: Additional arguments passed to GenerationConfig """ def __init__( self, max_length: int = 256, num_diffusion_steps: int = 256, mask_token_id: int = 151669, # Default for Qwen-based RND1 models temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, greedy: bool = True, seed: Optional[int] = None, # For reproducible generation bos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, use_cache: bool = False, **kwargs, ): # Force no caching for RND1 generation - remove from kwargs if present kwargs.pop('use_cache', None) super().__init__( max_length=max_length, bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=not greedy, use_cache=False, # Always False for RND1 **kwargs, ) # RND1-specific parameters self.num_diffusion_steps = num_diffusion_steps self.mask_token_id = mask_token_id self.greedy = greedy self.temperature = float(temperature) # Ensure it's a float self.seed = seed def to_dict(self): """Convert configuration to dictionary.""" output = super().to_dict() output["num_diffusion_steps"] = self.num_diffusion_steps output["mask_token_id"] = self.mask_token_id output["greedy"] = self.greedy if self.seed is not None: output["seed"] = self.seed return output