Spaces:
Sleeping
Sleeping
| import json | |
| import spaces | |
| from typing import List | |
| import gradio as gr | |
| from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline, ChatHuggingFace | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers.document_compressors import CrossEncoderReranker | |
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain.docstore.document import Document | |
| model = HuggingFacePipeline.from_model_id( | |
| model_id="HuggingFaceTB/SmolLM2-360M-Instruct", | |
| task="text-generation", | |
| pipeline_kwargs=dict( | |
| max_new_tokens=512, | |
| do_sample=False, | |
| repetition_penalty=1.03, | |
| return_full_text=False, | |
| ), | |
| ) | |
| llm = ChatHuggingFace(llm=model) | |
| def create_embeddings_model() -> HuggingFaceEmbeddings: | |
| model_name = "BAAI/bge-m3" | |
| model_kwargs = { | |
| 'device': 'cpu', | |
| 'trust_remote_code': True, | |
| } | |
| encode_kwargs = {'normalize_embeddings': True} | |
| return HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs, | |
| show_progress=True | |
| ) | |
| embeddings = create_embeddings_model() | |
| def load_faiss_retriever(path: str) -> FAISS: | |
| vector_store = FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True) | |
| return vector_store.as_retriever(search_kwargs={"k": 10}) | |
| def load_bm25_retriever(load_path: str) -> BM25Retriever: | |
| with open(load_path, "r", encoding="utf-8") as f: | |
| docs_json = json.load(f) | |
| documents = [Document(page_content=doc["page_content"], metadata=doc["metadata"]) for doc in docs_json] | |
| return BM25Retriever.from_documents(documents, language="english") | |
| class EmbeddingBM25RerankerRetriever: | |
| def __init__(self, vector_retriever, bm25_retriever, reranker): | |
| self.vector_retriever = vector_retriever | |
| self.bm25_retriever = bm25_retriever | |
| self.reranker = reranker | |
| def invoke(self, query: str): | |
| vector_docs = self.vector_retriever.invoke(query) | |
| bm25_docs = self.bm25_retriever.invoke(query) | |
| combined_docs = vector_docs + [doc for doc in bm25_docs if doc not in vector_docs] | |
| return self.reranker.compress_documents(combined_docs, query) | |
| faiss_path = "VectorDB/faiss_index" | |
| bm25_path = "VectorDB/bm25_index.json" | |
| faiss_retriever = load_faiss_retriever(faiss_path) | |
| bm25_retriever = load_bm25_retriever(bm25_path) | |
| reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3") | |
| reranker = CrossEncoderReranker(top_n=4, model=reranker_model) | |
| retriever = EmbeddingBM25RerankerRetriever(faiss_retriever, bm25_retriever, reranker) | |
| qa_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """You are an AI research assistant specializing in Autism research, powered by a retrieval system of curated PubMed documents. | |
| Response Guidelines: | |
| - Provide precise, evidence-based answers drawing directly from retrieved medical research | |
| - Synthesize information from multiple documents when possible | |
| - Clearly distinguish between established findings and emerging research | |
| - Maintain scientific rigor and objectivity | |
| Query Handling: | |
| - Prioritize direct, informative responses | |
| - When document evidence is incomplete, explain the current state of research | |
| - Highlight areas where more research is needed | |
| - Never introduce speculation or unsupported claims | |
| Contextual Integrity: | |
| - Ensure all statements are traceable to specific research documents | |
| - Preserve the nuance and complexity of scientific findings | |
| - Communicate with clarity, avoiding unnecessary medical jargon | |
| Knowledge Limitations: | |
| - If no relevant information is found, state: "Current research documents do not provide a comprehensive answer to this specific query." | |
| """), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "Context:\n{context}\n\nQuestion: {input}") | |
| ]) | |
| def format_context(docs) -> str: | |
| return "\n\n".join([f"Doc {i+1}: {doc.page_content}" for i, doc in enumerate(docs)]) | |
| def chat_with_rag(query: str, history: List[tuple[str, str]]) -> str: | |
| chat_history = [] | |
| for human, ai in history: | |
| chat_history.append(HumanMessage(content=human)) | |
| chat_history.append(AIMessage(content=ai)) | |
| docs = retriever.invoke(query) | |
| context = format_context(docs) | |
| prompt_input = { | |
| "chat_history": chat_history, | |
| "context": context, | |
| "input": query | |
| } | |
| prompt = qa_prompt.format(**prompt_input) | |
| response = llm.invoke(prompt) | |
| return response.content | |
| chat_interface = gr.ChatInterface( | |
| fn=chat_with_rag, | |
| title="Autism RAG Chatbot", | |
| description="Ask questions about Autism.", | |
| examples=["What causes Autism?", "How is Autism treated?", "What is Autism"], | |
| ) | |
| if __name__ == "__main__": | |
| chat_interface.launch(share=True) | |