Spaces:
Sleeping
Sleeping
File size: 15,391 Bytes
19aaa42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 |
"""
Groq Medical RAG System v2.0
FREE Groq Cloud API integration for advanced medical reasoning
"""
import os
import time
import logging
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from dotenv import load_dotenv
from pathlib import Path
import argparse
import shutil
import re
# Langchain for document loading and splitting
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Sentence Transformers for re-ranking
from sentence_transformers import CrossEncoder
# Groq API integration
from groq import Groq
from tenacity import retry, stop_after_attempt, wait_fixed, before_sleep_log
# Load environment variables from .env file
load_dotenv()
# Import our simplified components
from .simple_vector_store import SimpleVectorStore, SearchResult
@dataclass
class MedicalResponse:
"""Enhanced medical response structure"""
answer: str
confidence: float
sources: List[str]
query_time: float
class GroqMedicalRAG:
"""Groq-powered Medical RAG System v2.0 - FREE LLM integration"""
def __init__(self,
vector_store_dir: str = "simple_vector_store",
processed_docs_dir: str = "src/processed_markdown",
groq_api_key: Optional[str] = None):
"""Initialize the Groq medical RAG system"""
# Get the absolute path to the project root directory
project_root = Path(__file__).parent.parent.resolve()
self.vector_store_dir = project_root / vector_store_dir
self.processed_docs_dir = project_root / processed_docs_dir
# Initialize Groq client
self.groq_api_key = groq_api_key or os.getenv("GROQ_API_KEY")
if not self.groq_api_key:
raise ValueError("GROQ_API_KEY environment variable not set. Get your free API key from https://console.groq.com/keys")
self.groq_client = Groq(api_key=self.groq_api_key)
self.model_name = "llama3-70b-8192"
# Initialize Cross-Encoder for re-ranking
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# Initialize components
self.vector_store = None
self.setup_logging()
self._initialize_system()
def setup_logging(self):
"""Setup logging for the RAG system"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
def _initialize_system(self, force_recreate: bool = False):
"""Initialize the RAG system components"""
try:
# If forcing recreation, delete the old vector store
if force_recreate and self.vector_store_dir.exists():
self.logger.warning(f"Recreating index as requested. Deleting {self.vector_store_dir}...")
shutil.rmtree(self.vector_store_dir)
# Initialize vector store
self.vector_store = SimpleVectorStore(vector_store_dir=self.vector_store_dir)
# Try to load existing vector store
if not self.vector_store.load_vector_store():
self.logger.info("Creating new vector store from documents...")
self._create_vector_store()
else:
self.logger.info("Loaded existing vector store")
# Test Groq connection
self._test_groq_connection()
self.logger.info("Groq Medical RAG system initialized successfully")
except Exception as e:
self.logger.error(f"Error initializing RAG system: {e}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(2),
before_sleep=before_sleep_log(logging.getLogger(__name__), logging.INFO)
)
def _test_groq_connection(self):
"""Test Groq API connection with retry logic."""
try:
self.groq_client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": "Test"}],
max_tokens=10,
)
self.logger.info("✅ Groq API connection successful")
except Exception as e:
self.logger.error(f"❌ Groq API connection failed: {e}")
raise
def _create_vector_store(self):
"""Create vector store from processed markdown documents."""
self.logger.info(f"Checking for documents in {self.processed_docs_dir}...")
doc_files = list(self.processed_docs_dir.glob("**/*.md"))
if not doc_files:
self.logger.error(f"No markdown files found in {self.processed_docs_dir}. Please run the enhanced_pdf_processor.py script first.")
raise FileNotFoundError(f"No markdown files found in {self.processed_docs_dir}")
self.logger.info(f"Found {len(doc_files)} markdown documents to process.")
# Load documents using UnstructuredMarkdownLoader
all_docs = []
for doc_path in doc_files:
try:
loader = UnstructuredMarkdownLoader(str(doc_path))
loaded_docs = loader.load()
# We still need to ensure the 'source' is present for our context string.
for doc in loaded_docs:
if 'source' not in doc.metadata:
doc.metadata['source'] = str(doc_path)
all_docs.extend(loaded_docs)
except Exception as e:
self.logger.error(f"Error loading {doc_path}: {e}")
if not all_docs:
self.logger.error("Failed to load any documents. Vector store not created.")
return
# Split documents into chunks with smaller size and overlap
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024, # Reduced from 2048
chunk_overlap=128, # Reduced from 256
separators=["\n\n", "\n", " ", ""]
)
chunks = text_splitter.split_documents(all_docs)
self.logger.info(f"Created {len(chunks)} chunks from {len(all_docs)} documents.")
# Create embeddings and build index
embeddings, count = self.vector_store.create_embeddings(chunks)
self.vector_store.build_index(embeddings)
self.vector_store.save_vector_store()
self.logger.info(f"Created vector store with {count} embeddings.")
def query(self,
query: str,
history: Optional[List[Dict[str, str]]] = None,
k: int = 15, # Reduced from 30
top_n_rerank: int = 3, # Reduced from 5
use_llm: bool = True) -> MedicalResponse:
"""Query the Groq medical RAG system with re-ranking."""
start_time = time.time()
# Stage 1: Initial retrieval from vector store
docs = self.vector_store.search(query=query, k=k)
if not docs:
return self._create_no_results_response(query)
# Stage 2: Re-ranking with Cross-Encoder
sentence_pairs = [[query, doc.content] for doc in docs]
scores = self.reranker.predict(sentence_pairs)
# Combine docs with scores and sort
doc_score_pairs = list(zip(docs, scores))
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
# Select top N results after re-ranking
reranked_docs = [pair[0] for pair in doc_score_pairs[:top_n_rerank]]
reranked_scores = [pair[1] for pair in doc_score_pairs[:top_n_rerank]]
# Prepare context with rich metadata for the LLM
context_parts = []
for i, doc in enumerate(reranked_docs, 1):
citation = doc.metadata.get('citation')
if not citation:
source_path = doc.metadata.get('source', 'Unknown')
citation = Path(source_path).parent.name
# Add reference number to citation
context_parts.append(f"[{i}] Citation: {citation}\\n\\nContent: {doc.content}")
context = "\\n\\n---\\n\\n".join(context_parts)
confidence = self._calculate_confidence(reranked_scores, use_llm)
# Use a set to get unique citations for display
sources = list(set([
doc.metadata.get('citation', Path(doc.metadata.get('source', 'Unknown')).parent.name)
for doc in reranked_docs
]))
if use_llm:
# Phase 4: Persona-driven, structured response generation
system_prompt = (
"You are 'VedaMD', a world-class medical expert and a compassionate assistant for healthcare professionals in Sri Lanka. "
"Your primary goal is to provide accurate, evidence-based clinical information based ONLY on the provided context, which is sourced from official Sri Lankan maternal health guidelines. "
"Your tone should be professional, clear, and supportive.\\n\\n"
"**CRITICAL INSTRUCTIONS:**\\n"
"1. **Strictly Context-Bound:** Your answer MUST be based exclusively on the 'Content' provided for each source. Do not use any external knowledge or provide information not present in the context.\\n"
"2. **Markdown Formatting:** Structure your answers for maximum clarity. Use markdown for formatting:\\n"
" - Use headings (`##`) for main topics.\\n"
" - Use bullet points (`-` or `*`) for lists of symptoms, recommendations, or steps.\\n"
" - Use bold (`**text**`) to emphasize key terms, dosages, or critical warnings.\\n"
"3. **Synthesize, Don't Just Copy:** Read all context pieces, synthesize the information, and provide a comprehensive answer. Do not repeat information.\\n"
"4. **Scientific Citations:** Use numbered citations [1], [2], etc. in your answer text to reference specific information. At the end, list all sources under a 'References:' heading in scientific format:\\n"
" [1] Title of Guideline/Document\\n"
" [2] Title of Another Guideline/Document\\n"
"5. **Disclaimer:** At the end of EVERY response, include the following disclaimer: '_This information is for clinical reference based on Sri Lankan guidelines and does not replace professional medical judgment._'"
)
return self._create_llm_response(system_prompt, context, query, confidence, sources, start_time, history)
else:
# If not using LLM, return context directly
return MedicalResponse(
answer=context,
confidence=confidence,
sources=sources,
query_time=time.time() - start_time
)
def _create_llm_response(self, system_prompt: str, context: str, query: str, confidence: float, sources: List[str], start_time: float, history: Optional[List[Dict[str, str]]] = None) -> MedicalResponse:
"""Helper to generate response from LLM."""
try:
messages = [
{
"role": "system",
"content": system_prompt,
}
]
# Add conversation history to the messages
if history:
messages.extend(history)
# Add the current query
messages.append({"role": "user", "content": f"Context:\\n{context}\\n\\nQuestion: {query}"})
chat_completion = self.groq_client.chat.completions.create(
messages=messages,
model=self.model_name,
temperature=0.7,
max_tokens=2048,
top_p=1,
stream=False
)
response_content = chat_completion.choices[0].message.content
return MedicalResponse(
answer=response_content,
confidence=confidence,
sources=sources,
query_time=time.time() - start_time,
)
except Exception as e:
self.logger.error(f"Error during Groq API call: {e}")
return MedicalResponse(
answer=f"Sorry, I encountered an error while generating the response: {e}",
confidence=0,
sources=sources,
query_time=time.time() - start_time
)
def _calculate_confidence(self, scores: List[float], use_llm: bool) -> float:
"""
Calculate confidence score based on re-ranked results.
For LLM responses, we can be more optimistic.
"""
if not scores:
return 0.0
# Simple average of scores, scaled
avg_score = sum(scores) / len(scores)
# Sigmoid-like scaling for better confidence representation
confidence = 1 / (1 + np.exp(-avg_score))
if use_llm:
return min(confidence * 1.2, 1.0) # Boost confidence for LLM
return confidence
def _create_no_results_response(self, query: str) -> MedicalResponse:
"""Helper for no results response"""
return MedicalResponse(
answer="No relevant documents found for your query. Please try rephrasing your question.",
confidence=0,
sources=[],
query_time=0
)
def main(recreate_index: bool = False):
"""Main function to initialize and test the RAG system."""
print("Initializing Groq Medical RAG system...")
try:
rag_system = GroqMedicalRAG()
if recreate_index:
print("Recreating index as requested...")
# Re-initialize with force_recreate=True
rag_system._initialize_system(force_recreate=True)
print("✅ Index recreated successfully.")
return # Exit after recreating index
print("✅ System initialized successfully.")
# Example query for testing
print("\\n--- Testing with an example query ---")
query = "What is the management for puerperal sepsis?"
print(f"Query: {query}")
response = rag_system.query(query)
print("\\n--- Response ---")
print(f"Answer: {response.answer}")
print(f"Confidence: {response.confidence:.2f}")
print(f"Sources: {response.sources}")
print(f"Query Time: {response.query_time:.2f}s")
print("--------------------\\n")
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Groq Medical RAG System CLI")
parser.add_argument(
"--recreate-index",
action="store_true",
help="If set, deletes the existing vector store and creates a new one."
)
args = parser.parse_args()
main(recreate_index=args.recreate_index)
async def main_async(recreate_index: bool = False):
# This function seems to be unused in the current context, but I'll add a pass to avoid syntax errors.
pass |