athms commited on
Commit
139362b
·
verified ·
1 Parent(s): 47bef2e

Upload generation_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generation_utils.py +9 -43
generation_utils.py CHANGED
@@ -138,20 +138,12 @@ class RND1GenerationMixin(HFGenerationMixin):
138
  def generate_with_visualization(
139
  self,
140
  tokenizer,
141
- prefix_ids: Optional[torch.LongTensor] = None,
 
142
  suffix_ids: Optional[torch.LongTensor] = None,
143
  infill_length: Optional[int] = None,
144
- seq_len: int = 256,
145
- num_steps: int = 256,
146
- mask_token_id: int = 151669,
147
- temperature: float = 1.0,
148
- top_k: Optional[int] = None,
149
- top_p: Optional[float] = None,
150
- greedy: bool = True,
151
- eos_token_id: int = 151645,
152
- pad_token_id: Optional[int] = None,
153
- bos_token_id: Optional[int] = None,
154
  generator: Optional[torch.Generator] = None,
 
155
  ) -> torch.LongTensor:
156
  """
157
  Generate with live visualization (for demos).
@@ -161,20 +153,12 @@ class RND1GenerationMixin(HFGenerationMixin):
161
 
162
  Args:
163
  tokenizer: Tokenizer for decoding tokens to text
164
- prefix_ids: Optional prefix token IDs
 
165
  suffix_ids: Optional suffix token IDs
166
  infill_length: Length of infill region
167
- seq_len: Target sequence length
168
- num_steps: Number of diffusion steps
169
- mask_token_id: Mask token ID
170
- temperature: Sampling temperature
171
- top_k: Top-k filtering
172
- top_p: Top-p filtering
173
- greedy: Whether to use greedy sampling
174
- eos_token_id: End of sequence token ID
175
- pad_token_id: Padding token ID
176
- bos_token_id: Beginning of sequence token ID
177
  generator: Random generator for reproducibility
 
178
 
179
  Returns:
180
  Generated token IDs as LongTensor
@@ -182,33 +166,15 @@ class RND1GenerationMixin(HFGenerationMixin):
182
  from .terminal_visualizer import TerminalVisualizer
183
  visualizer = TerminalVisualizer(tokenizer, show_visualization=True)
184
 
185
- max_new_tokens = None
186
- if seq_len is not None and prefix_ids is not None:
187
- max_new_tokens = seq_len - prefix_ids.shape[1]
188
-
189
- from .generation_config import RND1GenerationConfig
190
- gen_config = RND1GenerationConfig(
191
- max_length=seq_len,
192
- max_new_tokens=max_new_tokens,
193
- num_diffusion_steps=num_steps,
194
- mask_token_id=mask_token_id,
195
- temperature=temperature,
196
- top_k=top_k,
197
- top_p=top_p,
198
- greedy=greedy,
199
- bos_token_id=bos_token_id,
200
- eos_token_id=eos_token_id,
201
- pad_token_id=pad_token_id,
202
- )
203
-
204
  return self.generate(
205
- inputs=prefix_ids,
 
206
  suffix_ids=suffix_ids,
207
  infill_length=infill_length,
208
- generation_config=gen_config,
209
  generator=generator,
210
  visualizer=visualizer,
211
  return_dict_in_generate=False,
 
212
  )
213
 
214
  def prepare_inputs_for_generation(
 
138
  def generate_with_visualization(
139
  self,
140
  tokenizer,
141
+ inputs: Optional[torch.LongTensor] = None,
142
+ generation_config: Optional[GenerationConfig] = None,
143
  suffix_ids: Optional[torch.LongTensor] = None,
144
  infill_length: Optional[int] = None,
 
 
 
 
 
 
 
 
 
 
145
  generator: Optional[torch.Generator] = None,
146
+ **kwargs,
147
  ) -> torch.LongTensor:
148
  """
149
  Generate with live visualization (for demos).
 
153
 
154
  Args:
155
  tokenizer: Tokenizer for decoding tokens to text
156
+ inputs: Input token IDs to use as prefix
157
+ generation_config: Generation configuration object
158
  suffix_ids: Optional suffix token IDs
159
  infill_length: Length of infill region
 
 
 
 
 
 
 
 
 
 
160
  generator: Random generator for reproducibility
161
+ **kwargs: Additional arguments for backward compatibility
162
 
163
  Returns:
164
  Generated token IDs as LongTensor
 
166
  from .terminal_visualizer import TerminalVisualizer
167
  visualizer = TerminalVisualizer(tokenizer, show_visualization=True)
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  return self.generate(
170
+ inputs=inputs,
171
+ generation_config=generation_config,
172
  suffix_ids=suffix_ids,
173
  infill_length=infill_length,
 
174
  generator=generator,
175
  visualizer=visualizer,
176
  return_dict_in_generate=False,
177
+ **kwargs,
178
  )
179
 
180
  def prepare_inputs_for_generation(