Spaces:
Sleeping
Sleeping
File size: 8,830 Bytes
44a2e1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
import os
import Stemmer
import requests
from utils import get_and_chunk_documents, llm, embed_model, get_index
from utils import Settings
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.settings import Settings
from llama_index.core import VectorStoreIndex
from llama_index.core.llms import ChatMessage
from llama_index.core.retrievers import QueryFusionRetriever
import json
Settings.llm = llm
Settings.embed_model = embed_model
index = get_index()
hybrid_retriever = None
vector_retriever = None
bm25_retriever = None
if index:
try:
vector_retriever = index.as_retriever(similarity_top_k=15)
print("β
Vector retriever initialized successfully")
all_nodes = index.docstore.docs
if len(all_nodes) == 0:
print("β οΈ Warning: No documents found in index, skipping BM25 retriever")
hybrid_retriever = vector_retriever
else:
has_text_content = False
for node_id, node in all_nodes.items():
if hasattr(node, 'text') and node.text and node.text.strip():
has_text_content = True
break
if not has_text_content:
print("β οΈ Warning: No text content found in documents, skipping BM25 retriever")
hybrid_retriever = vector_retriever
else:
try:
print("π Creating BM25 retriever...")
bm25_retriever = BM25Retriever.from_defaults(
docstore=index.docstore,
similarity_top_k=15,
verbose=False
)
print("β
BM25 retriever initialized successfully")
hybrid_retriever = QueryFusionRetriever(
retrievers=[vector_retriever, bm25_retriever],
similarity_top_k=20,
num_queries=1,
mode="reciprocal_rerank",
use_async=False,
)
print("β
Hybrid retriever initialized successfully")
except Exception as e:
print(f"β Warning: Could not initialize BM25 retriever: {e}")
print("π Falling back to vector-only retrieval")
hybrid_retriever = vector_retriever
except Exception as e:
print(f"β Warning: Could not initialize retrievers: {e}")
hybrid_retriever = None
vector_retriever = None
bm25_retriever = None
else:
print("β Warning: Could not initialize retrievers - index is None")
def call_groq_api(prompt):
"""Call Groq API instead of LM Studio"""
try:
response = Settings.llm.complete(prompt)
return str(response)
except Exception as e:
print(f"β Groq API call failed: {e}")
raise e
def get_direct_answer(question, symptom_summary, conversation_context="", max_context_nodes=8, is_risk_assessment=True):
"""Get answer using hybrid retriever with retrieved context"""
print(f"π― Processing question: {question}")
if not hybrid_retriever:
return "Error: Retriever not available. Please check if documents are properly loaded in the index."
try:
print("π Retrieving with available retrieval method...")
retrieved_nodes = hybrid_retriever.retrieve(question)
print(f"π Retrieved {len(retrieved_nodes)} nodes")
except Exception as e:
print(f"β Retrieval failed: {e}")
return f"Error during document retrieval: {e}. Please check your document index."
if not retrieved_nodes:
return "No relevant documents found for this question. Please ensure your medical knowledge base is properly loaded and consult your healthcare provider for medical advice."
try:
reranker = SentenceTransformerRerank(
model='cross-encoder/ms-marco-MiniLM-L-2-v2',
top_n=max_context_nodes,
)
reranked_nodes = reranker.postprocess_nodes(retrieved_nodes, query_str=question)
print(f"π― After reranking: {len(reranked_nodes)} nodes")
except Exception as e:
print(f"β Reranking failed: {e}, using original nodes")
reranked_nodes = retrieved_nodes[:max_context_nodes]
filtered_nodes = []
pregnancy_keywords = ['pregnancy', 'preeclampsia', 'gestational', 'trimester', 'fetal', 'bleeding', 'contractions', 'prenatal']
for node in reranked_nodes:
node_text = node.get_text().lower()
if any(keyword in node_text for keyword in pregnancy_keywords):
filtered_nodes.append(node)
if filtered_nodes:
reranked_nodes = filtered_nodes[:max_context_nodes]
print(f"π After pregnancy keyword filtering: {len(reranked_nodes)} nodes")
else:
print("β οΈ No pregnancy-related content found, using original nodes")
context_chunks = []
total_chars = 0
max_context_chars = 6000
for node in reranked_nodes:
node_text = node.get_text()
if total_chars + len(node_text) <= max_context_chars:
context_chunks.append(node_text)
total_chars += len(node_text)
else:
remaining_chars = max_context_chars - total_chars
if remaining_chars > 100:
context_chunks.append(node_text[:remaining_chars] + "...")
break
context_text = "\n\n---\n\n".join(context_chunks)
if is_risk_assessment:
prompt = f"""You are the GraviLog Pregnancy Risk Assessment Agent. Use ONLY the context belowβdo not invent or add any new medical facts.
SYMPTOM RESPONSES:
{symptom_summary}
MEDICAL KNOWLEDGE:
{context_text}
Respond ONLY in this exact format (no extra text):
π₯ Risk Assessment Complete
**Risk Level:** <Low/Medium/High>
**Recommended Action:** <from KB's Risk Output Labels>
π¬ Rationale:
<One or two sentences citing which bullet(s) from the KB triggered your risk level.>"""
else:
prompt = f"""You are a pregnancy health assistant. Based on the medical knowledge below, answer the user's question about pregnancy symptoms and conditions.
USER QUESTION: {question}
CONVERSATION CONTEXT:
{conversation_context}
CURRENT SYMPTOMS REPORTED:
{symptom_summary}
MEDICAL KNOWLEDGE:
{context_text}
Provide a clear, informative answer based on the medical knowledge. Always mention if symptoms require medical attention and provide risk level (Low/Medium/High) when relevant."""
try:
print("π€ Generating response with Groq API...")
response_text = call_groq_api(prompt)
return response_text
except Exception as e:
print(f"β LLM response failed: {e}")
import traceback
traceback.print_exc()
return f"Error generating response: {e}"
def get_answer_with_query_engine(question):
"""Alternative approach using LlamaIndex query engine with hybrid retrieval"""
try:
print(f"π― Processing question with query engine: {question}")
if index is None:
return "Error: Could not load index"
if hybrid_retriever:
query_engine = RetrieverQueryEngine.from_args(
retriever=hybrid_retriever,
response_synthesizer=get_response_synthesizer(
response_mode="compact",
use_async=False
),
node_postprocessors=[
SentenceTransformerRerank(
model='cross-encoder/ms-marco-MiniLM-L-2-v2',
top_n=5
)
]
)
else:
query_engine = index.as_query_engine(
similarity_top_k=10,
response_mode="compact"
)
print("π€ Querying with engine...")
response = query_engine.query(question)
return str(response)
except Exception as e:
print(f"β Query engine failed: {e}")
import traceback
traceback.print_exc()
return f"Error with query engine: {e}. Please check your setup and try again." |