File size: 5,482 Bytes
84b73c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b54001
84b73c1
 
 
 
 
 
 
 
 
 
 
 
9b54001
84b73c1
 
 
15176fe
84b73c1
 
 
 
 
 
 
 
 
 
 
 
7366c55
84b73c1
 
 
7366c55
84b73c1
7366c55
84b73c1
9b54001
84b73c1
 
 
 
 
 
 
 
 
7366c55
84b73c1
7366c55
84b73c1
 
 
7366c55
84b73c1
 
7366c55
84b73c1
 
 
 
 
 
 
 
 
 
7366c55
84b73c1
 
 
7366c55
84b73c1
 
 
 
7366c55
84b73c1
7366c55
 
 
 
 
 
 
 
 
 
 
 
 
 
84b73c1
 
 
 
 
 
 
7366c55
84b73c1
896c4b6
 
 
 
 
84b73c1
9b54001
 
896c4b6
 
 
84b73c1
15176fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
import requests
from bs4 import BeautifulSoup
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS

# --- Import updated RAG-specific libraries ---
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.schema import Document
from langchain.chains import RetrievalQA

# --- Basic Flask App Setup ---
app = Flask(__name__, static_folder='.', static_url_path='')
CORS(app)

# Global variable to hold the RAG chain
rag_chain = None

# --- Route for the homepage ---
@app.route('/')
def serve_index():
    """Serves the index.html file as the homepage."""
    return send_from_directory('.', 'index.html')

# --- API Route for queries ---
@app.route("/query", methods=["POST"])
def handle_query():
    """Handles incoming queries from the frontend."""
    global rag_chain
    if not rag_chain:
        return jsonify({"error": "Failed to process the query. Details: RAG pipeline not initialized. Check server logs."}), 500

    data = request.json
    query_text = data.get("query")
    if not query_text:
        return jsonify({"error": "No query provided."}), 400

    try:
        print(f"Received query: {query_text}")
        result = rag_chain.invoke(query_text)
        
        answer = result.get('result', 'No answer found.')
        sources = [doc.metadata.get('source', 'Unknown source') for doc in result.get('source_documents', [])]
        unique_sources = list(dict.fromkeys(sources))

        print(f"Generated response: {answer}")
        return jsonify({
            "answer": answer,
            "sources": unique_sources
        })
    except Exception as e:
        print(f"Error during query processing: {e}")
        return jsonify({"error": f"An error occurred: {str(e)}"}), 500

# --- RAG Pipeline Initialization ---
def initialize_rag_pipeline():
    """
    This function loads documents, creates the vector store, and initializes the RAG chain.
    """
    global rag_chain
    print("--- Starting RAG pipeline initialization ---")

    api_key = os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        print("ERROR: GOOGLE_API_KEY environment variable not set. Halting.")
        return
    print("Step 1: Google API Key found.")

    pdf_files = ["Augusta rule 101 CPE Webinar.pdf", "Augusta rule workshop.pdf"]
    pdf_docs = []
    try:
        for file in pdf_files:
            if os.path.exists(file):
                loader = PyPDFLoader(file)
                pdf_docs.extend(loader.load())
            else:
                print(f"Warning: PDF file not found at {file}")
    except Exception as e:
        print(f"ERROR loading PDFs: {e}. Halting.")
        return
    print(f"Step 2: Loaded {len(pdf_docs)} pages from PDF files.")

    def scrape_url(url):
        try:
            response = requests.get(url, timeout=15)
            response.raise_for_status()
            soup = BeautifulSoup(response.text, "html.parser")
            return Document(page_content=soup.get_text(separator=" ", strip=True), metadata={"source": url})
        except requests.RequestException as e:
            print(f"Warning: Could not scrape URL {url}. Error: {e}")
            return None

    urls = [
        "https://www.instead.com/blog/the-augusta-rule-a-tax-strategy-for-business-owners",
        "https://www.instead.com/blog/s-corp-reasonable-salary-guide",
        "https://www.instead.com/blog/how-to-start-an-s-corp"
    ]
    web_docs = [doc for doc in [scrape_url(url) for url in urls] if doc is not None]
    print(f"Step 3: Scraped {len(web_docs)} web pages.")

    all_docs = pdf_docs + web_docs
    if not all_docs:
        print("ERROR: No documents were loaded. Halting.")
        return

    splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    chunks = splitter.split_documents(all_docs)
    print(f"Step 4: Split documents into {len(chunks)} chunks.")

    try:
        embedding = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
        print("Step 5: Gemini embedding model loaded successfully.")
        
        # Create vector store in-memory for better stability on free servers
        vectorstore = Chroma.from_documents(chunks, embedding)
        print("Step 6: In-memory vector store created successfully.")

        llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", temperature=0.3)
        print("Step 7: Gemini LLM loaded successfully.")

    except Exception as e:
        print(f"ERROR during AI model initialization: {e}. Halting.")
        return

    rag_chain = RetrievalQA.from_chain_type(
        llm=llm,
        retriever=vectorstore.as_retriever(),
        return_source_documents=True
    )
    
    print("--- RAG pipeline initialized successfully! ---")

# --- Initialize the RAG pipeline when the app starts ---
# This is now called in the global scope to ensure it runs on Hugging Face
initialize_rag_pipeline()

# --- Main Execution Block (Only used for local testing) ---
if __name__ == "__main__":
    if not rag_chain:
        print("\nCould not start the server because the RAG pipeline failed to initialize.")
    else:
        # This app.run is for local development and will not be used by Gunicorn on Hugging Face
        app.run(host='0.0.0.0', port=5000)