Spaces:
Sleeping
Sleeping
from typing import Dict, List, Any, Optional, Tuple | |
from fastapi import FastAPI, HTTPException, Request, Depends | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel, Field | |
from pathlib import Path | |
import numpy as np, json, os, time, uuid, pandas as pd | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
# optional engines | |
try: | |
from pyspark.sql import SparkSession, functions as F | |
from pyspark.sql.types import StringType | |
SPARK_AVAILABLE = True | |
except Exception: | |
SPARK_AVAILABLE = False | |
try: | |
from sentence_transformers import CrossEncoder | |
RERANK_AVAILABLE = True | |
except Exception: | |
RERANK_AVAILABLE = False | |
APP_VERSION = "1.0.0" | |
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
DATA_DIR = Path("./data"); DATA_DIR.mkdir(parents=True, exist_ok=True) | |
INDEX_FP = DATA_DIR / "index.faiss" | |
META_FP = DATA_DIR / "meta.jsonl" | |
PARQ_FP = DATA_DIR / "meta.parquet" | |
CFG_FP = DATA_DIR / "store.json" | |
# --------- Schemas ---------- | |
class EchoRequest(BaseModel): | |
message: str | |
class HealthResponse(BaseModel): | |
status: str; version: str; index_size: int = 0; model: str = ""; spark: bool = False | |
persisted: bool = False; rerank: bool = False; index_type: str = "flat" | |
class EmbedRequest(BaseModel): | |
texts: List[str] = Field(..., min_items=1); preview_n: int = Field(default=6, ge=0, le=32); normalize: bool = True | |
class EmbedResponse(BaseModel): | |
dim: int; count: int; preview: List[List[float]] | |
class Doc(BaseModel): | |
id: Optional[str] = None; text: str; meta: Dict[str, Any] = Field(default_factory=dict) | |
class ChunkConfig(BaseModel): | |
size: int = Field(default=800, gt=0); overlap: int = Field(default=120, ge=0) | |
class IngestRequest(BaseModel): | |
docs: List[Doc]; chunk: ChunkConfig = Field(default_factory=ChunkConfig); normalize: bool = True; use_spark: Optional[bool] = None | |
class Match(BaseModel): | |
id: str; score: float; text: Optional[str] = None; meta: Dict[str, Any] = Field(default_factory=dict) | |
class QueryRequest(BaseModel): | |
q: str; k: int = Field(default=5, ge=1, le=50); return_text: bool = True | |
class QueryResponse(BaseModel): | |
matches: List[Match] | |
class ExplainMatch(Match): | |
start: int; end: int; token_overlap: float | |
class ExplainRequest(QueryRequest): pass | |
class ExplainResponse(BaseModel): | |
matches: List[ExplainMatch] | |
class AnswerRequest(BaseModel): | |
q: str; k: int = Field(default=5, ge=1, le=50); model: str = Field(default="mock") | |
max_context_chars: int = Field(default=1600, ge=200, le=20000) | |
return_contexts: bool = True; rerank: bool = False | |
rerank_model: str = Field(default="cross-encoder/ms-marco-MiniLM-L-6-v2") | |
class AnswerResponse(BaseModel): | |
answer: str; contexts: List[Match] = [] | |
class ReindexParams(BaseModel): | |
index_type: str = Field(default="flat", pattern="^(flat|ivf|hnsw)$") | |
nlist: int = Field(default=64, ge=1, le=65536); M: int = Field(default=32, ge=4, le=128) | |
# --------- Embeddings ---------- | |
class LazyEmbedder: | |
def __init__(self, model_name: str = EMBED_MODEL_NAME): | |
self.model_name = model_name; self._model: Optional[SentenceTransformer] = None; self._dim: Optional[int] = None | |
def _ensure(self): | |
if self._model is None: | |
self._model = SentenceTransformer(self.model_name) | |
self._dim = int(self._model.encode(["_probe_"], convert_to_numpy=True).shape[1]) # type: ignore | |
def dim(self) -> int: | |
self._ensure(); return int(self._dim) # type: ignore | |
def encode(self, texts: List[str], normalize: bool = True) -> np.ndarray: | |
self._ensure() | |
vecs = self._model.encode(texts, batch_size=32, show_progress_bar=False, convert_to_numpy=True) # type: ignore | |
if normalize: | |
norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12 | |
vecs = vecs / norms | |
return vecs.astype("float32") | |
_embedder = LazyEmbedder() | |
# --------- Reranker ---------- | |
class LazyReranker: | |
def __init__(self): self._model=None; self._name=None | |
def ensure(self, name: str): | |
if not RERANK_AVAILABLE: return | |
if self._model is None or self._name != name: | |
self._model = CrossEncoder(name); self._name = name | |
def score(self, q: str, texts: List[str]) -> List[float]: | |
if not RERANK_AVAILABLE or self._model is None: return [0.0]*len(texts) | |
return [float(s) for s in self._model.predict([(q,t) for t in texts])] # type: ignore | |
_reranker = LazyReranker() | |
# --------- Chunking ---------- | |
def chunk_text_py(text: str, size: int, overlap: int): | |
t = " ".join((text or "").split()); n=len(t); out=[]; s=0 | |
if overlap >= size: overlap = max(size - 1, 0) | |
while s<n: | |
e=min(s+size,n); out.append((t[s:e],(s,e))) | |
if e==n: break | |
s=max(e-overlap,0) | |
return out | |
def spark_clean_and_chunk(docs: List[Doc], size: int, overlap: int): | |
if not SPARK_AVAILABLE: raise RuntimeError("Spark not available") | |
spark = SparkSession.builder.appName("RAG-ETL").getOrCreate() | |
import json as _j | |
rows=[{"id":d.id or f"doc-{i}","text":d.text,"meta_json":_j.dumps(d.meta)} for i,d in enumerate(docs)] | |
df=spark.createDataFrame(rows).withColumn("text",F.regexp_replace(F.col("text"),r"\s+"," ")).withColumn("text",F.trim(F.col("text"))).filter(F.length("text")>0) | |
sz,ov=int(size),int(overlap); | |
if ov>=sz: ov=max(sz-1,0) | |
def chunk_udf(text: str, pid: str, meta_json: str) -> str: | |
t=" ".join((text or "").split()); n=len(t); s=0; base=_j.loads(meta_json) if meta_json else {}; out=[] | |
while s<n: | |
e=min(s+sz,n); cid=f"{pid}::offset:{s}-{e}"; m=dict(base); m.update({"parent_id":pid,"start":s,"end":e}) | |
out.append({"id":cid,"text":t[s:e],"meta":m}); | |
if e==n: break | |
s=max(e-ov,0) | |
return _j.dumps(out) | |
df=df.withColumn("chunks_json",chunk_udf(F.col("text"),F.col("id"),F.col("meta_json"))) | |
exploded=df.select(F.explode(F.from_json("chunks_json","array<map<string,string>>")).alias("c")) | |
out=exploded.select(F.col("c")["id"].alias("id"),F.col("c")["text"].alias("text"),F.col("c")["meta"].alias("meta_json")).collect() | |
import json as _j2 | |
return [{"id":r["id"],"text":r["text"],"meta":_j2.loads(r["meta_json"]) if r["meta_json"] else {}} for r in out] | |
# --------- Vector index ---------- | |
class VectorIndex: | |
def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32): | |
self.dim=dim; self.type=index_type; self.metric="ip"; self.nlist=nlist; self.M=M | |
if index_type=="flat": | |
self.index = faiss.IndexFlatIP(dim) | |
elif index_type=="ivf": | |
quant = faiss.IndexFlatIP(dim) | |
self.index = faiss.IndexIVFFlat(quant, dim, max(1,nlist), faiss.METRIC_INNER_PRODUCT) | |
elif index_type=="hnsw": | |
self.index = faiss.IndexHNSWFlat(dim, max(4,M)); self.metric="l2" | |
else: | |
raise ValueError("bad index_type") | |
def train(self, vecs: np.ndarray): | |
if hasattr(self.index,"is_trained") and not self.index.is_trained: | |
self.index.train(vecs) | |
def add(self, vecs: np.ndarray): | |
self.train(vecs); self.index.add(vecs) | |
def search(self, qvec: np.ndarray, k: int): | |
D,I = self.index.search(qvec,k) | |
scores = (1.0 - 0.5*D[0]).tolist() if self.metric=="l2" else D[0].tolist() | |
return I[0].tolist(), scores | |
def save(self, fp: Path): faiss.write_index(self.index, str(fp)) | |
def load(fp: Path) -> "VectorIndex": | |
idx = faiss.read_index(str(fp)) | |
vi = VectorIndex(idx.d, "flat"); vi.index = idx | |
vi.metric = "ip" if isinstance(idx, faiss.IndexFlatIP) or "IVF" in str(type(idx)) else "l2" | |
return vi | |
# --------- Store ---------- | |
class MemoryIndex: | |
def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32): | |
self.ids: List[str]=[]; self.texts: List[str]=[]; self.metas: List[Dict[str,Any]]=[] | |
self.vindex = VectorIndex(dim, index_type=index_type, nlist=nlist, M=M) | |
def add(self, vecs: np.ndarray, rows: List[Dict[str, Any]]): | |
if vecs.shape[0]!=len(rows): raise ValueError("Vector count != row count") | |
self.vindex.add(vecs) | |
for r in rows: self.ids.append(r["id"]); self.texts.append(r["text"]); self.metas.append(r["meta"]) | |
def size(self)->int: return self.vindex.index.ntotal | |
def search(self, qvec: np.ndarray, k: int): return self.vindex.search(qvec,k) | |
def save(self): | |
self.vindex.save(INDEX_FP) | |
with META_FP.open("w",encoding="utf-8") as f: | |
for i in range(len(self.ids)): | |
f.write(json.dumps({"id":self.ids[i],"text":self.texts[i],"meta":self.metas[i]})+"\n") | |
try: | |
df = pd.DataFrame({"id":self.ids,"text":self.texts,"meta_json":[json.dumps(m) for m in self.metas]}) | |
df.to_parquet(PARQ_FP, index=False) | |
except Exception: | |
pass | |
CFG_FP.write_text(json.dumps({"model":EMBED_MODEL_NAME,"dim":_embedder.dim,"index_type":self.vindex.type,"nlist":self.vindex.nlist,"M":self.vindex.M}),encoding="utf-8") | |
def load_if_exists() -> Optional["MemoryIndex"]: | |
if not INDEX_FP.exists() or not META_FP.exists(): return None | |
cfg={"index_type":"flat","nlist":64,"M":32} | |
if CFG_FP.exists(): | |
try: cfg.update(json.loads(CFG_FP.read_text())) | |
except Exception: pass | |
vi = VectorIndex.load(INDEX_FP) | |
store = MemoryIndex(dim=vi.dim, index_type=cfg.get("index_type","flat"), nlist=cfg.get("nlist",64), M=cfg.get("M",32)) | |
store.vindex = vi | |
ids,texts,metas=[],[],[] | |
with META_FP.open("r",encoding="utf-8") as f: | |
for line in f: | |
rec=json.loads(line); ids.append(rec["id"]); texts.append(rec["text"]); metas.append(rec.get("meta",{})) | |
store.ids,store.texts,store.metas=ids,texts,metas | |
return store | |
def reset_files(): | |
for p in [INDEX_FP, META_FP, PARQ_FP, CFG_FP]: | |
try: | |
if p.exists(): p.unlink() | |
except Exception: | |
pass | |
_mem_store: Optional[MemoryIndex] = MemoryIndex.load_if_exists() | |
def require_store() -> MemoryIndex: | |
if _mem_store is None or _mem_store.size()==0: | |
raise HTTPException(status_code=400, detail="Index empty. Ingest documents first.") | |
return _mem_store | |
# --------- Helpers ---------- | |
def _token_overlap(q: str, txt: str) -> float: | |
qt={t for t in q.lower().split() if t}; tt={t for t in (txt or "").lower().split() if t} | |
if not qt: return 0.0 | |
return float(len(qt & tt))/float(len(qt)) | |
def _topk(q: str, k: int) -> List[Match]: | |
store=require_store(); qvec=_embedder.encode([q], normalize=True) | |
idxs,scores=store.search(qvec,k); out=[] | |
for i,s in zip(idxs,scores): | |
if i==-1: continue | |
out.append(Match(id=store.ids[i], score=float(s), text=store.texts[i], meta=store.metas[i])) | |
return out | |
def _compose_contexts(matches: List[Match], max_chars: int) -> str: | |
buf=[]; total=0 | |
for m in matches: | |
t=m.text or ""; cut=min(len(t), max_chars-total) | |
if cut<=0: break | |
buf.append(t[:cut]); total+=cut | |
if total>=max_chars: break | |
return "\n\n".join(buf).strip() | |
def _answer_with_mock(q: str, contexts: str) -> str: | |
if not contexts: return "No indexed context available to answer the question." | |
lines=[ln.strip() for ln in contexts.split("\n") if ln.strip()] | |
hits=[ln for ln in lines if any(t in ln.lower() for t in q.lower().split())] | |
if not hits: hits=lines[:2] | |
return "Based on retrieved context, here’s a concise answer:\n- " + "\n- ".join(hits[:4]) | |
def _maybe_rerank(q: str, matches: List[Match], enabled: bool, model_name: str) -> List[Match]: | |
if not enabled: return matches | |
try: | |
_reranker.ensure(model_name) | |
scores=_reranker.score(q, [m.text or "" for m in matches]) | |
order=sorted(range(len(matches)), key=lambda i: scores[i], reverse=True) | |
return [matches[i] for i in order] | |
except Exception: | |
return matches | |
def _write_parquet_if_missing(): | |
if not PARQ_FP.exists() and META_FP.exists(): | |
try: | |
rows=[json.loads(line) for line in META_FP.open("r",encoding="utf-8")] | |
if rows: | |
pd.DataFrame({"id":[r["id"] for r in rows], | |
"text":[r["text"] for r in rows], | |
"meta_json":[json.dumps(r.get("meta",{})) for r in rows]}).to_parquet(PARQ_FP,index=False) | |
except Exception: | |
pass | |
# --------- Auth/limits/metrics ---------- | |
API_KEY = os.getenv("API_KEY","") | |
_rate = {"capacity":60,"refill_per_sec":1.0} | |
_buckets: Dict[str, Dict[str, float]] = {} | |
_metrics = {"requests":0,"by_endpoint":{}, "started": time.time()} | |
def _allow(ip: str) -> bool: | |
now=time.time(); b=_buckets.get(ip,{"tokens":_rate["capacity"],"ts":now}) | |
tokens=min(b["tokens"]+(now-b["ts"])*_rate["refill_per_sec"], _rate["capacity"]) | |
if tokens<1.0: | |
_buckets[ip]={"tokens":tokens,"ts":now}; return False | |
_buckets[ip]={"tokens":tokens-1.0,"ts":now}; return True | |
async def guard(request: Request): | |
if API_KEY and request.headers.get("x-api-key","")!=API_KEY: | |
raise HTTPException(status_code=401, detail="invalid api key") | |
ip=request.client.host if request.client else "local" | |
if not _allow(ip): | |
raise HTTPException(status_code=429, detail="rate limited") | |
app = FastAPI(title="RAG-as-a-Service", version=APP_VERSION, description="Steps 10–13") | |
async def req_meta(request: Request, call_next): | |
rid=str(uuid.uuid4()); _metrics["requests"]+=1 | |
ep=f"{request.method} {request.url.path}"; _metrics["by_endpoint"][ep]=_metrics["by_endpoint"].get(ep,0)+1 | |
resp=await call_next(request) | |
try: resp.headers["x-request-id"]=rid | |
except Exception: pass | |
return resp | |
# --------- API ---------- | |
def root(): | |
return """<!doctype html><html><head><meta charset="utf-8"><title>RAG-as-a-Service</title></head> | |
<body style="font-family:system-ui;margin:2rem;max-width:900px"> | |
<h2>RAG-as-a-Service</h2> | |
<input id="q" style="width:70%" placeholder="Ask a question"><button onclick="ask()">Ask</button> | |
<pre id="out" style="background:#111;color:#eee;padding:1rem;border-radius:8px;white-space:pre-wrap"></pre> | |
<script> | |
async function ask(){ | |
const q=document.getElementById('q').value; | |
const res=await fetch('/answer',{method:'POST',headers:{'content-type':'application/json'},body:JSON.stringify({q, k:5, return_contexts:true})}); | |
document.getElementById('out').textContent=JSON.stringify(await res.json(),null,2); | |
} | |
</script></body></html>""" | |
def health() -> HealthResponse: | |
size=_mem_store.size() if _mem_store is not None else 0 | |
persisted=INDEX_FP.exists() and META_FP.exists() | |
idx_type="flat" | |
if CFG_FP.exists(): | |
try: idx_type=json.loads(CFG_FP.read_text()).get("index_type","flat") | |
except Exception: pass | |
return HealthResponse(status="ok", version=APP_VERSION, index_size=size, model=EMBED_MODEL_NAME, spark=SPARK_AVAILABLE, persisted=persisted, rerank=RERANK_AVAILABLE, index_type=idx_type) | |
def metrics(): | |
up=time.time()-_metrics["started"] | |
return {"requests":_metrics["requests"],"by_endpoint":_metrics["by_endpoint"],"uptime_sec":round(up,2)} | |
def echo(payload: EchoRequest) -> Dict[str, str]: | |
return {"echo": payload.message, "length": str(len(payload.message))} | |
def embed(payload: EmbedRequest) -> EmbedResponse: | |
vecs=_embedder.encode(payload.texts, normalize=payload.normalize) | |
preview=[[float(round(v,5)) for v in row[:payload.preview_n]] for row in vecs] if payload.preview_n>0 else [] | |
return EmbedResponse(dim=int(vecs.shape[1]), count=int(vecs.shape[0]), preview=preview) | |
def ingest(req: IngestRequest) -> Dict[str, Any]: | |
global _mem_store | |
if _mem_store is None: | |
cfg={"index_type":"flat","nlist":64,"M":32} | |
if CFG_FP.exists(): | |
try: cfg.update(json.loads(CFG_FP.read_text())) | |
except Exception: pass | |
_mem_store=MemoryIndex(dim=_embedder.dim, index_type=cfg["index_type"], nlist=cfg["nlist"], M=cfg["M"]) | |
use_spark=SPARK_AVAILABLE if req.use_spark is None else bool(req.use_spark) | |
rows=[] | |
if use_spark: | |
try: rows=spark_clean_and_chunk(req.docs, size=req.chunk.size, overlap=req.chunk.overlap) | |
except Exception: rows=[] | |
if not rows: | |
for d in req.docs: | |
pid=d.id or "doc" | |
for ctext,(s,e) in chunk_text_py(d.text, size=req.chunk.size, overlap=req.chunk.overlap): | |
meta=dict(d.meta); meta.update({"parent_id":pid,"start":s,"end":e}) | |
rows.append({"id":f"{pid}::offset:{s}-{e}","text":ctext,"meta":meta}) | |
if not rows: raise HTTPException(status_code=400, detail="No non-empty chunks produced") | |
vecs=_embedder.encode([r["text"] for r in rows], normalize=req.normalize) | |
_mem_store.add(vecs, rows); _mem_store.save(); | |
if not PARQ_FP.exists(): | |
try: | |
pd.DataFrame({"id":[r["id"] for r in rows],"text":[r["text"] for r in rows],"meta_json":[json.dumps(r["meta"]) for r in rows]}).to_parquet(PARQ_FP,index=False) | |
except Exception: pass | |
return {"docs": len(req.docs), "chunks": len(rows), "index_size": _mem_store.size(), "engine": "spark" if use_spark else "python", "persisted": True} | |
def query(req: QueryRequest) -> QueryResponse: | |
matches=_topk(req.q, req.k) | |
if not req.return_text: matches=[Match(id=m.id, score=m.score, text=None, meta=m.meta) for m in matches] | |
return QueryResponse(matches=matches) | |
def explain(req: ExplainRequest) -> ExplainResponse: | |
matches=_topk(req.q, req.k); out=[] | |
for m in matches: | |
meta=m.meta; start=int(meta.get("start",0)); end=int(meta.get("end",0)) | |
out.append(ExplainMatch(id=m.id, score=m.score, text=m.text if req.return_text else None, meta=meta, start=start, end=end, token_overlap=float(round(_token_overlap(req.q, m.text or ""),4)))) | |
return ExplainResponse(matches=out) | |
def answer(req: AnswerRequest) -> AnswerResponse: | |
matches=_topk(req.q, req.k) | |
matches=_maybe_rerank(req.q, matches, enabled=req.rerank, model_name=req.rerank_model) | |
ctx=_compose_contexts(matches, req.max_context_chars) | |
out=_answer_with_mock(req.q, ctx) if req.model=="mock" else _answer_with_mock(req.q, ctx) | |
return AnswerResponse(answer=out, contexts=matches if req.return_contexts else []) | |
def reindex(params: ReindexParams) -> Dict[str, Any]: | |
global _mem_store | |
if not META_FP.exists(): | |
raise HTTPException(status_code=400, detail="no metadata on disk") | |
rows = [json.loads(line) for line in META_FP.open("r", encoding="utf-8")] | |
if not rows: | |
raise HTTPException(status_code=400, detail="empty metadata") | |
texts = [r["text"] for r in rows] | |
vecs = _embedder.encode(texts, normalize=True) | |
# Cap nlist to dataset size for IVF | |
idx_type = params.index_type | |
eff_nlist = params.nlist | |
if idx_type == "ivf": | |
eff_nlist = max(1, min(eff_nlist, len(rows))) | |
try: | |
_mem_store = MemoryIndex(dim=_embedder.dim, index_type=idx_type, nlist=eff_nlist, M=params.M) | |
_mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows]) | |
_mem_store.save() | |
return { | |
"reindexed": True, | |
"index_type": idx_type, | |
"index_size": _mem_store.size(), | |
"nlist": eff_nlist, | |
"M": params.M | |
} | |
except Exception as e: | |
# Fallback to flat if IVF/HNSW training/add fails for any reason | |
_mem_store = MemoryIndex(dim=_embedder.dim, index_type="flat") | |
_mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows]) | |
_mem_store.save() | |
return { | |
"reindexed": True, | |
"index_type": "flat", | |
"index_size": _mem_store.size(), | |
"note": f"fallback due to: {str(e)[:120]}" | |
} | |
def reset() -> Dict[str, Any]: | |
global _mem_store; _mem_store=None; MemoryIndex.reset_files(); return {"reset": True} | |
def bulk_load_hf(repo: str, split: str = "train", text_field: str = "text", id_field: Optional[str]=None, meta_fields: Optional[List[str]]=None, chunk_size:int=800, overlap:int=120): | |
try: | |
from datasets import load_dataset | |
ds = load_dataset(repo, split=split) | |
docs=[] | |
for rec in ds: | |
rid = str(rec[id_field]) if id_field and id_field in rec else None | |
meta = {k: rec[k] for k in (meta_fields or []) if k in rec} | |
docs.append(Doc(id=rid, text=str(rec[text_field]), meta=meta)) | |
return ingest(IngestRequest(docs=docs, chunk=ChunkConfig(size=chunk_size, overlap=overlap), normalize=True)) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"bulk_load_hf failed: {e}") | |