Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import faiss | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # ===================================== | |
| # 1. LOAD DOCUMENTS | |
| # ===================================== | |
| def load_documents(path="documents.txt"): | |
| with open(path, "r", encoding="utf-8") as f: | |
| docs = f.readlines() | |
| return [doc.strip() for doc in docs if doc.strip()] | |
| documents = load_documents() | |
| # ===================================== | |
| # 2. LOAD EMBEDDING MODEL (HF Open Source) | |
| # ===================================== | |
| embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| doc_embeddings = embedding_model.encode(documents, convert_to_numpy=True) | |
| dimension = doc_embeddings.shape[1] | |
| # ===================================== | |
| # 3. BUILD FAISS INDEX | |
| # ===================================== | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(doc_embeddings) | |
| # ===================================== | |
| # 4. LOAD OPEN-SOURCE LLM (HF) | |
| # ===================================== | |
| MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # change if needed | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" | |
| ) | |
| generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer | |
| ) | |
| # ===================================== | |
| # 5. RETRIEVAL FUNCTION | |
| # ===================================== | |
| def retrieve(query, top_k=3): | |
| query_embedding = embedding_model.encode([query], convert_to_numpy=True) | |
| distances, indices = index.search(query_embedding, top_k) | |
| retrieved_docs = [documents[i] for i in indices[0]] | |
| return retrieved_docs | |
| # ===================================== | |
| # 6. GENERIC LLM CALL | |
| # ===================================== | |
| def call_llm(prompt): | |
| response = generator( | |
| prompt, | |
| max_new_tokens=300, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| return response[0]["generated_text"] | |
| # ===================================== | |
| # 7. RAG PIPELINE | |
| # ===================================== | |
| def rag_pipeline(query): | |
| retrieved_docs = retrieve(query) | |
| context = "\n".join(retrieved_docs) | |
| prompt = f""" | |
| You are a helpful AI assistant. | |
| Answer ONLY from the provided context. | |
| Context: | |
| {context} | |
| Question: | |
| {query} | |
| Answer: | |
| """ | |
| answer = call_llm(prompt) | |
| return answer | |
| # ===================================== | |
| # 8. GRADIO UI | |
| # ===================================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🧠 Open Source RAG (HF Only)") | |
| query = gr.Textbox(label="Ask your question") | |
| output = gr.Textbox(label="Answer") | |
| query.submit(rag_pipeline, query, output) | |
| demo.launch() |