File size: 2,902 Bytes
48ec4db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from app.settings import app
import redis.asyncio as redis
from app.settings import logger, settings
from app.core.response_parser import add_links
import json
import asyncio
from app.backend.controllers.messages import register_message
from app.core.utils import initialize_rag
from celery import Task


class AsyncTask(Task):
    abstract = True

    def __call__(self, *args, **kwargs):
        try:
            loop = asyncio.get_event_loop()
            if loop.is_closed():
                loop = asyncio.new_event_loop()
                asyncio.set_event_loop(loop)
        except RuntimeError:
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
        return loop.run_until_complete(self.run(*args, **kwargs))

    async def run(self, *args, **kwargs):
        pass


redis_settings = settings.redis.model_dump()
redis_client = redis.Redis(**redis_settings)


@app.task(base=AsyncTask, queue='high_priority', bind=True, max_retries=3)
async def process_documents(self, collection_name: str, files: list[str], chat_id: str):
    await logger.info("Start background task")
    RAG = initialize_rag()
    try:
        await RAG.upload_documents(collection_name=collection_name, documents=files)
        return {"status": "success", "collection_name": collection_name, "chat_id": chat_id}
    except Exception as e:
        await logger.error(f"Error processing the documents at process_documents: {e}")
        self.retry(countdown=2**self.request.retries, exc=e)


@app.task(base=AsyncTask, queue='default', bind=True, max_retries=3)
async def generate_response(self, collection_name: str, prompt: str, chat_id: str, task_id: str):
    RAG = initialize_rag()
    await logger.info(f"Task id -----> {task_id}")
    try:
        full_response = ""
        async for chunk in RAG.generate_response_stream(collection_name=collection_name, user_prompt=prompt):
            print(chunk)
            full_response += chunk
            await redis_client.rpush(f"response:{task_id}:chunks", json.dumps({"chunk": chunk}))
            await redis_client.set(f"response:{task_id}:status", "streaming")
            await asyncio.sleep(0.01)
        await logger.info(f"Full response length: {len(full_response)}, preview: {full_response[:200]}...")
        await register_message(content=await add_links(full_response), sender="assistant", chat_id=chat_id)
        await redis_client.set(f"response:{task_id}:status", "completed")
        await redis_client.expire(f"response:{task_id}:chunks", 300)
        return {"status": "success", "response": full_response, "chat_id": chat_id}
    except Exception as e:
        await logger.error(f"Error at generate_response: {e}")
        await redis_client.set(f"response:{task_id}:status", "failed")
        await redis_client.set(f"response:{task_id}:error", str(e))
        self.retry(countdown=2**self.request.retries, exc=e)