Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, BitsAndBytesConfig | |
from threading import Thread | |
import time | |
import logging | |
import gc | |
from pathlib import Path | |
import re | |
from huggingface_hub import HfApi, list_models | |
import os | |
import queue | |
import threading | |
from collections import deque | |
# Set PyTorch memory management environment variables | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler('gradio-chat-ui.log'), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Log memory management settings | |
logger.info(f"PyTorch CUDA allocation config: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}") | |
logger.info(f"CUDA device count: {torch.cuda.device_count() if torch.cuda.is_available() else 'N/A'}") | |
# Model parameters | |
MODEL_NAME = "No Model Loaded" | |
MAX_LENGTH = 16384 | |
DEFAULT_TEMPERATURE = 0.15 | |
DEFAULT_TOP_P = 0.93 | |
DEFAULT_TOP_K = 50 | |
DEFAULT_REP_PENALTY = 1.15 | |
# Base location for local models | |
LOCAL_MODELS_BASE = "/home/llm-models/" | |
# Global variables | |
model = None | |
tokenizer = None | |
hf_api = HfApi() | |
# Generation metadata storage with automatic cleanup | |
generation_metadata = deque(maxlen=100) # Fixed size deque to prevent unlimited growth | |
class RAMSavingIteratorStreamer: | |
""" | |
Custom streamer that saves VRAM by moving tokens to CPU and provides iteration interface for Gradio. | |
Combines the benefits of TextStreamer (RAM saving) with TextIteratorStreamer (iteration). | |
""" | |
def __init__(self, tokenizer, skip_special_tokens=True, skip_prompt=True, timeout=None): | |
self.tokenizer = tokenizer | |
self.skip_special_tokens = skip_special_tokens | |
self.skip_prompt = skip_prompt | |
self.timeout = timeout | |
# Token and text storage (CPU-based) | |
self.generated_tokens = [] | |
self.generated_text = "" | |
self.token_cache = "" | |
# Queue for streaming interface | |
self.text_queue = queue.Queue() | |
self.stop_signal = threading.Event() | |
# Track prompt tokens to skip them | |
self.prompt_length = 0 | |
self.tokens_processed = 0 | |
# Decoding state | |
self.print_len = 0 | |
def put(self, value): | |
""" | |
Receive new token(s) and process them for streaming. | |
This method is called by the model during generation. | |
""" | |
try: | |
# Handle different input types | |
if isinstance(value, torch.Tensor): | |
if value.dim() > 1: | |
value = value[0] # Remove batch dimension if present | |
token_ids = value.tolist() | |
# Store CPU version to save VRAM | |
self.generated_tokens.append(value.detach().cpu()) | |
else: | |
token_ids = value if isinstance(value, list) else [value] | |
self.generated_tokens.append(torch.tensor(token_ids, dtype=torch.long)) | |
# Track tokens processed | |
if isinstance(token_ids, list): | |
self.tokens_processed += len(token_ids) | |
else: | |
self.tokens_processed += 1 | |
# Skip prompt tokens if requested | |
if self.skip_prompt and self.tokens_processed <= self.prompt_length: | |
return | |
# Decode incrementally for real-time streaming | |
try: | |
# Get all generated tokens so far | |
if self.generated_tokens: | |
all_tokens = [] | |
for tokens in self.generated_tokens: | |
if isinstance(tokens, torch.Tensor): | |
if tokens.dim() == 0: | |
all_tokens.append(tokens.item()) | |
else: | |
all_tokens.extend(tokens.tolist()) | |
elif isinstance(tokens, list): | |
all_tokens.extend(tokens) | |
else: | |
all_tokens.append(tokens) | |
# Decode the full sequence | |
full_text = self.tokenizer.decode( | |
all_tokens, | |
skip_special_tokens=self.skip_special_tokens | |
) | |
# Get new text since last update | |
if len(full_text) > self.print_len: | |
new_text = full_text[self.print_len:] | |
self.print_len = len(full_text) | |
self.generated_text = full_text | |
# Put new text in queue for iteration | |
if new_text: | |
self.text_queue.put(new_text) | |
except Exception as decode_error: | |
logger.warning(f"Decoding error in streamer: {decode_error}") | |
except Exception as e: | |
logger.error(f"Error in RAMSavingIteratorStreamer.put: {e}") | |
def end(self): | |
"""Signal end of generation.""" | |
self.text_queue.put(None) # Sentinel value | |
def __iter__(self): | |
"""Make this streamer iterable for Gradio compatibility.""" | |
return self | |
def __next__(self): | |
"""Get next chunk of text for streaming.""" | |
try: | |
value = self.text_queue.get(timeout=self.timeout) | |
if value is None: # End signal | |
raise StopIteration | |
return value | |
except queue.Empty: | |
raise StopIteration | |
def set_prompt_length(self, prompt_length): | |
"""Set the length of prompt tokens to skip.""" | |
self.prompt_length = prompt_length | |
def get_generated_text(self): | |
"""Get the complete generated text.""" | |
return self.generated_text | |
def get_generated_tokens(self): | |
"""Get all generated tokens as a single tensor.""" | |
if not self.generated_tokens: | |
return torch.tensor([]) | |
# Combine all tokens | |
all_tokens = [] | |
for tokens in self.generated_tokens: | |
if isinstance(tokens, torch.Tensor): | |
if tokens.dim() == 0: | |
all_tokens.append(tokens.item()) | |
else: | |
all_tokens.extend(tokens.tolist()) | |
elif isinstance(tokens, list): | |
all_tokens.extend(tokens) | |
else: | |
all_tokens.append(tokens) | |
return torch.tensor(all_tokens, dtype=torch.long) | |
def cleanup(self): | |
"""Clean up resources.""" | |
self.generated_tokens.clear() | |
self.generated_text = "" | |
self.token_cache = "" | |
# Clear queue | |
while not self.text_queue.empty(): | |
try: | |
self.text_queue.get_nowait() | |
except queue.Empty: | |
break | |
self.stop_signal.set() | |
def scan_local_models(base_path=LOCAL_MODELS_BASE): | |
"""Scan for valid models in the local models directory""" | |
try: | |
base_path = Path(base_path) | |
if not base_path.exists(): | |
logger.warning(f"Base path does not exist: {base_path}") | |
return [] | |
valid_models = [] | |
# Scan subdirectories (depth 1 only) | |
for item in base_path.iterdir(): | |
if item.is_dir(): | |
# Check if directory contains required model files | |
config_file = item / "config.json" | |
# Look for model weight files (safetensors or bin) | |
safetensors_files = list(item.glob("*.safetensors")) | |
bin_files = list(item.glob("*.bin")) | |
# Check if it's a valid model directory | |
if config_file.exists() and (safetensors_files or bin_files): | |
valid_models.append(str(item)) | |
logger.info(f"Found valid model: {item}") | |
# Sort models for consistent ordering | |
valid_models.sort() | |
logger.info(f"Found {len(valid_models)} valid models in {base_path}") | |
return valid_models | |
except Exception as e: | |
logger.error(f"Error scanning local models: {e}") | |
return [] | |
def update_local_models_dropdown(base_path): | |
"""Update the local models dropdown based on base path""" | |
if not base_path or not base_path.strip(): | |
return gr.Dropdown(choices=[], value=None, interactive=True) | |
models = scan_local_models(base_path) | |
model_choices = [Path(model).name for model in models] # Show just the model name | |
model_paths = models # Keep full paths for internal use | |
# Create a mapping for display name to full path | |
if model_choices: | |
return gr.Dropdown( | |
choices=list(zip(model_choices, model_paths)), | |
value=model_paths[0] if model_paths else None, | |
label="๐ Available Local Models", | |
interactive=True, | |
allow_custom_value=False, # Don't allow custom for local models | |
filterable=True | |
) | |
else: | |
return gr.Dropdown( | |
choices=[], | |
value=None, | |
label="๐ Available Local Models (None found)", | |
interactive=True, | |
allow_custom_value=False, | |
filterable=True | |
) | |
def search_hf_models(query, limit=20): | |
"""Enhanced search for models on Hugging Face Hub with better coverage""" | |
if not query or len(query.strip()) < 2: | |
return [] | |
try: | |
query = query.strip() | |
model_choices = [] | |
# Strategy 1: Direct model ID search (if query looks like a model ID) | |
if '/' in query: | |
try: | |
# Try to get the specific model | |
model_info = hf_api.model_info(query) | |
if model_info and hasattr(model_info, 'id'): | |
model_choices.append(model_info.id) | |
logger.info(f"Found direct model: {model_info.id}") | |
except Exception as direct_error: | |
logger.debug(f"Direct model search failed: {direct_error}") | |
# Strategy 2: Search with different parameters | |
search_strategies = [ | |
# Exact search | |
{"search": query, "sort": "downloads", "direction": -1, "limit": limit//2}, | |
# Author search (if query contains /) | |
{"author": query.split('/')[0] if '/' in query else query, "sort": "downloads", "direction": -1, "limit": limit//4} if '/' in query else None, | |
# Broader search | |
{"search": query, "sort": "trending", "direction": -1, "limit": limit//4}, | |
] | |
for strategy in search_strategies: | |
if strategy is None: | |
continue | |
try: | |
models = list_models( | |
task="text-generation", | |
**strategy | |
) | |
for model in models: | |
if model.id not in model_choices: | |
model_choices.append(model.id) | |
except Exception as strategy_error: | |
logger.debug(f"Search strategy failed: {strategy_error}") | |
# Remove duplicates while preserving order | |
seen = set() | |
unique_choices = [] | |
for choice in model_choices: | |
if choice not in seen: | |
seen.add(choice) | |
unique_choices.append(choice) | |
# Limit results | |
final_choices = unique_choices[:limit] | |
logger.info(f"HF search for '{query}' returned {len(final_choices)} models") | |
return final_choices | |
except Exception as e: | |
logger.error(f"Error searching models: {str(e)}") | |
return [] | |
def update_model_dropdown(query): | |
"""Update dropdown with enhanced search results""" | |
if not query or len(query.strip()) < 2: | |
return gr.Dropdown(choices=[], value=None, interactive=True) | |
choices = search_hf_models(query, limit=20) | |
return gr.Dropdown( | |
choices=choices, | |
value=choices[0] if choices else None, | |
interactive=True, | |
allow_custom_value=True, # Allow manual typing | |
filterable=True | |
) | |
def load_model_with_progress(model_source, hf_model, local_path, local_model_selection, quantization, memory_optimization): | |
"""Load model with progress tracking and memory optimization""" | |
global model, tokenizer, MODEL_NAME | |
# Determine model path based on source | |
if model_source == "Hugging Face Model": | |
if not hf_model: | |
return "โ Error: Please select a model from the dropdown" | |
model_path = hf_model | |
else: | |
# Use selected local model if available, otherwise use manual path | |
if local_model_selection: | |
model_path = local_model_selection | |
else: | |
model_path = local_path | |
if not Path(model_path).exists(): | |
logger.error(f"Local path does not exist: {model_path}") | |
return f"โ Error: Local path does not exist: {model_path}" | |
MODEL_NAME = model_path.split("/")[-1] if "/" in model_path else model_path | |
logger.info(f"Loading model from {model_path} with memory optimization: {memory_optimization}") | |
try: | |
# Yield progress updates | |
yield "๐ Initializing model loading..." | |
# Setup memory configuration (GPU-only, generous allocation) | |
if torch.cuda.is_available(): | |
device_properties = torch.cuda.get_device_properties(0) | |
total_memory_gb = device_properties.total_memory / (1024**3) | |
# Set max memory to 11GB as requested (GPU-bound) | |
max_memory_val = 11.5 # Fixed 11GB allocation | |
max_memory = f"{max_memory_val}GB" | |
logger.info(f"Setting max GPU memory to {max_memory} (Total available: {total_memory_gb:.2f}GB)") | |
else: | |
max_memory = "11GB" | |
logger.info("CUDA not available. Using CPU fallback.") | |
yield "๐ Configuring quantization settings..." | |
# Configure quantization (removed CPU offloading) | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=quantization == "4bit", | |
load_in_8bit=quantization == "8bit", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_quant_type="nf4", | |
) | |
yield "๐ Loading tokenizer..." | |
# Load tokenizer | |
if model_source == "Local Path": | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
trust_remote_code=True, | |
local_files_only=True | |
) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
trust_remote_code=True | |
) | |
yield "๐ Cleaning memory cache..." | |
# Clean memory | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Determine torch dtype | |
if quantization in ["4bit", "8bit"]: | |
torch_dtype = torch.bfloat16 | |
elif quantization == "f16": | |
torch_dtype = torch.float16 | |
else: # bf16 | |
torch_dtype = torch.bfloat16 | |
yield "๐ Loading model weights (this may take a while)..." | |
# Simple GPU-only model loading parameters | |
model_kwargs = { | |
"device_map": "auto", | |
"max_memory": {0: max_memory} if torch.cuda.is_available() else None, | |
"torch_dtype": torch_dtype, | |
"quantization_config": bnb_config if quantization in ["4bit", "8bit"] else None, | |
"trust_remote_code": True, | |
} | |
# Memory optimization specific settings (GPU-only) | |
if memory_optimization: | |
model_kwargs.update({ | |
"attn_implementation": "flash_attention_2" if torch.cuda.is_available() else "sdpa", | |
"use_cache": False, # Disable cache by default for memory optimization | |
}) | |
else: | |
model_kwargs.update({ | |
"attn_implementation": "flash_attention_2" if torch.cuda.is_available() else "sdpa", | |
#"use_cache": True, # Enable cache for performance | |
}) | |
# Add local files only for local models | |
if model_source == "Local Path": | |
model_kwargs["local_files_only"] = True | |
# Load model | |
model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) | |
# Post-loading memory optimization | |
if memory_optimization: | |
yield "๐ Applying memory optimizations..." | |
# Additional memory cleanup after loading | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
logger.info("Model loaded successfully with memory optimization") | |
yield "โ Model loaded successfully with memory optimization!" if memory_optimization else "โ Model loaded successfully!" | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}", exc_info=True) | |
yield f"โ Error loading model: {str(e)}" | |
def unload_model(): | |
"""Unload the model and free memory with aggressive cleanup""" | |
global model, tokenizer, MODEL_NAME | |
if model is None: | |
return "No model loaded" | |
try: | |
logger.info("Unloading model with aggressive memory cleanup...") | |
# Step 1: Move model to CPU first (if it was on GPU) | |
if torch.cuda.is_available() and hasattr(model, 'device'): | |
try: | |
model.cpu() | |
logger.info("Model moved to CPU") | |
except Exception as cpu_error: | |
logger.warning(f"Could not move model to CPU: {cpu_error}") | |
# Step 2: Clear model cache if available | |
if hasattr(model, 'clear_cache'): | |
model.clear_cache() | |
# Step 3: Delete model and tokenizer references | |
del model | |
del tokenizer | |
model = None | |
tokenizer = None | |
# Step 4: Reset model name | |
MODEL_NAME = "No Model Loaded" | |
# Step 5: Clear metadata deque | |
generation_metadata.clear() | |
# Step 6: Aggressive garbage collection (multiple rounds) | |
for i in range(5): # More aggressive - 5 rounds | |
gc.collect() | |
time.sleep(0.1) # Small delay between rounds | |
# Step 7: Aggressive CUDA cleanup | |
if torch.cuda.is_available(): | |
logger.info("Performing aggressive CUDA cleanup...") | |
# Multiple rounds of cache clearing | |
for i in range(5): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
# Additional PyTorch CUDA cleanup | |
if hasattr(torch.cuda, 'ipc_collect'): | |
torch.cuda.ipc_collect() | |
# Reset memory stats | |
if hasattr(torch.cuda, 'reset_peak_memory_stats'): | |
torch.cuda.reset_peak_memory_stats() | |
if hasattr(torch.cuda, 'reset_accumulated_memory_stats'): | |
torch.cuda.reset_accumulated_memory_stats() | |
time.sleep(0.1) | |
# Step 8: Force PyTorch to release all unused memory | |
if torch.cuda.is_available(): | |
try: | |
# Try to trigger the memory pool cleanup | |
torch.cuda.empty_cache() | |
# Force a small allocation and deallocation to trigger cleanup | |
dummy_tensor = torch.zeros(1, device='cuda') | |
del dummy_tensor | |
torch.cuda.empty_cache() | |
logger.info("Forced memory pool cleanup") | |
except Exception as cleanup_error: | |
logger.warning(f"Advanced cleanup failed: {cleanup_error}") | |
# Step 9: Final garbage collection | |
gc.collect() | |
logger.info("Model unloaded successfully with aggressive cleanup") | |
return "โ Model unloaded with aggressive memory cleanup" | |
except Exception as e: | |
logger.error(f"Error unloading model: {str(e)}", exc_info=True) | |
# Emergency cleanup even if unload fails | |
model = None | |
tokenizer = None | |
MODEL_NAME = "No Model Loaded" | |
generation_metadata.clear() | |
# Emergency memory cleanup | |
for _ in range(3): | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return f"โ Error unloading model: {str(e)} (Emergency cleanup performed)" | |
def cleanup_memory(): | |
"""Enhanced memory cleanup function with PyTorch optimizations""" | |
try: | |
# Clear Python garbage | |
gc.collect() | |
# Clear CUDA cache if available | |
if torch.cuda.is_available(): | |
# Multiple aggressive cleanup rounds | |
for i in range(3): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
if hasattr(torch.cuda, 'ipc_collect'): | |
torch.cuda.ipc_collect() | |
# PyTorch specific memory management | |
if hasattr(torch.cuda, 'reset_peak_memory_stats'): | |
torch.cuda.reset_peak_memory_stats() | |
if hasattr(torch.cuda, 'reset_accumulated_memory_stats'): | |
torch.cuda.reset_accumulated_memory_stats() | |
# Brief pause between cleanup rounds | |
time.sleep(0.1) | |
# Clear metadata deque | |
generation_metadata.clear() | |
# Force garbage collection again | |
gc.collect() | |
logger.info("Enhanced memory cleanup completed") | |
return "๐งน Enhanced memory cleanup completed" | |
except Exception as e: | |
logger.error(f"Memory cleanup error: {e}") | |
return f"Memory cleanup error: {e}" | |
def nuclear_memory_cleanup(): | |
"""Nuclear option: Complete VRAM reset (use if normal unload doesn't work)""" | |
global model, tokenizer, MODEL_NAME | |
try: | |
logger.info("Performing nuclear memory cleanup...") | |
# Force unload everything | |
model = None | |
tokenizer = None | |
MODEL_NAME = "No Model Loaded" | |
generation_metadata.clear() | |
# Import PyTorch again to reset some internal states | |
import torch | |
# Multiple aggressive cleanup rounds | |
for round_num in range(10): # Very aggressive - 10 rounds | |
gc.collect() | |
if torch.cuda.is_available(): | |
# Multiple types of CUDA cleanup | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
# Try to reset CUDA context | |
try: | |
if hasattr(torch.cuda, 'ipc_collect'): | |
torch.cuda.ipc_collect() | |
if hasattr(torch.cuda, 'memory_summary'): | |
logger.info(f"Round {round_num + 1}: {torch.cuda.memory_summary()}") | |
except Exception: | |
pass | |
# Reset memory stats | |
try: | |
if hasattr(torch.cuda, 'reset_peak_memory_stats'): | |
torch.cuda.reset_peak_memory_stats() | |
if hasattr(torch.cuda, 'reset_accumulated_memory_stats'): | |
torch.cuda.reset_accumulated_memory_stats() | |
except Exception: | |
pass | |
time.sleep(0.1) | |
# Final attempt: allocate and free a small tensor to trigger cleanup | |
if torch.cuda.is_available(): | |
try: | |
for _ in range(5): | |
dummy = torch.zeros(1024, 1024, device='cuda') # 4MB tensor | |
del dummy | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
except Exception as nuclear_error: | |
logger.warning(f"Nuclear tensor cleanup failed: {nuclear_error}") | |
logger.info("Nuclear memory cleanup completed") | |
return "โข๏ธ Nuclear memory cleanup completed! VRAM should be minimal now." | |
except Exception as e: | |
logger.error(f"Nuclear cleanup error: {e}") | |
return f"โข๏ธ Nuclear cleanup error: {e}" | |
def get_memory_stats(): | |
"""Get comprehensive VRAM usage information""" | |
if not torch.cuda.is_available(): | |
return """ | |
<div style="text-align: center; padding: 15px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white;"> | |
<h3 style="margin: 0; font-size: 16px;">๐ป CPU Mode</h3> | |
<p style="margin: 5px 0; opacity: 0.9;">GPU not available</p> | |
</div> | |
""" | |
try: | |
torch.cuda.synchronize() | |
total = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
allocated = torch.cuda.memory_allocated(0) / (1024**3) | |
reserved = torch.cuda.memory_reserved(0) / (1024**3) | |
free = total - reserved | |
usage_percent = (reserved/total)*100 | |
# Get peak memory if available | |
peak_allocated = 0 | |
if hasattr(torch.cuda, 'max_memory_allocated'): | |
peak_allocated = torch.cuda.max_memory_allocated(0) / (1024**3) | |
# Dynamic color based on usage | |
if usage_percent < 50: | |
color = "#10b981" # Green | |
elif usage_percent < 80: | |
color = "#f59e0b" # Orange | |
else: | |
color = "#ef4444" # Red | |
return f""" | |
<div style="text-align: center; padding: 15px; background: linear-gradient(135deg, {color}22 0%, {color}44 100%); border: 2px solid {color}; border-radius: 10px;"> | |
<h3 style="margin: 0; font-size: 16px; color: {color};">๐ฎ VRAM Usage</h3> | |
<div style="margin: 10px 0;"> | |
<div style="background: #f3f4f6; border-radius: 8px; height: 8px; overflow: hidden;"> | |
<div style="width: {usage_percent}%; height: 100%; background: {color}; transition: width 0.3s ease;"></div> | |
</div> | |
</div> | |
<p style="margin: 5px 0; font-weight: 600;">Total: {total:.2f} GB</p> | |
<p style="margin: 5px 0;">Allocated: {allocated:.2f} GB ({usage_percent:.1f}%)</p> | |
<p style="margin: 5px 0;">Reserved: {reserved:.2f} GB</p> | |
<p style="margin: 5px 0;">Free: {free:.2f} GB</p> | |
<p style="margin: 5px 0; font-size: 12px; opacity: 0.8;">Peak: {peak_allocated:.2f} GB</p> | |
<p style="margin: 5px 0; font-size: 10px; opacity: 0.6;">RAM-Saving Streamer Active</p> | |
</div> | |
""" | |
except Exception as e: | |
logger.error(f"Error getting memory stats: {str(e)}") | |
return f""" | |
<div style="text-align: center; padding: 15px; background: #fee2e2; border: 2px solid #ef4444; border-radius: 10px;"> | |
<h3 style="margin: 0; color: #ef4444;">โ Error</h3> | |
<p style="margin: 5px 0;">{str(e)}</p> | |
</div> | |
""" | |
def process_latex_content(text): | |
"""Enhanced LaTeX processing for streaming without UI glitches""" | |
# Don't process LaTeX here - let Gradio handle it natively | |
# Just return the text as-is for now | |
return text | |
def process_think_tags(text): | |
"""Process thinking tags with progressive streaming support""" | |
# Check if we're in the middle of generating a think section | |
if '<think>' in text and '</think>' not in text: | |
# We're currently generating inside a think section | |
parts = text.split('<think>') | |
if len(parts) == 2: | |
before_think = parts[0] | |
thinking_content = parts[1] | |
# Create a progressive thinking display | |
formatted_thinking = f""" | |
<div style="background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); border-left: 4px solid #6366f1; padding: 12px; margin: 8px 0; border-radius: 8px;"> | |
<div style="display: flex; align-items: center; margin-bottom: 8px;"> | |
<span style="font-size: 16px; margin-right: 8px;">๐ค</span> | |
<strong style="color: #4338ca;">Thinking...</strong> | |
</div> | |
<div style="color: #475569; font-style: italic;">{thinking_content}</div> | |
</div> | |
""" | |
return before_think + formatted_thinking | |
# Handle completed think sections | |
think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL) | |
def replace_think(match): | |
think_content = match.group(1).strip() | |
return f""" | |
<div style="background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); border-left: 4px solid #6366f1; padding: 12px; margin: 8px 0; border-radius: 8px;"> | |
<div style="display: flex; align-items: center; margin-bottom: 8px;"> | |
<span style="font-size: 16px; margin-right: 8px;">๐ค</span> | |
<strong style="color: #4338ca;">Thinking...</strong> | |
</div> | |
<div style="color: #475569; font-style: italic;">{think_content}</div> | |
</div> | |
""" | |
# Replace completed <think> tags with formatted version | |
processed_text = think_pattern.sub(replace_think, text) | |
return processed_text | |
def calculate_generation_metrics(start_time, total_tokens): | |
"""Calculate generation metrics""" | |
end_time = time.time() | |
generation_time = end_time - start_time | |
tokens_per_second = total_tokens / generation_time if generation_time > 0 else 0 | |
return { | |
"generation_time": generation_time, | |
"total_tokens": total_tokens, | |
"tokens_per_second": tokens_per_second, | |
"model_name": MODEL_NAME | |
} | |
def format_metadata_tooltip(metadata): | |
"""Format metadata for tooltip display""" | |
return f"""Model: {metadata['model_name']} | |
Tokens: {metadata['total_tokens']} | |
Speed: {metadata['tokens_per_second']:.2f} tok/s | |
Time: {metadata['generation_time']:.2f}s""" | |
def add_metadata_to_response(response_text, metadata): | |
"""Add metadata icon with tooltip to the response""" | |
tooltip_content = format_metadata_tooltip(metadata) | |
# Create a metadata icon with tooltip using HTML | |
metadata_html = f""" | |
<div style="position: relative; display: inline-block; margin-left: 8px;"> | |
<span class="metadata-icon" style="cursor: help; opacity: 0.6; font-size: 14px;" title="{tooltip_content}">โน๏ธ</span> | |
</div> | |
""" | |
# Add metadata icon at the end of the response | |
return response_text + "\n\n" + metadata_html | |
def chat_with_model(message, history, system_prompt, temp, top_p_val, top_k_val, rep_penalty_val, memory_opt): | |
""" | |
Enhanced chat function with RAM-saving streamer and improved memory management. | |
Uses direct generation approach for better memory control and VRAM efficiency. | |
""" | |
global model, tokenizer, generation_metadata | |
# Check if model is loaded | |
if model is None or tokenizer is None: | |
return "โ Model not loaded. Please load the model first." | |
# Initialize variables for cleanup | |
input_ids = None | |
streamer = None | |
try: | |
# Record start time for metrics | |
start_time = time.time() | |
token_count = 0 | |
# Format conversation for model | |
messages = [{"role": "system", "content": system_prompt}] | |
# Add chat history - HANDLE BOTH FORMATS (tuples from original and dicts from new) | |
for h in history: | |
if isinstance(h, dict): | |
# New dict format | |
if h.get("role") == "user": | |
messages.append({"role": "user", "content": h["content"]}) | |
elif h.get("role") == "assistant": | |
messages.append({"role": "assistant", "content": h["content"]}) | |
else: | |
# Original tuple format (user_msg, bot_msg) | |
if len(h) >= 2: | |
messages.append({"role": "user", "content": h[0]}) | |
if h[1] is not None: | |
messages.append({"role": "assistant", "content": h[1]}) | |
# Add the current message | |
messages.append({"role": "user", "content": message}) | |
# Wrap generation in torch.no_grad() to prevent gradient accumulation | |
with torch.no_grad(): | |
# Create model input with memory-efficient approach | |
input_ids = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
) | |
# Handle edge case | |
if input_ids.ndim == 1: | |
input_ids = input_ids.unsqueeze(0) | |
# Move to device | |
input_ids = input_ids.to(model.device) | |
# Setup RAM-saving streamer | |
streamer = RAMSavingIteratorStreamer( | |
tokenizer, | |
skip_special_tokens=True, | |
skip_prompt=True, | |
timeout=1.0 | |
) | |
# Set prompt length for the streamer | |
streamer.set_prompt_length(input_ids.shape[1]) | |
# Pre-generation memory cleanup (only if memory optimization is on) | |
if memory_opt: | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Conditional generation parameters based on memory optimization | |
gen_kwargs = { | |
"input_ids": input_ids, | |
"max_new_tokens": MAX_LENGTH, | |
"temperature": temp, | |
"top_p": top_p_val, | |
"top_k": top_k_val, | |
"repetition_penalty": rep_penalty_val, | |
"do_sample": temp > 0, | |
"streamer": streamer, | |
"use_cache": not memory_opt, # Disable cache only if memory optimization is on | |
} | |
# Generate in a thread for real-time streaming | |
thread = Thread( | |
target=model.generate, | |
kwargs=gen_kwargs, | |
daemon=True | |
) | |
thread.start() | |
# Stream the response with conditional memory management | |
partial_text = "" | |
try: | |
for new_text in streamer: | |
partial_text += new_text | |
token_count += 1 | |
# Process the text to handle think tags while preserving LaTeX | |
processed_text = process_think_tags(partial_text) | |
yield processed_text | |
# Conditional cleanup based on memory optimization setting (less frequent) | |
if memory_opt and token_count % 150 == 0: # Reduced frequency for performance | |
gc.collect() # Only light cleanup if memory optimization is on | |
except StopIteration: | |
# Normal end of generation | |
pass | |
except Exception as stream_error: | |
logger.error(f"Streaming error: {stream_error}") | |
yield f"โ Streaming error: {stream_error}" | |
return | |
finally: | |
# Add metadata to final response | |
try: | |
metrics = calculate_generation_metrics(start_time, token_count) | |
partial_text = add_metadata_to_response(partial_text, metrics) | |
except Exception as e: | |
logger.warning(f"Couldn't add metadata: {str(e)}") | |
yield partial_text | |
# Ensure thread completion | |
if thread.is_alive(): | |
thread.join(timeout=5.0) | |
if thread.is_alive(): | |
logger.warning("Generation thread did not complete in time") | |
# Calculate generation metrics | |
try: | |
metrics = calculate_generation_metrics(start_time, token_count) | |
# Store metadata (using deque with max size to prevent memory leaks) | |
generation_metadata.append(metrics) | |
# Log the metrics | |
logger.info(f"Generation metrics - Tokens: {metrics['total_tokens']}, Speed: {metrics['tokens_per_second']:.2f} tok/s, Time: {metrics['generation_time']:.2f}s") | |
except Exception as metrics_error: | |
logger.warning(f"Error calculating metrics: {metrics_error}") | |
# Final cleanup | |
try: | |
# Clean up streamer | |
if streamer: | |
streamer.cleanup() | |
del streamer | |
streamer = None | |
# Clean up input tensors | |
if input_ids is not None: | |
del input_ids | |
input_ids = None | |
# Conditional cleanup based on memory optimization setting | |
if memory_opt: | |
# Aggressive cleanup only if memory optimization is enabled | |
if torch.cuda.is_available(): | |
for _ in range(2): # Reduced rounds for performance | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
# Force garbage collection | |
for _ in range(2): | |
gc.collect() | |
else: | |
# Light cleanup for performance mode | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
logger.info(f"Generation completed, {token_count} tokens, memory_opt: {memory_opt}, VRAM saved with RAM-saving streamer") | |
except Exception as cleanup_error: | |
logger.warning(f"Final cleanup warning: {cleanup_error}") | |
except Exception as e: | |
logger.error(f"Error in chat_with_model: {str(e)}", exc_info=True) | |
# Emergency cleanup | |
try: | |
if streamer: | |
streamer.cleanup() | |
del streamer | |
if input_ids is not None: | |
del input_ids | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
except Exception as emergency_cleanup_error: | |
logger.error(f"Emergency cleanup failed: {emergency_cleanup_error}") | |
yield f"โ Error: {str(e)}" | |
def update_model_name(): | |
"""Update the displayed model name""" | |
return f"๐ฎ AI Chat Assistant ({MODEL_NAME})" | |
def add_page_refresh_warning(): | |
"""Add JavaScript to warn about page refresh when model is loaded""" | |
return """ | |
<script> | |
window.addEventListener('beforeunload', function (e) { | |
// Check if model is loaded by looking for specific text in the page | |
const statusElements = document.querySelectorAll('input[type="text"], textarea'); | |
let modelLoaded = false; | |
statusElements.forEach(element => { | |
if (element.value && element.value.includes('Model loaded successfully')) { | |
modelLoaded = true; | |
} | |
}); | |
if (modelLoaded) { | |
e.preventDefault(); | |
e.returnValue = 'A model is currently loaded. Are you sure you want to leave?'; | |
return 'A model is currently loaded. Are you sure you want to leave?'; | |
} | |
}); | |
</script> | |
""" | |
# Custom CSS for elegant styling with fixed dropdown behavior | |
custom_css = """ | |
/* Main container styling */ | |
.gradio-container { | |
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important; | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
min-height: 100vh; | |
} | |
/* Header styling */ | |
.header-text { | |
background: rgba(255, 255, 255, 0.95); | |
backdrop-filter: blur(10px); | |
border-radius: 15px; | |
padding: 20px; | |
margin: 20px 0; | |
text-align: center; | |
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1); | |
border: 1px solid rgba(255, 255, 255, 0.2); | |
} | |
/* Chat interface styling */ | |
.chat-container { | |
background: rgba(255, 255, 255, 0.95) !important; | |
border-radius: 20px !important; | |
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1) !important; | |
border: 1px solid rgba(255, 255, 255, 0.2) !important; | |
backdrop-filter: blur(10px) !important; | |
} | |
/* Control panel styling */ | |
.control-panel { | |
background: rgba(255, 255, 255, 0.9) !important; | |
border-radius: 15px !important; | |
padding: 20px !important; | |
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1) !important; | |
border: 1px solid rgba(255, 255, 255, 0.3) !important; | |
backdrop-filter: blur(10px) !important; | |
overflow: visible !important; /* Allow dropdowns to overflow */ | |
} | |
/* Button styling */ | |
.btn-primary { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
border: none !important; | |
border-radius: 10px !important; | |
color: white !important; | |
font-weight: 600 !important; | |
transition: all 0.3s ease !important; | |
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important; | |
} | |
.btn-primary:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 8px 25px rgba(102, 126, 234, 0.6) !important; | |
} | |
.btn-secondary { | |
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important; | |
border: none !important; | |
border-radius: 10px !important; | |
color: white !important; | |
font-weight: 600 !important; | |
transition: all 0.3s ease !important; | |
} | |
/* Input field styling */ | |
.input-field { | |
border-radius: 10px !important; | |
border: 2px solid rgba(102, 126, 234, 0.2) !important; | |
transition: all 0.3s ease !important; | |
} | |
.input-field:focus { | |
border-color: #667eea !important; | |
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important; | |
} | |
/* Dropdown fixes */ | |
.dropdown-container { | |
position: relative !important; | |
z-index: 1000 !important; | |
overflow: visible !important; | |
} | |
/* Fix dropdown menu positioning and styling */ | |
.dropdown select, | |
.dropdown-menu, | |
.svelte-select, | |
.svelte-select-list { | |
position: relative !important; | |
z-index: 1001 !important; | |
background: white !important; | |
border: 2px solid rgba(102, 126, 234, 0.2) !important; | |
border-radius: 10px !important; | |
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.15) !important; | |
max-height: 200px !important; | |
overflow-y: auto !important; | |
} | |
/* Fix dropdown option styling */ | |
.dropdown option, | |
.svelte-select-option { | |
padding: 8px 12px !important; | |
background: white !important; | |
color: #333 !important; | |
border: none !important; | |
} | |
.dropdown option:hover, | |
.svelte-select-option:hover { | |
background: #f0f0f0 !important; | |
color: #667eea !important; | |
} | |
/* Ensure dropdown arrow is clickable */ | |
.dropdown::after, | |
.dropdown-arrow { | |
pointer-events: none !important; | |
z-index: 1002 !important; | |
} | |
/* Fix any overflow issues in parent containers */ | |
.gradio-group, | |
.gradio-column { | |
overflow: visible !important; | |
} | |
/* Accordion styling */ | |
.accordion { | |
border-radius: 10px !important; | |
border: 1px solid rgba(102, 126, 234, 0.2) !important; | |
overflow: visible !important; /* Allow dropdowns to overflow accordion */ | |
} | |
/* Status indicators */ | |
.status-success { | |
color: #10b981 !important; | |
font-weight: 600 !important; | |
} | |
.status-error { | |
color: #ef4444 !important; | |
font-weight: 600 !important; | |
} | |
/* Reduced transition frequency to avoid conflicts */ | |
.gradio-container * { | |
transition: background-color 0.3s ease, border-color 0.3s ease !important; | |
} | |
/* Chat bubble styling */ | |
.message { | |
border-radius: 18px !important; | |
padding: 12px 16px !important; | |
margin: 8px 0 !important; | |
max-width: 80% !important; | |
} | |
.user-message { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
color: white !important; | |
margin-left: auto !important; | |
} | |
.bot-message { | |
background: #f8fafc !important; | |
border: 1px solid #e2e8f0 !important; | |
} | |
/* Metadata tooltip styling - Enhanced */ | |
.metadata-icon { | |
display: inline-block; | |
margin-left: 8px; | |
cursor: help; | |
opacity: 0.6; | |
transition: opacity 0.3s ease, transform 0.2s ease; | |
font-size: 14px; | |
user-select: none; | |
vertical-align: middle; | |
} | |
.metadata-icon:hover { | |
opacity: 1; | |
transform: scale(1.1); | |
} | |
/* Enhanced tooltip styling */ | |
.metadata-icon[title]:hover::after { | |
content: attr(title); | |
position: absolute; | |
bottom: 100%; | |
left: 50%; | |
transform: translateX(-50%); | |
background: rgba(0, 0, 0, 0.9); | |
color: white; | |
padding: 8px 12px; | |
border-radius: 6px; | |
font-size: 12px; | |
white-space: pre-line; | |
z-index: 1000; | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); | |
margin-bottom: 5px; | |
min-width: 200px; | |
text-align: left; | |
} | |
.metadata-icon[title]:hover::before { | |
content: ''; | |
position: absolute; | |
bottom: 100%; | |
left: 50%; | |
transform: translateX(-50%); | |
border: 5px solid transparent; | |
border-top-color: rgba(0, 0, 0, 0.9); | |
z-index: 1001; | |
} | |
/* Compact system prompt */ | |
.compact-prompt { | |
min-height: 40px !important; | |
transition: min-height 0.3s ease !important; | |
} | |
.compact-prompt:focus { | |
min-height: 80px !important; | |
} | |
""" | |
# Main application | |
with gr.Blocks(css=custom_css, title="๐ฎ AI Chat Assistant") as demo: | |
# Add page refresh warning script | |
gr.HTML(add_page_refresh_warning()) | |
# Header | |
with gr.Row(): | |
title = gr.Markdown("# ๐ฎ AI Chat Assistant (No Model Loaded)", elem_classes="header-text") | |
with gr.Row(equal_height=True): | |
# Main chat area (left side - 70% width) | |
with gr.Column(scale=7, elem_classes="chat-container"): | |
# Compact system prompt (changed from 4 lines to 1) | |
system_prompt = gr.Textbox( | |
label="๐ฏ System Prompt", | |
value="You are a helpful AI assistant.", | |
lines=1, # Changed from 4 to 1 | |
elem_classes="input-field compact-prompt" | |
) | |
# Generation settings in accordion | |
with gr.Accordion("โ๏ธ Generation Settings", open=False, elem_classes="accordion"): | |
with gr.Row(): | |
temperature = gr.Slider(0.0, 2.0, DEFAULT_TEMPERATURE, step=0.05, label="๐ก๏ธ Temperature") | |
top_p = gr.Slider(0.0, 1.0, DEFAULT_TOP_P, step=0.01, label="๐ฏ Top-p") | |
with gr.Row(): | |
top_k = gr.Slider(1, 200, DEFAULT_TOP_K, step=1, label="๐ Top-k") | |
rep_penalty = gr.Slider(1.0, 2.0, DEFAULT_REP_PENALTY, step=0.01, label="๐ Repetition Penalty") | |
# Memory optimization for chat (moved here to be defined before use) | |
memory_opt_chat = gr.Checkbox( | |
label="๐ง Memory Optimization for Chat", | |
value=True, | |
info="Use memory optimization during chat generation (disables KV cache)" | |
) | |
# Chat interface using original gr.ChatInterface for fast streaming and stop button | |
chatbot = gr.Chatbot( | |
height=500, | |
latex_delimiters=[ | |
{"left": "$", "right": "$", "display": True}, | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True} | |
], | |
show_copy_button=True, | |
avatar_images=("๐ค", "๐ค"), | |
type="messages", | |
render_markdown=True | |
) | |
chat_interface = gr.ChatInterface( | |
fn=chat_with_model, | |
chatbot=chatbot, | |
additional_inputs=[system_prompt, temperature, top_p, top_k, rep_penalty, memory_opt_chat], | |
type="messages", | |
submit_btn="Send ๐ค", | |
stop_btn="โน๏ธ Stop" | |
) | |
# Control panel (right side - 30% width) | |
with gr.Column(scale=3, elem_classes="control-panel"): | |
# Model status and controls | |
with gr.Group(): | |
gr.Markdown("### ๐ Model Controls") | |
with gr.Row(): | |
load_btn = gr.Button("๐ Load Model", variant="primary", elem_classes="btn-primary") | |
unload_btn = gr.Button("๐๏ธ Unload", variant="secondary", elem_classes="btn-secondary") | |
model_status = gr.Textbox( | |
label="๐ Status", | |
value="Model not loaded", | |
interactive=False, | |
elem_classes="input-field" | |
) | |
progress_display = gr.Textbox( | |
label="๐ Progress", | |
value="Ready to load model", | |
interactive=False, | |
elem_classes="input-field" | |
) | |
# Model selection | |
with gr.Group(): | |
gr.Markdown("### ๐๏ธ Model Selection") | |
model_source = gr.Radio( | |
choices=["Hugging Face Model", "Local Path"], | |
value="Local Path", # Changed default to Local Path | |
label="๐ Model Source" | |
) | |
# HF Model search and selection (initially hidden) | |
with gr.Group(visible=False) as hf_group: | |
model_search = gr.Textbox( | |
label="๐ Search Models", | |
placeholder="e.g., microsoft/Phi-3, meta-llama/Llama-3, ykarout/your-model", | |
elem_classes="input-field" | |
) | |
hf_model = gr.Dropdown( | |
label="๐ Select Model", | |
choices=[], | |
interactive=True, | |
elem_classes="input-field dropdown-container", | |
allow_custom_value=True, # Allow typing custom model names | |
filterable=True # Enable filtering | |
) | |
# Local path group (visible by default) | |
with gr.Group(visible=True) as local_group: | |
local_path = gr.Textbox( | |
value=LOCAL_MODELS_BASE, # Changed default to new base location | |
label="๐ Local Models Base Path", | |
elem_classes="input-field" | |
) | |
# Button to refresh local models | |
refresh_local_btn = gr.Button("๐ Scan Local Models", elem_classes="btn-secondary") | |
# Dropdown for local models with better configuration | |
local_models_dropdown = gr.Dropdown( | |
label="๐ Available Local Models", | |
choices=[], | |
interactive=True, | |
elem_classes="input-field dropdown-container", | |
allow_custom_value=False, # Don't allow custom for local models | |
filterable=True # Enable filtering | |
) | |
quantization = gr.Radio( | |
choices=["4bit", "8bit", "bf16", "f16"], | |
value="4bit", | |
label="โก Quantization" | |
) | |
# Advanced memory optimization toggle | |
memory_optimization = gr.Checkbox( | |
label="๐ง Advanced Memory Optimization", | |
value=True, | |
info="Reduces VRAM usage but may slightly impact speed" | |
) | |
# Note: Memory optimization for chat is now in Generation Settings | |
# Memory stats with cleanup buttons | |
with gr.Group(): | |
gr.Markdown("### ๐พ System Status") | |
memory_info = gr.HTML() | |
with gr.Row(): | |
refresh_btn = gr.Button("โป Refresh Stats", elem_classes="btn-secondary") | |
cleanup_btn = gr.Button("๐งน Clean Memory", elem_classes="btn-secondary") | |
with gr.Row(): | |
nuclear_btn = gr.Button("โข๏ธ Nuclear Cleanup", elem_classes="btn-secondary", variant="stop") | |
# Event handlers | |
# Model search functionality for HF | |
model_search.change( | |
update_model_dropdown, | |
inputs=[model_search], | |
outputs=[hf_model] | |
) | |
# Show/hide model selection based on source | |
def toggle_model_source(choice): | |
return ( | |
gr.Group(visible=choice == "Hugging Face Model"), | |
gr.Group(visible=choice == "Local Path") | |
) | |
model_source.change( | |
toggle_model_source, | |
inputs=[model_source], | |
outputs=[hf_group, local_group] | |
) | |
# Local model scanning | |
refresh_local_btn.click( | |
update_local_models_dropdown, | |
inputs=[local_path], | |
outputs=[local_models_dropdown] | |
) | |
# Auto-scan on path change | |
local_path.change( | |
update_local_models_dropdown, | |
inputs=[local_path], | |
outputs=[local_models_dropdown] | |
) | |
# Model loading with progress | |
load_btn.click( | |
load_model_with_progress, | |
inputs=[model_source, hf_model, local_path, local_models_dropdown, quantization, memory_optimization], | |
outputs=[progress_display] | |
).then( | |
lambda: "โ Model loaded successfully!" if model is not None else "โ Model loading failed", | |
outputs=[model_status] | |
).then( | |
get_memory_stats, | |
outputs=[memory_info] | |
).then( | |
update_model_name, | |
outputs=[title] | |
) | |
# Model unloading | |
unload_btn.click( | |
unload_model, | |
outputs=[model_status] | |
).then( | |
lambda: "Ready to load model", | |
outputs=[progress_display] | |
).then( | |
get_memory_stats, | |
outputs=[memory_info] | |
).then( | |
lambda: "# ๐ฎ AI Chat Assistant (No Model Loaded)", | |
outputs=[title] | |
) | |
# Refresh memory stats | |
refresh_btn.click(get_memory_stats, outputs=[memory_info]) | |
# Manual memory cleanup | |
cleanup_btn.click(cleanup_memory, outputs=[]).then( | |
get_memory_stats, outputs=[memory_info] | |
) | |
# Nuclear memory cleanup | |
nuclear_btn.click(nuclear_memory_cleanup, outputs=[]).then( | |
get_memory_stats, outputs=[memory_info] | |
) | |
# Initialize on load | |
demo.load(get_memory_stats, outputs=[memory_info]) | |
demo.load( | |
lambda: update_local_models_dropdown(LOCAL_MODELS_BASE), | |
outputs=[local_models_dropdown] | |
) | |
# Enable queue for streaming | |
demo.queue() |