import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from config import LLM_MODEL, CONFIDENCE_THRESHOLD, VECTORSTORE_DIR import os import sys import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) hf_cache = "/tmp/huggingface" os.environ["HF_HOME"] = hf_cache os.environ["TRANSFORMERS_CACHE"] = hf_cache os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache os.makedirs(hf_cache, exist_ok=True) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if BASE_DIR not in sys.path: sys.path.insert(0, BASE_DIR) # Load BioMistral once class BioMistralModel: def __init__(self, model_name=LLM_MODEL, device=None): logger.info(f"Loading model: {model_name}") self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") try: self.tokenizer = AutoTokenizer.from_pretrained( model_name, cache_dir=hf_cache ) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto" if self.device == "cuda" else None, cache_dir=hf_cache ) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Error loading model: {e}") # Fallback to pipeline self.pipeline = pipeline( "text-generation", model=model_name, device=0 if self.device == "cuda" else -1, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) self.use_pipeline = True else: self.use_pipeline = False def generate_answer(self, query: str) -> str: prompt = f"""You are a helpful bioinformatics tutor. Answer clearly and concisely. Question: {query} Answer:""" try: if hasattr(self, 'use_pipeline') and self.use_pipeline: # Use pipeline fallback result = self.pipeline( prompt, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=0.7, pad_token_id=self.pipeline.tokenizer.eos_token_id ) full_text = result[0]['generated_text'] else: # Use model directly inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=0.7, pad_token_id=self.tokenizer.eos_token_id ) full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the answer part if "Answer:" in full_text: return full_text.split("Answer:", 1)[-1].strip() else: return full_text.replace(prompt, "").strip() except Exception as e: logger.error(f"Error generating answer: {e}") return f"I apologize, but I encountered an error while processing your question: {str(e)}" # Formatting utility class TextFormatter: @staticmethod def format_text(text: str) -> str: """Clean and format text output""" if not text: return "I don't have an answer for that question. Could you please rephrase or ask something else?" # Basic cleaning cleaned = " ".join(text.split()) if cleaned: cleaned = cleaned[0].upper() + cleaned[1:] # Ensure it ends with punctuation if not cleaned[-1] in {'.', '!', '?'}: cleaned += '.' return cleaned # Tutor Agent class TutorAgent: def __init__(self): logger.info("Initializing TutorAgent") self.model = BioMistralModel() self.formatter = TextFormatter() # Initialize RAG self.rag_agent = None try: from rag import RAGAgent self.rag_agent = RAGAgent(vectorstore_dir=str(VECTORSTORE_DIR)) logger.info("RAG agent initialized") except ImportError as e: logger.warning(f"RAG not available: {e}") except Exception as e: logger.warning(f"Failed to initialize RAG: {e}") def process_query(self, query: str) -> str: logger.info(f"Processing query: {query}") if not query or len(query.strip()) < 2: return "Please ask a meaningful question about bioinformatics." # Generate answer answer = self.model.generate_answer(query) confidence = self.estimate_confidence(answer) logger.info(f"Confidence: {confidence:.2f}") # If confidence is low and RAG is available, try to enhance if confidence < CONFIDENCE_THRESHOLD and self.rag_agent: logger.info("Low confidence, attempting RAG enhancement") try: rag_answer = self._enhance_with_rag(query) if rag_answer and len(rag_answer) > len(answer): answer = rag_answer except Exception as e: logger.warning(f"RAG enhancement failed: {e}") return self.formatter.format_text(answer) def _enhance_with_rag(self, query: str) -> str: """Enhance answer using RAG if available""" if not self.rag_agent: return "" try: # Assuming RAGAgent has an answer method if hasattr(self.rag_agent, 'answer'): result = self.rag_agent.answer(query) return result.get('answer', '') if isinstance(result, dict) else str(result) else: return "" except Exception as e: logger.error(f"RAG error: {e}") return "" def estimate_confidence(self, answer: str) -> float: """Simple confidence estimation""" answer = answer.strip() if not answer: return 0.0 length = len(answer) if length > 150: return 0.85 elif length > 80: return 0.7 elif length > 30: return 0.5 else: return 0.3 # User class ( class BioUser: def __init__(self, name="BioUser"): self.name = name def ask_question(self, question: str, tutor: TutorAgent) -> str: return tutor.process_query(question)