navjotk's picture
Update app.py
290b82a verified
import os
from pathlib import Path
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
# Constants
DB_FAISS_PATH = "vectorstore/db_faiss" # Pre-generated FAISS directory
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
MODEL_NAME = "MBZUAI/LaMini-Flan-T5-783M" # Lightweight CPU-friendly model
# Step 1: Load FAISS vectorstore (already created offline)
def load_vector_store():
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
return FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
# Step 2: Load lightweight HuggingFace model (no token needed)
def load_llm():
pipe = pipeline("text2text-generation", model=MODEL_NAME)
return HuggingFacePipeline(pipeline=pipe)
# Step 3: Setup QA chain
def setup_chain():
prompt_template = """
Use the following context to answer the question.
If the answer is not in the context, just say you don't know.
Context: {context}
Question: {question}
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
retriever = load_vector_store().as_retriever(search_kwargs={"k": 3})
llm = load_llm()
return RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": prompt}
)
qa_chain = setup_chain()
# Step 4: Gradio Interface
def rag_bot(query):
result = qa_chain.invoke({"query": query})
return result["result"]
# Step 5: Launch Interface
demo = gr.Interface(
fn=rag_bot,
inputs="text",
outputs="text",
title="TextileVision: AI Chatbot",
description="Ask queries about loom speed, yarn mixing, knitting prediction, and textile operations."
)
demo.launch()