khalednabawi11 commited on
Commit
0c82f40
·
verified ·
1 Parent(s): 0b8a0e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -22
app.py CHANGED
@@ -223,6 +223,9 @@ import torch
223
  from fastapi import FastAPI, Request, HTTPException, status
224
  from pydantic import BaseModel, Field
225
 
 
 
 
226
  # Load model and tokenizer
227
  model_name = "FreedomIntelligence/Apollo-7B"
228
  # model_name = "emilyalsentzer/Bio_ClinicalBERT"
@@ -269,29 +272,33 @@ Question: {message}
269
  Answer:"""
270
 
271
  # Chat function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  @app.post("/ask")
273
- def chat_fn(message):
 
274
  prompt = generate_prompt(message)
275
- response = pipe(prompt,
276
- max_new_tokens=512,
277
- temperature=0.7,
278
- do_sample = True,
279
- top_p=0.9)[0]['generated_text']
280
- answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
281
- return {"Answer": answer}
282
-
283
- # Gradio ChatInterface
284
- # demo = gr.ChatInterface(
285
- # fn=chat_fn,
286
- # title="🩺 Apollo Medical Chatbot",
287
- # description="Multilingual (Arabic & English) medical Q&A chatbot powered by Apollo-7B model inference.",
288
- # theme=gr.themes.Soft()
289
- # )
290
 
 
 
 
 
291
 
292
- # if __name__ == "__main__":
293
- # demo.launch(share=True)
294
-
295
-
296
-
297
-
 
223
  from fastapi import FastAPI, Request, HTTPException, status
224
  from pydantic import BaseModel, Field
225
 
226
+ import asyncio
227
+ from concurrent.futures import ThreadPoolExecutor
228
+
229
  # Load model and tokenizer
230
  model_name = "FreedomIntelligence/Apollo-7B"
231
  # model_name = "emilyalsentzer/Bio_ClinicalBERT"
 
272
  Answer:"""
273
 
274
  # Chat function
275
+ # @app.post("/ask")
276
+ # def chat_fn(message):
277
+ # prompt = generate_prompt(message)
278
+ # response = pipe(prompt,
279
+ # max_new_tokens=512,
280
+ # temperature=0.7,
281
+ # do_sample = True,
282
+ # top_p=0.9)[0]['generated_text']
283
+ # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
284
+ # return {"Answer": answer}
285
+
286
+ executor = ThreadPoolExecutor()
287
+
288
+ # Define request model
289
+ class Query(BaseModel):
290
+ message: str
291
+
292
  @app.post("/ask")
293
+ async def chat_fn(query: Query):
294
+ message = query.message
295
  prompt = generate_prompt(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
+ # Run blocking inference in thread
298
+ loop = asyncio.get_event_loop()
299
+ response = await loop.run_in_executor(executor,
300
+ lambda: pipe(prompt, max_new_tokens=512, temperature=0.7, do_sample=True, top_p=0.9)[0]['generated_text'])
301
 
302
+ # Parse answer
303
+ answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
304
+ return {"Answer": answer}