GenAIDevTOProd's picture
Update app.py
9facab9 verified
import gradio as gr
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
# --- minimal core (in-memory only) ---
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
_model = SentenceTransformer(MODEL_NAME)
_dim = int(_model.encode(["_probe_"], convert_to_numpy=True).shape[1]) # 384
_index = faiss.IndexFlatIP(_dim) # cosine via L2-normalized IP
_ids, _texts, _metas = [], [], []
def _normalize(v: np.ndarray) -> np.ndarray:
n = np.linalg.norm(v, axis=1, keepdims=True) + 1e-12
return (v / n).astype("float32")
def _chunk(text: str, size: int, overlap: int):
t = " ".join((text or "").split())
n = len(t); s = 0; out = []
if overlap >= size: overlap = max(size - 1, 0)
while s < n:
e = min(s + size, n)
out.append((t[s:e], s, e))
if e == n: break
s = max(e - overlap, 0)
return out
def reset():
global _index, _ids, _texts, _metas
_index = faiss.IndexFlatIP(_dim)
_ids, _texts, _metas = [], [], []
return gr.update(value="Index reset."), gr.update(value=0)
def load_sample():
docs = [
("a", "PySpark scales ETL across clusters.", {"tag":"spark"}),
("b", "FAISS powers fast vector similarity search used in retrieval.", {"tag":"faiss"})
]
return "\n".join([d[1] for d in docs])
def ingest(docs_text, size, overlap):
if not docs_text.strip():
return "Provide at least one line of text.", len(_ids)
# one document per line
lines = [ln.strip() for ln in docs_text.splitlines() if ln.strip()]
rows = []
for i, ln in enumerate(lines):
pid = f"doc-{len(_ids)}-{i}"
for ctext, s, e in _chunk(ln, size, overlap):
rows.append((f"{pid}::offset:{s}-{e}", ctext, {"parent_id": pid, "start": s, "end": e}))
if not rows:
return "No chunks produced.", len(_ids)
vecs = _normalize(_model.encode([r[1] for r in rows], convert_to_numpy=True))
_index.add(vecs)
for rid, txt, meta in rows:
_ids.append(rid); _texts.append(txt); _metas.append(meta)
return f"Ingested docs={len(lines)} chunks={len(rows)}", len(_ids)
def answer(q, k, max_context_chars):
if _index.ntotal == 0:
return {"answer": "Index is empty. Ingest first.", "matches": []}
qv = _normalize(_model.encode([q], convert_to_numpy=True))
D, I = _index.search(qv, int(k))
matches = []
for i, s in zip(I[0].tolist(), D[0].tolist()):
if i < 0:
continue
matches.append({
"id": _ids[i],
"score": float(s),
"text": _texts[i],
"meta": _metas[i]
})
if not matches:
out = "No relevant context."
else:
# 👇 only use the top match for the answer
top = matches[0]["text"]
out = f"Based on retrieved context:\n- {top}"
return {"answer": out, "matches": matches}
with gr.Blocks(title="RAG-as-a-Service") as demo:
gr.Markdown("### RAG-as-a-Service - Gradio\nIn-memory FAISS + MiniLM\n; one-line-per-doc ingest\n; quick answers.")
with gr.Row():
with gr.Column():
docs = gr.Textbox(label="Documents (one per line)", lines=6, placeholder="One document per line…")
with gr.Row():
chunk_size = gr.Slider(64, 1024, value=256, step=16, label="Chunk size")
overlap = gr.Slider(0, 256, value=32, step=8, label="Overlap")
with gr.Row():
ingest_btn = gr.Button("Ingest")
sample_btn = gr.Button("Load sample")
reset_btn = gr.Button("Reset")
ingest_status = gr.Textbox(label="Ingest status", interactive=False)
index_size = gr.Number(label="Index size", interactive=False, value=0)
with gr.Column():
q = gr.Textbox(label="Query", placeholder="Ask something...")
k = gr.Slider(1, 10, value=5, step=1, label="Top-K")
max_chars = gr.Slider(200, 4000, value=1000, step=100, label="Max context chars")
run = gr.Button("Answer")
out = gr.JSON(label="Answer + matches")
ingest_btn.click(
ingest,
[docs, chunk_size, overlap],
[ingest_status, index_size],
api_name="ingest" # exposes POST /api/ingest
)
sample_btn.click(load_sample, None, docs)
reset_btn.click(
reset,
None,
[ingest_status, index_size],
api_name="reset" # exposes POST /api/reset (optional)
)
run.click(
answer,
[q, k, max_chars],
out,
api_name="answer" # exposes POST /api/answer
)
if __name__ == "__main__":
demo.launch(share=True)