rag-as-a-service / main.py
GenAIDevTOProd's picture
Upload 2 files
35c5459 verified
from typing import Dict, List, Any, Optional, Tuple
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
from pathlib import Path
import numpy as np, json, os, time, uuid, pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
# optional engines
try:
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import StringType
SPARK_AVAILABLE = True
except Exception:
SPARK_AVAILABLE = False
try:
from sentence_transformers import CrossEncoder
RERANK_AVAILABLE = True
except Exception:
RERANK_AVAILABLE = False
APP_VERSION = "1.0.0"
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
DATA_DIR = Path("./data"); DATA_DIR.mkdir(parents=True, exist_ok=True)
INDEX_FP = DATA_DIR / "index.faiss"
META_FP = DATA_DIR / "meta.jsonl"
PARQ_FP = DATA_DIR / "meta.parquet"
CFG_FP = DATA_DIR / "store.json"
# --------- Schemas ----------
class EchoRequest(BaseModel):
message: str
class HealthResponse(BaseModel):
status: str; version: str; index_size: int = 0; model: str = ""; spark: bool = False
persisted: bool = False; rerank: bool = False; index_type: str = "flat"
class EmbedRequest(BaseModel):
texts: List[str] = Field(..., min_items=1); preview_n: int = Field(default=6, ge=0, le=32); normalize: bool = True
class EmbedResponse(BaseModel):
dim: int; count: int; preview: List[List[float]]
class Doc(BaseModel):
id: Optional[str] = None; text: str; meta: Dict[str, Any] = Field(default_factory=dict)
class ChunkConfig(BaseModel):
size: int = Field(default=800, gt=0); overlap: int = Field(default=120, ge=0)
class IngestRequest(BaseModel):
docs: List[Doc]; chunk: ChunkConfig = Field(default_factory=ChunkConfig); normalize: bool = True; use_spark: Optional[bool] = None
class Match(BaseModel):
id: str; score: float; text: Optional[str] = None; meta: Dict[str, Any] = Field(default_factory=dict)
class QueryRequest(BaseModel):
q: str; k: int = Field(default=5, ge=1, le=50); return_text: bool = True
class QueryResponse(BaseModel):
matches: List[Match]
class ExplainMatch(Match):
start: int; end: int; token_overlap: float
class ExplainRequest(QueryRequest): pass
class ExplainResponse(BaseModel):
matches: List[ExplainMatch]
class AnswerRequest(BaseModel):
q: str; k: int = Field(default=5, ge=1, le=50); model: str = Field(default="mock")
max_context_chars: int = Field(default=1600, ge=200, le=20000)
return_contexts: bool = True; rerank: bool = False
rerank_model: str = Field(default="cross-encoder/ms-marco-MiniLM-L-6-v2")
class AnswerResponse(BaseModel):
answer: str; contexts: List[Match] = []
class ReindexParams(BaseModel):
index_type: str = Field(default="flat", pattern="^(flat|ivf|hnsw)$")
nlist: int = Field(default=64, ge=1, le=65536); M: int = Field(default=32, ge=4, le=128)
# --------- Embeddings ----------
class LazyEmbedder:
def __init__(self, model_name: str = EMBED_MODEL_NAME):
self.model_name = model_name; self._model: Optional[SentenceTransformer] = None; self._dim: Optional[int] = None
def _ensure(self):
if self._model is None:
self._model = SentenceTransformer(self.model_name)
self._dim = int(self._model.encode(["_probe_"], convert_to_numpy=True).shape[1]) # type: ignore
@property
def dim(self) -> int:
self._ensure(); return int(self._dim) # type: ignore
def encode(self, texts: List[str], normalize: bool = True) -> np.ndarray:
self._ensure()
vecs = self._model.encode(texts, batch_size=32, show_progress_bar=False, convert_to_numpy=True) # type: ignore
if normalize:
norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12
vecs = vecs / norms
return vecs.astype("float32")
_embedder = LazyEmbedder()
# --------- Reranker ----------
class LazyReranker:
def __init__(self): self._model=None; self._name=None
def ensure(self, name: str):
if not RERANK_AVAILABLE: return
if self._model is None or self._name != name:
self._model = CrossEncoder(name); self._name = name
def score(self, q: str, texts: List[str]) -> List[float]:
if not RERANK_AVAILABLE or self._model is None: return [0.0]*len(texts)
return [float(s) for s in self._model.predict([(q,t) for t in texts])] # type: ignore
_reranker = LazyReranker()
# --------- Chunking ----------
def chunk_text_py(text: str, size: int, overlap: int):
t = " ".join((text or "").split()); n=len(t); out=[]; s=0
if overlap >= size: overlap = max(size - 1, 0)
while s<n:
e=min(s+size,n); out.append((t[s:e],(s,e)))
if e==n: break
s=max(e-overlap,0)
return out
def spark_clean_and_chunk(docs: List[Doc], size: int, overlap: int):
if not SPARK_AVAILABLE: raise RuntimeError("Spark not available")
spark = SparkSession.builder.appName("RAG-ETL").getOrCreate()
import json as _j
rows=[{"id":d.id or f"doc-{i}","text":d.text,"meta_json":_j.dumps(d.meta)} for i,d in enumerate(docs)]
df=spark.createDataFrame(rows).withColumn("text",F.regexp_replace(F.col("text"),r"\s+"," ")).withColumn("text",F.trim(F.col("text"))).filter(F.length("text")>0)
sz,ov=int(size),int(overlap);
if ov>=sz: ov=max(sz-1,0)
@F.udf(returnType=StringType())
def chunk_udf(text: str, pid: str, meta_json: str) -> str:
t=" ".join((text or "").split()); n=len(t); s=0; base=_j.loads(meta_json) if meta_json else {}; out=[]
while s<n:
e=min(s+sz,n); cid=f"{pid}::offset:{s}-{e}"; m=dict(base); m.update({"parent_id":pid,"start":s,"end":e})
out.append({"id":cid,"text":t[s:e],"meta":m});
if e==n: break
s=max(e-ov,0)
return _j.dumps(out)
df=df.withColumn("chunks_json",chunk_udf(F.col("text"),F.col("id"),F.col("meta_json")))
exploded=df.select(F.explode(F.from_json("chunks_json","array<map<string,string>>")).alias("c"))
out=exploded.select(F.col("c")["id"].alias("id"),F.col("c")["text"].alias("text"),F.col("c")["meta"].alias("meta_json")).collect()
import json as _j2
return [{"id":r["id"],"text":r["text"],"meta":_j2.loads(r["meta_json"]) if r["meta_json"] else {}} for r in out]
# --------- Vector index ----------
class VectorIndex:
def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32):
self.dim=dim; self.type=index_type; self.metric="ip"; self.nlist=nlist; self.M=M
if index_type=="flat":
self.index = faiss.IndexFlatIP(dim)
elif index_type=="ivf":
quant = faiss.IndexFlatIP(dim)
self.index = faiss.IndexIVFFlat(quant, dim, max(1,nlist), faiss.METRIC_INNER_PRODUCT)
elif index_type=="hnsw":
self.index = faiss.IndexHNSWFlat(dim, max(4,M)); self.metric="l2"
else:
raise ValueError("bad index_type")
def train(self, vecs: np.ndarray):
if hasattr(self.index,"is_trained") and not self.index.is_trained:
self.index.train(vecs)
def add(self, vecs: np.ndarray):
self.train(vecs); self.index.add(vecs)
def search(self, qvec: np.ndarray, k: int):
D,I = self.index.search(qvec,k)
scores = (1.0 - 0.5*D[0]).tolist() if self.metric=="l2" else D[0].tolist()
return I[0].tolist(), scores
def save(self, fp: Path): faiss.write_index(self.index, str(fp))
@staticmethod
def load(fp: Path) -> "VectorIndex":
idx = faiss.read_index(str(fp))
vi = VectorIndex(idx.d, "flat"); vi.index = idx
vi.metric = "ip" if isinstance(idx, faiss.IndexFlatIP) or "IVF" in str(type(idx)) else "l2"
return vi
# --------- Store ----------
class MemoryIndex:
def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32):
self.ids: List[str]=[]; self.texts: List[str]=[]; self.metas: List[Dict[str,Any]]=[]
self.vindex = VectorIndex(dim, index_type=index_type, nlist=nlist, M=M)
def add(self, vecs: np.ndarray, rows: List[Dict[str, Any]]):
if vecs.shape[0]!=len(rows): raise ValueError("Vector count != row count")
self.vindex.add(vecs)
for r in rows: self.ids.append(r["id"]); self.texts.append(r["text"]); self.metas.append(r["meta"])
def size(self)->int: return self.vindex.index.ntotal
def search(self, qvec: np.ndarray, k: int): return self.vindex.search(qvec,k)
def save(self):
self.vindex.save(INDEX_FP)
with META_FP.open("w",encoding="utf-8") as f:
for i in range(len(self.ids)):
f.write(json.dumps({"id":self.ids[i],"text":self.texts[i],"meta":self.metas[i]})+"\n")
try:
df = pd.DataFrame({"id":self.ids,"text":self.texts,"meta_json":[json.dumps(m) for m in self.metas]})
df.to_parquet(PARQ_FP, index=False)
except Exception:
pass
CFG_FP.write_text(json.dumps({"model":EMBED_MODEL_NAME,"dim":_embedder.dim,"index_type":self.vindex.type,"nlist":self.vindex.nlist,"M":self.vindex.M}),encoding="utf-8")
@staticmethod
def load_if_exists() -> Optional["MemoryIndex"]:
if not INDEX_FP.exists() or not META_FP.exists(): return None
cfg={"index_type":"flat","nlist":64,"M":32}
if CFG_FP.exists():
try: cfg.update(json.loads(CFG_FP.read_text()))
except Exception: pass
vi = VectorIndex.load(INDEX_FP)
store = MemoryIndex(dim=vi.dim, index_type=cfg.get("index_type","flat"), nlist=cfg.get("nlist",64), M=cfg.get("M",32))
store.vindex = vi
ids,texts,metas=[],[],[]
with META_FP.open("r",encoding="utf-8") as f:
for line in f:
rec=json.loads(line); ids.append(rec["id"]); texts.append(rec["text"]); metas.append(rec.get("meta",{}))
store.ids,store.texts,store.metas=ids,texts,metas
return store
@staticmethod
def reset_files():
for p in [INDEX_FP, META_FP, PARQ_FP, CFG_FP]:
try:
if p.exists(): p.unlink()
except Exception:
pass
_mem_store: Optional[MemoryIndex] = MemoryIndex.load_if_exists()
def require_store() -> MemoryIndex:
if _mem_store is None or _mem_store.size()==0:
raise HTTPException(status_code=400, detail="Index empty. Ingest documents first.")
return _mem_store
# --------- Helpers ----------
def _token_overlap(q: str, txt: str) -> float:
qt={t for t in q.lower().split() if t}; tt={t for t in (txt or "").lower().split() if t}
if not qt: return 0.0
return float(len(qt & tt))/float(len(qt))
def _topk(q: str, k: int) -> List[Match]:
store=require_store(); qvec=_embedder.encode([q], normalize=True)
idxs,scores=store.search(qvec,k); out=[]
for i,s in zip(idxs,scores):
if i==-1: continue
out.append(Match(id=store.ids[i], score=float(s), text=store.texts[i], meta=store.metas[i]))
return out
def _compose_contexts(matches: List[Match], max_chars: int) -> str:
buf=[]; total=0
for m in matches:
t=m.text or ""; cut=min(len(t), max_chars-total)
if cut<=0: break
buf.append(t[:cut]); total+=cut
if total>=max_chars: break
return "\n\n".join(buf).strip()
def _answer_with_mock(q: str, contexts: str) -> str:
if not contexts: return "No indexed context available to answer the question."
lines=[ln.strip() for ln in contexts.split("\n") if ln.strip()]
hits=[ln for ln in lines if any(t in ln.lower() for t in q.lower().split())]
if not hits: hits=lines[:2]
return "Based on retrieved context, here’s a concise answer:\n- " + "\n- ".join(hits[:4])
def _maybe_rerank(q: str, matches: List[Match], enabled: bool, model_name: str) -> List[Match]:
if not enabled: return matches
try:
_reranker.ensure(model_name)
scores=_reranker.score(q, [m.text or "" for m in matches])
order=sorted(range(len(matches)), key=lambda i: scores[i], reverse=True)
return [matches[i] for i in order]
except Exception:
return matches
def _write_parquet_if_missing():
if not PARQ_FP.exists() and META_FP.exists():
try:
rows=[json.loads(line) for line in META_FP.open("r",encoding="utf-8")]
if rows:
pd.DataFrame({"id":[r["id"] for r in rows],
"text":[r["text"] for r in rows],
"meta_json":[json.dumps(r.get("meta",{})) for r in rows]}).to_parquet(PARQ_FP,index=False)
except Exception:
pass
# --------- Auth/limits/metrics ----------
API_KEY = os.getenv("API_KEY","")
_rate = {"capacity":60,"refill_per_sec":1.0}
_buckets: Dict[str, Dict[str, float]] = {}
_metrics = {"requests":0,"by_endpoint":{}, "started": time.time()}
def _allow(ip: str) -> bool:
now=time.time(); b=_buckets.get(ip,{"tokens":_rate["capacity"],"ts":now})
tokens=min(b["tokens"]+(now-b["ts"])*_rate["refill_per_sec"], _rate["capacity"])
if tokens<1.0:
_buckets[ip]={"tokens":tokens,"ts":now}; return False
_buckets[ip]={"tokens":tokens-1.0,"ts":now}; return True
async def guard(request: Request):
if API_KEY and request.headers.get("x-api-key","")!=API_KEY:
raise HTTPException(status_code=401, detail="invalid api key")
ip=request.client.host if request.client else "local"
if not _allow(ip):
raise HTTPException(status_code=429, detail="rate limited")
app = FastAPI(title="RAG-as-a-Service", version=APP_VERSION, description="Steps 10–13")
@app.middleware("http")
async def req_meta(request: Request, call_next):
rid=str(uuid.uuid4()); _metrics["requests"]+=1
ep=f"{request.method} {request.url.path}"; _metrics["by_endpoint"][ep]=_metrics["by_endpoint"].get(ep,0)+1
resp=await call_next(request)
try: resp.headers["x-request-id"]=rid
except Exception: pass
return resp
# --------- API ----------
@app.get("/", response_class=HTMLResponse)
def root():
return """<!doctype html><html><head><meta charset="utf-8"><title>RAG-as-a-Service</title></head>
<body style="font-family:system-ui;margin:2rem;max-width:900px">
<h2>RAG-as-a-Service</h2>
<input id="q" style="width:70%" placeholder="Ask a question"><button onclick="ask()">Ask</button>
<pre id="out" style="background:#111;color:#eee;padding:1rem;border-radius:8px;white-space:pre-wrap"></pre>
<script>
async function ask(){
const q=document.getElementById('q').value;
const res=await fetch('/answer',{method:'POST',headers:{'content-type':'application/json'},body:JSON.stringify({q, k:5, return_contexts:true})});
document.getElementById('out').textContent=JSON.stringify(await res.json(),null,2);
}
</script></body></html>"""
@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
size=_mem_store.size() if _mem_store is not None else 0
persisted=INDEX_FP.exists() and META_FP.exists()
idx_type="flat"
if CFG_FP.exists():
try: idx_type=json.loads(CFG_FP.read_text()).get("index_type","flat")
except Exception: pass
return HealthResponse(status="ok", version=APP_VERSION, index_size=size, model=EMBED_MODEL_NAME, spark=SPARK_AVAILABLE, persisted=persisted, rerank=RERANK_AVAILABLE, index_type=idx_type)
@app.get("/metrics")
def metrics():
up=time.time()-_metrics["started"]
return {"requests":_metrics["requests"],"by_endpoint":_metrics["by_endpoint"],"uptime_sec":round(up,2)}
@app.post("/echo", dependencies=[Depends(guard)])
def echo(payload: EchoRequest) -> Dict[str, str]:
return {"echo": payload.message, "length": str(len(payload.message))}
@app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(guard)])
def embed(payload: EmbedRequest) -> EmbedResponse:
vecs=_embedder.encode(payload.texts, normalize=payload.normalize)
preview=[[float(round(v,5)) for v in row[:payload.preview_n]] for row in vecs] if payload.preview_n>0 else []
return EmbedResponse(dim=int(vecs.shape[1]), count=int(vecs.shape[0]), preview=preview)
@app.post("/ingest", dependencies=[Depends(guard)])
def ingest(req: IngestRequest) -> Dict[str, Any]:
global _mem_store
if _mem_store is None:
cfg={"index_type":"flat","nlist":64,"M":32}
if CFG_FP.exists():
try: cfg.update(json.loads(CFG_FP.read_text()))
except Exception: pass
_mem_store=MemoryIndex(dim=_embedder.dim, index_type=cfg["index_type"], nlist=cfg["nlist"], M=cfg["M"])
use_spark=SPARK_AVAILABLE if req.use_spark is None else bool(req.use_spark)
rows=[]
if use_spark:
try: rows=spark_clean_and_chunk(req.docs, size=req.chunk.size, overlap=req.chunk.overlap)
except Exception: rows=[]
if not rows:
for d in req.docs:
pid=d.id or "doc"
for ctext,(s,e) in chunk_text_py(d.text, size=req.chunk.size, overlap=req.chunk.overlap):
meta=dict(d.meta); meta.update({"parent_id":pid,"start":s,"end":e})
rows.append({"id":f"{pid}::offset:{s}-{e}","text":ctext,"meta":meta})
if not rows: raise HTTPException(status_code=400, detail="No non-empty chunks produced")
vecs=_embedder.encode([r["text"] for r in rows], normalize=req.normalize)
_mem_store.add(vecs, rows); _mem_store.save();
if not PARQ_FP.exists():
try:
pd.DataFrame({"id":[r["id"] for r in rows],"text":[r["text"] for r in rows],"meta_json":[json.dumps(r["meta"]) for r in rows]}).to_parquet(PARQ_FP,index=False)
except Exception: pass
return {"docs": len(req.docs), "chunks": len(rows), "index_size": _mem_store.size(), "engine": "spark" if use_spark else "python", "persisted": True}
@app.post("/query", response_model=QueryResponse, dependencies=[Depends(guard)])
def query(req: QueryRequest) -> QueryResponse:
matches=_topk(req.q, req.k)
if not req.return_text: matches=[Match(id=m.id, score=m.score, text=None, meta=m.meta) for m in matches]
return QueryResponse(matches=matches)
@app.post("/explain", response_model=ExplainResponse, dependencies=[Depends(guard)])
def explain(req: ExplainRequest) -> ExplainResponse:
matches=_topk(req.q, req.k); out=[]
for m in matches:
meta=m.meta; start=int(meta.get("start",0)); end=int(meta.get("end",0))
out.append(ExplainMatch(id=m.id, score=m.score, text=m.text if req.return_text else None, meta=meta, start=start, end=end, token_overlap=float(round(_token_overlap(req.q, m.text or ""),4))))
return ExplainResponse(matches=out)
@app.post("/answer", response_model=AnswerResponse, dependencies=[Depends(guard)])
def answer(req: AnswerRequest) -> AnswerResponse:
matches=_topk(req.q, req.k)
matches=_maybe_rerank(req.q, matches, enabled=req.rerank, model_name=req.rerank_model)
ctx=_compose_contexts(matches, req.max_context_chars)
out=_answer_with_mock(req.q, ctx) if req.model=="mock" else _answer_with_mock(req.q, ctx)
return AnswerResponse(answer=out, contexts=matches if req.return_contexts else [])
@app.post("/reindex", dependencies=[Depends(guard)])
def reindex(params: ReindexParams) -> Dict[str, Any]:
global _mem_store
if not META_FP.exists():
raise HTTPException(status_code=400, detail="no metadata on disk")
rows = [json.loads(line) for line in META_FP.open("r", encoding="utf-8")]
if not rows:
raise HTTPException(status_code=400, detail="empty metadata")
texts = [r["text"] for r in rows]
vecs = _embedder.encode(texts, normalize=True)
# Cap nlist to dataset size for IVF
idx_type = params.index_type
eff_nlist = params.nlist
if idx_type == "ivf":
eff_nlist = max(1, min(eff_nlist, len(rows)))
try:
_mem_store = MemoryIndex(dim=_embedder.dim, index_type=idx_type, nlist=eff_nlist, M=params.M)
_mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows])
_mem_store.save()
return {
"reindexed": True,
"index_type": idx_type,
"index_size": _mem_store.size(),
"nlist": eff_nlist,
"M": params.M
}
except Exception as e:
# Fallback to flat if IVF/HNSW training/add fails for any reason
_mem_store = MemoryIndex(dim=_embedder.dim, index_type="flat")
_mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows])
_mem_store.save()
return {
"reindexed": True,
"index_type": "flat",
"index_size": _mem_store.size(),
"note": f"fallback due to: {str(e)[:120]}"
}
@app.post("/reset", dependencies=[Depends(guard)])
def reset() -> Dict[str, Any]:
global _mem_store; _mem_store=None; MemoryIndex.reset_files(); return {"reset": True}
@app.post("/bulk_load_hf", dependencies=[Depends(guard)])
def bulk_load_hf(repo: str, split: str = "train", text_field: str = "text", id_field: Optional[str]=None, meta_fields: Optional[List[str]]=None, chunk_size:int=800, overlap:int=120):
try:
from datasets import load_dataset
ds = load_dataset(repo, split=split)
docs=[]
for rec in ds:
rid = str(rec[id_field]) if id_field and id_field in rec else None
meta = {k: rec[k] for k in (meta_fields or []) if k in rec}
docs.append(Doc(id=rid, text=str(rec[text_field]), meta=meta))
return ingest(IngestRequest(docs=docs, chunk=ChunkConfig(size=chunk_size, overlap=overlap), normalize=True))
except Exception as e:
raise HTTPException(status_code=400, detail=f"bulk_load_hf failed: {e}")