|
""" |
|
RND1 Generation Configuration. |
|
|
|
This module defines the generation configuration for RND1 models, |
|
controlling the diffusion-based generation process. |
|
""" |
|
|
|
from typing import Optional |
|
from transformers.generation.configuration_utils import GenerationConfig |
|
|
|
|
|
class RND1GenerationConfig(GenerationConfig): |
|
""" |
|
Configuration class for RND1 generation parameters. |
|
|
|
This class extends the base GenerationConfig to include parameters |
|
specific to diffusion-based language generation. |
|
|
|
Args: |
|
max_length: Maximum sequence length |
|
num_diffusion_steps: Number of denoising steps in the diffusion process |
|
mask_token_id: Token ID used for masking during diffusion |
|
temperature: Temperature for sampling (higher = more random) |
|
top_k: Optional top-k filtering |
|
top_p: Optional nucleus (top-p) filtering |
|
greedy: Whether to use greedy decoding (True) or stochastic sampling (False) |
|
**kwargs: Additional arguments passed to GenerationConfig |
|
""" |
|
|
|
def __init__( |
|
self, |
|
max_length: int = 256, |
|
num_diffusion_steps: int = 256, |
|
mask_token_id: int = 151669, |
|
temperature: float = 1.0, |
|
top_k: Optional[int] = None, |
|
top_p: Optional[float] = None, |
|
greedy: bool = True, |
|
bos_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
use_cache: bool = False, |
|
**kwargs, |
|
): |
|
|
|
kwargs.pop('use_cache', None) |
|
|
|
super().__init__( |
|
max_length=max_length, |
|
bos_token_id=bos_token_id, |
|
eos_token_id=eos_token_id, |
|
pad_token_id=pad_token_id, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
do_sample=not greedy, |
|
use_cache=False, |
|
**kwargs, |
|
) |
|
|
|
|
|
self.num_diffusion_steps = num_diffusion_steps |
|
self.mask_token_id = mask_token_id |
|
self.greedy = greedy |
|
self.temperature = float(temperature) |
|
|
|
def to_dict(self): |
|
"""Convert configuration to dictionary.""" |
|
output = super().to_dict() |
|
output["num_diffusion_steps"] = self.num_diffusion_steps |
|
output["mask_token_id"] = self.mask_token_id |
|
output["greedy"] = self.greedy |
|
return output |