File size: 4,759 Bytes
b3b7a20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# rag_system.py
import logging
from typing import Dict, List, Optional, TypedDict

from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.messages import HumanMessage
from langchain.prompts import ChatPromptTemplate
from langchain_core.tools import tool

from langgraph.graph import StateGraph, START, END

# Logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# RAG prompt for puppy-related questions
RAG_PROMPT =    """
                You are an assistant specialized in puppy education and care.
                Your role is to help new puppy owners by answering their questions with accuracy and kindness.
                Use only the information provided in the context to formulate your answers.
                If you cannot find the information in the context, just say "I don't know".

                ### Question
                {question}

                ### Context
                {context}
                """

class State(TypedDict):
    question: str
    context: List[Document]
    response: str

class RAGSystem:
    """RAG system for puppy-related questions"""
    
    def __init__(self, retriever, model_name: str = "gpt-4o-mini"):
        self.retriever = retriever
        self.llm = ChatOpenAI(model=model_name)
        self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
        self.graph_rag = self._build_graph()
    
    def _build_graph(self):
        """Builds the RAG graph"""
        
        def retrieve(state):
            retrieved_docs = self.retriever.invoke(state["question"])
            return {"context": retrieved_docs}

        def generate(state):
            docs_content = "\n\n".join([doc.page_content for doc in state["context"]])
            messages = self.rag_prompt.format_messages(
                question=state["question"], 
                context=docs_content
            )
            response = self.llm.invoke(messages)
            return {"response": response.content}

        # Build the graph
        graph_builder = StateGraph(State).add_sequence([retrieve, generate])
        graph_builder.add_edge(START, "retrieve")
        return graph_builder.compile()
    
    def process_query(self, question: str) -> Dict:
        """ Processes a query and returns the response with context """
        result = self.graph_rag.invoke({"question": question})
        
        # Format detailed source information
        sources_info = []
        for i, doc in enumerate(result["context"], 1):
            metadata = doc.metadata
            # Extract useful metadata information
            source_name = metadata.get('source', 'Unknown')
            page = metadata.get('page', 'N/A')
            chapter = metadata.get('chapter', '')
            
            # Create a detailed source description
            if chapter:
                source_desc = f"Chunk {i} - {source_name} (Chapter: {chapter}, Page: {page})"
            else:
                source_desc = f"Chunk {i} - {source_name} (Page: {page})"
                
            sources_info.append({
                'chunk_number': i,
                'description': source_desc,
                'source': source_name,
                'page': page,
                'chapter': chapter,
                'content_preview': doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
            })
        
        return {
            "response": result["response"],
            "context": result["context"],
            "sources_info": sources_info,
            "total_chunks": len(result["context"])
        }
    
    def create_rag_tool(self):
        """Creates a RAG tool for the agent"""
        
        # Reference to the current instance to use it in the tool
        rag_system = self
        
        @tool
        def ai_rag_tool(question: str) -> Dict:
            """MANDATORY for all questions about puppies, their behavior, education or training.
            This tool accesses a specialized knowledge base on puppies with expert and reliable information.
            Any question regarding puppy care, education, behavior or health MUST be processed by this tool.
            The input must be a complete question."""
            
            # Invoke the RAG graph
            result = rag_system.process_query(question)
            
            return {
                "messages": [HumanMessage(content=result["response"])],
                "context": result["context"],
                "sources_info": result["sources_info"],
                "total_chunks": result["total_chunks"]
            }
        
        return ai_rag_tool