Spaces:
Sleeping
Sleeping
import os | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from langdetect import detect | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GenerationConfig | |
from langchain.vectorstores import Qdrant | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain.llms import HuggingFacePipeline | |
from qdrant_client import QdrantClient | |
# Get environment variables | |
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
QDRANT_URL = os.getenv("QDRANT_URL") | |
COLLECTION_NAME = "arabic_rag_collection" | |
# Load model and tokenizer | |
model_name = "FreedomIntelligence/Apollo-7B" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
# Generation settings | |
generation_config = GenerationConfig( | |
max_new_tokens=150, | |
temperature=0.2, | |
top_k=20, | |
do_sample=True, | |
top_p=0.7, | |
repetition_penalty=1.3, | |
) | |
# Text generation pipeline | |
llm_pipeline = pipeline( | |
model=model, | |
tokenizer=tokenizer, | |
task="text-generation", | |
generation_config=generation_config, | |
device=model.device.index if model.device.type == "cuda" else -1 | |
) | |
llm = HuggingFacePipeline(pipeline=llm_pipeline) | |
# Connect to Qdrant + embedding | |
embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1") | |
qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) | |
vector_store = Qdrant( | |
client=qdrant_client, | |
collection_name=COLLECTION_NAME, | |
embeddings=embedding | |
) | |
retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
# Set up RAG QA chain | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=retriever, | |
chain_type="stuff" | |
) | |
# FastAPI setup | |
app = FastAPI(title="Apollo RAG Medical Chatbot") | |
class Query(BaseModel): | |
question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3) | |
class TimeoutCallback(BaseCallbackHandler): | |
def __init__(self, timeout_seconds: int = 60): | |
self.timeout_seconds = timeout_seconds | |
self.start_time = None | |
async def on_llm_start(self, *args, **kwargs): | |
self.start_time = asyncio.get_event_loop().time() | |
async def on_llm_new_token(self, *args, **kwargs): | |
if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds: | |
raise TimeoutError("LLM processing timeout") | |
# Prompt template | |
def generate_prompt(question: str) -> str: | |
lang = detect(question) | |
if lang == "ar": | |
return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. | |
وتأكد من ان: | |
- عدم تكرار أي نقطة أو عبارة أو كلمة | |
- وضوح وسلاسة كل نقطة | |
- تجنب الحشو والعبارات الزائدة | |
السؤال: {question} | |
الإجابة:""" | |
else: | |
return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas or restate the question. If the context lacks information, rely on prior medical knowledge. | |
Question: {question} | |
Answer:""" | |
# Input schema | |
# class ChatRequest(BaseModel): | |
# message: str | |
# # Output endpoint | |
# @app.post("/chat") | |
# def chat_rag(req: ChatRequest): | |
# prompt = generate_prompt(req.message) | |
# response = qa_chain.run(prompt) | |
# return {"response": response} | |
# === ROUTES === # | |
async def root(): | |
return {"message": "Medical QA API is running!"} | |
async def ask(query: Query): | |
try: | |
logger.debug(f"Received question: {query.question}") | |
prompt = generate_prompt(query.question) | |
timeout_callback = TimeoutCallback(timeout_seconds=60) | |
loop = asyncio.get_event_loop() | |
answer = await asyncio.wait_for( | |
# qa_chain.run(prompt, callbacks=[timeout_callback]), | |
loop.run_in_executor(None, qa_chain.run, prompt), | |
timeout=360 | |
) | |
if not answer: | |
raise ValueError("Empty answer returned from model") | |
if 'Answer:' in answer: | |
response_text = answer.split('Answer:')[-1].strip() | |
elif 'الإجابة:' in answer: | |
response_text = answer.split('الإجابة:')[-1].strip() | |
else: | |
response_text = answer.strip() | |
return { | |
"status": "success", | |
"response": response_text, | |
"language": detect(query.question) | |
} | |
except TimeoutError as te: | |
logger.error("Request timed out", exc_info=True) | |
raise HTTPException( | |
status_code=status.HTTP_504_GATEWAY_TIMEOUT, | |
detail={"status": "error", "message": "Request timed out", "error": str(te)} | |
) | |
except Exception as e: | |
logger.error(f"Unexpected error: {e}", exc_info=True) | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail={"status": "error", "message": "Internal server error", "error": str(e)} | |
) | |
# === ENTRYPOINT === # | |
if __name__ == "__main__": | |
def handle_exit(signum, frame): | |
print("Shutting down gracefully...") | |
exit(0) | |
signal.signal(signal.SIGINT, handle_exit) | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |