Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
# --- minimal core (in-memory only) --- | |
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
_model = SentenceTransformer(MODEL_NAME) | |
_dim = int(_model.encode(["_probe_"], convert_to_numpy=True).shape[1]) # 384 | |
_index = faiss.IndexFlatIP(_dim) # cosine via L2-normalized IP | |
_ids, _texts, _metas = [], [], [] | |
def _normalize(v: np.ndarray) -> np.ndarray: | |
n = np.linalg.norm(v, axis=1, keepdims=True) + 1e-12 | |
return (v / n).astype("float32") | |
def _chunk(text: str, size: int, overlap: int): | |
t = " ".join((text or "").split()) | |
n = len(t); s = 0; out = [] | |
if overlap >= size: overlap = max(size - 1, 0) | |
while s < n: | |
e = min(s + size, n) | |
out.append((t[s:e], s, e)) | |
if e == n: break | |
s = max(e - overlap, 0) | |
return out | |
def reset(): | |
global _index, _ids, _texts, _metas | |
_index = faiss.IndexFlatIP(_dim) | |
_ids, _texts, _metas = [], [], [] | |
return gr.update(value="Index reset."), gr.update(value=0) | |
def load_sample(): | |
docs = [ | |
("a", "PySpark scales ETL across clusters.", {"tag":"spark"}), | |
("b", "FAISS powers fast vector similarity search used in retrieval.", {"tag":"faiss"}) | |
] | |
return "\n".join([d[1] for d in docs]) | |
def ingest(docs_text, size, overlap): | |
if not docs_text.strip(): | |
return "Provide at least one line of text.", len(_ids) | |
# one document per line | |
lines = [ln.strip() for ln in docs_text.splitlines() if ln.strip()] | |
rows = [] | |
for i, ln in enumerate(lines): | |
pid = f"doc-{len(_ids)}-{i}" | |
for ctext, s, e in _chunk(ln, size, overlap): | |
rows.append((f"{pid}::offset:{s}-{e}", ctext, {"parent_id": pid, "start": s, "end": e})) | |
if not rows: | |
return "No chunks produced.", len(_ids) | |
vecs = _normalize(_model.encode([r[1] for r in rows], convert_to_numpy=True)) | |
_index.add(vecs) | |
for rid, txt, meta in rows: | |
_ids.append(rid); _texts.append(txt); _metas.append(meta) | |
return f"Ingested docs={len(lines)} chunks={len(rows)}", len(_ids) | |
def answer(q, k, max_context_chars): | |
if _index.ntotal == 0: | |
return {"answer": "Index is empty. Ingest first.", "matches": []} | |
qv = _normalize(_model.encode([q], convert_to_numpy=True)) | |
D, I = _index.search(qv, int(k)) | |
matches = [] | |
for i, s in zip(I[0].tolist(), D[0].tolist()): | |
if i < 0: | |
continue | |
matches.append({ | |
"id": _ids[i], | |
"score": float(s), | |
"text": _texts[i], | |
"meta": _metas[i] | |
}) | |
if not matches: | |
out = "No relevant context." | |
else: | |
# 👇 only use the top match for the answer | |
top = matches[0]["text"] | |
out = f"Based on retrieved context:\n- {top}" | |
return {"answer": out, "matches": matches} | |
with gr.Blocks(title="RAG-as-a-Service") as demo: | |
gr.Markdown("### RAG-as-a-Service - Gradio\nIn-memory FAISS + MiniLM\n; one-line-per-doc ingest\n; quick answers.") | |
with gr.Row(): | |
with gr.Column(): | |
docs = gr.Textbox(label="Documents (one per line)", lines=6, placeholder="One document per line…") | |
with gr.Row(): | |
chunk_size = gr.Slider(64, 1024, value=256, step=16, label="Chunk size") | |
overlap = gr.Slider(0, 256, value=32, step=8, label="Overlap") | |
with gr.Row(): | |
ingest_btn = gr.Button("Ingest") | |
sample_btn = gr.Button("Load sample") | |
reset_btn = gr.Button("Reset") | |
ingest_status = gr.Textbox(label="Ingest status", interactive=False) | |
index_size = gr.Number(label="Index size", interactive=False, value=0) | |
with gr.Column(): | |
q = gr.Textbox(label="Query", placeholder="Ask something...") | |
k = gr.Slider(1, 10, value=5, step=1, label="Top-K") | |
max_chars = gr.Slider(200, 4000, value=1000, step=100, label="Max context chars") | |
run = gr.Button("Answer") | |
out = gr.JSON(label="Answer + matches") | |
ingest_btn.click( | |
ingest, | |
[docs, chunk_size, overlap], | |
[ingest_status, index_size], | |
api_name="ingest" # exposes POST /api/ingest | |
) | |
sample_btn.click(load_sample, None, docs) | |
reset_btn.click( | |
reset, | |
None, | |
[ingest_status, index_size], | |
api_name="reset" # exposes POST /api/reset (optional) | |
) | |
run.click( | |
answer, | |
[q, k, max_chars], | |
out, | |
api_name="answer" # exposes POST /api/answer | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |