Upload terminal_visualizer.py with huggingface_hub
Browse files- terminal_visualizer.py +246 -0
terminal_visualizer.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Terminal visualization for RND1 generation.
|
3 |
+
|
4 |
+
This module provides real-time visualization of the diffusion denoising process,
|
5 |
+
showing token evolution and generation progress in the terminal using rich
|
6 |
+
formatting when available.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from typing import Optional
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
try:
|
14 |
+
from rich.console import Console
|
15 |
+
from rich.live import Live
|
16 |
+
from rich.text import Text
|
17 |
+
from rich.panel import Panel
|
18 |
+
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
19 |
+
from rich.layout import Layout
|
20 |
+
RICH_AVAILABLE = True
|
21 |
+
except ImportError:
|
22 |
+
RICH_AVAILABLE = False
|
23 |
+
|
24 |
+
|
25 |
+
class TerminalVisualizer:
|
26 |
+
"""
|
27 |
+
Rich-based visualization for diffusion process with live updates.
|
28 |
+
|
29 |
+
Provides real-time visualization of the token denoising process during
|
30 |
+
diffusion-based language generation, with colored highlighting of masked
|
31 |
+
positions and progress tracking.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, tokenizer, show_visualization: bool = True):
|
35 |
+
"""
|
36 |
+
Initialize the terminal visualizer.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
tokenizer: The tokenizer for decoding tokens to text
|
40 |
+
show_visualization: Whether to show visualization (requires rich)
|
41 |
+
"""
|
42 |
+
self.tokenizer = tokenizer
|
43 |
+
self.show_visualization = show_visualization and RICH_AVAILABLE
|
44 |
+
if not RICH_AVAILABLE and show_visualization:
|
45 |
+
print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.")
|
46 |
+
self.show_visualization = False
|
47 |
+
|
48 |
+
if self.show_visualization:
|
49 |
+
self.console = Console()
|
50 |
+
self.live = None
|
51 |
+
self.progress = None
|
52 |
+
self.layout = None
|
53 |
+
else:
|
54 |
+
self.pbar = None
|
55 |
+
|
56 |
+
self.current_tokens = None
|
57 |
+
self.mask_positions = None
|
58 |
+
self.total_steps = 0
|
59 |
+
self.current_step = 0
|
60 |
+
|
61 |
+
def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int):
|
62 |
+
"""
|
63 |
+
Start the visualization.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
initial_tokens: Initial token IDs (possibly masked)
|
67 |
+
mask_positions: Boolean mask indicating which positions are masked
|
68 |
+
total_steps: Total number of diffusion steps
|
69 |
+
"""
|
70 |
+
if not self.show_visualization:
|
71 |
+
self.pbar = tqdm(total=total_steps, desc="Diffusion")
|
72 |
+
return
|
73 |
+
|
74 |
+
self.current_tokens = initial_tokens.clone()
|
75 |
+
self.mask_positions = mask_positions
|
76 |
+
self.total_steps = total_steps
|
77 |
+
self.current_step = 0
|
78 |
+
|
79 |
+
self.layout = Layout()
|
80 |
+
self.layout.split_column(
|
81 |
+
Layout(name="header", size=3),
|
82 |
+
Layout(name="text", ratio=1),
|
83 |
+
Layout(name="progress", size=3)
|
84 |
+
)
|
85 |
+
|
86 |
+
self.progress = Progress(
|
87 |
+
TextColumn("[bold blue]Diffusion"),
|
88 |
+
BarColumn(),
|
89 |
+
MofNCompleteColumn(),
|
90 |
+
TextColumn("•"),
|
91 |
+
TextColumn("[cyan]Masks: {task.fields[masks]}"),
|
92 |
+
TimeRemainingColumn(),
|
93 |
+
)
|
94 |
+
self.progress_task = self.progress.add_task(
|
95 |
+
"Generating",
|
96 |
+
total=total_steps,
|
97 |
+
masks=mask_positions.sum().item()
|
98 |
+
)
|
99 |
+
|
100 |
+
self.live = Live(self.layout, console=self.console, refresh_per_second=4)
|
101 |
+
self.live.start()
|
102 |
+
self._update_display()
|
103 |
+
|
104 |
+
def update_step(self, tokens: torch.LongTensor, maskable: Optional[torch.BoolTensor], step: int,
|
105 |
+
entropy: Optional[torch.FloatTensor] = None, confidence: Optional[torch.FloatTensor] = None):
|
106 |
+
"""
|
107 |
+
Update visualization for current step.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
tokens: Current token IDs
|
111 |
+
maskable: Boolean mask of remaining masked positions
|
112 |
+
step: Current step number
|
113 |
+
entropy: Optional entropy scores for each position
|
114 |
+
confidence: Optional confidence scores for each position
|
115 |
+
"""
|
116 |
+
if not self.show_visualization:
|
117 |
+
if self.pbar:
|
118 |
+
self.pbar.update(1)
|
119 |
+
masks = maskable.sum().item() if maskable is not None else 0
|
120 |
+
self.pbar.set_postfix({'masks': masks})
|
121 |
+
return
|
122 |
+
|
123 |
+
self.current_tokens = tokens.clone()
|
124 |
+
self.mask_positions = maskable
|
125 |
+
self.current_step = step
|
126 |
+
|
127 |
+
masks_remaining = maskable.sum().item() if maskable is not None else 0
|
128 |
+
self.progress.update(
|
129 |
+
self.progress_task,
|
130 |
+
advance=1,
|
131 |
+
masks=masks_remaining
|
132 |
+
)
|
133 |
+
|
134 |
+
self._update_display()
|
135 |
+
|
136 |
+
def _update_display(self):
|
137 |
+
"""Update the live display."""
|
138 |
+
if not self.live:
|
139 |
+
return
|
140 |
+
|
141 |
+
header = Text("🎭 RND1-Base Generation", style="bold magenta", justify="center")
|
142 |
+
self.layout["header"].update(Panel(header, border_style="bright_blue"))
|
143 |
+
|
144 |
+
text_display = self._format_text_with_masks()
|
145 |
+
self.layout["text"].update(
|
146 |
+
Panel(
|
147 |
+
text_display,
|
148 |
+
title="[bold]Generated Text",
|
149 |
+
subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]",
|
150 |
+
border_style="cyan"
|
151 |
+
)
|
152 |
+
)
|
153 |
+
|
154 |
+
self.layout["progress"].update(Panel(self.progress))
|
155 |
+
|
156 |
+
def _format_text_with_masks(self) -> Text:
|
157 |
+
"""
|
158 |
+
Format text with colored masks.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
Rich Text object with formatted tokens
|
162 |
+
"""
|
163 |
+
text = Text()
|
164 |
+
|
165 |
+
if self.current_tokens is None:
|
166 |
+
return text
|
167 |
+
|
168 |
+
token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
|
169 |
+
mask_flags = self.mask_positions[0] if self.mask_positions is not None and self.mask_positions.dim() > 1 else self.mask_positions
|
170 |
+
|
171 |
+
for i, token_id in enumerate(token_ids):
|
172 |
+
if mask_flags is not None and i < len(mask_flags) and mask_flags[i]:
|
173 |
+
# Alternate colors for visual effect
|
174 |
+
text.append("[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red")
|
175 |
+
else:
|
176 |
+
try:
|
177 |
+
token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False)
|
178 |
+
# Skip special tokens in display
|
179 |
+
if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<s>", "</s>"]:
|
180 |
+
# Color based on position
|
181 |
+
text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan")
|
182 |
+
except:
|
183 |
+
continue
|
184 |
+
|
185 |
+
return text
|
186 |
+
|
187 |
+
def stop_visualization(self):
|
188 |
+
"""Stop the visualization and display final result."""
|
189 |
+
if not self.show_visualization:
|
190 |
+
if self.pbar:
|
191 |
+
self.pbar.close()
|
192 |
+
print("\n✨ Generation complete!\n")
|
193 |
+
return
|
194 |
+
|
195 |
+
if self.live:
|
196 |
+
self.live.stop()
|
197 |
+
|
198 |
+
self.console.print("\n[bold green]✨ Generation complete![/bold green]\n")
|
199 |
+
|
200 |
+
# Display final text
|
201 |
+
if self.current_tokens is not None:
|
202 |
+
try:
|
203 |
+
token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
|
204 |
+
final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
205 |
+
|
206 |
+
self.console.print(Panel(
|
207 |
+
final_text,
|
208 |
+
title="[bold]Final Generated Text",
|
209 |
+
border_style="green",
|
210 |
+
padding=(1, 2)
|
211 |
+
))
|
212 |
+
except:
|
213 |
+
pass
|
214 |
+
|
215 |
+
|
216 |
+
class SimpleProgressBar:
|
217 |
+
"""
|
218 |
+
Simple progress bar fallback when rich is not available.
|
219 |
+
|
220 |
+
Provides basic progress tracking using tqdm when the rich library
|
221 |
+
is not installed.
|
222 |
+
"""
|
223 |
+
|
224 |
+
def __init__(self, total_steps: int):
|
225 |
+
"""
|
226 |
+
Initialize simple progress bar.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
total_steps: Total number of steps
|
230 |
+
"""
|
231 |
+
self.pbar = tqdm(total=total_steps, desc="Diffusion")
|
232 |
+
|
233 |
+
def update(self, masks_remaining: int = 0):
|
234 |
+
"""
|
235 |
+
Update progress bar.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
masks_remaining: Number of masks still remaining
|
239 |
+
"""
|
240 |
+
self.pbar.update(1)
|
241 |
+
self.pbar.set_postfix({'masks': masks_remaining})
|
242 |
+
|
243 |
+
def close(self):
|
244 |
+
"""Close the progress bar."""
|
245 |
+
self.pbar.close()
|
246 |
+
print("\n✨ Generation complete!\n")
|