khalednabawi11 commited on
Commit
8981e66
·
verified ·
1 Parent(s): bde5081

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +284 -284
app.py CHANGED
@@ -1,352 +1,352 @@
1
- # import torch
2
- # import asyncio
3
- # import logging
4
- # import signal
5
- # import uvicorn
6
- # import os
7
 
8
- # from fastapi import FastAPI, Request, HTTPException, status
9
- # from pydantic import BaseModel, Field
10
- # from langdetect import detect
11
 
12
- # from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GenerationConfig
13
- # from langchain.vectorstores import Qdrant
14
- # from langchain.embeddings import HuggingFaceEmbeddings
15
- # from langchain.chains import RetrievalQA
16
- # from langchain.llms import HuggingFacePipeline
17
- # from qdrant_client import QdrantClient
18
- # from langchain.callbacks.base import BaseCallbackHandler
19
- # from huggingface_hub import hf_hub_download
20
- # from contextlib import asynccontextmanager
21
-
22
- # # Get environment variables
23
- # COLLECTION_NAME = "arabic_rag_collection"
24
- # QDRANT_URL = os.getenv("QDRANT_URL", "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333")
25
- # QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4")
26
-
27
- # # === LOGGING === #
28
- # logging.basicConfig(level=logging.DEBUG)
29
- # logger = logging.getLogger(__name__)
30
 
31
- # # Load model and tokenizer
32
- # model_name = "FreedomIntelligence/Apollo-7B"
33
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
34
- # model = AutoModelForCausalLM.from_pretrained(model_name)
35
- # tokenizer.pad_token = tokenizer.eos_token
36
 
37
 
38
- # # FastAPI setup
39
- # app = FastAPI(title="Apollo RAG Medical Chatbot")
40
 
41
 
42
- # # Generation settings
43
- # generation_config = GenerationConfig(
44
- # max_new_tokens=150,
45
- # temperature=0.2,
46
- # top_k=20,
47
- # do_sample=True,
48
- # top_p=0.7,
49
- # repetition_penalty=1.3,
50
- # )
51
 
52
- # # Text generation pipeline
53
- # llm_pipeline = pipeline(
54
- # model=model,
55
- # tokenizer=tokenizer,
56
- # task="text-generation",
57
- # generation_config=generation_config,
58
- # device=model.device.index if model.device.type == "cuda" else -1
59
- # )
60
 
61
- # llm = HuggingFacePipeline(pipeline=llm_pipeline)
62
 
63
- # # Connect to Qdrant + embedding
64
- # embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1")
65
- # qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
66
 
67
- # vector_store = Qdrant(
68
- # client=qdrant_client,
69
- # collection_name=COLLECTION_NAME,
70
- # embeddings=embedding
71
- # )
72
 
73
- # retriever = vector_store.as_retriever(search_kwargs={"k": 3})
74
 
75
- # # Set up RAG QA chain
76
- # qa_chain = RetrievalQA.from_chain_type(
77
- # llm=llm,
78
- # retriever=retriever,
79
- # chain_type="stuff"
80
- # )
81
 
82
- # class Query(BaseModel):
83
- # question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
84
-
85
- # class TimeoutCallback(BaseCallbackHandler):
86
- # def __init__(self, timeout_seconds: int = 60):
87
- # self.timeout_seconds = timeout_seconds
88
- # self.start_time = None
89
-
90
- # async def on_llm_start(self, *args, **kwargs):
91
- # self.start_time = asyncio.get_event_loop().time()
92
-
93
- # async def on_llm_new_token(self, *args, **kwargs):
94
- # if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
95
- # raise TimeoutError("LLM processing timeout")
96
-
97
-
98
- # # def generate_prompt(question: str) -> str:
99
- # # lang = detect(question)
100
- # # if lang == "ar":
101
- # # return (
102
- # # "أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. \n"
103
- # # "- عدم تكرار أي نقطة أو عبارة أو كلمة\n"
104
- # # "- وضوح وسلاسة كل نقطة\n"
105
- # # "- تجنب الحشو والعبارات الزائدة\n"
106
- # # f"\nالسؤال: {question}\nالإجابة:"
107
- # # )
108
- # # else:
109
- # # return (
110
- # # "Answer the following medical question in clear English with a detailed, non-redundant response. "
111
- # # "Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant "
112
- # # "information, rely on your prior medical knowledge. If the answer involves multiple points, list them "
113
- # # "in concise and distinct bullet points:\n"
114
- # # f"Question: {question}\nAnswer:"
115
- # # )
116
-
117
- # def generate_prompt(question):
118
  # lang = detect(question)
119
  # if lang == "ar":
120
- # return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
121
- # وتأكد من ان:
122
- # - عدم تكرار أي نقطة أو عبارة أو كلمة
123
- # - وضوح وسلاسة كل نقطة
124
- # - تجنب الحشو والعبارات الزائدة-
 
 
 
 
 
 
 
 
 
 
125
 
126
- # السؤال: {question}
127
- # الإجابة:
128
- # """
 
 
 
 
 
 
 
 
 
129
 
130
- # else:
131
- # return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points:
132
- # Question: {question}
133
- # Answer:"""
134
 
135
- # # === ROUTES === #
136
- # @app.get("/")
137
- # async def root():
138
- # return {"message": "Medical QA API is running!"}
139
 
140
- # @app.post("/ask")
141
- # async def ask(query: Query):
142
- # try:
143
- # logger.debug(f"Received question: {query.question}")
144
- # prompt = generate_prompt(query.question)
145
- # timeout_callback = TimeoutCallback(timeout_seconds=60)
146
 
147
- # # docs = retriever.get_relevant_documents(query.question)
148
- # # if not docs:
149
- # # logger.warning("No documents retrieved from Qdrant for the question.")
150
- # # else:
151
- # # logger.debug(f"Retrieved documents: {[doc.page_content for doc in docs[:1]]}")
152
 
153
- # loop = asyncio.get_event_loop()
154
 
155
- # answer = await asyncio.wait_for(
156
- # # qa_chain.run(prompt, callbacks=[timeout_callback]),
157
- # loop.run_in_executor(None, qa_chain.run, prompt),
158
- # timeout=360
159
- # )
160
 
161
- # if not answer:
162
- # raise ValueError("Empty answer returned from model")
163
 
164
- # if 'Answer:' in answer:
165
- # response_text = answer.split('Answer:')[-1].strip()
166
- # elif 'الإجابة:' in answer:
167
- # response_text = answer.split('الإجابة:')[-1].strip()
168
- # else:
169
- # response_text = answer.strip()
170
 
171
 
172
- # return {
173
- # "status": "success",
174
- # "answer": answer,
175
- # "response": response_text,
176
- # "language": detect(query.question)
177
- # }
178
-
179
- # except TimeoutError as te:
180
- # logger.error("Request timed out", exc_info=True)
181
- # raise HTTPException(
182
- # status_code=status.HTTP_504_GATEWAY_TIMEOUT,
183
- # detail={"status": "error", "message": "Request timed out", "error": str(te)}
184
- # )
185
 
186
- # except Exception as e:
187
- # logger.error(f"Unexpected error: {e}", exc_info=True)
188
- # raise HTTPException(
189
- # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
190
- # detail={"status": "error", "message": "Internal server error", "error": str(e)}
191
- # )
192
 
193
- # @app.post("/chat")
194
- # def chat(query: Query):
195
 
196
- # prompt = generate_prompt(query.question)
197
 
198
- # answer = qa_chain.run(prompt)
199
 
200
- # return {
201
 
202
- # "answer": answer
203
- # }
204
 
205
 
206
 
207
- # # === ENTRYPOINT === #
208
- # if __name__ == "__main__":
209
- # def handle_exit(signum, frame):
210
- # print("Shutting down gracefully...")
211
- # exit(0)
212
 
213
- # signal.signal(signal.SIGINT, handle_exit)
214
- # import uvicorn
215
- # uvicorn.run(app, host="0.0.0.0", port=8000)
216
 
217
 
218
 
219
- from langdetect import detect
220
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline, GenerationConfig
221
- import torch
222
- import logging
223
- from fastapi import FastAPI, Request, HTTPException, status
224
- from pydantic import BaseModel, Field
225
- import time
226
- import asyncio
227
- from concurrent.futures import ThreadPoolExecutor
228
- from fastapi.middleware.cors import CORSMiddleware
229
 
230
- logging.basicConfig(level=logging.INFO)
231
- logger = logging.getLogger(__name__)
232
 
233
- # Load model and tokenizer
234
- model_name = "FreedomIntelligence/Apollo-7B"
235
- # model_name = "emilyalsentzer/Bio_ClinicalBERT"
236
- # model_name = "FreedomIntelligence/Apollo-2B"
237
 
238
- tokenizer = AutoTokenizer.from_pretrained(model_name)
239
- model = AutoModelForCausalLM.from_pretrained(model_name)
240
 
241
- tokenizer.pad_token = tokenizer.eos_token
242
 
243
- app = FastAPI(title="Apollo RAG Medical Chatbot")
244
 
245
- # Add this after creating the `app`
246
- app.add_middleware(
247
- CORSMiddleware,
248
- allow_origins=["*"], # Allow all origins
249
- allow_credentials=True,
250
- allow_methods=["*"],
251
- allow_headers=["*"],
252
- )
253
 
254
 
255
 
256
- generation_config = GenerationConfig(
257
- max_new_tokens=150,
258
- temperature=0.2,
259
- top_k=20,
260
- do_sample=True,
261
- top_p=0.7,
262
- repetition_penalty=1.3,
263
- )
264
 
265
- # Create generation pipeline
266
- pipe = TextGenerationPipeline(
267
- model=model,
268
- tokenizer=tokenizer,
269
- device=model.device.index if torch.cuda.is_available() else "cpu"
270
- )
271
 
272
- # Prompt formatter based on language
273
- def generate_prompt(message):
274
- lang = detect(message)
275
- if lang == "ar":
276
- return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
277
- وتأكد من ان:
278
- - عدم تكرار أي نقطة أو عبارة أو كلمة
279
- - وضوح وسلاسة كل نقطة
280
- - تجنب الحشو والعبارات الزائدة
281
- السؤال: {message}
282
- الإجابة:"""
283
- else:
284
- return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas or restate the question. If information is missing, rely on your prior medical knowledge:
285
- Question: {message}
286
- Answer:"""
287
 
288
- # Chat function
289
- # @app.post("/ask")
290
- # def chat_fn(message):
291
- # prompt = generate_prompt(message)
292
- # response = pipe(prompt,
293
- # max_new_tokens=512,
294
- # temperature=0.7,
295
- # do_sample = True,
296
- # top_p=0.9)[0]['generated_text']
297
- # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
298
- # return {"Answer": answer}
 
 
 
 
 
 
299
 
300
- executor = ThreadPoolExecutor()
 
 
301
 
302
- # Define request model
303
- class Query(BaseModel):
304
- message: str
305
 
306
- @app.get("/")
307
- def read_root():
308
- return {"message": "Apollo Medical Chatbot API is running"}
 
 
 
 
309
 
 
 
 
 
 
 
 
 
310
 
311
  # @app.post("/ask")
312
  # async def chat_fn(query: Query):
313
-
314
  # message = query.message
315
  # logger.info(f"Received message: {message}")
316
 
317
  # prompt = generate_prompt(message)
318
 
319
- # # Run blocking inference in thread
320
- # loop = asyncio.get_event_loop()
321
- # response = await loop.run_in_executor(executor,
322
- # lambda: pipe(prompt, max_new_tokens=512, temperature=0.7, do_sample=True, top_p=0.9)[0]['generated_text'])
323
-
324
- # # Parse answer
325
- # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
326
- # return {"Answer": answer}
327
-
328
- @app.post("/ask")
329
- async def chat_fn(query: Query):
330
- message = query.message
331
- logger.info(f"Received message: {message}")
332
-
333
- prompt = generate_prompt(message)
334
-
335
- try:
336
- start_time = time.time()
337
 
338
- loop = asyncio.get_event_loop()
339
- response = await loop.run_in_executor(
340
- executor,
341
- lambda: pipe(prompt, max_new_tokens=150, temperature=0.7, do_sample=True, top_p=0.9)[0]['generated_text']
342
- )
343
 
344
- duration = time.time() - start_time
345
- logger.info(f"Model inference completed in {duration:.2f} seconds")
346
 
347
- logger.info(f"Generated answer: {answer}")
348
- return {"Answer": answer}
349
 
350
- except Exception as e:
351
- logger.error(f"Inference failed: {str(e)}")
352
- raise HTTPException(status_code=500, detail="Model inference TimeOut failed.")
 
1
+ import torch
2
+ import asyncio
3
+ import logging
4
+ import signal
5
+ import uvicorn
6
+ import os
7
 
8
+ from fastapi import FastAPI, Request, HTTPException, status
9
+ from pydantic import BaseModel, Field
10
+ from langdetect import detect
11
 
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, GenerationConfig
13
+ from langchain.vectorstores import Qdrant
14
+ from langchain.embeddings import HuggingFaceEmbeddings
15
+ from langchain.chains import RetrievalQA
16
+ from langchain.llms import HuggingFacePipeline
17
+ from qdrant_client import QdrantClient
18
+ from langchain.callbacks.base import BaseCallbackHandler
19
+ from huggingface_hub import hf_hub_download
20
+ from contextlib import asynccontextmanager
21
+
22
+ # Get environment variables
23
+ COLLECTION_NAME = "arabic_rag_collection"
24
+ QDRANT_URL = os.getenv("QDRANT_URL", "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333")
25
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4")
26
+
27
+ # === LOGGING === #
28
+ logging.basicConfig(level=logging.DEBUG)
29
+ logger = logging.getLogger(__name__)
30
 
31
+ # Load model and tokenizer
32
+ model_name = "FreedomIntelligence/Apollo-2B"
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+ model = AutoModelForCausalLM.from_pretrained(model_name)
35
+ tokenizer.pad_token = tokenizer.eos_token
36
 
37
 
38
+ # FastAPI setup
39
+ app = FastAPI(title="Apollo RAG Medical Chatbot")
40
 
41
 
42
+ # Generation settings
43
+ generation_config = GenerationConfig(
44
+ max_new_tokens=150,
45
+ temperature=0.2,
46
+ top_k=20,
47
+ do_sample=True,
48
+ top_p=0.7,
49
+ repetition_penalty=1.3,
50
+ )
51
 
52
+ # Text generation pipeline
53
+ llm_pipeline = pipeline(
54
+ model=model,
55
+ tokenizer=tokenizer,
56
+ task="text-generation",
57
+ generation_config=generation_config,
58
+ device=model.device.index if model.device.type == "cuda" else -1
59
+ )
60
 
61
+ llm = HuggingFacePipeline(pipeline=llm_pipeline)
62
 
63
+ # Connect to Qdrant + embedding
64
+ embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1")
65
+ qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
66
 
67
+ vector_store = Qdrant(
68
+ client=qdrant_client,
69
+ collection_name=COLLECTION_NAME,
70
+ embeddings=embedding
71
+ )
72
 
73
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3})
74
 
75
+ # Set up RAG QA chain
76
+ qa_chain = RetrievalQA.from_chain_type(
77
+ llm=llm,
78
+ retriever=retriever,
79
+ chain_type="stuff"
80
+ )
81
 
82
+ class Query(BaseModel):
83
+ question: str = Field(..., example="ما هي اسباب تساقط الشعر ؟", min_length=3)
84
+
85
+ class TimeoutCallback(BaseCallbackHandler):
86
+ def __init__(self, timeout_seconds: int = 60):
87
+ self.timeout_seconds = timeout_seconds
88
+ self.start_time = None
89
+
90
+ async def on_llm_start(self, *args, **kwargs):
91
+ self.start_time = asyncio.get_event_loop().time()
92
+
93
+ async def on_llm_new_token(self, *args, **kwargs):
94
+ if asyncio.get_event_loop().time() - self.start_time > self.timeout_seconds:
95
+ raise TimeoutError("LLM processing timeout")
96
+
97
+
98
+ # def generate_prompt(question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # lang = detect(question)
100
  # if lang == "ar":
101
+ # return (
102
+ # "أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة. \n"
103
+ # "- عدم تكرار أي نقطة أو عبارة أو كلمة\n"
104
+ # "- وضوح وسلاسة كل نقطة\n"
105
+ # "- تجنب الحشو والعبارات الزائدة\n"
106
+ # f"\nالسؤال: {question}\nالإجابة:"
107
+ # )
108
+ # else:
109
+ # return (
110
+ # "Answer the following medical question in clear English with a detailed, non-redundant response. "
111
+ # "Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant "
112
+ # "information, rely on your prior medical knowledge. If the answer involves multiple points, list them "
113
+ # "in concise and distinct bullet points:\n"
114
+ # f"Question: {question}\nAnswer:"
115
+ # )
116
 
117
+ def generate_prompt(question):
118
+ lang = detect(question)
119
+ if lang == "ar":
120
+ return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
121
+ وتأكد من ان:
122
+ - عدم تكرار أي نقطة أو عبارة أو كلمة
123
+ - وضوح وسلاسة كل نقطة
124
+ - تجنب الحشو والعبارات الزائدة-
125
+
126
+ السؤال: {question}
127
+ الإجابة:
128
+ """
129
 
130
+ else:
131
+ return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas, phrases, or restate the question in the answer. If the context lacks relevant information, rely on your prior medical knowledge. If the answer involves multiple points, list them in concise and distinct bullet points:
132
+ Question: {question}
133
+ Answer:"""
134
 
135
+ # === ROUTES === #
136
+ @app.get("/")
137
+ async def root():
138
+ return {"message": "Medical QA API is running!"}
139
 
140
+ @app.post("/ask")
141
+ async def ask(query: Query):
142
+ try:
143
+ logger.debug(f"Received question: {query.question}")
144
+ prompt = generate_prompt(query.question)
145
+ timeout_callback = TimeoutCallback(timeout_seconds=60)
146
 
147
+ # docs = retriever.get_relevant_documents(query.question)
148
+ # if not docs:
149
+ # logger.warning("No documents retrieved from Qdrant for the question.")
150
+ # else:
151
+ # logger.debug(f"Retrieved documents: {[doc.page_content for doc in docs[:1]]}")
152
 
153
+ loop = asyncio.get_event_loop()
154
 
155
+ answer = await asyncio.wait_for(
156
+ # qa_chain.run(prompt, callbacks=[timeout_callback]),
157
+ loop.run_in_executor(None, qa_chain.run, prompt),
158
+ timeout=360
159
+ )
160
 
161
+ if not answer:
162
+ raise ValueError("Empty answer returned from model")
163
 
164
+ if 'Answer:' in answer:
165
+ response_text = answer.split('Answer:')[-1].strip()
166
+ elif 'الإجابة:' in answer:
167
+ response_text = answer.split('الإجابة:')[-1].strip()
168
+ else:
169
+ response_text = answer.strip()
170
 
171
 
172
+ return {
173
+ "status": "success",
174
+ "answer": answer,
175
+ "response": response_text,
176
+ "language": detect(query.question)
177
+ }
178
+
179
+ except TimeoutError as te:
180
+ logger.error("Request timed out", exc_info=True)
181
+ raise HTTPException(
182
+ status_code=status.HTTP_504_GATEWAY_TIMEOUT,
183
+ detail={"status": "error", "message": "Request timed out", "error": str(te)}
184
+ )
185
 
186
+ except Exception as e:
187
+ logger.error(f"Unexpected error: {e}", exc_info=True)
188
+ raise HTTPException(
189
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
190
+ detail={"status": "error", "message": "Internal server error", "error": str(e)}
191
+ )
192
 
193
+ @app.post("/chat")
194
+ def chat(query: Query):
195
 
196
+ prompt = generate_prompt(query.question)
197
 
198
+ answer = qa_chain.run(prompt)
199
 
200
+ return {
201
 
202
+ "answer": answer
203
+ }
204
 
205
 
206
 
207
+ # === ENTRYPOINT === #
208
+ if __name__ == "__main__":
209
+ def handle_exit(signum, frame):
210
+ print("Shutting down gracefully...")
211
+ exit(0)
212
 
213
+ signal.signal(signal.SIGINT, handle_exit)
214
+ import uvicorn
215
+ uvicorn.run(app, host="0.0.0.0", port=8000)
216
 
217
 
218
 
219
+ # from langdetect import detect
220
+ # from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline, GenerationConfig
221
+ # import torch
222
+ # import logging
223
+ # from fastapi import FastAPI, Request, HTTPException, status
224
+ # from pydantic import BaseModel, Field
225
+ # import time
226
+ # import asyncio
227
+ # from concurrent.futures import ThreadPoolExecutor
228
+ # from fastapi.middleware.cors import CORSMiddleware
229
 
230
+ # logging.basicConfig(level=logging.INFO)
231
+ # logger = logging.getLogger(__name__)
232
 
233
+ # # Load model and tokenizer
234
+ # model_name = "FreedomIntelligence/Apollo-7B"
235
+ # # model_name = "emilyalsentzer/Bio_ClinicalBERT"
236
+ # # model_name = "FreedomIntelligence/Apollo-2B"
237
 
238
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
239
+ # model = AutoModelForCausalLM.from_pretrained(model_name)
240
 
241
+ # tokenizer.pad_token = tokenizer.eos_token
242
 
243
+ # app = FastAPI(title="Apollo RAG Medical Chatbot")
244
 
245
+ # # Add this after creating the `app`
246
+ # app.add_middleware(
247
+ # CORSMiddleware,
248
+ # allow_origins=["*"], # Allow all origins
249
+ # allow_credentials=True,
250
+ # allow_methods=["*"],
251
+ # allow_headers=["*"],
252
+ # )
253
 
254
 
255
 
256
+ # generation_config = GenerationConfig(
257
+ # max_new_tokens=150,
258
+ # temperature=0.2,
259
+ # top_k=20,
260
+ # do_sample=True,
261
+ # top_p=0.7,
262
+ # repetition_penalty=1.3,
263
+ # )
264
 
265
+ # # Create generation pipeline
266
+ # pipe = TextGenerationPipeline(
267
+ # model=model,
268
+ # tokenizer=tokenizer,
269
+ # device=model.device.index if torch.cuda.is_available() else "cpu"
270
+ # )
271
 
272
+ # # Prompt formatter based on language
273
+ # def generate_prompt(message):
274
+ # lang = detect(message)
275
+ # if lang == "ar":
276
+ # return f"""أجب على السؤال الطبي التالي بلغة عربية فصحى، بإجابة دقيقة ومفصلة. إذا لم تجد معلومات كافية في السياق، استخدم معرفتك الطبية السابقة.
277
+ # وتأكد من ان:
278
+ # - عدم تكرار أي نقطة أو عبارة أو كلمة
279
+ # - وضوح وسلاسة كل نقطة
280
+ # - تجنب الحشو والعبارات الزائدة
281
+ # السؤال: {message}
282
+ # الإجابة:"""
283
+ # else:
284
+ # return f"""Answer the following medical question in clear English with a detailed, non-redundant response. Do not repeat ideas or restate the question. If information is missing, rely on your prior medical knowledge:
285
+ # Question: {message}
286
+ # Answer:"""
287
 
288
+ # # Chat function
289
+ # # @app.post("/ask")
290
+ # # def chat_fn(message):
291
+ # # prompt = generate_prompt(message)
292
+ # # response = pipe(prompt,
293
+ # # max_new_tokens=512,
294
+ # # temperature=0.7,
295
+ # # do_sample = True,
296
+ # # top_p=0.9)[0]['generated_text']
297
+ # # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
298
+ # # return {"Answer": answer}
299
+
300
+ # executor = ThreadPoolExecutor()
301
+
302
+ # # Define request model
303
+ # class Query(BaseModel):
304
+ # message: str
305
 
306
+ # @app.get("/")
307
+ # def read_root():
308
+ # return {"message": "Apollo Medical Chatbot API is running"}
309
 
 
 
 
310
 
311
+ # # @app.post("/ask")
312
+ # # async def chat_fn(query: Query):
313
+
314
+ # # message = query.message
315
+ # # logger.info(f"Received message: {message}")
316
+
317
+ # # prompt = generate_prompt(message)
318
 
319
+ # # # Run blocking inference in thread
320
+ # # loop = asyncio.get_event_loop()
321
+ # # response = await loop.run_in_executor(executor,
322
+ # # lambda: pipe(prompt, max_new_tokens=512, temperature=0.7, do_sample=True, top_p=0.9)[0]['generated_text'])
323
+
324
+ # # # Parse answer
325
+ # # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
326
+ # # return {"Answer": answer}
327
 
328
  # @app.post("/ask")
329
  # async def chat_fn(query: Query):
 
330
  # message = query.message
331
  # logger.info(f"Received message: {message}")
332
 
333
  # prompt = generate_prompt(message)
334
 
335
+ # try:
336
+ # start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ # loop = asyncio.get_event_loop()
339
+ # response = await loop.run_in_executor(
340
+ # executor,
341
+ # lambda: pipe(prompt, max_new_tokens=150, temperature=0.7, do_sample=True, top_p=0.9)[0]['generated_text']
342
+ # )
343
 
344
+ # duration = time.time() - start_time
345
+ # logger.info(f"Model inference completed in {duration:.2f} seconds")
346
 
347
+ # logger.info(f"Generated answer: {answer}")
348
+ # return {"Answer": answer}
349
 
350
+ # except Exception as e:
351
+ # logger.error(f"Inference failed: {str(e)}")
352
+ # raise HTTPException(status_code=500, detail="Model inference TimeOut failed.")