import torch import time import gc import json import re from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from typing import Dict, Any, Optional # Performance optimizations torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Global model and tokenizer variables model = None tokenizer = None MODEL_ID = "kshitijthakkar/loggenix-moe-0.3B-A0.1B-e3-lr7e5-b16-4090-v7-sft-v1" # Inference configurations INFERENCE_CONFIGS = { "Optimized for Speed": { "max_new_tokens_base": 512, "max_new_tokens_cap": 512, "min_tokens": 50, "temperature": 0.7, "top_p": 0.9, "do_sample": True, "use_cache": False, "description": "Fast responses with limited output length" }, "Middle-ground": { "max_new_tokens_base": 4096, "max_new_tokens_cap": 4096, "min_tokens": 50, "temperature": 0.7, "top_p": 0.9, "do_sample": True, "use_cache": False, "description": "Balanced performance and output quality" }, "Full Capacity": { "max_new_tokens_base": 8192, "max_new_tokens_cap": 8192, "min_tokens": 1, "temperature": 0.7, "top_p": 0.9, "do_sample": True, "use_cache": False, "description": "Maximum output length with dynamic allocation" } } def get_inference_configs(): """Get available inference configurations""" return INFERENCE_CONFIGS def load_model(): """Load model and tokenizer with optimizations""" global model, tokenizer if model is not None and tokenizer is not None: return model, tokenizer print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) ## load 8 bit quants quantization_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, ) # # Or 4-bit for even more memory savings # quantization_config = BitsAndBytesConfig( # load_in_4bit=True, # bnb_4bit_compute_dtype=torch.float16, # bnb_4bit_quant_type="nf4", # bnb_4bit_use_double_quant=True, # ) print("Loading model...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", dtype=torch.float16, # Use half precision for speed attn_implementation="flash_attention_2" if hasattr(torch.nn, 'scaled_dot_product_attention') else None, use_cache=True, #quantization_config=quantization_config, ).eval() # Enable gradient checkpointing if available if hasattr(model, 'gradient_checkpointing_enable'): model.gradient_checkpointing_enable() # Set pad_token_id if model.config.pad_token_id is None and tokenizer.pad_token_id is not None: model.config.pad_token_id = tokenizer.pad_token_id # Set padding side to left for better batching tokenizer.padding_side = "left" memory = model.get_memory_footprint() / 1e6 print(f"Memory footprint: {memory:,.1f} MB") return model, tokenizer # ===== TOOL DEFINITIONS ===== def calculate_numbers(operation: str, num1: float, num2: float) -> Dict[str, Any]: """ Sample tool to perform basic mathematical operations on two numbers. Args: operation: The operation to perform ('add', 'subtract', 'multiply', 'divide') num1: First number num2: Second number Returns: Dictionary with result and operation details """ try: num1, num2 = float(num1), float(num2) if operation.lower() == 'add': result = num1 + num2 elif operation.lower() == 'subtract': result = num1 - num2 elif operation.lower() == 'multiply': result = num1 * num2 elif operation.lower() == 'divide': if num2 == 0: return {"error": "Division by zero is not allowed"} result = num1 / num2 else: return {"error": f"Unknown operation: {operation}"} return { "result": result, "operation": operation, "operands": [num1, num2], "formatted": f"{num1} {operation} {num2} = {result}" } except ValueError as e: return {"error": f"Invalid number format: {str(e)}"} except Exception as e: return {"error": f"Calculation error: {str(e)}"} # Tool registry AVAILABLE_TOOLS = { "calculate_numbers": { "function": calculate_numbers, "description": "Perform basic mathematical operations (add, subtract, multiply, divide) on two numbers", "parameters": { "operation": "The mathematical operation to perform", "num1": "First number", "num2": "Second number" } } } def execute_tool_call(tool_name: str, **kwargs) -> Dict[str, Any]: """Execute a tool call with given parameters""" print(f"Executing tool: {tool_name} with parameters: {kwargs}") if tool_name not in AVAILABLE_TOOLS: return {"error": f"Unknown tool: {tool_name}"} try: tool_function = AVAILABLE_TOOLS[tool_name]["function"] result = tool_function(**kwargs) return { "tool_name": tool_name, "parameters": kwargs, "result": result } except Exception as e: print(f"Tool execution failed: {str(e)}") return { "tool_name": tool_name, "parameters": kwargs, "error": f"Tool execution error: {str(e)}" } # def parse_tool_calls(text: str) -> list: # """ # Parse tool calls from model output. # Expected format: [TOOL_CALL:tool_name(param1=value1, param2=value2)] # """ # tool_calls = [] # #pattern = r'\[TOOL_CALL:(\w+)\((.*?)\)\]' # pattern = r'(\[TOOL_CALL:(\w+)\((.*?)\)\]|\s*{"name":\s*"(\w+)",\s*"parameters":\s*{([^}]*)}\s*}\s*)' # matches = re.findall(pattern, text) # print(matches) # # for tool_name, params_str in matches: # try: # params = {} # if params_str.strip(): # param_pairs = params_str.split(',') # for pair in param_pairs: # if '=' in pair: # key, value = pair.split('=', 1) # key = key.strip() # value = value.strip().strip('"\'') # Remove quotes # params[key] = value # tool_calls.append({ # "tool_name": tool_name, # "parameters": params, # "original_call": f"[TOOL_CALL:{tool_name}({params_str})]" # Store original call for replacement # }) # except Exception as e: # print(f"Error parsing tool call '{tool_name}({params_str})': {e}") # continue # # return tool_calls def parse_tool_calls(text: str) -> list: """ Parse tool calls from model output. Supports both formats: - [TOOL_CALL:tool_name(param1=value1, param2=value2)] - {"name": "tool_name", "parameters": {"param1": "value1", "param2": "value2"}} """ tool_calls = [] # Pattern for both formats pattern = r'(\[TOOL_CALL:(\w+)\((.*?)\)\]|\s*{"name":\s*"(\w+)",\s*"parameters":\s*{([^}]*)}\s*}\s*)' matches = re.findall(pattern, text) print("Raw matches:", matches) for match in matches: full_match, old_tool_name, old_params, json_tool_name, json_params = match # Determine which format was matched if old_tool_name: # Old format: [TOOL_CALL:tool_name(params)] tool_name = old_tool_name params_str = old_params original_call = f"[TOOL_CALL:{tool_name}({params_str})]" try: params = {} if params_str.strip(): param_pairs = params_str.split(',') for pair in param_pairs: if '=' in pair: key, value = pair.split('=', 1) key = key.strip() value = value.strip().strip('"\'') # Remove quotes params[key] = value tool_calls.append({ "tool_name": tool_name, "parameters": params, "original_call": original_call }) except Exception as e: print(f"Error parsing old format tool call '{tool_name}({params_str})': {e}") continue elif json_tool_name: # JSON format: ... tool_name = json_tool_name params_str = json_params original_call = full_match try: params = {} if params_str.strip(): # Parse JSON-like parameters # Handle the format: "operation": "add", "num1": "125", "num2": "675" param_pairs = params_str.split(',') for pair in param_pairs: if ':' in pair: key, value = pair.split(':', 1) key = key.strip().strip('"\'') # Remove quotes and whitespace value = value.strip().strip('"\'') # Remove quotes and whitespace params[key] = value tool_calls.append({ "tool_name": tool_name, "parameters": params, "original_call": original_call }) except Exception as e: print(f"Error parsing JSON format tool call '{tool_name}': {e}") continue return tool_calls def process_tool_calls(text: str) -> str: """Process tool calls in the generated text and replace with results""" tool_calls = parse_tool_calls(text) if not tool_calls: return text processed_text = text for tool_call in tool_calls: tool_name = tool_call["tool_name"] parameters = tool_call["parameters"] original_call = tool_call["original_call"] try: # Validate parameters before execution if not isinstance(parameters, dict): raise ValueError(f"Invalid parameters for tool {tool_name}: {parameters}") # Execute tool result = execute_tool_call(tool_name, **parameters) # Create replacement text if "error" in result: replacement = f"[TOOL_ERROR: {result['error']}]" else: if "result" in result["result"]: replacement = f"[TOOL_RESULT: {result['result']['formatted']}]" else: replacement = f"[TOOL_RESULT: {result['result']}]" # Replace tool call with result processed_text = processed_text.replace(original_call, replacement) except Exception as e: print(f"Error processing tool call '{tool_name}': {e}") replacement = f"[TOOL_ERROR: Failed to process tool call: {str(e)}]" processed_text = processed_text.replace(original_call, replacement) return processed_text def monitor_memory(): if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1e9 cached = torch.cuda.memory_reserved() / 1e9 print(f"GPU Memory - Allocated: {allocated:.2f}GB, Cached: {cached:.2f}GB") def generate_response(system_prompt: str, user_input: str, config_name: str = "Middle-ground") -> str: """ Run inference with the given task (system prompt) and user input using the specified config. """ load_model() config = INFERENCE_CONFIGS[config_name] input_messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_input} ] prompt_text = tokenizer.apply_chat_template( input_messages, tokenize=False, add_generation_prompt=True ) input_length = len(tokenizer.encode(prompt_text)) context_length = min(input_length, 3584) # Leave room for generation inputs = tokenizer( prompt_text, return_tensors="pt", truncation=True, max_length=context_length, padding=False ).to(model.device) actual_input_length = inputs['input_ids'].shape[1] max_new_tokens = min(config["max_new_tokens_cap"], 4096 - actual_input_length - 10) max_new_tokens = max(config["min_tokens"], max_new_tokens) with torch.no_grad(): start_time = time.time() outputs = model.generate( **inputs, do_sample=config["do_sample"], temperature=config["temperature"], top_p=config["top_p"], use_cache=config["use_cache"], max_new_tokens=max_new_tokens, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, # Memory optimizations output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ) inference_time = time.time() - start_time print(f"Inference time: {inference_time:.2f} seconds") memory = model.get_memory_footprint() / 1e6 monitor_memory() print(f"Memory footprint: {memory:,.1f} MB") # Clean up gc.collect() full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) if prompt_text in full_text: response_start = full_text.find(prompt_text) + len(prompt_text) generated_response = full_text[response_start:].strip() else: # More robust fallback: try to extract response after the last user message generated_response = full_text.strip() try: # Look for common assistant/response indicators response_indicators = ["Assistant:", "<|assistant|>", "[/INST]", "Response:"] for indicator in response_indicators: if indicator in full_text: parts = full_text.split(indicator) if len(parts) > 1: generated_response = parts[-1].strip() break # If no indicator found, try to remove the input part user_message = user_input if user_message in full_text: parts = full_text.split(user_message) if len(parts) > 1: generated_response = parts[-1].strip() except Exception: generated_response = full_text.strip() # Process any tool calls in the generated response generated_response = process_tool_calls(generated_response) # print('Input tokens:', inputs.input_ids.numel()) #print('Output tokens:', outputs.input_ids.numel()) # print('Output tokens:', outputs['input_ids'].numel()) return generated_response