athms commited on
Commit
32b8af1
·
verified ·
1 Parent(s): cac427a

Upload generation_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generation_utils.py +225 -0
generation_utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RND1 Generation Utilities.
3
+
4
+ This module provides generation utilities and mixins for RND1 models,
5
+ including the main GenerationMixin class that integrates with HuggingFace.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from typing import Optional, Union, Dict, Any
11
+ from transformers import GenerationMixin as HFGenerationMixin
12
+ from transformers.generation import GenerationConfig
13
+
14
+ from .sampling import diffusion_sample, apply_top_k_filtering, apply_top_p_filtering
15
+
16
+
17
+ class RND1GenerationMixin(HFGenerationMixin):
18
+ """
19
+ Generation mixin for RND1 models.
20
+
21
+ This mixin provides generation methods compatible with HuggingFace's
22
+ generation API while using RND1's diffusion-based sampling internally.
23
+ """
24
+
25
+ def generate(
26
+ self,
27
+ inputs: Optional[torch.LongTensor] = None,
28
+ generation_config: Optional[GenerationConfig] = None,
29
+ # RND1-specific parameters
30
+ prefix_ids: Optional[torch.LongTensor] = None,
31
+ suffix_ids: Optional[torch.LongTensor] = None,
32
+ infill_length: Optional[int] = None,
33
+ return_dict_in_generate: Optional[bool] = None,
34
+ **kwargs, # Accept all kwargs to be compatible with pipelines
35
+ ) -> Union[torch.LongTensor, Dict[str, Any]]:
36
+ """
37
+ Generate text using RND1's diffusion-based sampling.
38
+
39
+ Follows HuggingFace's standard generate API, using diffusion sampling
40
+ internally. Supports both standard generation and infilling.
41
+
42
+ Args:
43
+ inputs: Input token IDs to use as prefix (standard HF parameter)
44
+ generation_config: Generation configuration object
45
+ prefix_ids: Alternative to inputs for infilling tasks
46
+ suffix_ids: Optional suffix for infilling tasks
47
+ infill_length: Length of infill region (for infilling)
48
+ return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
49
+ **kwargs: Additional arguments (accepted for compatibility)
50
+
51
+ Returns:
52
+ Generated token IDs or GenerateDecoderOnlyOutput
53
+ """
54
+ if generation_config is not None:
55
+ gen_config = generation_config
56
+ model_kwargs = kwargs.copy()
57
+ else:
58
+ # Only prepare config from kwargs if no config was provided
59
+ gen_config, model_kwargs = self._prepare_generation_config(None, **kwargs)
60
+
61
+ device = next(self.parameters()).device
62
+
63
+ if inputs is not None:
64
+ prefix_ids = inputs.to(device)
65
+ elif prefix_ids is not None:
66
+ prefix_ids = prefix_ids.to(device)
67
+ else:
68
+ prefix_ids = None
69
+
70
+ if suffix_ids is not None:
71
+ suffix_ids = suffix_ids.to(device)
72
+
73
+ eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
74
+ pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", None)
75
+ bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
76
+ mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
77
+
78
+ if infill_length is not None and prefix_ids is not None:
79
+ # Infilling mode: use specified infill_length
80
+ prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
81
+ suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
82
+ seq_len = prefix_len + infill_length + suffix_len
83
+ else:
84
+ # Standard generation mode
85
+ if prefix_ids is not None:
86
+ prefix_len = prefix_ids.shape[1]
87
+ if gen_config.max_new_tokens is not None:
88
+ seq_len = prefix_len + gen_config.max_new_tokens
89
+ else:
90
+ seq_len = gen_config.max_length or self.config.max_position_embeddings
91
+ else:
92
+ seq_len = gen_config.max_length or self.config.max_position_embeddings
93
+
94
+ num_diffusion_steps = getattr(gen_config, "num_diffusion_steps",
95
+ getattr(self.config, "num_diffusion_steps", 256))
96
+
97
+ temperature = float(getattr(gen_config, "temperature", 1.0))
98
+ top_k = getattr(gen_config, "top_k", None)
99
+ top_p = getattr(gen_config, "top_p", None)
100
+
101
+ greedy = getattr(gen_config, "greedy",
102
+ not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
103
+
104
+ generator = model_kwargs.get("generator", None)
105
+ if generator is None:
106
+ seed = getattr(gen_config, 'seed', None)
107
+ if seed is not None:
108
+ generator = torch.Generator(device=device)
109
+ generator.manual_seed(seed)
110
+
111
+ with torch.inference_mode():
112
+ sequences = diffusion_sample(
113
+ model=self,
114
+ seq_len=seq_len,
115
+ num_steps=num_diffusion_steps,
116
+ mask_token_id=mask_token_id,
117
+ temperature=temperature,
118
+ top_k=top_k,
119
+ top_p=top_p,
120
+ greedy=greedy,
121
+ prefix_ids=prefix_ids,
122
+ suffix_ids=suffix_ids,
123
+ infill_length=infill_length,
124
+ eos_token_id=eos_token_id,
125
+ pad_token_id=pad_token_id,
126
+ bos_token_id=bos_token_id,
127
+ device=device,
128
+ generator=generator,
129
+ visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
130
+ )
131
+
132
+ if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
133
+ from transformers.generation.utils import GenerateDecoderOnlyOutput
134
+ return GenerateDecoderOnlyOutput(sequences=sequences)
135
+
136
+ return sequences
137
+
138
+ def generate_with_visualization(
139
+ self,
140
+ tokenizer,
141
+ prefix_ids: Optional[torch.LongTensor] = None,
142
+ suffix_ids: Optional[torch.LongTensor] = None,
143
+ infill_length: Optional[int] = None,
144
+ seq_len: int = 256,
145
+ num_steps: int = 256,
146
+ mask_token_id: int = 151669,
147
+ temperature: float = 1.0,
148
+ top_k: Optional[int] = None,
149
+ top_p: Optional[float] = None,
150
+ greedy: bool = True,
151
+ eos_token_id: int = 151645,
152
+ pad_token_id: Optional[int] = None,
153
+ bos_token_id: Optional[int] = None,
154
+ generator: Optional[torch.Generator] = None,
155
+ ) -> torch.LongTensor:
156
+ """
157
+ Generate with live visualization (for demos).
158
+
159
+ This method requires a tokenizer to display the generation process.
160
+ For production use, prefer `generate()`.
161
+
162
+ Args:
163
+ tokenizer: Tokenizer for decoding tokens to text
164
+ prefix_ids: Optional prefix token IDs
165
+ suffix_ids: Optional suffix token IDs
166
+ infill_length: Length of infill region
167
+ seq_len: Target sequence length
168
+ num_steps: Number of diffusion steps
169
+ mask_token_id: Mask token ID
170
+ temperature: Sampling temperature
171
+ top_k: Top-k filtering
172
+ top_p: Top-p filtering
173
+ greedy: Whether to use greedy sampling
174
+ eos_token_id: End of sequence token ID
175
+ pad_token_id: Padding token ID
176
+ bos_token_id: Beginning of sequence token ID
177
+ generator: Random generator for reproducibility
178
+
179
+ Returns:
180
+ Generated token IDs as LongTensor
181
+ """
182
+ from .terminal_visualizer import TerminalVisualizer
183
+ visualizer = TerminalVisualizer(tokenizer, show_visualization=True)
184
+
185
+ max_new_tokens = None
186
+ if seq_len is not None and prefix_ids is not None:
187
+ max_new_tokens = seq_len - prefix_ids.shape[1]
188
+
189
+ from .generation_config import RND1GenerationConfig
190
+ gen_config = RND1GenerationConfig(
191
+ max_length=seq_len,
192
+ max_new_tokens=max_new_tokens,
193
+ num_diffusion_steps=num_steps,
194
+ mask_token_id=mask_token_id,
195
+ temperature=temperature,
196
+ top_k=top_k,
197
+ top_p=top_p,
198
+ greedy=greedy,
199
+ bos_token_id=bos_token_id,
200
+ eos_token_id=eos_token_id,
201
+ pad_token_id=pad_token_id,
202
+ )
203
+
204
+ return self.generate(
205
+ inputs=prefix_ids,
206
+ suffix_ids=suffix_ids,
207
+ infill_length=infill_length,
208
+ generation_config=gen_config,
209
+ generator=generator,
210
+ visualizer=visualizer,
211
+ return_dict_in_generate=False,
212
+ )
213
+
214
+ def prepare_inputs_for_generation(
215
+ self,
216
+ input_ids: torch.LongTensor,
217
+ **kwargs,
218
+ ) -> Dict[str, Any]:
219
+ """
220
+ Prepare inputs for generation (required by HuggingFace).
221
+
222
+ For RND1, we don't use the standard autoregressive generation,
223
+ so this just returns the input_ids.
224
+ """
225
+ return {"input_ids": input_ids}