NNEngine's picture
Create app.py
0ff7449 verified
import gradio as gr
import numpy as np
import faiss
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# =====================================
# 1. LOAD DOCUMENTS
# =====================================
def load_documents(path="documents.txt"):
with open(path, "r", encoding="utf-8") as f:
docs = f.readlines()
return [doc.strip() for doc in docs if doc.strip()]
documents = load_documents()
# =====================================
# 2. LOAD EMBEDDING MODEL (HF Open Source)
# =====================================
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
doc_embeddings = embedding_model.encode(documents, convert_to_numpy=True)
dimension = doc_embeddings.shape[1]
# =====================================
# 3. BUILD FAISS INDEX
# =====================================
index = faiss.IndexFlatL2(dimension)
index.add(doc_embeddings)
# =====================================
# 4. LOAD OPEN-SOURCE LLM (HF)
# =====================================
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # change if needed
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
# =====================================
# 5. RETRIEVAL FUNCTION
# =====================================
def retrieve(query, top_k=3):
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_embedding, top_k)
retrieved_docs = [documents[i] for i in indices[0]]
return retrieved_docs
# =====================================
# 6. GENERIC LLM CALL
# =====================================
def call_llm(prompt):
response = generator(
prompt,
max_new_tokens=300,
do_sample=True,
temperature=0.7,
top_p=0.9
)
return response[0]["generated_text"]
# =====================================
# 7. RAG PIPELINE
# =====================================
def rag_pipeline(query):
retrieved_docs = retrieve(query)
context = "\n".join(retrieved_docs)
prompt = f"""
You are a helpful AI assistant.
Answer ONLY from the provided context.
Context:
{context}
Question:
{query}
Answer:
"""
answer = call_llm(prompt)
return answer
# =====================================
# 8. GRADIO UI
# =====================================
with gr.Blocks() as demo:
gr.Markdown("# 🧠 Open Source RAG (HF Only)")
query = gr.Textbox(label="Ask your question")
output = gr.Textbox(label="Answer")
query.submit(rag_pipeline, query, output)
demo.launch()