athms commited on
Commit
47bef2e
·
verified ·
1 Parent(s): 32b8af1

Upload terminal_visualizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. terminal_visualizer.py +246 -0
terminal_visualizer.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Terminal visualization for RND1 generation.
3
+
4
+ This module provides real-time visualization of the diffusion denoising process,
5
+ showing token evolution and generation progress in the terminal using rich
6
+ formatting when available.
7
+ """
8
+
9
+ import torch
10
+ from typing import Optional
11
+ from tqdm import tqdm
12
+
13
+ try:
14
+ from rich.console import Console
15
+ from rich.live import Live
16
+ from rich.text import Text
17
+ from rich.panel import Panel
18
+ from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
19
+ from rich.layout import Layout
20
+ RICH_AVAILABLE = True
21
+ except ImportError:
22
+ RICH_AVAILABLE = False
23
+
24
+
25
+ class TerminalVisualizer:
26
+ """
27
+ Rich-based visualization for diffusion process with live updates.
28
+
29
+ Provides real-time visualization of the token denoising process during
30
+ diffusion-based language generation, with colored highlighting of masked
31
+ positions and progress tracking.
32
+ """
33
+
34
+ def __init__(self, tokenizer, show_visualization: bool = True):
35
+ """
36
+ Initialize the terminal visualizer.
37
+
38
+ Args:
39
+ tokenizer: The tokenizer for decoding tokens to text
40
+ show_visualization: Whether to show visualization (requires rich)
41
+ """
42
+ self.tokenizer = tokenizer
43
+ self.show_visualization = show_visualization and RICH_AVAILABLE
44
+ if not RICH_AVAILABLE and show_visualization:
45
+ print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.")
46
+ self.show_visualization = False
47
+
48
+ if self.show_visualization:
49
+ self.console = Console()
50
+ self.live = None
51
+ self.progress = None
52
+ self.layout = None
53
+ else:
54
+ self.pbar = None
55
+
56
+ self.current_tokens = None
57
+ self.mask_positions = None
58
+ self.total_steps = 0
59
+ self.current_step = 0
60
+
61
+ def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int):
62
+ """
63
+ Start the visualization.
64
+
65
+ Args:
66
+ initial_tokens: Initial token IDs (possibly masked)
67
+ mask_positions: Boolean mask indicating which positions are masked
68
+ total_steps: Total number of diffusion steps
69
+ """
70
+ if not self.show_visualization:
71
+ self.pbar = tqdm(total=total_steps, desc="Diffusion")
72
+ return
73
+
74
+ self.current_tokens = initial_tokens.clone()
75
+ self.mask_positions = mask_positions
76
+ self.total_steps = total_steps
77
+ self.current_step = 0
78
+
79
+ self.layout = Layout()
80
+ self.layout.split_column(
81
+ Layout(name="header", size=3),
82
+ Layout(name="text", ratio=1),
83
+ Layout(name="progress", size=3)
84
+ )
85
+
86
+ self.progress = Progress(
87
+ TextColumn("[bold blue]Diffusion"),
88
+ BarColumn(),
89
+ MofNCompleteColumn(),
90
+ TextColumn("•"),
91
+ TextColumn("[cyan]Masks: {task.fields[masks]}"),
92
+ TimeRemainingColumn(),
93
+ )
94
+ self.progress_task = self.progress.add_task(
95
+ "Generating",
96
+ total=total_steps,
97
+ masks=mask_positions.sum().item()
98
+ )
99
+
100
+ self.live = Live(self.layout, console=self.console, refresh_per_second=4)
101
+ self.live.start()
102
+ self._update_display()
103
+
104
+ def update_step(self, tokens: torch.LongTensor, maskable: Optional[torch.BoolTensor], step: int,
105
+ entropy: Optional[torch.FloatTensor] = None, confidence: Optional[torch.FloatTensor] = None):
106
+ """
107
+ Update visualization for current step.
108
+
109
+ Args:
110
+ tokens: Current token IDs
111
+ maskable: Boolean mask of remaining masked positions
112
+ step: Current step number
113
+ entropy: Optional entropy scores for each position
114
+ confidence: Optional confidence scores for each position
115
+ """
116
+ if not self.show_visualization:
117
+ if self.pbar:
118
+ self.pbar.update(1)
119
+ masks = maskable.sum().item() if maskable is not None else 0
120
+ self.pbar.set_postfix({'masks': masks})
121
+ return
122
+
123
+ self.current_tokens = tokens.clone()
124
+ self.mask_positions = maskable
125
+ self.current_step = step
126
+
127
+ masks_remaining = maskable.sum().item() if maskable is not None else 0
128
+ self.progress.update(
129
+ self.progress_task,
130
+ advance=1,
131
+ masks=masks_remaining
132
+ )
133
+
134
+ self._update_display()
135
+
136
+ def _update_display(self):
137
+ """Update the live display."""
138
+ if not self.live:
139
+ return
140
+
141
+ header = Text("🎭 RND1-Base Generation", style="bold magenta", justify="center")
142
+ self.layout["header"].update(Panel(header, border_style="bright_blue"))
143
+
144
+ text_display = self._format_text_with_masks()
145
+ self.layout["text"].update(
146
+ Panel(
147
+ text_display,
148
+ title="[bold]Generated Text",
149
+ subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]",
150
+ border_style="cyan"
151
+ )
152
+ )
153
+
154
+ self.layout["progress"].update(Panel(self.progress))
155
+
156
+ def _format_text_with_masks(self) -> Text:
157
+ """
158
+ Format text with colored masks.
159
+
160
+ Returns:
161
+ Rich Text object with formatted tokens
162
+ """
163
+ text = Text()
164
+
165
+ if self.current_tokens is None:
166
+ return text
167
+
168
+ token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
169
+ mask_flags = self.mask_positions[0] if self.mask_positions is not None and self.mask_positions.dim() > 1 else self.mask_positions
170
+
171
+ for i, token_id in enumerate(token_ids):
172
+ if mask_flags is not None and i < len(mask_flags) and mask_flags[i]:
173
+ # Alternate colors for visual effect
174
+ text.append("[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red")
175
+ else:
176
+ try:
177
+ token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False)
178
+ # Skip special tokens in display
179
+ if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<s>", "</s>"]:
180
+ # Color based on position
181
+ text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan")
182
+ except:
183
+ continue
184
+
185
+ return text
186
+
187
+ def stop_visualization(self):
188
+ """Stop the visualization and display final result."""
189
+ if not self.show_visualization:
190
+ if self.pbar:
191
+ self.pbar.close()
192
+ print("\n✨ Generation complete!\n")
193
+ return
194
+
195
+ if self.live:
196
+ self.live.stop()
197
+
198
+ self.console.print("\n[bold green]✨ Generation complete![/bold green]\n")
199
+
200
+ # Display final text
201
+ if self.current_tokens is not None:
202
+ try:
203
+ token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
204
+ final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
205
+
206
+ self.console.print(Panel(
207
+ final_text,
208
+ title="[bold]Final Generated Text",
209
+ border_style="green",
210
+ padding=(1, 2)
211
+ ))
212
+ except:
213
+ pass
214
+
215
+
216
+ class SimpleProgressBar:
217
+ """
218
+ Simple progress bar fallback when rich is not available.
219
+
220
+ Provides basic progress tracking using tqdm when the rich library
221
+ is not installed.
222
+ """
223
+
224
+ def __init__(self, total_steps: int):
225
+ """
226
+ Initialize simple progress bar.
227
+
228
+ Args:
229
+ total_steps: Total number of steps
230
+ """
231
+ self.pbar = tqdm(total=total_steps, desc="Diffusion")
232
+
233
+ def update(self, masks_remaining: int = 0):
234
+ """
235
+ Update progress bar.
236
+
237
+ Args:
238
+ masks_remaining: Number of masks still remaining
239
+ """
240
+ self.pbar.update(1)
241
+ self.pbar.set_postfix({'masks': masks_remaining})
242
+
243
+ def close(self):
244
+ """Close the progress bar."""
245
+ self.pbar.close()
246
+ print("\n✨ Generation complete!\n")