""" Groq Medical RAG System v2.0 FREE Groq Cloud API integration for advanced medical reasoning """ import os import time import logging import numpy as np from typing import List, Dict, Any, Optional, Tuple from dataclasses import dataclass from dotenv import load_dotenv from pathlib import Path import argparse import shutil import re # Langchain for document loading and splitting from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain.text_splitter import RecursiveCharacterTextSplitter # Sentence Transformers for re-ranking from sentence_transformers import CrossEncoder # Groq API integration from groq import Groq from tenacity import retry, stop_after_attempt, wait_fixed, before_sleep_log # Load environment variables from .env file load_dotenv() # Import our simplified components from .simple_vector_store import SimpleVectorStore, SearchResult @dataclass class MedicalResponse: """Enhanced medical response structure""" answer: str confidence: float sources: List[str] query_time: float class GroqMedicalRAG: """Groq-powered Medical RAG System v2.0 - FREE LLM integration""" def __init__(self, vector_store_dir: str = "simple_vector_store", processed_docs_dir: str = "src/processed_markdown", groq_api_key: Optional[str] = None): """Initialize the Groq medical RAG system""" # Get the absolute path to the project root directory project_root = Path(__file__).parent.parent.resolve() self.vector_store_dir = project_root / vector_store_dir self.processed_docs_dir = project_root / processed_docs_dir # Initialize Groq client self.groq_api_key = groq_api_key or os.getenv("GROQ_API_KEY") if not self.groq_api_key: raise ValueError("GROQ_API_KEY environment variable not set. Get your free API key from https://console.groq.com/keys") self.groq_client = Groq(api_key=self.groq_api_key) self.model_name = "llama3-70b-8192" # Initialize Cross-Encoder for re-ranking self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # Initialize components self.vector_store = None self.setup_logging() self._initialize_system() def setup_logging(self): """Setup logging for the RAG system""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) self.logger = logging.getLogger(__name__) def _initialize_system(self, force_recreate: bool = False): """Initialize the RAG system components""" try: # If forcing recreation, delete the old vector store if force_recreate and self.vector_store_dir.exists(): self.logger.warning(f"Recreating index as requested. Deleting {self.vector_store_dir}...") shutil.rmtree(self.vector_store_dir) # Initialize vector store self.vector_store = SimpleVectorStore(vector_store_dir=self.vector_store_dir) # Try to load existing vector store if not self.vector_store.load_vector_store(): self.logger.info("Creating new vector store from documents...") self._create_vector_store() else: self.logger.info("Loaded existing vector store") # Test Groq connection self._test_groq_connection() self.logger.info("Groq Medical RAG system initialized successfully") except Exception as e: self.logger.error(f"Error initializing RAG system: {e}") raise @retry( stop=stop_after_attempt(3), wait=wait_fixed(2), before_sleep=before_sleep_log(logging.getLogger(__name__), logging.INFO) ) def _test_groq_connection(self): """Test Groq API connection with retry logic.""" try: self.groq_client.chat.completions.create( model=self.model_name, messages=[{"role": "user", "content": "Test"}], max_tokens=10, ) self.logger.info("✅ Groq API connection successful") except Exception as e: self.logger.error(f"❌ Groq API connection failed: {e}") raise def _create_vector_store(self): """Create vector store from processed markdown documents.""" self.logger.info(f"Checking for documents in {self.processed_docs_dir}...") doc_files = list(self.processed_docs_dir.glob("**/*.md")) if not doc_files: self.logger.error(f"No markdown files found in {self.processed_docs_dir}. Please run the enhanced_pdf_processor.py script first.") raise FileNotFoundError(f"No markdown files found in {self.processed_docs_dir}") self.logger.info(f"Found {len(doc_files)} markdown documents to process.") # Load documents using UnstructuredMarkdownLoader all_docs = [] for doc_path in doc_files: try: loader = UnstructuredMarkdownLoader(str(doc_path)) loaded_docs = loader.load() # We still need to ensure the 'source' is present for our context string. for doc in loaded_docs: if 'source' not in doc.metadata: doc.metadata['source'] = str(doc_path) all_docs.extend(loaded_docs) except Exception as e: self.logger.error(f"Error loading {doc_path}: {e}") if not all_docs: self.logger.error("Failed to load any documents. Vector store not created.") return # Split documents into chunks with smaller size and overlap text_splitter = RecursiveCharacterTextSplitter( chunk_size=1024, # Reduced from 2048 chunk_overlap=128, # Reduced from 256 separators=["\n\n", "\n", " ", ""] ) chunks = text_splitter.split_documents(all_docs) self.logger.info(f"Created {len(chunks)} chunks from {len(all_docs)} documents.") # Create embeddings and build index embeddings, count = self.vector_store.create_embeddings(chunks) self.vector_store.build_index(embeddings) self.vector_store.save_vector_store() self.logger.info(f"Created vector store with {count} embeddings.") def query(self, query: str, history: Optional[List[Dict[str, str]]] = None, k: int = 15, # Reduced from 30 top_n_rerank: int = 3, # Reduced from 5 use_llm: bool = True) -> MedicalResponse: """Query the Groq medical RAG system with re-ranking.""" start_time = time.time() # Stage 1: Initial retrieval from vector store docs = self.vector_store.search(query=query, k=k) if not docs: return self._create_no_results_response(query) # Stage 2: Re-ranking with Cross-Encoder sentence_pairs = [[query, doc.content] for doc in docs] scores = self.reranker.predict(sentence_pairs) # Combine docs with scores and sort doc_score_pairs = list(zip(docs, scores)) doc_score_pairs.sort(key=lambda x: x[1], reverse=True) # Select top N results after re-ranking reranked_docs = [pair[0] for pair in doc_score_pairs[:top_n_rerank]] reranked_scores = [pair[1] for pair in doc_score_pairs[:top_n_rerank]] # Prepare context with rich metadata for the LLM context_parts = [] for i, doc in enumerate(reranked_docs, 1): citation = doc.metadata.get('citation') if not citation: source_path = doc.metadata.get('source', 'Unknown') citation = Path(source_path).parent.name # Add reference number to citation context_parts.append(f"[{i}] Citation: {citation}\\n\\nContent: {doc.content}") context = "\\n\\n---\\n\\n".join(context_parts) confidence = self._calculate_confidence(reranked_scores, use_llm) # Use a set to get unique citations for display sources = list(set([ doc.metadata.get('citation', Path(doc.metadata.get('source', 'Unknown')).parent.name) for doc in reranked_docs ])) if use_llm: # Phase 4: Persona-driven, structured response generation system_prompt = ( "You are 'VedaMD', a world-class medical expert and a compassionate assistant for healthcare professionals in Sri Lanka. " "Your primary goal is to provide accurate, evidence-based clinical information based ONLY on the provided context, which is sourced from official Sri Lankan maternal health guidelines. " "Your tone should be professional, clear, and supportive.\\n\\n" "**CRITICAL INSTRUCTIONS:**\\n" "1. **Strictly Context-Bound:** Your answer MUST be based exclusively on the 'Content' provided for each source. Do not use any external knowledge or provide information not present in the context.\\n" "2. **Markdown Formatting:** Structure your answers for maximum clarity. Use markdown for formatting:\\n" " - Use headings (`##`) for main topics.\\n" " - Use bullet points (`-` or `*`) for lists of symptoms, recommendations, or steps.\\n" " - Use bold (`**text**`) to emphasize key terms, dosages, or critical warnings.\\n" "3. **Synthesize, Don't Just Copy:** Read all context pieces, synthesize the information, and provide a comprehensive answer. Do not repeat information.\\n" "4. **Scientific Citations:** Use numbered citations [1], [2], etc. in your answer text to reference specific information. At the end, list all sources under a 'References:' heading in scientific format:\\n" " [1] Title of Guideline/Document\\n" " [2] Title of Another Guideline/Document\\n" "5. **Disclaimer:** At the end of EVERY response, include the following disclaimer: '_This information is for clinical reference based on Sri Lankan guidelines and does not replace professional medical judgment._'" ) return self._create_llm_response(system_prompt, context, query, confidence, sources, start_time, history) else: # If not using LLM, return context directly return MedicalResponse( answer=context, confidence=confidence, sources=sources, query_time=time.time() - start_time ) def _create_llm_response(self, system_prompt: str, context: str, query: str, confidence: float, sources: List[str], start_time: float, history: Optional[List[Dict[str, str]]] = None) -> MedicalResponse: """Helper to generate response from LLM.""" try: messages = [ { "role": "system", "content": system_prompt, } ] # Add conversation history to the messages if history: messages.extend(history) # Add the current query messages.append({"role": "user", "content": f"Context:\\n{context}\\n\\nQuestion: {query}"}) chat_completion = self.groq_client.chat.completions.create( messages=messages, model=self.model_name, temperature=0.7, max_tokens=2048, top_p=1, stream=False ) response_content = chat_completion.choices[0].message.content return MedicalResponse( answer=response_content, confidence=confidence, sources=sources, query_time=time.time() - start_time, ) except Exception as e: self.logger.error(f"Error during Groq API call: {e}") return MedicalResponse( answer=f"Sorry, I encountered an error while generating the response: {e}", confidence=0, sources=sources, query_time=time.time() - start_time ) def _calculate_confidence(self, scores: List[float], use_llm: bool) -> float: """ Calculate confidence score based on re-ranked results. For LLM responses, we can be more optimistic. """ if not scores: return 0.0 # Simple average of scores, scaled avg_score = sum(scores) / len(scores) # Sigmoid-like scaling for better confidence representation confidence = 1 / (1 + np.exp(-avg_score)) if use_llm: return min(confidence * 1.2, 1.0) # Boost confidence for LLM return confidence def _create_no_results_response(self, query: str) -> MedicalResponse: """Helper for no results response""" return MedicalResponse( answer="No relevant documents found for your query. Please try rephrasing your question.", confidence=0, sources=[], query_time=0 ) def main(recreate_index: bool = False): """Main function to initialize and test the RAG system.""" print("Initializing Groq Medical RAG system...") try: rag_system = GroqMedicalRAG() if recreate_index: print("Recreating index as requested...") # Re-initialize with force_recreate=True rag_system._initialize_system(force_recreate=True) print("✅ Index recreated successfully.") return # Exit after recreating index print("✅ System initialized successfully.") # Example query for testing print("\\n--- Testing with an example query ---") query = "What is the management for puerperal sepsis?" print(f"Query: {query}") response = rag_system.query(query) print("\\n--- Response ---") print(f"Answer: {response.answer}") print(f"Confidence: {response.confidence:.2f}") print(f"Sources: {response.sources}") print(f"Query Time: {response.query_time:.2f}s") print("--------------------\\n") except Exception as e: print(f"An error occurred: {e}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Groq Medical RAG System CLI") parser.add_argument( "--recreate-index", action="store_true", help="If set, deletes the existing vector store and creates a new one." ) args = parser.parse_args() main(recreate_index=args.recreate_index) async def main_async(recreate_index: bool = False): # This function seems to be unused in the current context, but I'll add a pass to avoid syntax errors. pass