import gradio as gr import os import torch import time import uuid from transformers import pipeline, StoppingCriteria, StoppingCriteriaList import traceback import sys from config import MODEL_SETTINGS, CONVERSATION_SETTINGS from chat_storage import initialize_chat_storage, save_chat_turn, get_current_session_info, ChatStorage print("===== Application Startup - Authentic Alice =====") print("๐Ÿš€ Loading ChaiML model with Alice's authentic personality...") # HYBRID CACHING STRATEGY: # - Persistent storage (/data/): Model weights loaded once at startup # - Fast ephemeral (/tmp/): Kernel cache accessed every conversation # Create persistent cache directories (slow but keeps downloads) import os persistent_cache_available = True try: os.makedirs("/data/transformers_cache", exist_ok=True) os.makedirs("/data/hf_cache", exist_ok=True) os.makedirs("/data/matplotlib_cache", exist_ok=True) print("โœ… Persistent storage (/data/) available") # Set environment variables to use persistent storage os.environ["TRANSFORMERS_CACHE"] = "/data/transformers_cache" os.environ["HF_HOME"] = "/data/hf_cache" os.environ["MPLCONFIGDIR"] = "/data/matplotlib_cache" except (PermissionError, OSError) as e: print(f"โš ๏ธ Cannot access /data/ directory: {e}") print("๐Ÿ”„ Falling back to ephemeral storage (/tmp/) for all caches") persistent_cache_available = False # Fallback: set environment variables to /tmp/ os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache" os.environ["HF_HOME"] = "/tmp/hf_cache" os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_cache" # Create /tmp/ directories instead os.makedirs("/tmp/transformers_cache", exist_ok=True) os.makedirs("/tmp/hf_cache", exist_ok=True) os.makedirs("/tmp/matplotlib_cache", exist_ok=True) # Create fast cache directories (fast but temporary) os.makedirs("/tmp/torch_cache", exist_ok=True) os.makedirs("/tmp/torch_kernels", exist_ok=True) # Configure PyTorch CUDA memory allocation to reduce fragmentation os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" def investigate_cache_bloat(): """Investigate what's taking up 115GB in the cache""" cache_path = "/data/transformers_cache" print(f"\n๐Ÿ” INVESTIGATING CACHE BLOAT at {cache_path}") if not os.path.exists(cache_path): print("โŒ Cache directory doesn't exist") return # Find largest files and directories large_items = [] try: for root, dirs, files in os.walk(cache_path): for file in files: file_path = os.path.join(root, file) try: size = os.path.getsize(file_path) if size > 100 * 1024 * 1024: # Files larger than 100MB size_gb = size / (1024 * 1024 * 1024) rel_path = os.path.relpath(file_path, cache_path) large_items.append((size_gb, rel_path)) except: pass # Sort by size, largest first large_items.sort(reverse=True) print(f"๐Ÿ“Š Found {len(large_items)} files larger than 100MB:") total_gb = 0 for size_gb, path in large_items[:15]: # Show top 15 print(f" {size_gb:.1f}GB - {path}") total_gb += size_gb if len(large_items) > 15: remaining = sum(size for size, _ in large_items[15:]) print(f" ... and {len(large_items)-15} more files totaling {remaining:.1f}GB") total_gb += remaining print(f"๐Ÿ“ˆ Total large files: {total_gb:.1f}GB") # Look for model-specific patterns model_dirs = [] for item in os.listdir(cache_path): item_path = os.path.join(cache_path, item) if os.path.isdir(item_path): dir_size = sum(os.path.getsize(os.path.join(dirpath, filename)) for dirpath, dirnames, filenames in os.walk(item_path) for filename in filenames) / (1024 * 1024 * 1024) if dir_size > 1: # Directories larger than 1GB model_dirs.append((dir_size, item)) model_dirs.sort(reverse=True) print(f"\n๐Ÿ“ Large directories:") for size_gb, dirname in model_dirs: print(f" {size_gb:.1f}GB - {dirname}") except Exception as e: print(f"โŒ Error investigating cache: {e}") def clear_transformers_cache(): """Clear the bloated transformers cache to start fresh""" import shutil cache_path = "/data/transformers_cache" print(f"\n๐Ÿงน CLEARING TRANSFORMERS CACHE at {cache_path}") if os.path.exists(cache_path): try: # Get size before deletion total_size = sum(os.path.getsize(os.path.join(dirpath, filename)) for dirpath, dirnames, filenames in os.walk(cache_path) for filename in filenames) / (1024 * 1024 * 1024) print(f"๐Ÿ“Š Deleting {total_size:.1f}GB of cached data...") shutil.rmtree(cache_path) os.makedirs(cache_path, exist_ok=True) print("โœ… Cache cleared successfully!") print("๐Ÿ”„ Model will be re-downloaded fresh on next load") except Exception as e: print(f"โŒ Error clearing cache: {e}") else: print("โš ๏ธ Cache directory doesn't exist") # Call investigation before model loading only when explicitly enabled if os.getenv("ENABLE_CACHE_DIAGNOSTICS") == "1": investigate_cache_bloat() else: print("โญ๏ธ Skipping cache scan at startup. Set ENABLE_CACHE_DIAGNOSTICS=1 to enable.") # Uncomment the next line if you want to clear the cache: # clear_transformers_cache() def log_cache_diagnostics(): """Diagnostic logging to verify cache effectiveness""" cache_dirs = { "TRANSFORMERS_CACHE (persistent)": "/data/transformers_cache", "HF_HOME (persistent)": "/data/hf_cache", "TORCH_HOME (fast)": "/tmp/torch_cache", "PYTORCH_KERNEL_CACHE (fast)": "/tmp/torch_kernels" } print("๐Ÿ” Cache Diagnostics:") for name, path in cache_dirs.items(): if os.path.exists(path): try: # Count files and estimate size file_count = sum([len(files) for r, d, files in os.walk(path)]) # Quick size estimate by checking a few files total_size = 0 file_sample = 0 for root, dirs, files in os.walk(path): for file in files[:10]: # Sample first 10 files try: total_size += os.path.getsize(os.path.join(root, file)) file_sample += 1 except: pass if file_sample >= 10: break avg_size = total_size / max(file_sample, 1) estimated_total_mb = (avg_size * file_count) / (1024 * 1024) # Test I/O performance test_file = os.path.join(path, "io_test.tmp") start = time.time() try: with open(test_file, "w") as f: f.write("test" * 1000) # 4KB write write_time = (time.time() - start) * 1000 start = time.time() with open(test_file, "r") as f: f.read() read_time = (time.time() - start) * 1000 os.remove(test_file) io_status = f"W:{write_time:.1f}ms R:{read_time:.1f}ms" except Exception as e: io_status = f"I/O ERROR: {e}" print(f" โœ… {name}") print(f" ๐Ÿ“ {file_count} files (~{estimated_total_mb:.1f}MB)") print(f" โšก I/O: {io_status}") except Exception as e: print(f" โŒ {name}: Error reading cache - {e}") else: print(f" โš ๏ธ {name}: Directory not found") log_cache_diagnostics() print("๐Ÿ“ Hybrid Cache Strategy Configured:") if persistent_cache_available: print(" ๐Ÿ  PERSISTENT (/data/) - Model weights (loaded once):") print(f" TRANSFORMERS_CACHE: {os.getenv('TRANSFORMERS_CACHE', 'Not set')}") print(f" HF_HOME: {os.getenv('HF_HOME', 'Not set')}") else: print(" โš ๏ธ EPHEMERAL (/tmp/) - All caches (cleared on restart):") print(f" TRANSFORMERS_CACHE: {os.getenv('TRANSFORMERS_CACHE', 'Not set')}") print(f" HF_HOME: {os.getenv('HF_HOME', 'Not set')}") print(" โšก FAST (/tmp/) - Tools & kernels (accessed every conversation):") print(f" TORCH_HOME: {os.getenv('TORCH_HOME', 'Not set')}") print(f" PYTORCH_KERNEL_CACHE_PATH: {os.getenv('PYTORCH_KERNEL_CACHE_PATH', 'Not set')}") print(" ๐Ÿ”ง Optimizations:") print(f" PYTORCH_CUDA_ALLOC_CONF: {os.getenv('PYTORCH_CUDA_ALLOC_CONF', 'Not set')}") # Initialize model on startup generator = None import time model_load_start = time.time() try: # Check GPU availability if torch.cuda.is_available(): print(f"GPU detected: {torch.cuda.get_device_name(0)}") print("๐Ÿ“ฅ Loading model weights from cache...") generator = pipeline( "text-generation", model="ChaiML/gptj_ppo_retry_and_continue", tokenizer="EleutherAI/gpt-j-6b", torch_dtype=torch.float32, # Use float32 for better compatibility device_map="auto", trust_remote_code=True ) else: print("No GPU detected, falling back to CPU") print("๐Ÿ“ฅ Loading model weights from cache...") generator = pipeline( "text-generation", model="ChaiML/gptj_ppo_retry_and_continue", tokenizer="EleutherAI/gpt-j-6b", torch_dtype=torch.float32, device_map="auto", trust_remote_code=True ) print("โœ… ChaiML model loaded successfully!") model_load_time = time.time() - model_load_start print(f"โฑ๏ธ Model loading took: {model_load_time:.2f} seconds") print("๐ŸŽญ Ready to chat with the authentic Alice!") except Exception as e: print(f"โŒ Error loading ChaiML model: {e}") print("๐Ÿ”„ Trying with alternative configuration...") try: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") if torch.cuda.is_available(): generator = pipeline( "text-generation", model="ChaiML/gptj_ppo_retry_and_continue", tokenizer=tokenizer, torch_dtype=torch.float32, # Use float32 for compatibility device_map="auto", trust_remote_code=True ) else: generator = pipeline( "text-generation", model="ChaiML/gptj_ppo_retry_and_continue", tokenizer=tokenizer, torch_dtype=torch.float32, # Use float32 for compatibility device_map="auto", trust_remote_code=True ) print("โœ… ChaiML model loaded with fallback configuration!") model_load_time = time.time() - model_load_start print(f"โฑ๏ธ Model loading took: {model_load_time:.2f} seconds") print("๐ŸŽญ Ready to chat with the authentic Alice!") except Exception as e2: print(f"โŒ Fallback also failed: {e2}") print("๐Ÿ”„ Trying CPU-only mode...") try: generator = pipeline( "text-generation", model="ChaiML/gptj_ppo_retry_and_continue", tokenizer=tokenizer, device="cpu", trust_remote_code=True ) print("โœ… ChaiML model loaded in CPU mode!") model_load_time = time.time() - model_load_start print(f"โฑ๏ธ Model loading took: {model_load_time:.2f} seconds") print("๐ŸŽญ Ready to chat with the authentic Alice (CPU mode)!") except Exception as e3: print(f"โŒ All fallbacks failed: {e3}") generator = None # Initialize base chat storage system (without creating a session) print("๐Ÿ—„๏ธ Initializing chat storage system...") from chat_storage import chat_storage # Restore any existing chats from HuggingFace local_files = chat_storage.list_chat_files() if not local_files: print("๐Ÿ“‚ No local chat files found, attempting HuggingFace restore...") chat_storage.restore_from_huggingface() chat_storage_initialized = True print("โœ… Chat storage system ready for per-user sessions") class NewlineStoppingCriteria(StoppingCriteria): def __init__(self, tokenizer): self.tokenizer = tokenizer self.newline_token_id = tokenizer.encode('\n', add_special_tokens=False)[0] if tokenizer.encode('\n', add_special_tokens=False) else None def __call__(self, input_ids, scores, **kwargs): if self.newline_token_id is None: return False return input_ids[0][-1] == self.newline_token_id def count_tokens(text, tokenizer): """Count the number of tokens in a text string""" if not text: return 0 return len(tokenizer.encode(text)) def generate_alice_response(prompt): """Generates a response from Alice based on a fully-formed prompt.""" if generator is None: return "Alice is currently unavailable. The ChaiML model failed to load." print(f"---PROMPT---\n{prompt}\n---END PROMPT---", flush=True) try: # Create stopping criteria for newlines stopping_criteria = StoppingCriteriaList([NewlineStoppingCriteria(generator.tokenizer)]) # Time the actual inference inference_start = time.time() response = generator( prompt, max_new_tokens=MODEL_SETTINGS["max_tokens"], temperature=MODEL_SETTINGS["temperature"], top_p=MODEL_SETTINGS["top_p"], top_k=MODEL_SETTINGS["top_k"], repetition_penalty=MODEL_SETTINGS["repetition_penalty"], num_return_sequences=1, do_sample=True, return_full_text=False, pad_token_id=generator.tokenizer.eos_token_id, eos_token_id=generator.tokenizer.eos_token_id, stopping_criteria=stopping_criteria ) inference_time = time.time() - inference_start print(f"โšก Inference completed in: {inference_time:.2f} seconds", flush=True) # If inference is slow, run cache diagnostics to see what's happening if inference_time > 5.0: print(f"๐Ÿšจ SLOW INFERENCE DETECTED ({inference_time:.2f}s) - Running cache diagnostics:") log_cache_diagnostics() print(f"Raw response from generator: {response}", flush=True) if response and len(response) > 0: alice_response = response[0]["generated_text"].strip() # Clean up response if "\nUser:" in alice_response: alice_response = alice_response.split("\nUser:")[0] if "\n\n" in alice_response: alice_response = alice_response.split("\n\n")[0] if "### " in alice_response: alice_response = alice_response.split("### ")[0] return alice_response.strip() else: return "I'm having trouble responding right now." except Exception as gen_error: print("!!! ERROR DURING GENERATION !!!", flush=True) traceback.print_exc(file=sys.stdout) return f"An error occurred during generation: {gen_error}" # Global session manager - maps Gradio session to our chat sessions user_sessions = {} def chat_with_alice(message, history): """Chat function with per-user session management - new session on page refresh""" if generator is None: error_msg = "Alice is currently unavailable. The ChaiML model failed to load." return error_msg, error_msg try: # Create a unique session identifier for this browser session # Since ChatInterface doesn't give us easy access to session info, # we'll create a session on first message and reuse it for the conversation # But create new sessions when the page is refreshed (new conversation) # Use a simple heuristic: if history is empty, it's likely a new session if not history or len(history) == 0: # This is a new conversation (page refresh or first visit) session_hash = f"session_{int(time.time())}_{uuid.uuid4().hex[:8]}" print(f"๐Ÿ†• Detected new conversation - creating session: {session_hash}") else: # Try to find existing session for this conversation # Look for a session that matches recent timing session_hash = None current_time = time.time() for s_hash, s_data in user_sessions.items(): if current_time - s_data.get('last_activity', 0) < 300: # 5 minutes session_hash = s_hash break if not session_hash: # No recent session found, create new one session_hash = f"session_{int(time.time())}_{uuid.uuid4().hex[:8]}" print(f"๐Ÿ†• No recent session found - creating new session: {session_hash}") # Create a new session for this user if they don't have one if session_hash not in user_sessions: print(f"๐Ÿ†• Creating new session for user (session: {session_hash})") initial_settings = { "model_settings": MODEL_SETTINGS, "conversation_settings": CONVERSATION_SETTINGS, "user_agent": "Web Interface", "session_hash": session_hash } session_id = chat_storage.start_new_session(initial_settings) if session_id: user_sessions[session_hash] = { "session_id": session_id, "last_activity": time.time() } print(f"โœ… New user session created: {session_id}") else: print("โŒ Failed to create user session") user_sessions[session_hash] = {"last_activity": time.time()} user_session_state = user_sessions[session_hash] user_session_state["last_activity"] = time.time() # Update activity time instruction = CONVERSATION_SETTINGS["instruction"] # Build conversation history with token counting conversation_history = "" total_tokens = count_tokens(instruction + "\n\n", generator.tokenizer) if history: # Convert history to messages, newest first messages = [] for msg in reversed(history): if isinstance(msg, dict): user_text = f"User: {msg.get('content', '')}\n" alice_text = f"Alice: {msg.get('content', '')}\n" messages.extend([alice_text, user_text]) elif isinstance(msg, (list, tuple)) and len(msg) >= 2: user_text = f"User: {msg[0]}\n" if msg[0] else "" alice_text = f"Alice: {msg[1]}\n" if msg[1] else "" messages.extend([alice_text, user_text]) # Add messages until we hit token limit for msg in messages: msg_tokens = count_tokens(msg, generator.tokenizer) if total_tokens + msg_tokens > 1024: break total_tokens += msg_tokens conversation_history = msg + conversation_history # Add current user message current_msg = f"User: {message}\nAlice:" total_tokens += count_tokens(current_msg, generator.tokenizer) conversation_history += current_msg # Simple format: instruction + conversation history full_prompt = instruction + "\n\n" + conversation_history alice_response = generate_alice_response(full_prompt) # Save this conversation turn to the user's specific session if chat_storage_initialized and user_session_state.get("session_id"): # Temporarily set the current session to this user's session original_session_id = chat_storage.current_session_id original_chat_file = chat_storage.current_chat_file # Set to user's session chat_storage.current_session_id = user_session_state["session_id"] # Reconstruct the chat file path session_files = chat_storage.list_chat_files() user_chat_file = None for file_info in session_files: if file_info["session_id"] == user_session_state["session_id"]: user_chat_file = file_info["filepath"] break if user_chat_file: chat_storage.current_chat_file = user_chat_file additional_data = { "inference_time": time.time() - time.time(), # Will be updated with actual time "model_settings": MODEL_SETTINGS.copy(), "total_tokens_used": total_tokens, "prompt_length": len(full_prompt) } save_success = chat_storage.save_turn(message, alice_response, additional_data) if save_success: turn_count = chat_storage._count_turns() print(f"๐Ÿ’พ Chat turn saved to user session {user_session_state['session_id']} (Turn #{turn_count})") else: print("โš ๏ธ Failed to save chat turn") # Restore original global session (if any) chat_storage.current_session_id = original_session_id chat_storage.current_chat_file = original_chat_file # Build the full transcript for the UI from the original, untruncated history transcript_parts = [] for user_turn, alice_turn in history: if user_turn: transcript_parts.append(f"User: {user_turn}") if alice_turn: transcript_parts.append(f"Alice: {alice_turn}") transcript_parts.append(f"User: {message}") transcript_parts.append(f"Alice: {alice_response}") full_transcript = "\n\n".join(transcript_parts) return alice_response, full_transcript except Exception as e: print(f"Error in chat_with_alice: {e}") traceback.print_exc() error_msg = "I'm having trouble responding right now." return error_msg, error_msg # Create simple Gradio interface with new message format print("๐ŸŽฏ Gradio interface ready!") # Custom CSS to remove message background shading custom_css = """ /* Remove background shading from chat messages */ .message-wrap .message { background: transparent !important; box-shadow: none !important; border: none !important; } .user .message, .bot .message { background: transparent !important; } /* Target additional potential chat message selectors */ .chat-message, .chatbot .message { background: transparent !important; box-shadow: none !important; border: none !important; } /* Remove any gradient backgrounds */ .message-wrap { background: transparent !important; } """ demo = gr.ChatInterface( fn=chat_with_alice, title="Chat with Alice", css=custom_css, chatbot=gr.Chatbot( value=[(None, CONVERSATION_SETTINGS["initial_message"])] ), additional_outputs=[gr.Textbox(label="Full Transcript", interactive=False, lines=20)] ) if __name__ == "__main__": print("๐ŸŒ Launching server...") demo.launch(server_name="0.0.0.0", server_port=7860)