Spaces:
Sleeping
Sleeping
# app.py | |
import logging | |
import re | |
import requests | |
import numpy as np | |
import faiss | |
import gradio as gr | |
from bs4 import BeautifulSoup | |
from sentence_transformers import SentenceTransformer | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores.faiss import FAISS | |
from langchain.llms import Together | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.docstore.document import Document | |
from langchain.chains import RetrievalQA | |
# Logging setup | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load Embedding Model | |
logger.info("π Loading sentence transformer...") | |
embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
# Load LLM (Replace with your API Key) | |
llm = Together( | |
model="togethercomputer/llama-3-70b-chat", | |
temperature=0.7, | |
max_tokens=512, | |
together_api_key="your_together_api_key" | |
) | |
def fetch_webpage_text(url): | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.text, "html.parser") | |
content_div = soup.find("div", {"id": "mw-content-text"}) or soup.body | |
return content_div.get_text(separator="\n", strip=True) | |
except Exception as e: | |
logger.error(f"Error fetching content: {e}") | |
return "" | |
def clean_text(text): | |
text = re.sub(r'\[\s*\d+\s*\]', '', text) | |
text = re.sub(r'\[\s*[a-zA-Z]+\s*\]', '', text) | |
text = re.sub(r'^\[\s*\d+\s*\]$', '', text, flags=re.MULTILINE) | |
text = re.sub(r'\n{2,}', '\n', text) | |
text = re.sub(r'[ \t]+', ' ', text) | |
return text.strip() | |
def chunk_text(text, chunk_size=500, overlap=50): | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=overlap | |
) | |
return splitter.split_text(text) | |
def create_vectorstore(chunks): | |
texts = [chunk for chunk in chunks] | |
embeddings = [embed_model.encode(text) for text in texts] | |
dim = embeddings[0].shape[0] | |
index = faiss.IndexFlatL2(dim) | |
index.add(np.array(embeddings).astype(np.float32)) | |
return index, texts, embeddings | |
def get_summary(chunks): | |
full_doc = Document(page_content="\n\n".join(chunks)) | |
summarize_chain = load_summarize_chain(llm, chain_type="map_reduce") | |
return summarize_chain.run([full_doc]) | |
def retrieve_answer(query, chunks, embeddings, texts): | |
query_vector = embed_model.encode(query).astype(np.float32) | |
index = faiss.IndexFlatL2(embeddings[0].shape[0]) | |
index.add(np.array(embeddings).astype(np.float32)) | |
D, I = index.search(np.array([query_vector]), k=5) | |
top_chunks = [texts[i] for i in I[0]] | |
rag_doc = "\n\n".join(top_chunks) | |
qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=None) | |
return qa_chain.run(input_documents=[Document(page_content=rag_doc)], question=query) | |
# Gradio Interface | |
def run_chatbot(url, query): | |
raw_text = fetch_webpage_text(url) | |
if not raw_text: | |
return "β Failed to fetch content.", "" | |
cleaned = clean_text(raw_text) | |
chunks = chunk_text(cleaned) | |
if not chunks: | |
return "β No valid content to process.", "" | |
summary = get_summary(chunks) | |
index, texts, embeddings = create_vectorstore(chunks) | |
answer = retrieve_answer(query, chunks, embeddings, texts) | |
return summary, answer | |
demo = gr.Interface( | |
fn=run_chatbot, | |
inputs=[ | |
gr.Textbox(label="Webpage URL", placeholder="Enter a Wikipedia link"), | |
gr.Textbox(label="Your Question", placeholder="Ask a question about the webpage") | |
], | |
outputs=[ | |
gr.Textbox(label="Webpage Summary"), | |
gr.Textbox(label="Answer") | |
], | |
title="π¦ LLaMA RAG Chatbot", | |
description="Enter a Wikipedia article URL and ask a question. Powered by Together AI and LangChain.", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |