khalednabawi11 commited on
Commit
af539aa
·
verified ·
1 Parent(s): dc80dbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -40
app.py CHANGED
@@ -222,6 +222,10 @@ from fastapi.middleware.cors import CORSMiddleware
222
  logging.basicConfig(level=logging.INFO)
223
  logger = logging.getLogger(__name__)
224
 
 
 
 
 
225
  # Load model and tokenizer
226
  # model_name = "FreedomIntelligence/Apollo-7B"
227
  # model_name = "emilyalsentzer/Bio_ClinicalBERT"
@@ -243,24 +247,43 @@ app.add_middleware(
243
  allow_headers=["*"],
244
  )
245
 
246
-
247
-
248
  generation_config = GenerationConfig(
249
  max_new_tokens=150,
250
- temperature=0.2,
251
  top_k=20,
252
  do_sample=True,
253
- top_p=0.7,
254
- repetition_penalty=1.3,
255
  )
256
 
257
  # Create generation pipeline
258
  pipe = TextGenerationPipeline(
259
  model=model,
260
  tokenizer=tokenizer,
261
- device=model.device.index if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
262
  )
263
 
 
 
 
 
 
 
 
 
 
 
264
  # Prompt formatter based on language
265
  def generate_prompt(message):
266
  lang = detect(message)
@@ -277,17 +300,6 @@ def generate_prompt(message):
277
  Question: {message}
278
  Answer:"""
279
 
280
- # Chat function
281
- # @app.post("/ask")
282
- # def chat_fn(message):
283
- # prompt = generate_prompt(message)
284
- # response = pipe(prompt,
285
- # max_new_tokens=512,
286
- # temperature=0.7,
287
- # do_sample = True,
288
- # top_p=0.9)[0]['generated_text']
289
- # answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
290
- # return {"Answer": answer}
291
 
292
  executor = ThreadPoolExecutor()
293
 
@@ -320,28 +332,19 @@ async def chat_fn(query: Query):
320
  "Answer": answer
321
  }
322
 
323
- # @app.post("/ask")
324
- # async def chat_fn(query: Query):
325
- # message = query.message
326
- # logger.info(f"Received message: {message}")
327
-
328
- # prompt = generate_prompt(message)
329
-
330
- # try:
331
- # start_time = time.time()
332
-
333
- # loop = asyncio.get_event_loop()
334
- # response = await loop.run_in_executor(
335
- # executor,
336
- # lambda: pipe(prompt, max_new_tokens=150, temperature=0.6, do_sample=True, top_p=0.8)[0]['generated_text']
337
- # )
338
-
339
- # duration = time.time() - start_time
340
- # logger.info(f"Model inference completed in {duration:.2f} seconds")
341
 
342
- # logger.info(f"Generated answer: {answer}")
343
- # return {"Answer": answer}
 
344
 
345
- # except Exception as e:
346
- # logger.error(f"Inference failed: {str(e)}")
347
- # raise HTTPException(status_code=500, detail="Model inference TimeOut failed.")
 
 
 
 
 
222
  logging.basicConfig(level=logging.INFO)
223
  logger = logging.getLogger(__name__)
224
 
225
+ COLLECTION_NAME = "arabic_rag_collection"
226
+ QDRANT_URL = os.getenv("QDRANT_URL", "https://12efeef2-9f10-4402-9deb-f070977ddfc8.eu-central-1-0.aws.cloud.qdrant.io:6333")
227
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.Jb39rYQW2rSE9RdXrjdzKY6T1RF44XjdQzCvzFkjat4")
228
+
229
  # Load model and tokenizer
230
  # model_name = "FreedomIntelligence/Apollo-7B"
231
  # model_name = "emilyalsentzer/Bio_ClinicalBERT"
 
247
  allow_headers=["*"],
248
  )
249
 
 
 
250
  generation_config = GenerationConfig(
251
  max_new_tokens=150,
252
+ temperature=0.7,
253
  top_k=20,
254
  do_sample=True,
255
+ top_p=0.8,
 
256
  )
257
 
258
  # Create generation pipeline
259
  pipe = TextGenerationPipeline(
260
  model=model,
261
  tokenizer=tokenizer,
262
+ device=model.device.index if torch.cuda.is_available() else -1
263
+ )
264
+
265
+ llm = HuggingFacePipeline(pipeline=pipe)
266
+
267
+ embedding = HuggingFaceEmbeddings(model_name="Omartificial-Intelligence-Space/GATE-AraBert-v1")
268
+
269
+ qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
270
+
271
+ vector_store = Qdrant(
272
+ client=qdrant_client,
273
+ collection_name=COLLECTION_NAME,
274
+ embeddings=embedding
275
  )
276
 
277
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3})
278
+
279
+ # ----------------- RAG Chain ------------------
280
+ qa_chain = RetrievalQA.from_chain_type(
281
+ llm=llm,
282
+ retriever=retriever,
283
+ chain_type="stuff"
284
+ )
285
+
286
+
287
  # Prompt formatter based on language
288
  def generate_prompt(message):
289
  lang = detect(message)
 
300
  Question: {message}
301
  Answer:"""
302
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  executor = ThreadPoolExecutor()
305
 
 
332
  "Answer": answer
333
  }
334
 
335
+ @app.post("/ask-rag")
336
+ async def chat_fn(query: Query):
337
+ message = query.message
338
+ logger.info(f"Received message: {message}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
+ # Run RAG inference in thread
341
+ loop = asyncio.get_event_loop()
342
+ response = await loop.run_in_executor(executor, lambda: qa_chain.run(message))
343
 
344
+ answer = response.split("Answer:")[-1].strip() if "Answer:" in response else response.split("الإجابة:")[-1].strip()
345
+
346
+
347
+ return {
348
+ "response": response,
349
+ "answer": answer
350
+ }