# dream_app.py import torch import numpy as np import gradio as gr import spaces import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel, AutoConfig import time import copy # Determine device device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") # --- Model and Tokenizer Loading --- model_path = "Dream-org/Dream-v0-Instruct-7B" print(f"Loading tokenizer from {model_path}...") # Load configuration first to get special token IDs config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) print(f"Loading model from {model_path}...") model = AutoModel.from_pretrained( model_path, torch_dtype=torch.bfloat16, trust_remote_code=True ).to(device).eval() print("Model loaded successfully.") # --- Constants from Dream Model --- # Get IDs directly from config or tokenizer if available MASK_TOKEN = tokenizer.mask_token MASK_ID = config.mask_token_id if hasattr(config, 'mask_token_id') else tokenizer.mask_token_id EOS_ID = config.eos_token_id if hasattr(config, 'eos_token_id') else tokenizer.eos_token_id PAD_ID = config.pad_token_id if hasattr(config, 'pad_token_id') else tokenizer.pad_token_id # Often same as EOS print(f"MASK_TOKEN: '{MASK_TOKEN}', MASK_ID: {MASK_ID}") print(f"EOS_ID: {EOS_ID}, PAD_ID: {PAD_ID}") if MASK_ID is None: raise ValueError("Could not determine MASK_ID from model config or tokenizer.") if EOS_ID is None: raise ValueError("Could not determine EOS_ID from model config or tokenizer.") if PAD_ID is None: raise ValueError("Could not determine PAD_ID from model config or tokenizer.") # --- Helper Functions --- def parse_constraints(constraints_text, tokenizer): """Parse constraints in format: 'position:word, position:word, ...'""" constraints = {} processed_constraints_tokens = {} if not constraints_text: return constraints, processed_constraints_tokens parts = constraints_text.split(',') for part in parts: if ':' not in part: continue pos_str, word = part.split(':', 1) try: pos = int(pos_str.strip()) word = word.strip() if word and pos >= 0: # Store original word constraint for display/debugging if needed constraints[pos] = word # Tokenize the word (add space for consistency if not BOS) # Note: Dream tokenizer might handle spaces differently, adjust if needed prefix = " " if pos > 0 else "" tokens = tokenizer.encode(prefix + word, add_special_tokens=False) for i, token_id in enumerate(tokens): # Prevent overwriting multi-token constraints partially if pos + i not in processed_constraints_tokens: processed_constraints_tokens[pos + i] = token_id except ValueError: continue except Exception as e: print(f"Error tokenizing constraint word '{word}': {e}") continue # Sort by position for consistent application processed_constraints_tokens = dict(sorted(processed_constraints_tokens.items())) print(f"Parsed Constraints (Word): {constraints}") print(f"Parsed Constraints (Tokens): {processed_constraints_tokens}") return constraints, processed_constraints_tokens def format_chat_history(history): """ Format chat history for the Dream model using its chat template convention. Args: history: List of [user_message, assistant_message] pairs Returns: Formatted list of message dictionaries for the model """ messages = [] # Add system prompt if not present (standard practice) if not history or history[0][0] is None or history[0][0].lower() != "system": messages.append({"role": "system", "content": "You are a helpful assistant."}) for user_msg, assistant_msg in history: if user_msg is not None: # Handle potential system message case messages.append({"role": "user", "content": user_msg}) if assistant_msg: # Skip if None (for the latest user message) messages.append({"role": "assistant", "content": assistant_msg}) return messages # --- Core Generation Logic with Visualization Hook --- @spaces.GPU def generate_response_with_visualization( messages, # List of message dictionaries gen_length=64, steps=64, constraints_text="", # Raw constraint text temperature=0.2, top_p=0.95, top_k=None, # Added for Dream alg="entropy", # Changed from remasking alg_temp=0.0, # Added for Dream visualization_delay=0.05, tokenizer=tokenizer, model=model, device=device, MASK_ID=MASK_ID, EOS_ID=EOS_ID, PAD_ID=PAD_ID ): """ Generate text with Dream model with real-time visualization using a hook. """ visualization_states = [] final_text = "" # Use a list to hold previous_x, allowing nonlocal modification # Initialize with None, it will be set after the first hook call shared_state = {'previous_x': None} try: # --- 1. Prepare Inputs --- _, parsed_constraints_tokens = parse_constraints(constraints_text, tokenizer) # Apply chat template # Important: Keep tokenize=False initially to get prompt length correctly # The template adds roles and special tokens like <|im_start|> etc. chat_input_text = tokenizer.apply_chat_template( messages, add_generation_prompt=True, # Adds the prompt for the assistant's turn tokenize=False ) # Tokenize the full templated chat string inputs = tokenizer(chat_input_text, return_tensors="pt", return_dict=True) input_ids = inputs.input_ids.to(device) attention_mask = inputs.attention_mask.to(device) # Use mask from tokenizer prompt_length = input_ids.shape[1] total_length = prompt_length + gen_length # --- 2. Initialize Generation Sequence --- # Start with the prompt, pad the rest with MASK_ID x = torch.full((1, total_length), MASK_ID, dtype=torch.long, device=device) x[:, :prompt_length] = input_ids.clone() attention_mask = F.pad(attention_mask, (0, gen_length), value=1) # Extend attention mask # Apply initial constraints to the masked sequence `x` for pos, token_id in parsed_constraints_tokens.items(): absolute_pos = prompt_length + pos if absolute_pos < total_length: print(f"Applying initial constraint at pos {absolute_pos}: token {token_id}") x[:, absolute_pos] = token_id # Store initial state (prompt + all masked) for visualization initial_state_vis = [] # Add prompt tokens (optional visualization, could be grayed out or skipped) # for i in range(prompt_length): # token_str = tokenizer.decode([x[0, i].item()], skip_special_tokens=True) # initial_state_vis.append((token_str if token_str else " ", "#AAAAAA")) # Gray for prompt # Add masked tokens for the generation part for _ in range(gen_length): initial_state_vis.append((MASK_TOKEN, "#444444")) # Dark gray for masks visualization_states.append(initial_state_vis) shared_state['previous_x'] = x.clone() # Initialize previous_x # --- 3. Define the Visualization Hook --- def generation_tokens_hook_func(step, current_x_hook, logits): # nonlocal previous_x # Allow modification of the outer scope variable current_x_hook = current_x_hook.clone() # Work on a copy # --- Apply constraints within the hook --- # This ensures constraints are respected even if the model tries to overwrite them for pos, token_id in parsed_constraints_tokens.items(): absolute_pos = prompt_length + pos if absolute_pos < total_length: current_x_hook[:, absolute_pos] = token_id # --- End Constraint Application --- if shared_state['previous_x'] is None: # First call shared_state['previous_x'] = current_x_hook.clone() return current_x_hook # Must return the (potentially modified) sequence # Generate visualization state for this step current_state_vis = [] prev_x_step = shared_state['previous_x'] for i in range(gen_length): pos = prompt_length + i # Absolute position in the sequence current_token_id = current_x_hook[0, pos].item() prev_token_id = prev_x_step[0, pos].item() # Decode token, handling special tokens we want to hide token_str = "" color = "#444444" # Default: Dark Gray (Mask) token_str_raw = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for ID check if current_token_id == MASK_ID: token_str = MASK_TOKEN color = "#444444" # Dark gray elif current_token_id == EOS_ID or current_token_id == PAD_ID: token_str = "" # Hide EOS/PAD visually color = "#DDDDDD" # Use a light gray or make transparent if possible else: # Decode without special tokens for display if it's not MASK/EOS/PAD token_str = tokenizer.decode([current_token_id], skip_special_tokens=True).strip() if not token_str: token_str = token_str_raw # Fallback if strip removes everything (e.g., space) if prev_token_id == MASK_ID: # Newly revealed in this step color = "#66CC66" # Light green (Simplified from confidence levels) else: # Previously revealed color = "#6699CC" # Light blue current_state_vis.append((token_str if token_str else " ", color)) # Ensure non-empty tuple element visualization_states.append(current_state_vis) shared_state['previous_x'] = current_x_hook.clone() # Update previous_x for the next step return current_x_hook # Return the sequence (constraints applied) # --- 4. Run Diffusion Generation --- print("Starting diffusion generation...") start_time = time.time() output = model.diffusion_generate( input_ids=x[:, :prompt_length], # Pass only the initial prompt to diffusion_generate # as it handles the masking internally based on MASK_ID attention_mask=attention_mask, # Provide the full attention mask max_new_tokens=gen_length, output_history=False, # We capture history via the hook return_dict_in_generate=True, steps=steps, temperature=temperature, top_p=top_p, top_k=top_k, alg=alg, alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence-based # Pass the hook function generation_tokens_hook_func=generation_tokens_hook_func, # Ensure the initial masked sequence `x` is used correctly if needed by internal logic # Depending on the exact implementation of diffusion_generate, passing x directly might be needed # Check Dream's generation_utils if issues arise. For now, assume it uses input_ids + max_new_tokens ) end_time = time.time() print(f"Diffusion generation finished in {end_time - start_time:.2f} seconds.") # --- 5. Process Final Output --- # The hook has already built visualization_states final_sequence = output.sequences[0] # Decode the generated part, skipping special tokens for the final text output response_tokens = final_sequence[prompt_length:] # Filter out PAD tokens before final decode, keep EOS if needed conceptually, but skip for clean text response_tokens_cleaned = [tok for tok in response_tokens if tok != PAD_ID] # Keep EOS initially if needed final_text = tokenizer.decode( response_tokens_cleaned, skip_special_tokens=True, # Skip EOS, BOS, etc. clean_up_tokenization_spaces=True # Recommended for cleaner output ).strip() # Ensure the last state in visualization matches the final text (debug check) # print(f"Last Vis State Tokens: {''.join([t[0] for t in visualization_states[-1]]).strip()}") # print(f"Final Decoded Text: {final_text}") except Exception as e: print(f"Error during generation: {e}") import traceback traceback.print_exc() # Add error message to visualization error_msg = f"Error: {str(e)}" visualization_states.append([(error_msg, "red")]) final_text = error_msg # Display error in the chatbot too # Make sure at least the initial state is present if not visualization_states: visualization_states.append([("Error: No states generated.", "red")]) return visualization_states, final_text # --- Gradio UI Definition --- css = ''' .category-legend{display:none} button{height: 60px} .token-text { white-space: pre; } /* Preserve spaces in tokens */ footer { display: none !important; visibility: hidden !important; } ''' def create_chatbot_demo(): with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: gr.Markdown("# Dream 7B - Diffusion Language Model Demo") gr.Markdown( "[[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)] " "[[Blog Post](https://hkunlp.github.io/blog/2025/dream/)] " "(Note: Visualization shows token reveal steps, colors indicate status: Gray=Masked, Green=Newly Revealed, Blue=Previously Revealed)" ) # STATE MANAGEMENT chat_history = gr.State([]) # Store constraints parsed into token IDs parsed_constraints_state = gr.State({}) # UI COMPONENTS with gr.Row(): with gr.Column(scale=3): chatbot_ui = gr.Chatbot( label="Conversation", height=500, bubble_full_width=False # Makes bubbles wrap content ) # Message input with gr.Group(): with gr.Row(): user_input = gr.Textbox( label="Your Message", placeholder="Type your message here...", show_label=False, scale=7 ) send_btn = gr.Button("Send", scale=1) constraints_input = gr.Textbox( label="Word Constraints (Experimental)", info="Place specific words at positions (0-indexed). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'. Multi-token words supported.", placeholder="0:The, 10:story", value="" ) with gr.Column(scale=2): output_vis = gr.HighlightedText( label="Denoising Process Visualization", combine_adjacent=False, show_legend=False, # Legend not very informative here height=560, # Match chatbot height + input box approx elem_classes=["token-text"] # Apply custom class for styling if needed ) # Advanced generation settings with gr.Accordion("Generation Settings", open=False): with gr.Row(): gen_length = gr.Slider( minimum=16, maximum=512, value=128, step=8, label="Max New Tokens" ) steps = gr.Slider( minimum=8, maximum=512, value=128, step=4, label="Denoising Steps" ) with gr.Row(): temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.2, step=0.05, label="Temperature" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P" ) top_k = gr.Slider( minimum=0, maximum=200, value=0, step=5, label="Top-K (0=disabled)" ) with gr.Row(): alg = gr.Radio( choices=['origin', 'maskgit_plus', 'topk_margin', 'entropy'], value='entropy', label="Sampling Algorithm (`alg`)" ) alg_temp = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.05, label="Algorithm Temp (`alg_temp`, adds randomness to confidence-based `alg`)" ) with gr.Row(): visualization_delay = gr.Slider( minimum=0.0, maximum=0.5, value=0.02, step=0.01, label="Visualization Delay (seconds)" ) # Clear button clear_btn = gr.Button("Clear Conversation") # --- Event Handlers --- def add_message(history, message, response): """Add a message pair to the history and return the updated history""" # Ensure history is a list if not isinstance(history, list): history = [] history.append([message, response]) return history def user_message_submitted(message, history): """Process a submitted user message""" if not message.strip(): return history, history, "", [] # No change if empty # Add user message (response is None for now) history = add_message(history, message, None) # Return updated history for display, clear input box return history, history, "", [] # history, chatbot_ui, user_input, output_vis def bot_response_stream( history, # Current chat history (list of lists) gen_length, steps, constraints, # Generation settings temperature, top_p, top_k, alg, alg_temp, # Sampling settings delay # Visualization delay ): """Generate bot response and stream visualization states""" if not history or history[-1][1] is not None: # Check if history is present and last response isn't already set print("Skipping bot response generation: No new user message.") # Yield empty state if needed to prevent errors downstream # Ensure history is returned correctly if nothing happens yield history, [], "Internal Error: No user message found." return # Format messages for the model # Exclude the last entry as it only contains the user message messages_for_model = format_chat_history(history) # Already includes system prompt print("\n--- Generating Bot Response ---") print(f"History: {history}") print(f"Messages for model: {messages_for_model}") print(f"Constraints text: '{constraints}'") print(f"Gen length: {gen_length}, Steps: {steps}, Temp: {temperature}, Top-P: {top_p}, Top-K: {top_k}, Alg: {alg}, Alg Temp: {alg_temp}") # Call the generation function vis_states, response_text = generate_response_with_visualization( messages_for_model, gen_length=gen_length, steps=steps, constraints_text=constraints, temperature=temperature, top_p=top_p if top_p < 1.0 else None, # None disables top-p top_k=top_k if top_k > 0 else None, # None disables top-k alg=alg, alg_temp=alg_temp, visualization_delay=delay, # Pass other necessary args like tokenizer, model if not global ) print(f"Generated response text: '{response_text}'") print(f"Number of visualization states: {len(vis_states)}") # Update the history with the final response # Make sure history is mutable if needed or reassign if history: history[-1][1] = response_text else: print("Warning: History was empty when trying to update response.") # Stream the visualization states if not vis_states: print("Warning: No visualization states were generated.") # Yield something to prevent downstream errors yield history, [("Error: No visualization.", "red")], response_text return # Yield initial state immediately if desired, or just start loop # yield history, vis_states[0], response_text for state in vis_states: yield history, state, response_text # Yield updated history, current vis state, final text time.sleep(delay) # Pause between steps # Final yield to ensure the last state is displayed yield history, vis_states[-1], response_text def clear_conversation(): """Clear the conversation history and visualization""" return [], [], "", [] # history, chatbot, user_input, output_vis # --- Event Wiring --- # Clear button clear_btn.click( fn=clear_conversation, inputs=[], outputs=[chat_history, chatbot_ui, user_input, output_vis] ) # User message submission flow (2-step using .then) # 1. User submits message -> Update history and chatbot UI immediately submit_action = user_input.submit( fn=user_message_submitted, inputs=[user_input, chat_history], outputs=[chat_history, chatbot_ui, user_input, output_vis] # Update chatbot, clear input ) # Connect send button to the same function send_action = send_btn.click( fn=user_message_submitted, inputs=[user_input, chat_history], outputs=[chat_history, chatbot_ui, user_input, output_vis] ) # 2. After UI update -> Trigger bot response generation and streaming # Use the updated chat_history from the first step submit_action.then( fn=bot_response_stream, inputs=[ chat_history, gen_length, steps, constraints_input, temperature, top_p, top_k, alg, alg_temp, visualization_delay ], outputs=[chatbot_ui, output_vis, user_input] # Update chatbot, visualization. Keep user_input as output to potentially display final text/error? (Check Gradio docs for Textbox output binding on yield) Let's remove user_input from outputs here. ) send_action.then( fn=bot_response_stream, inputs=[ chat_history, gen_length, steps, constraints_input, temperature, top_p, top_k, alg, alg_temp, visualization_delay ], outputs=[chatbot_ui, output_vis] # Update chatbot and visualization ) # Clear input after send/submit (already handled in user_message_submitted) # submit_action.then(lambda: "", outputs=user_input) # send_action.then(lambda: "", outputs=user_input) return demo # --- Launch the Gradio App --- if __name__ == "__main__": demo = create_chatbot_demo() # Using queue for streaming and handling multiple users demo.queue().launch(debug=True, share=True)