khalednabawi11 commited on
Commit
c5ecd72
·
verified ·
1 Parent(s): 51265e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -277
app.py CHANGED
@@ -1,344 +1,344 @@
1
- import torch
2
- import asyncio
3
- import logging
4
- import signal
5
- import uvicorn
6
- import os
7
-
8
- from fastapi import FastAPI, Request, HTTPException, status
9
- from pydantic import BaseModel, Field
10
- from langdetect import detect
11
-
12
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GenerationConfig
13
- from langchain.vectorstores import Qdrant
14
- from langchain.embeddings import HuggingFaceEmbeddings
15
- from langchain.chains import RetrievalQA
16
- from langchain.llms import HuggingFacePipeline
17
- from qdrant_client import QdrantClient
18
- from langchain.callbacks.base import BaseCallbackHandler
19
- from huggingface_hub import hf_hub_download
20
- from contextlib import asynccontextmanager
21
-
22
- # Get environment variables
23
- COLLECTION_NAME = "arabic_rag_collection"
24
- QDRANT_URL = os.getenv("QDRANT_URL", "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333")
25
- QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4")
26
-
27
- # === LOGGING === #
28
- logging.basicConfig(level=logging.DEBUG)
29
- logger = logging.getLogger(__name__)
30
-
31
- # Load model and tokenizer
32
- model_name = "FreedomIntelligence/Apollo-2B"
33
- tokenizer = AutoTokenizer.from_pretrained(model_name)
34
- model = AutoModelForCausalLM.from_pretrained(model_name)
35
- tokenizer.pad_token = tokenizer.eos_token
36
-
37
-
38
- # FastAPI setup
39
- app = FastAPI(title="Apollo RAG Medical Chatbot")
40
 
 
 
 
41
 
42
- # Generation settings
43
- generation_config = GenerationConfig(
44
- max_new_tokens=150,
45
- temperature=0.2,
46
- top_k=20,
47
- do_sample=True,
48
- top_p=0.7,
49
- repetition_penalty=1.3,
50
- )
 
 
 
 
 
 
 
 
 
51
 
52
- # Text generation pipeline
53
- llm_pipeline = pipeline(
54
- model=model,
55
- tokenizer=tokenizer,
56
- task="text-generation",
57
- generation_config=generation_config,
58
- device=model.device.index if model.device.type == "cuda" else -1
59
- )
60
 
61
- llm = HuggingFacePipeline(pipeline=llm_pipeline)
62
 
63
- # Connect to Qdrant + embedding
64
- embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1")
65
- qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
66
 
67
- vector_store = Qdrant(
68
- client=qdrant_client,
69
- collection_name=COLLECTION_NAME,
70
- embeddings=embedding
71
- )
72
 
73
- retriever = vector_store.as_retriever(search_kwargs={"k": 3})
 
 
 
 
 
 
 
 
74
 
75
- # Set up RAG QA chain
76
- qa_chain = RetrievalQA.from_chain_type(
77
- llm=llm,
78
- retriever=retriever,
79
- chain_type="stuff"
80
- )
 
 
81
 
82
- class Query(BaseModel):
83
- question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
84
 
85
- class TimeoutCallback(BaseCallbackHandler):
86
- def __init__(self, timeout_seconds: int = 60):
87
- self.timeout_seconds = timeout_seconds
88
- self.start_time = None
89
 
90
- async def on_llm_start(self, *args, **kwargs):
91
- self.start_time = asyncio.get_event_loop().time()
 
 
 
92
 
93
- async def on_llm_new_token(self, *args, **kwargs):
94
- if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
95
- raise TimeoutError("LLM processing timeout")
96
 
 
 
 
 
 
 
97
 
98
- # def generate_prompt(question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # lang = detect(question)
100
  # if lang == "ar":
101
- # return (
102
- # "أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. \n"
103
- # "- عدم تكرار أي نقطة أو عبارة أو كلمة\n"
104
- # "- وضوح وسلاسة كل نقطة\n"
105
- # "- تجنب الحشو والعبارات الزائدة\n"
106
- # f"\nالسؤال: {question}\nالإجابة:"
107
- # )
108
- # else:
109
- # return (
110
- # "Answer the following medical question in clear English with a detailed, non-redundant response. "
111
- # "Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant "
112
- # "information, rely on your prior medical knowledge. If the answer involves multiple points, list them "
113
- # "in concise and distinct bullet points:\n"
114
- # f"Question: {question}\nAnswer:"
115
- # )
116
-
117
- def generate_prompt(question):
118
- lang = detect(question)
119
- if lang == "ar":
120
- return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
121
- وتأكد من ان:
122
- - عدم تكرار أي نقطة أو عبارة أو كلمة
123
- - وضوح وسلاسة كل نقطة
124
- - تجنب الحشو والعبارات الزائدة-
125
 
126
- السؤال: {question}
127
- الإجابة:
128
- """
129
 
130
- else:
131
- return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points:
132
- Question: {question}
133
- Answer:"""
134
 
135
- # === ROUTES === #
136
- @app.get("/")
137
- async def root():
138
- return {"message": "Medical QA API is running!"}
139
 
140
- @app.post("/ask")
141
- async def ask(query: Query):
142
- try:
143
- logger.debug(f"Received question: {query.question}")
144
- prompt = generate_prompt(query.question)
145
- timeout_callback = TimeoutCallback(timeout_seconds=360)
146
- loop = asyncio.get_event_loop()
147
 
148
- response = await asyncio.wait_for(
149
- # qa_chain.run(prompt, callbacks=[timeout_callback]),
150
- loop.run_in_executor(None, qa_chain.run, prompt),
151
- timeout=360
152
- )
153
 
154
- if not response:
155
- raise ValueError("Empty answer returned from model")
156
 
157
- answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
158
 
159
- return {
160
- "status": "success",
161
- "response": response,
162
- "answer": answer,
163
- "language": detect(query.question)
164
- }
165
-
166
- except TimeoutError as te:
167
- logger.error("Request timed out", exc_info=True)
168
- raise HTTPException(
169
- status_code=status.HTTP_504_GATEWAY_TIMEOUT,
170
- detail={"status": "error", "message": "Request timed out", "error": str(te)}
171
- )
172
 
173
- except Exception as e:
174
- logger.error(f"Unexpected error: {e}", exc_info=True)
175
- raise HTTPException(
176
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
177
- detail={"status": "error", "message": "Internal server error", "error": str(e)}
178
- )
179
 
180
- @app.post("/chat")
181
- def chat(query: Query):
182
 
183
- logger.debug(f"Received question: {query.question}")
184
 
185
- prompt = generate_prompt(query.question)
186
 
187
- response = qa_chain.run(prompt)
188
 
189
- answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
190
 
191
 
192
- return {
193
- "response": response,
194
- "answer": answer
195
- }
196
 
197
 
198
 
199
- # === ENTRYPOINT === #
200
- if __name__ == "__main__":
201
- def handle_exit(signum, frame):
202
- print("Shutting down gracefully...")
203
- exit(0)
204
 
205
- signal.signal(signal.SIGINT, handle_exit)
206
- import uvicorn
207
- uvicorn.run(app, host="0.0.0.0", port=8000)
208
 
209
 
210
 
211
- # from langdetect import detect
212
- # from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline, GenerationConfig
213
- # import torch
214
- # import logging
215
- # from fastapi import FastAPI, Request, HTTPException, status
216
- # from pydantic import BaseModel, Field
217
- # import time
218
- # import asyncio
219
- # from concurrent.futures import ThreadPoolExecutor
220
- # from fastapi.middleware.cors import CORSMiddleware
221
-
222
- # logging.basicConfig(level=logging.INFO)
223
- # logger = logging.getLogger(__name__)
224
 
225
- # # Load model and tokenizer
226
- # model_name = "FreedomIntelligence/Apollo-7B"
227
- # # model_name = "emilyalsentzer/Bio_ClinicalBERT"
228
- # # model_name = "FreedomIntelligence/Apollo-2B"
229
 
230
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
231
- # model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
232
 
233
- # tokenizer.pad_token = tokenizer.eos_token
 
234
 
235
- # app = FastAPI(title="Apollo RAG Medical Chatbot")
236
 
237
- # # Add this after creating the `app`
238
- # app.add_middleware(
239
- # CORSMiddleware,
240
- # allow_origins=["*"], # Allow all origins
241
- # allow_credentials=True,
242
- # allow_methods=["*"],
243
- # allow_headers=["*"],
244
- # )
245
 
 
 
 
 
 
 
 
 
246
 
247
 
248
- # generation_config = GenerationConfig(
249
- # max_new_tokens=150,
250
- # temperature=0.2,
251
- # top_k=20,
252
- # do_sample=True,
253
- # top_p=0.7,
254
- # repetition_penalty=1.3,
255
- # )
256
 
257
- # # Create generation pipeline
258
- # pipe = TextGenerationPipeline(
259
- # model=model,
260
- # tokenizer=tokenizer,
261
- # device=model.device.index if torch.cuda.is_available() else "cpu"
262
- # )
 
 
263
 
264
- # # Prompt formatter based on language
265
- # def generate_prompt(message):
266
- # lang = detect(message)
267
- # if lang == "ar":
268
- # return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
269
- # وتأكد من ان:
270
- # - عدم تكرار أي نقطة أو عبارة أو كلمة
271
- # - وضوح وسلاسة كل نقطة
272
- # - تجنب الحشو والعبارات الزائدة
273
- # السؤال: {message}
274
- # الإجابة:"""
275
- # else:
276
- # return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas or restate the question. If information is missing, rely on your prior medical knowledge:
277
- # Question: {message}
278
- # Answer:"""
279
 
280
- # # Chat function
281
- # # @app.post("/ask")
282
- # # def chat_fn(message):
283
- # # prompt = generate_prompt(message)
284
- # # response = pipe(prompt,
285
- # # max_new_tokens=512,
286
- # # temperature=0.7,
287
- # # do_sample = True,
288
- # # top_p=0.9)[0]['generated_text']
289
- # # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
290
- # # return {"Answer": answer}
291
-
292
- # executor = ThreadPoolExecutor()
293
-
294
- # # Define request model
295
- # class Query(BaseModel):
296
- # message: str
297
 
298
- # @app.get("/")
299
- # def read_root():
300
- # return {"message": "Apollo Medical Chatbot API is running"}
 
 
 
 
 
 
 
 
301
 
 
302
 
303
- # # @app.post("/ask")
304
- # # async def chat_fn(query: Query):
305
-
306
- # # message = query.message
307
- # # logger.info(f"Received message: {message}")
308
-
309
- # # prompt = generate_prompt(message)
310
 
311
- # # # Run blocking inference in thread
312
- # # loop = asyncio.get_event_loop()
313
- # # response = await loop.run_in_executor(executor,
314
- # # lambda: pipe(prompt, max_new_tokens=512, temperature=0.7, do_sample=True, top_p=0.9)[0]['generated_text'])
315
 
316
- # # # Parse answer
317
- # # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
318
- # # return {"Answer": answer}
319
 
320
  # @app.post("/ask")
321
  # async def chat_fn(query: Query):
 
322
  # message = query.message
323
  # logger.info(f"Received message: {message}")
324
 
325
  # prompt = generate_prompt(message)
326
 
327
- # try:
328
- # start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
- # loop = asyncio.get_event_loop()
331
- # response = await loop.run_in_executor(
332
- # executor,
333
- # lambda: pipe(prompt, max_new_tokens=150, temperature=0.7, do_sample=True, top_p=0.9)[0]['generated_text']
334
- # )
335
 
336
- # duration = time.time() - start_time
337
- # logger.info(f"Model inference completed in {duration:.2f} seconds")
338
 
339
- # logger.info(f"Generated answer: {answer}")
340
- # return {"Answer": answer}
341
 
342
- # except Exception as e:
343
- # logger.error(f"Inference failed: {str(e)}")
344
- # raise HTTPException(status_code=500, detail="Model inference TimeOut failed.")
 
1
+ # import torch
2
+ # import asyncio
3
+ # import logging
4
+ # import signal
5
+ # import uvicorn
6
+ # import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # from fastapi import FastAPI, Request, HTTPException, status
9
+ # from pydantic import BaseModel, Field
10
+ # from langdetect import detect
11
 
12
+ # from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GenerationConfig
13
+ # from langchain.vectorstores import Qdrant
14
+ # from langchain.embeddings import HuggingFaceEmbeddings
15
+ # from langchain.chains import RetrievalQA
16
+ # from langchain.llms import HuggingFacePipeline
17
+ # from qdrant_client import QdrantClient
18
+ # from langchain.callbacks.base import BaseCallbackHandler
19
+ # from huggingface_hub import hf_hub_download
20
+ # from contextlib import asynccontextmanager
21
+
22
+ # # Get environment variables
23
+ # COLLECTION_NAME = "arabic_rag_collection"
24
+ # QDRANT_URL = os.getenv("QDRANT_URL", "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333")
25
+ # QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4")
26
+
27
+ # # === LOGGING === #
28
+ # logging.basicConfig(level=logging.DEBUG)
29
+ # logger = logging.getLogger(__name__)
30
 
31
+ # # Load model and tokenizer
32
+ # model_name = "FreedomIntelligence/Apollo-2B"
33
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+ # model = AutoModelForCausalLM.from_pretrained(model_name)
35
+ # tokenizer.pad_token = tokenizer.eos_token
 
 
 
36
 
 
37
 
38
+ # # FastAPI setup
39
+ # app = FastAPI(title="Apollo RAG Medical Chatbot")
 
40
 
 
 
 
 
 
41
 
42
+ # # Generation settings
43
+ # generation_config = GenerationConfig(
44
+ # max_new_tokens=150,
45
+ # temperature=0.2,
46
+ # top_k=20,
47
+ # do_sample=True,
48
+ # top_p=0.7,
49
+ # repetition_penalty=1.3,
50
+ # )
51
 
52
+ # # Text generation pipeline
53
+ # llm_pipeline = pipeline(
54
+ # model=model,
55
+ # tokenizer=tokenizer,
56
+ # task="text-generation",
57
+ # generation_config=generation_config,
58
+ # device=model.device.index if model.device.type == "cuda" else -1
59
+ # )
60
 
61
+ # llm = HuggingFacePipeline(pipeline=llm_pipeline)
 
62
 
63
+ # # Connect to Qdrant + embedding
64
+ # embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1")
65
+ # qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
 
66
 
67
+ # vector_store = Qdrant(
68
+ # client=qdrant_client,
69
+ # collection_name=COLLECTION_NAME,
70
+ # embeddings=embedding
71
+ # )
72
 
73
+ # retriever = vector_store.as_retriever(search_kwargs={"k": 3})
 
 
74
 
75
+ # # Set up RAG QA chain
76
+ # qa_chain = RetrievalQA.from_chain_type(
77
+ # llm=llm,
78
+ # retriever=retriever,
79
+ # chain_type="stuff"
80
+ # )
81
 
82
+ # class Query(BaseModel):
83
+ # question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
84
+
85
+ # class TimeoutCallback(BaseCallbackHandler):
86
+ # def __init__(self, timeout_seconds: int = 60):
87
+ # self.timeout_seconds = timeout_seconds
88
+ # self.start_time = None
89
+
90
+ # async def on_llm_start(self, *args, **kwargs):
91
+ # self.start_time = asyncio.get_event_loop().time()
92
+
93
+ # async def on_llm_new_token(self, *args, **kwargs):
94
+ # if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
95
+ # raise TimeoutError("LLM processing timeout")
96
+
97
+
98
+ # # def generate_prompt(question: str) -> str:
99
+ # # lang = detect(question)
100
+ # # if lang == "ar":
101
+ # # return (
102
+ # # "أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. \n"
103
+ # # "- عدم تكرار أي نقطة أو عبارة أو كلمة\n"
104
+ # # "- وضوح وسلاسة كل نقطة\n"
105
+ # # "- تجنب الحشو والعبارات الزائدة\n"
106
+ # # f"\nالسؤال: {question}\nالإجابة:"
107
+ # # )
108
+ # # else:
109
+ # # return (
110
+ # # "Answer the following medical question in clear English with a detailed, non-redundant response. "
111
+ # # "Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant "
112
+ # # "information, rely on your prior medical knowledge. If the answer involves multiple points, list them "
113
+ # # "in concise and distinct bullet points:\n"
114
+ # # f"Question: {question}\nAnswer:"
115
+ # # )
116
+
117
+ # def generate_prompt(question):
118
  # lang = detect(question)
119
  # if lang == "ar":
120
+ # return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
121
+ # وتأكد من ان:
122
+ # - عدم تكرار أي نقطة أو عبارة أو كلمة
123
+ # - وضوح وسلاسة كل نقطة
124
+ # - تجنب الحشو والعبارات الزائدة-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # السؤال: {question}
127
+ # الإجابة:
128
+ # """
129
 
130
+ # else:
131
+ # return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points:
132
+ # Question: {question}
133
+ # Answer:"""
134
 
135
+ # # === ROUTES === #
136
+ # @app.get("/")
137
+ # async def root():
138
+ # return {"message": "Medical QA API is running!"}
139
 
140
+ # @app.post("/ask")
141
+ # async def ask(query: Query):
142
+ # try:
143
+ # logger.debug(f"Received question: {query.question}")
144
+ # prompt = generate_prompt(query.question)
145
+ # timeout_callback = TimeoutCallback(timeout_seconds=360)
146
+ # loop = asyncio.get_event_loop()
147
 
148
+ # response = await asyncio.wait_for(
149
+ # # qa_chain.run(prompt, callbacks=[timeout_callback]),
150
+ # loop.run_in_executor(None, qa_chain.run, prompt),
151
+ # timeout=360
152
+ # )
153
 
154
+ # if not response:
155
+ # raise ValueError("Empty answer returned from model")
156
 
157
+ # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
158
 
159
+ # return {
160
+ # "status": "success",
161
+ # "response": response,
162
+ # "answer": answer,
163
+ # "language": detect(query.question)
164
+ # }
165
+
166
+ # except TimeoutError as te:
167
+ # logger.error("Request timed out", exc_info=True)
168
+ # raise HTTPException(
169
+ # status_code=status.HTTP_504_GATEWAY_TIMEOUT,
170
+ # detail={"status": "error", "message": "Request timed out", "error": str(te)}
171
+ # )
172
 
173
+ # except Exception as e:
174
+ # logger.error(f"Unexpected error: {e}", exc_info=True)
175
+ # raise HTTPException(
176
+ # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
177
+ # detail={"status": "error", "message": "Internal server error", "error": str(e)}
178
+ # )
179
 
180
+ # @app.post("/chat")
181
+ # def chat(query: Query):
182
 
183
+ # logger.debug(f"Received question: {query.question}")
184
 
185
+ # prompt = generate_prompt(query.question)
186
 
187
+ # response = qa_chain.run(prompt)
188
 
189
+ # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
190
 
191
 
192
+ # return {
193
+ # "response": response,
194
+ # "answer": answer
195
+ # }
196
 
197
 
198
 
199
+ # # === ENTRYPOINT === #
200
+ # if __name__ == "__main__":
201
+ # def handle_exit(signum, frame):
202
+ # print("Shutting down gracefully...")
203
+ # exit(0)
204
 
205
+ # signal.signal(signal.SIGINT, handle_exit)
206
+ # import uvicorn
207
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
208
 
209
 
210
 
211
+ from langdetect import detect
212
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline, GenerationConfig
213
+ import torch
214
+ import logging
215
+ from fastapi import FastAPI, Request, HTTPException, status
216
+ from pydantic import BaseModel, Field
217
+ import time
218
+ import asyncio
219
+ from concurrent.futures import ThreadPoolExecutor
220
+ from fastapi.middleware.cors import CORSMiddleware
 
 
 
221
 
222
+ logging.basicConfig(level=logging.INFO)
223
+ logger = logging.getLogger(__name__)
 
 
224
 
225
+ # Load model and tokenizer
226
+ model_name = "FreedomIntelligence/Apollo-7B"
227
+ # model_name = "emilyalsentzer/Bio_ClinicalBERT"
228
+ # model_name = "FreedomIntelligence/Apollo-2B"
229
 
230
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
231
+ model = AutoModelForCausalLM.from_pretrained(model_name)
232
 
233
+ tokenizer.pad_token = tokenizer.eos_token
234
 
235
+ app = FastAPI(title="Apollo RAG Medical Chatbot")
 
 
 
 
 
 
 
236
 
237
+ # Add this after creating the `app`
238
+ app.add_middleware(
239
+ CORSMiddleware,
240
+ allow_origins=["*"], # Allow all origins
241
+ allow_credentials=True,
242
+ allow_methods=["*"],
243
+ allow_headers=["*"],
244
+ )
245
 
246
 
 
 
 
 
 
 
 
 
247
 
248
+ generation_config = GenerationConfig(
249
+ max_new_tokens=150,
250
+ temperature=0.2,
251
+ top_k=20,
252
+ do_sample=True,
253
+ top_p=0.7,
254
+ repetition_penalty=1.3,
255
+ )
256
 
257
+ # Create generation pipeline
258
+ pipe = TextGenerationPipeline(
259
+ model=model,
260
+ tokenizer=tokenizer,
261
+ device=model.device.index if torch.cuda.is_available() else "cpu"
262
+ )
 
 
 
 
 
 
 
 
 
263
 
264
+ # Prompt formatter based on language
265
+ def generate_prompt(message):
266
+ lang = detect(message)
267
+ if lang == "ar":
268
+ return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
269
+ وتأكد من ان:
270
+ - عدم تكرار أي نقطة أو عبارة أو كلمة
271
+ - وضوح وسلاسة كل نقطة
272
+ - تجنب الحشو والعبارات الزائدة
273
+ السؤال: {message}
274
+ الإجابة:"""
275
+ else:
276
+ return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas or restate the question. If information is missing, rely on your prior medical knowledge:
277
+ Question: {message}
278
+ Answer:"""
 
 
279
 
280
+ # Chat function
281
+ # @app.post("/ask")
282
+ # def chat_fn(message):
283
+ # prompt = generate_prompt(message)
284
+ # response = pipe(prompt,
285
+ # max_new_tokens=512,
286
+ # temperature=0.7,
287
+ # do_sample = True,
288
+ # top_p=0.9)[0]['generated_text']
289
+ # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
290
+ # return {"Answer": answer}
291
 
292
+ executor = ThreadPoolExecutor()
293
 
294
+ # Define request model
295
+ class Query(BaseModel):
296
+ message: str
 
 
 
 
297
 
298
+ @app.get("/")
299
+ def read_root():
300
+ return {"message": "Apollo Medical Chatbot API is running"}
 
301
 
 
 
 
302
 
303
  # @app.post("/ask")
304
  # async def chat_fn(query: Query):
305
+
306
  # message = query.message
307
  # logger.info(f"Received message: {message}")
308
 
309
  # prompt = generate_prompt(message)
310
 
311
+ # # Run blocking inference in thread
312
+ # loop = asyncio.get_event_loop()
313
+ # response = await loop.run_in_executor(executor,
314
+ # lambda: pipe(prompt, max_new_tokens=512, temperature=0.7, do_sample=True, top_p=0.9)[0]['generated_text'])
315
+
316
+ # # Parse answer
317
+ # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
318
+ # return {"Answer": answer}
319
+
320
+ @app.post("/ask")
321
+ async def chat_fn(query: Query):
322
+ message = query.message
323
+ logger.info(f"Received message: {message}")
324
+
325
+ prompt = generate_prompt(message)
326
+
327
+ try:
328
+ start_time = time.time()
329
 
330
+ loop = asyncio.get_event_loop()
331
+ response = await loop.run_in_executor(
332
+ executor,
333
+ lambda: pipe(prompt, max_new_tokens=150, temperature=0.6, do_sample=True, top_p=0.8)[0]['generated_text']
334
+ )
335
 
336
+ duration = time.time() - start_time
337
+ logger.info(f"Model inference completed in {duration:.2f} seconds")
338
 
339
+ logger.info(f"Generated answer: {answer}")
340
+ return {"Answer": answer}
341
 
342
+ except Exception as e:
343
+ logger.error(f"Inference failed: {str(e)}")
344
+ raise HTTPException(status_code=500, detail="Model inference TimeOut failed.")