khalednabawi11 commited on
Commit
bc676ba
·
verified ·
1 Parent(s): f805e9e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from langdetect import detect
5
+
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GenerationConfig
7
+ from langchain.vectorstores import Qdrant
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.llms import HuggingFacePipeline
11
+ from qdrant_client import QdrantClient
12
+
13
+ # Get environment variables
14
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
15
+ QDRANT_URL = os.getenv("QDRANT_URL")
16
+ COLLECTION_NAME = "arabic_rag_collection"
17
+
18
+ # Load model and tokenizer
19
+ model_name = "FreedomIntelligence/Apollo-7B"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModelForCausalLM.from_pretrained(model_name)
22
+ tokenizer.pad_token = tokenizer.eos_token
23
+
24
+ # Generation settings
25
+ generation_config = GenerationConfig(
26
+ max_new_tokens=150,
27
+ temperature=0.2,
28
+ top_k=20,
29
+ do_sample=True,
30
+ top_p=0.7,
31
+ repetition_penalty=1.3,
32
+ )
33
+
34
+ # Text generation pipeline
35
+ llm_pipeline = pipeline(
36
+ model=model,
37
+ tokenizer=tokenizer,
38
+ task="text-generation",
39
+ generation_config=generation_config,
40
+ device=model.device.index if model.device.type == "cuda" else -1
41
+ )
42
+ llm = HuggingFacePipeline(pipeline=llm_pipeline)
43
+
44
+ # Connect to Qdrant + embedding
45
+ embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1")
46
+ qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
47
+
48
+ vector_store = Qdrant(
49
+ client=qdrant_client,
50
+ collection_name=COLLECTION_NAME,
51
+ embeddings=embedding
52
+ )
53
+
54
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3})
55
+
56
+ # Set up RAG QA chain
57
+ qa_chain = RetrievalQA.from_chain_type(
58
+ llm=llm,
59
+ retriever=retriever,
60
+ chain_type="stuff"
61
+ )
62
+
63
+ # FastAPI setup
64
+ app = FastAPI(title="Apollo RAG Medical Chatbot")
65
+
66
+
67
+ class Query(BaseModel):
68
+ question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
69
+
70
+ class TimeoutCallback(BaseCallbackHandler):
71
+ def __init__(self, timeout_seconds: int = 60):
72
+ self.timeout_seconds = timeout_seconds
73
+ self.start_time = None
74
+
75
+ async def on_llm_start(self, *args, **kwargs):
76
+ self.start_time = asyncio.get_event_loop().time()
77
+
78
+ async def on_llm_new_token(self, *args, **kwargs):
79
+ if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
80
+ raise TimeoutError("LLM processing timeout")
81
+
82
+ # Prompt template
83
+ def generate_prompt(question: str) -> str:
84
+ lang = detect(question)
85
+ if lang == "ar":
86
+ return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
87
+ وتأكد من ان:
88
+ - عدم تكرار أي نقطة أو عبارة أو كلمة
89
+ - وضوح وسلاسة كل نقطة
90
+ - تجنب الحشو والعبارات الزائدة
91
+ السؤال: {question}
92
+ الإجابة:"""
93
+ else:
94
+ return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas or restate the question. If the context lacks information, rely on prior medical knowledge.
95
+ Question: {question}
96
+ Answer:"""
97
+
98
+ # Input schema
99
+ # class ChatRequest(BaseModel):
100
+ # message: str
101
+
102
+ # # Output endpoint
103
+ # @app.post("/chat")
104
+ # def chat_rag(req: ChatRequest):
105
+ # prompt = generate_prompt(req.message)
106
+ # response = qa_chain.run(prompt)
107
+ # return {"response": response}
108
+
109
+
110
+ # === ROUTES === #
111
+ @app.get("/")
112
+ async def root():
113
+ return {"message": "Medical QA API is running!"}
114
+
115
+ @app.post("/ask")
116
+ async def ask(query: Query):
117
+ try:
118
+ logger.debug(f"Received question: {query.question}")
119
+ prompt = generate_prompt(query.question)
120
+ timeout_callback = TimeoutCallback(timeout_seconds=60)
121
+
122
+
123
+ loop = asyncio.get_event_loop()
124
+
125
+ answer = await asyncio.wait_for(
126
+ # qa_chain.run(prompt, callbacks=[timeout_callback]),
127
+ loop.run_in_executor(None, qa_chain.run, prompt),
128
+ timeout=360
129
+ )
130
+
131
+ if not answer:
132
+ raise ValueError("Empty answer returned from model")
133
+
134
+ if 'Answer:' in answer:
135
+ response_text = answer.split('Answer:')[-1].strip()
136
+ elif 'الإجابة:' in answer:
137
+ response_text = answer.split('الإجابة:')[-1].strip()
138
+ else:
139
+ response_text = answer.strip()
140
+
141
+
142
+ return {
143
+ "status": "success",
144
+ "response": response_text,
145
+ "language": detect(query.question)
146
+ }
147
+
148
+ except TimeoutError as te:
149
+ logger.error("Request timed out", exc_info=True)
150
+ raise HTTPException(
151
+ status_code=status.HTTP_504_GATEWAY_TIMEOUT,
152
+ detail={"status": "error", "message": "Request timed out", "error": str(te)}
153
+ )
154
+
155
+ except Exception as e:
156
+ logger.error(f"Unexpected error: {e}", exc_info=True)
157
+ raise HTTPException(
158
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
159
+ detail={"status": "error", "message": "Internal server error", "error": str(e)}
160
+ )
161
+
162
+ # === ENTRYPOINT === #
163
+ if __name__ == "__main__":
164
+ def handle_exit(signum, frame):
165
+ print("Shutting down gracefully...")
166
+ exit(0)
167
+
168
+ signal.signal(signal.SIGINT, handle_exit)
169
+ import uvicorn
170
+ uvicorn.run(app, host="0.0.0.0", port=8000)
171
+