khalednabawi11 commited on
Commit
dc80dbe
·
verified ·
1 Parent(s): c5ecd72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -33
app.py CHANGED
@@ -223,9 +223,9 @@ logging.basicConfig(level=logging.INFO)
223
  logger = logging.getLogger(__name__)
224
 
225
  # Load model and tokenizer
226
- model_name = "FreedomIntelligence/Apollo-7B"
227
  # model_name = "emilyalsentzer/Bio_ClinicalBERT"
228
- # model_name = "FreedomIntelligence/Apollo-2B"
229
 
230
  tokenizer = AutoTokenizer.from_pretrained(model_name)
231
  model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -300,45 +300,48 @@ def read_root():
300
  return {"message": "Apollo Medical Chatbot API is running"}
301
 
302
 
303
- # @app.post("/ask")
304
- # async def chat_fn(query: Query):
305
-
306
- # message = query.message
307
- # logger.info(f"Received message: {message}")
308
-
309
- # prompt = generate_prompt(message)
310
-
311
- # # Run blocking inference in thread
312
- # loop = asyncio.get_event_loop()
313
- # response = await loop.run_in_executor(executor,
314
- # lambda: pipe(prompt, max_new_tokens=512, temperature=0.7, do_sample=True, top_p=0.9)[0]['generated_text'])
315
-
316
- # # Parse answer
317
- # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
318
- # return {"Answer": answer}
319
-
320
  @app.post("/ask")
321
  async def chat_fn(query: Query):
 
322
  message = query.message
323
  logger.info(f"Received message: {message}")
324
 
325
  prompt = generate_prompt(message)
326
 
327
- try:
328
- start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
- loop = asyncio.get_event_loop()
331
- response = await loop.run_in_executor(
332
- executor,
333
- lambda: pipe(prompt, max_new_tokens=150, temperature=0.6, do_sample=True, top_p=0.8)[0]['generated_text']
334
- )
335
 
336
- duration = time.time() - start_time
337
- logger.info(f"Model inference completed in {duration:.2f} seconds")
338
 
339
- logger.info(f"Generated answer: {answer}")
340
- return {"Answer": answer}
341
 
342
- except Exception as e:
343
- logger.error(f"Inference failed: {str(e)}")
344
- raise HTTPException(status_code=500, detail="Model inference TimeOut failed.")
 
223
  logger = logging.getLogger(__name__)
224
 
225
  # Load model and tokenizer
226
+ # model_name = "FreedomIntelligence/Apollo-7B"
227
  # model_name = "emilyalsentzer/Bio_ClinicalBERT"
228
+ model_name = "FreedomIntelligence/Apollo-2B"
229
 
230
  tokenizer = AutoTokenizer.from_pretrained(model_name)
231
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
300
  return {"message": "Apollo Medical Chatbot API is running"}
301
 
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  @app.post("/ask")
304
  async def chat_fn(query: Query):
305
+
306
  message = query.message
307
  logger.info(f"Received message: {message}")
308
 
309
  prompt = generate_prompt(message)
310
 
311
+ # Run blocking inference in thread
312
+ loop = asyncio.get_event_loop()
313
+ response = await loop.run_in_executor(executor,
314
+ lambda: pipe(prompt, max_new_tokens=150, temperature=0.7, do_sample=True, top_p=0.9)[0]['generated_text'])
315
+
316
+ # Parse answer
317
+ answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
318
+ return {
319
+ "response": response,
320
+ "Answer": answer
321
+ }
322
+
323
+ # @app.post("/ask")
324
+ # async def chat_fn(query: Query):
325
+ # message = query.message
326
+ # logger.info(f"Received message: {message}")
327
+
328
+ # prompt = generate_prompt(message)
329
+
330
+ # try:
331
+ # start_time = time.time()
332
 
333
+ # loop = asyncio.get_event_loop()
334
+ # response = await loop.run_in_executor(
335
+ # executor,
336
+ # lambda: pipe(prompt, max_new_tokens=150, temperature=0.6, do_sample=True, top_p=0.8)[0]['generated_text']
337
+ # )
338
 
339
+ # duration = time.time() - start_time
340
+ # logger.info(f"Model inference completed in {duration:.2f} seconds")
341
 
342
+ # logger.info(f"Generated answer: {answer}")
343
+ # return {"Answer": answer}
344
 
345
+ # except Exception as e:
346
+ # logger.error(f"Inference failed: {str(e)}")
347
+ # raise HTTPException(status_code=500, detail="Model inference TimeOut failed.")