import torch import time import gc import json import re import logging import traceback import sys from pathlib import Path from typing import Dict, Any, Optional, Tuple from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # Configure logging def setup_logging(log_level=logging.INFO, log_file="model_inference.log"): """Setup comprehensive logging configuration""" # Create logs directory if it doesn't exist log_dir = Path("logs") log_dir.mkdir(exist_ok=True) # Create formatter formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s' ) # Setup file handler file_handler = logging.FileHandler(log_dir / log_file) file_handler.setLevel(log_level) file_handler.setFormatter(formatter) # Setup console handler console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(log_level) console_handler.setFormatter(formatter) # Setup logger logger = logging.getLogger(__name__) logger.setLevel(log_level) logger.addHandler(file_handler) logger.addHandler(console_handler) # Prevent duplicate logs logger.propagate = False return logger # Initialize logger logger = setup_logging() # Performance optimizations try: torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True logger.info("PyTorch optimizations enabled successfully") except Exception as e: logger.warning(f"Failed to enable some PyTorch optimizations: {e}") # Global model and tokenizer variables model = None tokenizer = None MODEL_ID = "kshitijthakkar/loggenix-moe-0.3B-A0.1B-e3-lr7e5-b16-4090-v7-sft-v1" # Inference configurations INFERENCE_CONFIGS = { "Optimized for Speed": { "max_new_tokens_base": 512, "max_new_tokens_cap": 512, "min_tokens": 50, "temperature": 0.7, "top_p": 0.9, "do_sample": True, "use_cache": False, "description": "Fast responses with limited output length" }, "Middle-ground": { "max_new_tokens_base": 4096, "max_new_tokens_cap": 4096, "min_tokens": 50, "temperature": 0.7, "top_p": 0.9, "do_sample": True, "use_cache": False, "description": "Balanced performance and output quality" }, "Full Capacity": { "max_new_tokens_base": 8192, "max_new_tokens_cap": 8192, "min_tokens": 1, "temperature": 0.7, "top_p": 0.9, "do_sample": True, "use_cache": False, "description": "Maximum output length with dynamic allocation" } } def validate_config(config_name: str) -> bool: """Validate inference configuration""" try: if config_name not in INFERENCE_CONFIGS: logger.error(f"Invalid config name: {config_name}. Available: {list(INFERENCE_CONFIGS.keys())}") return False config = INFERENCE_CONFIGS[config_name] required_fields = ["max_new_tokens_base", "max_new_tokens_cap", "min_tokens", "temperature", "top_p"] for field in required_fields: if field not in config: logger.error(f"Missing required field '{field}' in config '{config_name}'") return False logger.debug(f"Configuration '{config_name}' validated successfully") return True except Exception as e: logger.error(f"Error validating config '{config_name}': {e}") return False def get_inference_configs(): """Get available inference configurations""" try: logger.debug("Retrieving inference configurations") return INFERENCE_CONFIGS except Exception as e: logger.error(f"Error retrieving inference configurations: {e}") return {} def check_system_requirements() -> bool: """Check if system meets requirements for model loading""" try: # Check CUDA availability if not torch.cuda.is_available(): logger.warning("CUDA is not available. Model will run on CPU (much slower)") return True # Still allow CPU execution # Check GPU memory gpu_count = torch.cuda.device_count() logger.info(f"Found {gpu_count} GPU(s)") for i in range(gpu_count): gpu_props = torch.cuda.get_device_properties(i) total_memory = gpu_props.total_memory / 1e9 logger.info(f"GPU {i}: {gpu_props.name}, Memory: {total_memory:.1f}GB") if total_memory < 4.0: # Minimum 4GB for quantized model logger.warning(f"GPU {i} has insufficient memory ({total_memory:.1f}GB < 4.0GB)") return True except Exception as e: logger.error(f"Error checking system requirements: {e}") return False def load_model() -> Tuple[Optional[Any], Optional[Any]]: """Load model and tokenizer with comprehensive error handling""" global model, tokenizer try: if model is not None and tokenizer is not None: logger.debug("Model and tokenizer already loaded") return model, tokenizer logger.info("Starting model loading process...") # Check system requirements if not check_system_requirements(): logger.error("System requirements check failed") return None, None # Load tokenizer with error handling logger.info(f"Loading tokenizer from {MODEL_ID}...") try: tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, # Add this for custom tokenizers #cache_dir="./model_cache" # Use local cache ) logger.info("Tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to load tokenizer: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None, None # Configure quantization try: quantization_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, ) logger.info("8-bit quantization configuration created") except Exception as e: logger.error(f"Failed to create quantization config: {e}") quantization_config = None # Load model with extensive error handling logger.info(f"Loading model from {MODEL_ID}...") try: model_kwargs = { "device_map": "auto", #"dtype": torch.float16, "use_cache": False, "trust_remote_code": True, #"cache_dir": "./model_cache" } # Add quantization if available if quantization_config: model_kwargs["quantization_config"] = quantization_config # Try to use flash attention if available try: if hasattr(torch.nn, 'scaled_dot_product_attention'): model_kwargs["attn_implementation"] = "flash_attention_2" logger.info("Using Flash Attention 2") except Exception as e: logger.warning(f"Flash Attention 2 not available: {e}") model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **model_kwargs) model = model.eval() logger.info("Model loaded successfully") print(next(model.parameters()).device) from accelerate import infer_auto_device_map print(infer_auto_device_map(model)) # Should show "cuda" for all layers except torch.cuda.OutOfMemoryError: logger.error("CUDA out of memory. Try reducing batch size or using CPU") return None, None except Exception as e: logger.error(f"Failed to load model: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None, None # Configure model settings with error handling try: # Enable gradient checkpointing if available if hasattr(model, 'gradient_checkpointing_enable'): model.gradient_checkpointing_enable() logger.debug("Gradient checkpointing enabled") # Set pad_token_id if model.config.pad_token_id is None: if tokenizer.pad_token_id is not None: model.config.pad_token_id = tokenizer.pad_token_id logger.debug("Set model pad_token_id from tokenizer") else: # Fallback to eos_token_id model.config.pad_token_id = tokenizer.eos_token_id tokenizer.pad_token_id = tokenizer.eos_token_id logger.debug("Set pad_token_id to eos_token_id") # Set padding side to left for better batching tokenizer.padding_side = "left" logger.debug("Set tokenizer padding side to left") except Exception as e: logger.warning(f"Error configuring model settings: {e}") # Log memory usage try: if hasattr(model, 'get_memory_footprint'): memory = model.get_memory_footprint() / 1e6 logger.info(f"Model memory footprint: {memory:,.1f} MB") except Exception as e: logger.warning(f"Could not get memory footprint: {e}") logger.info("Model loading completed successfully") return model, tokenizer except Exception as e: logger.error(f"Unexpected error in load_model: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None, None # ===== TOOL DEFINITIONS ===== def calculate_numbers(operation: str, num1: float, num2: float) -> Dict[str, Any]: """ Sample tool to perform basic mathematical operations on two numbers. Args: operation: The operation to perform ('add', 'subtract', 'multiply', 'divide') num1: First number num2: Second number Returns: Dictionary with result and operation details """ try: logger.debug(f"Calculating: {num1} {operation} {num2}") # Validate inputs if not isinstance(operation, str): raise ValueError("Operation must be a string") try: num1, num2 = float(num1), float(num2) except (ValueError, TypeError) as e: logger.error(f"Invalid number format: num1={num1}, num2={num2}") return {"error": f"Invalid number format: {str(e)}"} operation = operation.lower().strip() # Perform operation if operation == 'add': result = num1 + num2 elif operation == 'subtract': result = num1 - num2 elif operation == 'multiply': result = num1 * num2 elif operation == 'divide': if num2 == 0: logger.error("Division by zero attempted") return {"error": "Division by zero is not allowed"} result = num1 / num2 else: logger.error(f"Unknown operation: {operation}") return {"error": f"Unknown operation: {operation}. Supported: add, subtract, multiply, divide"} response = { "result": result, "operation": operation, "operands": [num1, num2], "formatted": f"{num1} {operation} {num2} = {result}" } logger.debug(f"Calculation successful: {response['formatted']}") return response except Exception as e: logger.error(f"Unexpected error in calculate_numbers: {e}") return {"error": f"Calculation error: {str(e)}"} # Tool registry AVAILABLE_TOOLS = { "calculate_numbers": { "function": calculate_numbers, "description": "Perform basic mathematical operations (add, subtract, multiply, divide) on two numbers", "parameters": { "operation": "The mathematical operation to perform", "num1": "First number", "num2": "Second number" } } } def execute_tool_call(tool_name: str, **kwargs) -> Dict[str, Any]: """Execute a tool call with given parameters""" try: logger.info(f"Executing tool: {tool_name} with parameters: {kwargs}") if not tool_name or not isinstance(tool_name, str): logger.error(f"Invalid tool name: {tool_name}") return {"error": "Invalid tool name"} if tool_name not in AVAILABLE_TOOLS: logger.error(f"Unknown tool: {tool_name}. Available: {list(AVAILABLE_TOOLS.keys())}") return {"error": f"Unknown tool: {tool_name}"} if not isinstance(kwargs, dict): logger.error(f"Invalid parameters type: {type(kwargs)}") return {"error": "Parameters must be a dictionary"} tool_function = AVAILABLE_TOOLS[tool_name]["function"] result = tool_function(**kwargs) response = { "tool_name": tool_name, "parameters": kwargs, "result": result } if "error" not in result: logger.info(f"Tool execution successful: {tool_name}") else: logger.warning(f"Tool execution returned error: {result['error']}") return response except TypeError as e: logger.error(f"Parameter error for tool '{tool_name}': {e}") return { "tool_name": tool_name, "parameters": kwargs, "error": f"Invalid parameters: {str(e)}" } except Exception as e: logger.error(f"Tool execution failed: {str(e)}") logger.error(f"Traceback: {traceback.format_exc()}") return { "tool_name": tool_name, "parameters": kwargs, "error": f"Tool execution error: {str(e)}" } def parse_tool_calls(text: str) -> list: """ Parse tool calls from model output with comprehensive error handling. Supports both formats: - [TOOL_CALL:tool_name(param1=value1, param2=value2)] - {"name": "tool_name", "parameters": {"param1": "value1", "param2": "value2"}} """ try: if not text or not isinstance(text, str): logger.warning("Invalid text input for tool call parsing") return [] tool_calls = [] logger.debug(f"Parsing tool calls from text: {text[:200]}...") # Pattern for both formats pattern = r'(\[TOOL_CALL:(\w+)\((.*?)\)\]|\s*{"name":\s*"(\w+)",\s*"parameters":\s*{([^}]*)}\s*}\s*)' matches = re.findall(pattern, text) logger.debug(f"Found {len(matches)} potential tool call matches") for i, match in enumerate(matches): try: full_match, old_tool_name, old_params, json_tool_name, json_params = match # Determine which format was matched if old_tool_name: # Old format: [TOOL_CALL:tool_name(params)] tool_name = old_tool_name params_str = old_params original_call = f"[TOOL_CALL:{tool_name}({params_str})]" params = {} if params_str.strip(): param_pairs = params_str.split(',') for pair in param_pairs: try: if '=' in pair: key, value = pair.split('=', 1) key = key.strip() value = value.strip().strip('"\'') # Remove quotes params[key] = value except Exception as e: logger.warning(f"Error parsing parameter pair '{pair}': {e}") logger.debug(f"Parsed old format tool call: {tool_name} with params: {params}") elif json_tool_name: # JSON format: ... tool_name = json_tool_name params_str = json_params original_call = full_match params = {} if params_str.strip(): # Parse JSON-like parameters param_pairs = params_str.split(',') for pair in param_pairs: try: if ':' in pair: key, value = pair.split(':', 1) key = key.strip().strip('"\'') # Remove quotes and whitespace value = value.strip().strip('"\'') # Remove quotes and whitespace params[key] = value except Exception as e: logger.warning(f"Error parsing JSON parameter pair '{pair}': {e}") logger.debug(f"Parsed JSON format tool call: {tool_name} with params: {params}") else: logger.warning(f"Could not determine tool call format for match {i}") continue # Validate tool call if tool_name and isinstance(params, dict): tool_calls.append({ "tool_name": tool_name, "parameters": params, "original_call": original_call }) else: logger.warning(f"Invalid tool call data: tool_name='{tool_name}', params={params}") except Exception as e: logger.error(f"Error parsing tool call match {i}: {e}") continue logger.info(f"Successfully parsed {len(tool_calls)} tool calls") return tool_calls except Exception as e: logger.error(f"Unexpected error in parse_tool_calls: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return [] def process_tool_calls(text: str) -> str: """Process tool calls in the generated text and replace with results""" try: if not text: logger.warning("Empty text provided to process_tool_calls") return text logger.debug("Processing tool calls in generated text") tool_calls = parse_tool_calls(text) if not tool_calls: logger.debug("No tool calls found in text") return text processed_text = text successful_calls = 0 for i, tool_call in enumerate(tool_calls): try: tool_name = tool_call["tool_name"] parameters = tool_call["parameters"] original_call = tool_call["original_call"] logger.debug(f"Processing tool call {i + 1}/{len(tool_calls)}: {tool_name}") # Validate parameters before execution if not isinstance(parameters, dict): logger.error(f"Invalid parameters for tool {tool_name}: {parameters}") replacement = f"[TOOL_ERROR: Invalid parameters for tool {tool_name}]" else: # Execute tool result = execute_tool_call(tool_name, **parameters) # Create replacement text if "error" in result: replacement = f"[TOOL_ERROR: {result['error']}]" logger.warning(f"Tool call failed: {result['error']}") else: if "result" in result["result"] and "formatted" in result["result"]: replacement = f"[TOOL_RESULT: {result['result']['formatted']}]" elif "result" in result: replacement = f"[TOOL_RESULT: {result['result']}]" else: replacement = f"[TOOL_RESULT: Success]" successful_calls += 1 logger.debug(f"Tool call successful: {replacement}") # Replace tool call with result processed_text = processed_text.replace(original_call, replacement) except Exception as e: logger.error(f"Error processing tool call {i + 1}: {e}") tool_name = tool_call.get("tool_name", "unknown") original_call = tool_call.get("original_call", "") replacement = f"[TOOL_ERROR: Failed to process tool call: {str(e)}]" if original_call: processed_text = processed_text.replace(original_call, replacement) logger.info(f"Processed {len(tool_calls)} tool calls ({successful_calls} successful)") return processed_text except Exception as e: logger.error(f"Unexpected error in process_tool_calls: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return text # Return original text if processing fails def monitor_memory(): """Monitor and log memory usage""" try: if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1e9 cached = torch.cuda.memory_reserved() / 1e9 max_allocated = torch.cuda.max_memory_allocated() / 1e9 logger.info( f"GPU Memory - Allocated: {allocated:.2f}GB, Cached: {cached:.2f}GB, Max: {max_allocated:.2f}GB") # Log warning if memory usage is high total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if allocated / total_memory > 0.9: logger.warning(f"High GPU memory usage: {allocated / total_memory * 100:.1f}%") # Clean up cache if needed torch.cuda.empty_cache() else: logger.debug("CUDA not available, skipping GPU memory monitoring") # Clean up Python memory gc.collect() logger.debug("Resources cleaned up successfully") except Exception as e: logger.error(f"Error monitoring memory: {e}") def get_model_info() -> Dict[str, Any]: """Get information about the loaded model""" try: if model is None: return {"status": "not_loaded"} info = { "status": "loaded", "model_id": MODEL_ID, "device": str(model.device) if hasattr(model, 'device') else "unknown", "dtype": str(model.dtype) if hasattr(model, 'dtype') else "unknown" } # Add memory info if available if hasattr(model, 'get_memory_footprint'): try: info["memory_footprint_mb"] = model.get_memory_footprint() / 1e6 except: pass # Add GPU info if available if torch.cuda.is_available(): info["gpu_count"] = torch.cuda.device_count() info["current_gpu"] = torch.cuda.current_device() info["gpu_memory_allocated"] = torch.cuda.memory_allocated() / 1e9 info["gpu_memory_cached"] = torch.cuda.memory_reserved() / 1e9 return info except Exception as e: logger.error(f"Error getting model info: {e}") return {"status": "error", "error": str(e)} def health_check() -> Dict[str, Any]: """Perform a health check of the system""" try: health_status = { "timestamp": time.time(), "torch_version": torch.__version__, "cuda_available": torch.cuda.is_available(), "model_loaded": model is not None, "tokenizer_loaded": tokenizer is not None, } if torch.cuda.is_available(): health_status.update({ "cuda_version": torch.version.cuda, "gpu_count": torch.cuda.device_count(), "gpu_memory_total": torch.cuda.get_device_properties(0).total_memory / 1e9, "gpu_memory_available": (torch.cuda.get_device_properties( 0).total_memory - torch.cuda.memory_allocated()) / 1e9 }) # Test a simple generation if model is loaded if model is not None and tokenizer is not None: try: test_response = generate_response( "You are a helpful assistant.", "Say hello", "Optimized for Speed" ) health_status["test_generation"] = "success" if test_response else "failed" except Exception as e: health_status["test_generation"] = f"error: {str(e)}" logger.info(f"Health check completed: {health_status}") return health_status except Exception as e: logger.error(f"Error during health check: {e}") return {"status": "error", "error": str(e)} def validate_inputs(system_prompt: str, user_input: str, config_name: str) -> bool: """Validate inputs for generate_response""" try: if not isinstance(system_prompt, str) or not system_prompt.strip(): logger.error("System prompt must be a non-empty string") return False if not isinstance(user_input, str) or not user_input.strip(): logger.error("User input must be a non-empty string") return False if not validate_config(config_name): return False # Check input length total_length = len(system_prompt) + len(user_input) if total_length > 50000: # Reasonable limit logger.warning(f"Input length is very long: {total_length} characters") return True except Exception as e: logger.error(f"Error validating inputs: {e}") return False def generate_response(system_prompt: str, user_input: str, config_name: str = "Middle-ground") -> Optional[str]: """ Run inference with comprehensive error handling and logging. Args: system_prompt: System message/prompt user_input: User's input message config_name: Name of the inference configuration to use Returns: Generated response text, or None if generation failed """ try: logger.info(f"Starting response generation with config: {config_name}") # Validate inputs if not validate_inputs(system_prompt, user_input, config_name): logger.error("Input validation failed") return None # Load model model, tokenizer = load_model() if model is None or tokenizer is None: logger.error("Failed to load model or tokenizer") return None # Get configuration config = INFERENCE_CONFIGS[config_name] logger.debug(f"Using config: {config}") # Prepare messages input_messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_input} ] # Apply chat template try: prompt_text = tokenizer.apply_chat_template( input_messages, tokenize=False, add_generation_prompt=True ) logger.debug("Chat template applied successfully") except Exception as e: logger.error(f"Failed to apply chat template: {e}") # Fallback to simple concatenation prompt_text = f"System: {system_prompt}\nUser: {user_input}\nAssistant:" logger.info("Using fallback prompt format") # Tokenize input try: input_length = len(tokenizer.encode(prompt_text)) context_length = min(input_length, 3584) # Leave room for generation inputs = tokenizer( prompt_text, return_tensors="pt", truncation=True, max_length=context_length, padding=False ).to(model.device) logger.debug(f"Input tokenized: {inputs['input_ids'].shape[1]} tokens") except Exception as e: logger.error(f"Failed to tokenize input: {e}") return None # Calculate generation parameters actual_input_length = inputs['input_ids'].shape[1] max_new_tokens = min(config["max_new_tokens_cap"], 4096 - actual_input_length - 10) max_new_tokens = max(config["min_tokens"], max_new_tokens) logger.debug(f"Generation params - Input length: {actual_input_length}, Max new tokens: {max_new_tokens}") # Monitor memory before generation monitor_memory() # Generate response try: with torch.no_grad(): start_time = time.time() generation_kwargs = { "do_sample": config["do_sample"], "temperature": config["temperature"], "top_p": config["top_p"], "use_cache": config["use_cache"], "max_new_tokens": max_new_tokens, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id, "output_attentions": False, "output_hidden_states": False, "return_dict_in_generate": False, } outputs = model.generate(**inputs, **generation_kwargs) inference_time = time.time() - start_time logger.info(f"Generation completed in {inference_time:.2f} seconds") except torch.cuda.OutOfMemoryError: logger.error("CUDA out of memory during generation") # Try to free memory gc.collect() torch.cuda.empty_cache() return None except Exception as e: logger.error(f"Generation failed: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None # Monitor memory after generation monitor_memory() # Clean up GPU memory try: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: logger.warning(f"Error during cleanup: {e}") # Decode response try: full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract generated response if prompt_text in full_text: response_start = full_text.find(prompt_text) + len(prompt_text) generated_response = full_text[response_start:].strip() else: # More robust fallback generated_response = full_text.strip() try: # Look for common assistant/response indicators response_indicators = ["Assistant:", "<|assistant|>", "[/INST]", "Response:"] for indicator in response_indicators: if indicator in full_text: parts = full_text.split(indicator) if len(parts) > 1: generated_response = parts[-1].strip() break # If no indicator found, try to remove the input part if user_input in full_text: parts = full_text.split(user_input) if len(parts) > 1: generated_response = parts[-1].strip() except Exception as extract_error: logger.warning(f"Error extracting response: {extract_error}") generated_response = full_text.strip() logger.debug(f"Extracted response: {generated_response[:100]}...") except Exception as e: logger.error(f"Failed to decode response: {e}") return None # Process tool calls try: processed_response = process_tool_calls(generated_response) logger.debug("Tool call processing completed") except Exception as e: logger.error(f"Error processing tool calls: {e}") processed_response = generated_response # Use original if tool processing fails # Log final statistics input_tokens = inputs['input_ids'].shape[1] output_tokens = outputs.shape[1] - input_tokens logger.info( f"Generation stats - Input tokens: {input_tokens}, Output tokens: {output_tokens}, Time: {inference_time:.2f}s") logger.info("Response generation completed successfully") return processed_response except Exception as e: logger.error(f"Unexpected error in generate_response: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None def safe_generate_response(system_prompt: str, user_input: str, config_name: str = "Middle-ground", max_retries: int = 2) -> Optional[str]: """ Generate response with retry logic and fallback options Args: system_prompt: System message/prompt user_input: User's input message config_name: Name of the inference configuration to use max_retries: Maximum number of retry attempts Returns: Generated response text, or None if all attempts failed """ for attempt in range(max_retries + 1): try: logger.info(f"Generation attempt {attempt + 1}/{max_retries + 1}") response = generate_response(system_prompt, user_input, config_name) if response is not None: logger.info(f"Generation successful on attempt {attempt + 1}") return response if attempt < max_retries: logger.warning(f"Generation failed on attempt {attempt + 1}, retrying...") # Clean up before retry gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() time.sleep(1) # Brief pause before retry except Exception as e: logger.error(f"Error on generation attempt {attempt + 1}: {e}") if attempt < max_retries: logger.info("Cleaning up and retrying...") try: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() except: pass time.sleep(2) # Longer pause after error logger.error(f"All {max_retries + 1} generation attempts failed") return None # Context manager for safe model operations class ModelContext: """Context manager for safe model operations with automatic cleanup""" def __init__(self, auto_cleanup: bool = True): self.auto_cleanup = auto_cleanup self.original_model = None self.original_tokenizer = None def __enter__(self): global model, tokenizer self.original_model = model self.original_tokenizer = tokenizer logger.debug("Entered model context") return self def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: logger.error(f"Exception in model context: {exc_type.__name__}: {exc_val}") if self.auto_cleanup: try: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("Model context cleanup completed") except Exception as e: logger.warning(f"Error during model context cleanup: {e}") logger.debug("Exited model context") def cleanup_resources(): """Clean up model resources""" global model, tokenizer try: if model is not None: del model model = None logger.info("Model removed from memory") if tokenizer is not None: del tokenizer tokenizer = None logger.info("Tokenizer removed from memory") # Clean up GPU memory gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() logger.info("GPU memory cleaned up") logger.info("Resource cleanup completed") except Exception as e: logger.error(f"Error during resource cleanup: {e}") def unload_model(): """Explicitly unload the model and tokenizer""" try: logger.info("Unloading model and tokenizer...") cleanup_resources() logger.info("Model and tokenizer unloaded successfully") return True except Exception as e: logger.error(f"Error unloading model: {e}") return False def reload_model(): """Reload the model and tokenizer""" try: logger.info("Reloading model and tokenizer...") # First clean up existing resources cleanup_resources() time.sleep(1) # Brief pause # Load fresh model and tokenizer model, tokenizer = load_model() if model is not None and tokenizer is not None: logger.info("Model and tokenizer reloaded successfully") return True else: logger.error("Failed to reload model and tokenizer") return False except Exception as e: logger.error(f"Error reloading model: {e}") return False def get_available_tools() -> Dict[str, Any]: """Get information about available tools""" try: return { "tools": AVAILABLE_TOOLS, "count": len(AVAILABLE_TOOLS), "tool_names": list(AVAILABLE_TOOLS.keys()) } except Exception as e: logger.error(f"Error getting available tools: {e}") return {"error": str(e)} def add_tool(tool_name: str, tool_function, description: str, parameters: Dict[str, str]): """Add a new tool to the registry""" try: if not tool_name or not isinstance(tool_name, str): raise ValueError("Tool name must be a non-empty string") if not callable(tool_function): raise ValueError("Tool function must be callable") if tool_name in AVAILABLE_TOOLS: logger.warning(f"Tool '{tool_name}' already exists, replacing...") AVAILABLE_TOOLS[tool_name] = { "function": tool_function, "description": description, "parameters": parameters or {} } logger.info(f"Tool '{tool_name}' added successfully") return True except Exception as e: logger.error(f"Error adding tool '{tool_name}': {e}") return False def remove_tool(tool_name: str): """Remove a tool from the registry""" try: if tool_name not in AVAILABLE_TOOLS: logger.warning(f"Tool '{tool_name}' not found") return False del AVAILABLE_TOOLS[tool_name] logger.info(f"Tool '{tool_name}' removed successfully") return True except Exception as e: logger.error(f"Error removing tool '{tool_name}': {e}") return False # Example usage and testing functions def run_example(): """Run an example to test the system""" try: logger.info("Running example test") # Test health check health = health_check() logger.info(f"System health: {health}") # Test model loading model_obj, tokenizer_obj = load_model() if model_obj is None or tokenizer_obj is None: logger.error("Failed to load model for example") return False # Test generation with ModelContext(): response = safe_generate_response( "You are a helpful mathematical assistant.", "What is 15 + 25? Use the calculate_numbers tool.", "Optimized for Speed" ) if response: logger.info(f"Example response: {response}") return True else: logger.error("Example generation failed") return False except Exception as e: logger.error(f"Error in example: {e}") return False def run_batch_test(): """Run batch test with multiple inputs""" try: logger.info("Running batch test") test_cases = [ { "system": "You are a helpful assistant.", "user": "Hello, how are you?", "config": "Optimized for Speed" }, { "system": "You are a mathematical assistant.", "user": "Calculate 10 * 5 using the calculate_numbers tool.", "config": "Middle-ground" }, { "system": "You are a helpful assistant.", "user": "Explain the concept of machine learning in simple terms.", "config": "Full Capacity" } ] results = [] for i, test_case in enumerate(test_cases): logger.info(f"Running test case {i + 1}/{len(test_cases)}") with ModelContext(): response = safe_generate_response( test_case["system"], test_case["user"], test_case["config"] ) results.append({ "test_case": i + 1, "success": response is not None, "response": response[:100] + "..." if response and len(response) > 100 else response }) success_count = sum(1 for r in results if r["success"]) logger.info(f"Batch test completed: {success_count}/{len(test_cases)} successful") return results except Exception as e: logger.error(f"Error in batch test: {e}") return [] def benchmark_generation(num_runs: int = 5): """Benchmark generation performance""" try: logger.info(f"Running benchmark with {num_runs} iterations") # Load model first model_obj, tokenizer_obj = load_model() if model_obj is None or tokenizer_obj is None: logger.error("Failed to load model for benchmark") return None system_prompt = "You are a helpful assistant." user_input = "Explain the importance of renewable energy in 2-3 sentences." times = [] token_counts = [] for i in range(num_runs): logger.info(f"Benchmark run {i + 1}/{num_runs}") start_time = time.time() response = generate_response(system_prompt, user_input, "Middle-ground") end_time = time.time() if response: generation_time = end_time - start_time times.append(generation_time) # Estimate token count (rough approximation) token_count = len(response.split()) * 1.3 # Rough tokens-to-words ratio token_counts.append(token_count) logger.info(f"Run {i + 1}: {generation_time:.2f}s, ~{token_count:.0f} tokens") else: logger.warning(f"Run {i + 1} failed") if times: avg_time = sum(times) / len(times) avg_tokens = sum(token_counts) / len(token_counts) tokens_per_sec = avg_tokens / avg_time if avg_time > 0 else 0 benchmark_results = { "runs": num_runs, "successful_runs": len(times), "avg_time": avg_time, "avg_tokens": avg_tokens, "tokens_per_second": tokens_per_sec, "min_time": min(times), "max_time": max(times) } logger.info(f"Benchmark results: {benchmark_results}") return benchmark_results else: logger.error("All benchmark runs failed") return None except Exception as e: logger.error(f"Error in benchmark: {e}") return None # API-like interface functions def initialize_system(): """Initialize the inference system""" try: logger.info("Initializing inference system...") # Check system requirements if not check_system_requirements(): return {"status": "error", "message": "System requirements not met"} # Load model and tokenizer model_obj, tokenizer_obj = load_model() if model_obj is None or tokenizer_obj is None: return {"status": "error", "message": "Failed to load model"} # Run health check health = health_check() if "error" in health: return {"status": "warning", "message": "System initialized with warnings", "health": health} logger.info("Inference system initialized successfully") return {"status": "success", "message": "System initialized successfully", "health": health} except Exception as e: logger.error(f"Error initializing system: {e}") return {"status": "error", "message": str(e)} def shutdown_system(): """Shutdown the inference system cleanly""" try: logger.info("Shutting down inference system...") cleanup_resources() logger.info("Inference system shutdown complete") return {"status": "success", "message": "System shutdown successfully"} except Exception as e: logger.error(f"Error during shutdown: {e}") return {"status": "error", "message": str(e)} if __name__ == "__main__": """Main entry point for testing""" try: logger.info("Starting model inference system") # Initialize system init_result = initialize_system() logger.info(f"Initialization result: {init_result}") if init_result["status"] != "error": # Run example success = run_example() if success: logger.info("System test completed successfully") # Optionally run additional tests print("\nWould you like to run additional tests? (y/n)") try: choice = input().lower().strip() if choice == 'y': logger.info("Running batch test...") batch_results = run_batch_test() logger.info(f"Batch test results: {batch_results}") logger.info("Running benchmark...") benchmark_results = benchmark_generation(3) logger.info(f"Benchmark results: {benchmark_results}") except (EOFError, KeyboardInterrupt): logger.info("Skipping additional tests") else: logger.error("System test failed") # Shutdown shutdown_result = shutdown_system() logger.info(f"Shutdown result: {shutdown_result}") except KeyboardInterrupt: logger.info("Interrupted by user") cleanup_resources() except Exception as e: logger.error(f"Unexpected error in main: {e}") logger.error(f"Traceback: {traceback.format_exc()}") cleanup_resources() finally: logger.info("Program terminated")