File size: 2,647 Bytes
278d275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright 2025 Radical Numerics Inc.
#
# This source code is licensed under the Apache License, Version 2.0, found in the
# LICENSE file in the root directory of this source tree.

"""

RND1 Generation Configuration.



This module defines the generation configuration for RND1 models,

controlling the diffusion-based generation process.

"""

from typing import Optional
from transformers.generation.configuration_utils import GenerationConfig


class RND1GenerationConfig(GenerationConfig):
    """

    Configuration class for RND1 generation parameters.



    This class extends the base GenerationConfig to include parameters

    specific to diffusion-based language generation.



    Args:

        max_length: Maximum sequence length

        num_diffusion_steps: Number of denoising steps in the diffusion process

        mask_token_id: Token ID used for masking during diffusion

        temperature: Temperature for sampling (higher = more random)

        top_k: Optional top-k filtering

        top_p: Optional nucleus (top-p) filtering

        greedy: Whether to use greedy decoding (True) or stochastic sampling (False)

        **kwargs: Additional arguments passed to GenerationConfig

    """

    def __init__(

        self,

        max_length: int = 256,

        num_diffusion_steps: int = 256,

        mask_token_id: int = 151669,

        temperature: float = 1.0,

        top_k: Optional[int] = None,

        top_p: Optional[float] = None,

        greedy: bool = True,

        bos_token_id: int = None,

        eos_token_id: int = None,

        pad_token_id: int = None,

        use_cache: bool = False,

        **kwargs,

    ):
        # Force no caching for RND generation
        # kwargs['use_cache'] = False
        kwargs.pop('use_cache', None)
        super().__init__(
            max_length=max_length,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            pad_token_id=pad_token_id,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            do_sample=not greedy,
            use_cache=False,
            **kwargs,
        )

        # RND-specific parameters
        self.num_diffusion_steps = num_diffusion_steps
        self.mask_token_id = mask_token_id
        self.greedy = greedy

    def to_dict(self):
        """Convert configuration to dictionary."""
        output = super().to_dict()
        output["num_diffusion_steps"] = self.num_diffusion_steps
        output["mask_token_id"] = self.mask_token_id
        output["greedy"] = self.greedy
        return output