khalednabawi11 commited on
Commit
d940f83
·
verified ·
1 Parent(s): ae2daab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -174
app.py CHANGED
@@ -1,45 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
- import asyncio
3
- import logging
4
- import signal
5
- import uvicorn
6
- import os
7
 
8
  from fastapi import FastAPI, Request, HTTPException, status
9
  from pydantic import BaseModel, Field
10
- from langdetect import detect
11
-
12
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GenerationConfig
13
- from langchain.vectorstores import Qdrant
14
- from langchain.embeddings import HuggingFaceEmbeddings
15
- from langchain.chains import RetrievalQA
16
- from langchain.llms import HuggingFacePipeline
17
- from qdrant_client import QdrantClient
18
- from langchain.callbacks.base import BaseCallbackHandler
19
- from huggingface_hub import hf_hub_download
20
- from contextlib import asynccontextmanager
21
-
22
- # Get environment variables
23
- COLLECTION_NAME = "arabic_rag_collection"
24
- QDRANT_URL = os.getenv("QDRANT_URL", "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333")
25
- QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4")
26
-
27
- # === LOGGING === #
28
- logging.basicConfig(level=logging.DEBUG)
29
- logger = logging.getLogger(__name__)
30
 
31
  # Load model and tokenizer
32
  model_name = "FreedomIntelligence/Apollo-7B"
 
 
 
33
  tokenizer = AutoTokenizer.from_pretrained(model_name)
34
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
35
  tokenizer.pad_token = tokenizer.eos_token
36
 
37
 
38
- # FastAPI setup
39
  app = FastAPI(title="Apollo RAG Medical Chatbot")
40
 
41
-
42
- # Generation settings
43
  generation_config = GenerationConfig(
44
  max_new_tokens=150,
45
  temperature=0.2,
@@ -49,168 +246,53 @@ generation_config = GenerationConfig(
49
  repetition_penalty=1.3,
50
  )
51
 
52
- # Text generation pipeline
53
- llm_pipeline = pipeline(
54
  model=model,
55
  tokenizer=tokenizer,
56
- task="text-generation",
57
- generation_config=generation_config,
58
- device=model.device.index if model.device.type == "cuda" else -1
59
- )
60
-
61
- llm = HuggingFacePipeline(pipeline=llm_pipeline)
62
-
63
- # Connect to Qdrant + embedding
64
- embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1")
65
- qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
66
-
67
- vector_store = Qdrant(
68
- client=qdrant_client,
69
- collection_name=COLLECTION_NAME,
70
- embeddings=embedding
71
- )
72
-
73
- retriever = vector_store.as_retriever(search_kwargs={"k": 3})
74
-
75
- # Set up RAG QA chain
76
- qa_chain = RetrievalQA.from_chain_type(
77
- llm=llm,
78
- retriever=retriever,
79
- chain_type="stuff"
80
  )
81
 
82
- class Query(BaseModel):
83
- question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
84
-
85
- class TimeoutCallback(BaseCallbackHandler):
86
- def __init__(self, timeout_seconds: int = 60):
87
- self.timeout_seconds = timeout_seconds
88
- self.start_time = None
89
-
90
- async def on_llm_start(self, *args, **kwargs):
91
- self.start_time = asyncio.get_event_loop().time()
92
-
93
- async def on_llm_new_token(self, *args, **kwargs):
94
- if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
95
- raise TimeoutError("LLM processing timeout")
96
-
97
-
98
- # def generate_prompt(question: str) -> str:
99
- # lang = detect(question)
100
- # if lang == "ar":
101
- # return (
102
- # "أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. \n"
103
- # "- عدم تكرار أي نقطة أو عبارة أو كلمة\n"
104
- # "- وضوح وسلاسة كل نقطة\n"
105
- # "- تجنب الحشو والعبارات الزائدة\n"
106
- # f"\nالسؤال: {question}\nالإجابة:"
107
- # )
108
- # else:
109
- # return (
110
- # "Answer the following medical question in clear English with a detailed, non-redundant response. "
111
- # "Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant "
112
- # "information, rely on your prior medical knowledge. If the answer involves multiple points, list them "
113
- # "in concise and distinct bullet points:\n"
114
- # f"Question: {question}\nAnswer:"
115
- # )
116
-
117
- def generate_prompt(question):
118
- lang = detect(question)
119
  if lang == "ar":
120
- return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
121
  وتأكد من ان:
122
  - عدم تكرار أي نقطة أو عبارة أو كلمة
123
  - وضوح وسلاسة كل نقطة
124
- - تجنب الحشو والعبارات الزائدة-
125
-
126
- السؤال: {question}
127
- الإجابة:
128
- """
129
-
130
  else:
131
- return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points:
132
- Question: {question}
133
  Answer:"""
134
 
135
- # === ROUTES === #
136
- @app.get("/")
137
- async def root():
138
- return {"message": "Medical QA API is running!"}
139
-
140
  @app.post("/ask")
141
- async def ask(query: Query):
142
- try:
143
- logger.debug(f"Received question: {query.question}")
144
- prompt = generate_prompt(query.question)
145
- timeout_callback = TimeoutCallback(timeout_seconds=60)
146
-
147
- # docs = retriever.get_relevant_documents(query.question)
148
- # if not docs:
149
- # logger.warning("No documents retrieved from Qdrant for the question.")
150
- # else:
151
- # logger.debug(f"Retrieved documents: {[doc.page_content for doc in docs[:1]]}")
152
-
153
- loop = asyncio.get_event_loop()
154
-
155
- answer = await asyncio.wait_for(
156
- # qa_chain.run(prompt, callbacks=[timeout_callback]),
157
- loop.run_in_executor(None, qa_chain.run, prompt),
158
- timeout=360
159
- )
160
-
161
- if not answer:
162
- raise ValueError("Empty answer returned from model")
163
-
164
- if 'Answer:' in answer:
165
- response_text = answer.split('Answer:')[-1].strip()
166
- elif 'الإجابة:' in answer:
167
- response_text = answer.split('الإجابة:')[-1].strip()
168
- else:
169
- response_text = answer.strip()
170
 
171
-
172
- return {
173
- "status": "success",
174
- "answer": answer,
175
- "response": response_text,
176
- "language": detect(query.question)
177
- }
178
-
179
- except TimeoutError as te:
180
- logger.error("Request timed out", exc_info=True)
181
- raise HTTPException(
182
- status_code=status.HTTP_504_GATEWAY_TIMEOUT,
183
- detail={"status": "error", "message": "Request timed out", "error": str(te)}
184
- )
185
-
186
- except Exception as e:
187
- logger.error(f"Unexpected error: {e}", exc_info=True)
188
- raise HTTPException(
189
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
190
- detail={"status": "error", "message": "Internal server error", "error": str(e)}
191
- )
192
-
193
- @app.post("/chat")
194
- def chat(query: Query):
195
 
196
- prompt = generate_prompt(query.question)
197
 
198
- answer = qa_chain.run(prompt)
 
199
 
200
- return {
201
-
202
- "answer": answer
203
- }
204
 
205
 
206
 
207
- # === ENTRYPOINT === #
208
- if __name__ == "__main__":
209
- def handle_exit(signum, frame):
210
- print("Shutting down gracefully...")
211
- exit(0)
212
-
213
- signal.signal(signal.SIGINT, handle_exit)
214
- import uvicorn
215
- uvicorn.run(app, host="0.0.0.0", port=8000)
216
-
 
1
+ # import torch
2
+ # import asyncio
3
+ # import logging
4
+ # import signal
5
+ # import uvicorn
6
+ # import os
7
+
8
+ # from fastapi import FastAPI, Request, HTTPException, status
9
+ # from pydantic import BaseModel, Field
10
+ # from langdetect import detect
11
+
12
+ # from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GenerationConfig
13
+ # from langchain.vectorstores import Qdrant
14
+ # from langchain.embeddings import HuggingFaceEmbeddings
15
+ # from langchain.chains import RetrievalQA
16
+ # from langchain.llms import HuggingFacePipeline
17
+ # from qdrant_client import QdrantClient
18
+ # from langchain.callbacks.base import BaseCallbackHandler
19
+ # from huggingface_hub import hf_hub_download
20
+ # from contextlib import asynccontextmanager
21
+
22
+ # # Get environment variables
23
+ # COLLECTION_NAME = "arabic_rag_collection"
24
+ # QDRANT_URL = os.getenv("QDRANT_URL", "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333")
25
+ # QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4")
26
+
27
+ # # === LOGGING === #
28
+ # logging.basicConfig(level=logging.DEBUG)
29
+ # logger = logging.getLogger(__name__)
30
+
31
+ # # Load model and tokenizer
32
+ # model_name = "FreedomIntelligence/Apollo-7B"
33
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+ # model = AutoModelForCausalLM.from_pretrained(model_name)
35
+ # tokenizer.pad_token = tokenizer.eos_token
36
+
37
+
38
+ # # FastAPI setup
39
+ # app = FastAPI(title="Apollo RAG Medical Chatbot")
40
+
41
+
42
+ # # Generation settings
43
+ # generation_config = GenerationConfig(
44
+ # max_new_tokens=150,
45
+ # temperature=0.2,
46
+ # top_k=20,
47
+ # do_sample=True,
48
+ # top_p=0.7,
49
+ # repetition_penalty=1.3,
50
+ # )
51
+
52
+ # # Text generation pipeline
53
+ # llm_pipeline = pipeline(
54
+ # model=model,
55
+ # tokenizer=tokenizer,
56
+ # task="text-generation",
57
+ # generation_config=generation_config,
58
+ # device=model.device.index if model.device.type == "cuda" else -1
59
+ # )
60
+
61
+ # llm = HuggingFacePipeline(pipeline=llm_pipeline)
62
+
63
+ # # Connect to Qdrant + embedding
64
+ # embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1")
65
+ # qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
66
+
67
+ # vector_store = Qdrant(
68
+ # client=qdrant_client,
69
+ # collection_name=COLLECTION_NAME,
70
+ # embeddings=embedding
71
+ # )
72
+
73
+ # retriever = vector_store.as_retriever(search_kwargs={"k": 3})
74
+
75
+ # # Set up RAG QA chain
76
+ # qa_chain = RetrievalQA.from_chain_type(
77
+ # llm=llm,
78
+ # retriever=retriever,
79
+ # chain_type="stuff"
80
+ # )
81
+
82
+ # class Query(BaseModel):
83
+ # question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
84
+
85
+ # class TimeoutCallback(BaseCallbackHandler):
86
+ # def __init__(self, timeout_seconds: int = 60):
87
+ # self.timeout_seconds = timeout_seconds
88
+ # self.start_time = None
89
+
90
+ # async def on_llm_start(self, *args, **kwargs):
91
+ # self.start_time = asyncio.get_event_loop().time()
92
+
93
+ # async def on_llm_new_token(self, *args, **kwargs):
94
+ # if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
95
+ # raise TimeoutError("LLM processing timeout")
96
+
97
+
98
+ # # def generate_prompt(question: str) -> str:
99
+ # # lang = detect(question)
100
+ # # if lang == "ar":
101
+ # # return (
102
+ # # "أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. \n"
103
+ # # "- عدم تكرار أي نقطة أو عبارة أو كلمة\n"
104
+ # # "- وضوح وسلاسة كل نقطة\n"
105
+ # # "- تجنب الحشو والعبارات الزائدة\n"
106
+ # # f"\nالسؤال: {question}\nالإجابة:"
107
+ # # )
108
+ # # else:
109
+ # # return (
110
+ # # "Answer the following medical question in clear English with a detailed, non-redundant response. "
111
+ # # "Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant "
112
+ # # "information, rely on your prior medical knowledge. If the answer involves multiple points, list them "
113
+ # # "in concise and distinct bullet points:\n"
114
+ # # f"Question: {question}\nAnswer:"
115
+ # # )
116
+
117
+ # def generate_prompt(question):
118
+ # lang = detect(question)
119
+ # if lang == "ar":
120
+ # return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
121
+ # وتأكد من ان:
122
+ # - عدم تكرار أي نقطة أو عبارة أو كلمة
123
+ # - وضوح وسلاسة كل نقطة
124
+ # - تجنب الحشو والعبارات الزائدة-
125
+
126
+ # السؤال: {question}
127
+ # الإجابة:
128
+ # """
129
+
130
+ # else:
131
+ # return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points:
132
+ # Question: {question}
133
+ # Answer:"""
134
+
135
+ # # === ROUTES === #
136
+ # @app.get("/")
137
+ # async def root():
138
+ # return {"message": "Medical QA API is running!"}
139
+
140
+ # @app.post("/ask")
141
+ # async def ask(query: Query):
142
+ # try:
143
+ # logger.debug(f"Received question: {query.question}")
144
+ # prompt = generate_prompt(query.question)
145
+ # timeout_callback = TimeoutCallback(timeout_seconds=60)
146
+
147
+ # # docs = retriever.get_relevant_documents(query.question)
148
+ # # if not docs:
149
+ # # logger.warning("No documents retrieved from Qdrant for the question.")
150
+ # # else:
151
+ # # logger.debug(f"Retrieved documents: {[doc.page_content for doc in docs[:1]]}")
152
+
153
+ # loop = asyncio.get_event_loop()
154
+
155
+ # answer = await asyncio.wait_for(
156
+ # # qa_chain.run(prompt, callbacks=[timeout_callback]),
157
+ # loop.run_in_executor(None, qa_chain.run, prompt),
158
+ # timeout=360
159
+ # )
160
+
161
+ # if not answer:
162
+ # raise ValueError("Empty answer returned from model")
163
+
164
+ # if 'Answer:' in answer:
165
+ # response_text = answer.split('Answer:')[-1].strip()
166
+ # elif 'الإجابة:' in answer:
167
+ # response_text = answer.split('الإجابة:')[-1].strip()
168
+ # else:
169
+ # response_text = answer.strip()
170
+
171
+
172
+ # return {
173
+ # "status": "success",
174
+ # "answer": answer,
175
+ # "response": response_text,
176
+ # "language": detect(query.question)
177
+ # }
178
+
179
+ # except TimeoutError as te:
180
+ # logger.error("Request timed out", exc_info=True)
181
+ # raise HTTPException(
182
+ # status_code=status.HTTP_504_GATEWAY_TIMEOUT,
183
+ # detail={"status": "error", "message": "Request timed out", "error": str(te)}
184
+ # )
185
+
186
+ # except Exception as e:
187
+ # logger.error(f"Unexpected error: {e}", exc_info=True)
188
+ # raise HTTPException(
189
+ # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
190
+ # detail={"status": "error", "message": "Internal server error", "error": str(e)}
191
+ # )
192
+
193
+ # @app.post("/chat")
194
+ # def chat(query: Query):
195
+
196
+ # prompt = generate_prompt(query.question)
197
+
198
+ # answer = qa_chain.run(prompt)
199
+
200
+ # return {
201
+
202
+ # "answer": answer
203
+ # }
204
+
205
+
206
+
207
+ # # === ENTRYPOINT === #
208
+ # if __name__ == "__main__":
209
+ # def handle_exit(signum, frame):
210
+ # print("Shutting down gracefully...")
211
+ # exit(0)
212
+
213
+ # signal.signal(signal.SIGINT, handle_exit)
214
+ # import uvicorn
215
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
216
+
217
+
218
+
219
+ import gradio as gr
220
+ from langdetect import detect
221
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline, GenerationConfig
222
  import torch
 
 
 
 
 
223
 
224
  from fastapi import FastAPI, Request, HTTPException, status
225
  from pydantic import BaseModel, Field
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  # Load model and tokenizer
228
  model_name = "FreedomIntelligence/Apollo-7B"
229
+ # model_name = "emilyalsentzer/Bio_ClinicalBERT"
230
+ # model_name = "FreedomIntelligence/Apollo-2B"
231
+
232
  tokenizer = AutoTokenizer.from_pretrained(model_name)
233
  model = AutoModelForCausalLM.from_pretrained(model_name)
234
+
235
  tokenizer.pad_token = tokenizer.eos_token
236
 
237
 
 
238
  app = FastAPI(title="Apollo RAG Medical Chatbot")
239
 
 
 
240
  generation_config = GenerationConfig(
241
  max_new_tokens=150,
242
  temperature=0.2,
 
246
  repetition_penalty=1.3,
247
  )
248
 
249
+ # Create generation pipeline
250
+ pipe = TextGenerationPipeline(
251
  model=model,
252
  tokenizer=tokenizer,
253
+ device=model.device.index if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  )
255
 
256
+ # Prompt formatter based on language
257
+ def generate_prompt(message, history):
258
+ lang = detect(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  if lang == "ar":
260
+ return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
261
  وتأكد من ان:
262
  - عدم تكرار أي نقطة أو عبارة أو كلمة
263
  - وضوح وسلاسة كل نقطة
264
+ - تجنب الحشو والعبارات الزائدة
265
+ السؤال: {message}
266
+ الإجابة:"""
 
 
 
267
  else:
268
+ return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas or restate the question. If information is missing, rely on your prior medical knowledge:
269
+ Question: {message}
270
  Answer:"""
271
 
272
+ # Chat function
 
 
 
 
273
  @app.post("/ask")
274
+ def chat_fn(message, history):
275
+ prompt = generate_prompt(message, history)
276
+ response = pipe(prompt,
277
+ max_new_tokens=512,
278
+ temperature=0.7,
279
+ do_sample = True,
280
+ top_p=0.9)[0]['generated_text']
281
+ answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
282
+ return {"Answer": answer}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ # Gradio ChatInterface
285
+ # demo = gr.ChatInterface(
286
+ # fn=chat_fn,
287
+ # title="🩺 Apollo Medical Chatbot",
288
+ # description="Multilingual (Arabic & English) medical Q&A chatbot powered by Apollo-7B model inference.",
289
+ # theme=gr.themes.Soft()
290
+ # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
 
292
 
293
+ # if __name__ == "__main__":
294
+ # demo.launch(share=True)
295
 
 
 
 
 
296
 
297
 
298