File size: 15,391 Bytes
19aaa42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
"""
Groq Medical RAG System v2.0
FREE Groq Cloud API integration for advanced medical reasoning
"""

import os
import time
import logging
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from dotenv import load_dotenv
from pathlib import Path
import argparse
import shutil
import re

# Langchain for document loading and splitting
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Sentence Transformers for re-ranking
from sentence_transformers import CrossEncoder

# Groq API integration
from groq import Groq
from tenacity import retry, stop_after_attempt, wait_fixed, before_sleep_log

# Load environment variables from .env file
load_dotenv()

# Import our simplified components
from .simple_vector_store import SimpleVectorStore, SearchResult

@dataclass
class MedicalResponse:
    """Enhanced medical response structure"""
    answer: str
    confidence: float
    sources: List[str]
    query_time: float

class GroqMedicalRAG:
    """Groq-powered Medical RAG System v2.0 - FREE LLM integration"""
    
    def __init__(self, 
                 vector_store_dir: str = "simple_vector_store",
                 processed_docs_dir: str = "src/processed_markdown",
                 groq_api_key: Optional[str] = None):
        """Initialize the Groq medical RAG system"""
        # Get the absolute path to the project root directory
        project_root = Path(__file__).parent.parent.resolve()

        self.vector_store_dir = project_root / vector_store_dir
        self.processed_docs_dir = project_root / processed_docs_dir
        
        # Initialize Groq client
        self.groq_api_key = groq_api_key or os.getenv("GROQ_API_KEY")
        if not self.groq_api_key:
            raise ValueError("GROQ_API_KEY environment variable not set. Get your free API key from https://console.groq.com/keys")
        
        self.groq_client = Groq(api_key=self.groq_api_key)
        self.model_name = "llama3-70b-8192"
        
        # Initialize Cross-Encoder for re-ranking
        self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
        
        # Initialize components
        self.vector_store = None
        
        self.setup_logging()
        self._initialize_system()
    
    def setup_logging(self):
        """Setup logging for the RAG system"""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)
    
    def _initialize_system(self, force_recreate: bool = False):
        """Initialize the RAG system components"""
        try:
            # If forcing recreation, delete the old vector store
            if force_recreate and self.vector_store_dir.exists():
                self.logger.warning(f"Recreating index as requested. Deleting {self.vector_store_dir}...")
                shutil.rmtree(self.vector_store_dir)

            # Initialize vector store
            self.vector_store = SimpleVectorStore(vector_store_dir=self.vector_store_dir)
            
            # Try to load existing vector store
            if not self.vector_store.load_vector_store():
                self.logger.info("Creating new vector store from documents...")
                self._create_vector_store()
            else:
                self.logger.info("Loaded existing vector store")
            
            # Test Groq connection
            self._test_groq_connection()
                
            self.logger.info("Groq Medical RAG system initialized successfully")
            
        except Exception as e:
            self.logger.error(f"Error initializing RAG system: {e}")
            raise
    
    @retry(
        stop=stop_after_attempt(3),
        wait=wait_fixed(2),
        before_sleep=before_sleep_log(logging.getLogger(__name__), logging.INFO)
    )
    def _test_groq_connection(self):
        """Test Groq API connection with retry logic."""
        try:
            self.groq_client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": "Test"}],
                max_tokens=10,
            )
            self.logger.info("✅ Groq API connection successful")
        except Exception as e:
            self.logger.error(f"❌ Groq API connection failed: {e}")
            raise
    
    def _create_vector_store(self):
        """Create vector store from processed markdown documents."""
        self.logger.info(f"Checking for documents in {self.processed_docs_dir}...")
        
        doc_files = list(self.processed_docs_dir.glob("**/*.md"))
        if not doc_files:
            self.logger.error(f"No markdown files found in {self.processed_docs_dir}. Please run the enhanced_pdf_processor.py script first.")
            raise FileNotFoundError(f"No markdown files found in {self.processed_docs_dir}")

        self.logger.info(f"Found {len(doc_files)} markdown documents to process.")

        # Load documents using UnstructuredMarkdownLoader
        all_docs = []
        for doc_path in doc_files:
            try:
                loader = UnstructuredMarkdownLoader(str(doc_path))
                loaded_docs = loader.load()
                
                # We still need to ensure the 'source' is present for our context string.
                for doc in loaded_docs:
                    if 'source' not in doc.metadata:
                        doc.metadata['source'] = str(doc_path)
                
                all_docs.extend(loaded_docs)
            except Exception as e:
                self.logger.error(f"Error loading {doc_path}: {e}")
        
        if not all_docs:
            self.logger.error("Failed to load any documents. Vector store not created.")
            return

        # Split documents into chunks with smaller size and overlap
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1024,  # Reduced from 2048
            chunk_overlap=128,  # Reduced from 256
            separators=["\n\n", "\n", " ", ""]
        )
        chunks = text_splitter.split_documents(all_docs)
        
        self.logger.info(f"Created {len(chunks)} chunks from {len(all_docs)} documents.")

        # Create embeddings and build index
        embeddings, count = self.vector_store.create_embeddings(chunks)
        self.vector_store.build_index(embeddings)
        self.vector_store.save_vector_store()
        
        self.logger.info(f"Created vector store with {count} embeddings.")
    
    def query(self, 
              query: str, 
              history: Optional[List[Dict[str, str]]] = None,
              k: int = 15,  # Reduced from 30
              top_n_rerank: int = 3,  # Reduced from 5
              use_llm: bool = True) -> MedicalResponse:
        """Query the Groq medical RAG system with re-ranking."""
        start_time = time.time()
        
        # Stage 1: Initial retrieval from vector store
        docs = self.vector_store.search(query=query, k=k)
        
        if not docs:
            return self._create_no_results_response(query)

        # Stage 2: Re-ranking with Cross-Encoder
        sentence_pairs = [[query, doc.content] for doc in docs]
        scores = self.reranker.predict(sentence_pairs)
        
        # Combine docs with scores and sort
        doc_score_pairs = list(zip(docs, scores))
        doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
        
        # Select top N results after re-ranking
        reranked_docs = [pair[0] for pair in doc_score_pairs[:top_n_rerank]]
        reranked_scores = [pair[1] for pair in doc_score_pairs[:top_n_rerank]]

        # Prepare context with rich metadata for the LLM
        context_parts = []
        for i, doc in enumerate(reranked_docs, 1):
            citation = doc.metadata.get('citation')
            if not citation:
                source_path = doc.metadata.get('source', 'Unknown')
                citation = Path(source_path).parent.name
            
            # Add reference number to citation
            context_parts.append(f"[{i}] Citation: {citation}\\n\\nContent: {doc.content}")
        context = "\\n\\n---\\n\\n".join(context_parts)
        
        confidence = self._calculate_confidence(reranked_scores, use_llm)
        
        # Use a set to get unique citations for display
        sources = list(set([
            doc.metadata.get('citation', Path(doc.metadata.get('source', 'Unknown')).parent.name) 
            for doc in reranked_docs
        ]))

        if use_llm:
            # Phase 4: Persona-driven, structured response generation
            system_prompt = (
                "You are 'VedaMD', a world-class medical expert and a compassionate assistant for healthcare professionals in Sri Lanka. "
                "Your primary goal is to provide accurate, evidence-based clinical information based ONLY on the provided context, which is sourced from official Sri Lankan maternal health guidelines. "
                "Your tone should be professional, clear, and supportive.\\n\\n"
                "**CRITICAL INSTRUCTIONS:**\\n"
                "1.  **Strictly Context-Bound:** Your answer MUST be based exclusively on the 'Content' provided for each source. Do not use any external knowledge or provide information not present in the context.\\n"
                "2.  **Markdown Formatting:** Structure your answers for maximum clarity. Use markdown for formatting:\\n"
                "    - Use headings (`##`) for main topics.\\n"
                "    - Use bullet points (`-` or `*`) for lists of symptoms, recommendations, or steps.\\n"
                "    - Use bold (`**text**`) to emphasize key terms, dosages, or critical warnings.\\n"
                "3.  **Synthesize, Don't Just Copy:** Read all context pieces, synthesize the information, and provide a comprehensive answer. Do not repeat information.\\n"
                "4.  **Scientific Citations:** Use numbered citations [1], [2], etc. in your answer text to reference specific information. At the end, list all sources under a 'References:' heading in scientific format:\\n"
                "    [1] Title of Guideline/Document\\n"
                "    [2] Title of Another Guideline/Document\\n"
                "5.  **Disclaimer:** At the end of EVERY response, include the following disclaimer: '_This information is for clinical reference based on Sri Lankan guidelines and does not replace professional medical judgment._'"
            )
            return self._create_llm_response(system_prompt, context, query, confidence, sources, start_time, history)

        else:
            # If not using LLM, return context directly
            return MedicalResponse(
                answer=context,
                confidence=confidence,
                sources=sources,
                query_time=time.time() - start_time
            )

    def _create_llm_response(self, system_prompt: str, context: str, query: str, confidence: float, sources: List[str], start_time: float, history: Optional[List[Dict[str, str]]] = None) -> MedicalResponse:
        """Helper to generate response from LLM."""
        try:
            messages = [
                {
                    "role": "system",
                    "content": system_prompt,
                }
            ]
            
            # Add conversation history to the messages
            if history:
                messages.extend(history)
            
            # Add the current query
            messages.append({"role": "user", "content": f"Context:\\n{context}\\n\\nQuestion: {query}"})
            
            chat_completion = self.groq_client.chat.completions.create(
                messages=messages,
                model=self.model_name,
                temperature=0.7,
                max_tokens=2048,
                top_p=1,
                stream=False
            )
            
            response_content = chat_completion.choices[0].message.content
            
            return MedicalResponse(
                answer=response_content,
                confidence=confidence,
                sources=sources,
                query_time=time.time() - start_time,
            )
        except Exception as e:
            self.logger.error(f"Error during Groq API call: {e}")
            return MedicalResponse(
                answer=f"Sorry, I encountered an error while generating the response: {e}",
                confidence=0,
                sources=sources,
                query_time=time.time() - start_time
            )

    def _calculate_confidence(self, scores: List[float], use_llm: bool) -> float:
        """
        Calculate confidence score based on re-ranked results.
        For LLM responses, we can be more optimistic.
        """
        if not scores:
            return 0.0

        # Simple average of scores, scaled
        avg_score = sum(scores) / len(scores)
        
        # Sigmoid-like scaling for better confidence representation
        confidence = 1 / (1 + np.exp(-avg_score))
        
        if use_llm:
            return min(confidence * 1.2, 1.0)  # Boost confidence for LLM
        return confidence
    
    def _create_no_results_response(self, query: str) -> MedicalResponse:
        """Helper for no results response"""
        return MedicalResponse(
            answer="No relevant documents found for your query. Please try rephrasing your question.",
            confidence=0,
            sources=[],
            query_time=0
        )

def main(recreate_index: bool = False):
    """Main function to initialize and test the RAG system."""
    print("Initializing Groq Medical RAG system...")
    try:
        rag_system = GroqMedicalRAG()
        
        if recreate_index:
            print("Recreating index as requested...")
            # Re-initialize with force_recreate=True
            rag_system._initialize_system(force_recreate=True)
            print("✅ Index recreated successfully.")
            return # Exit after recreating index
            
        print("✅ System initialized successfully.")
        
        # Example query for testing
        print("\\n--- Testing with an example query ---")
        query = "What is the management for puerperal sepsis?"
        print(f"Query: {query}")
        
        response = rag_system.query(query)
        
        print("\\n--- Response ---")
        print(f"Answer: {response.answer}")
        print(f"Confidence: {response.confidence:.2f}")
        print(f"Sources: {response.sources}")
        print(f"Query Time: {response.query_time:.2f}s")
        print("--------------------\\n")

    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Groq Medical RAG System CLI")
    parser.add_argument(
        "--recreate-index",
        action="store_true",
        help="If set, deletes the existing vector store and creates a new one."
    )
    args = parser.parse_args()
    
    main(recreate_index=args.recreate_index)

async def main_async(recreate_index: bool = False):
    # This function seems to be unused in the current context, but I'll add a pass to avoid syntax errors.
    pass