File size: 2,509 Bytes
3e7a3bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2011e8a
3e7a3bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
RND1 Generation Configuration.

This module defines the generation configuration for RND1 models,
controlling the diffusion-based generation process.
"""

from typing import Optional
from transformers.generation.configuration_utils import GenerationConfig


class RND1GenerationConfig(GenerationConfig):
    """
    Configuration class for RND1 generation parameters.

    This class extends the base GenerationConfig to include parameters
    specific to diffusion-based language generation.

    Args:
        max_length: Maximum sequence length
        num_diffusion_steps: Number of denoising steps in the diffusion process
        mask_token_id: Token ID used for masking during diffusion
        temperature: Temperature for sampling (higher = more random)
        top_k: Optional top-k filtering
        top_p: Optional nucleus (top-p) filtering
        greedy: Whether to use greedy decoding (True) or stochastic sampling (False)
        **kwargs: Additional arguments passed to GenerationConfig
    """

    def __init__(
        self,
        max_length: int = 256,
        num_diffusion_steps: int = 256,
        mask_token_id: int = 151669,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        greedy: bool = True,
        bos_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        use_cache: bool = False,
        **kwargs,
    ):
        # Force no caching for RND1 generation - remove from kwargs if present
        kwargs.pop('use_cache', None)

        super().__init__(
            max_length=max_length,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            pad_token_id=pad_token_id,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            do_sample=not greedy,
            use_cache=False,  # Always False for RND1
            **kwargs,
        )

        # RND1-specific parameters
        self.num_diffusion_steps = num_diffusion_steps
        self.mask_token_id = mask_token_id
        self.greedy = greedy
        self.temperature = float(temperature)  # Ensure it's a float

    def to_dict(self):
        """Convert configuration to dictionary."""
        output = super().to_dict()
        output["num_diffusion_steps"] = self.num_diffusion_steps
        output["mask_token_id"] = self.mask_token_id
        output["greedy"] = self.greedy
        return output