from typing import Dict, List, Any, Optional, Tuple from fastapi import FastAPI, HTTPException, Request, Depends from fastapi.responses import HTMLResponse from pydantic import BaseModel, Field from pathlib import Path import numpy as np, json, os, time, uuid, pandas as pd from sentence_transformers import SentenceTransformer import faiss # optional engines try: from pyspark.sql import SparkSession, functions as F from pyspark.sql.types import StringType SPARK_AVAILABLE = True except Exception: SPARK_AVAILABLE = False try: from sentence_transformers import CrossEncoder RERANK_AVAILABLE = True except Exception: RERANK_AVAILABLE = False APP_VERSION = "1.0.0" EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" DATA_DIR = Path("./data"); DATA_DIR.mkdir(parents=True, exist_ok=True) INDEX_FP = DATA_DIR / "index.faiss" META_FP = DATA_DIR / "meta.jsonl" PARQ_FP = DATA_DIR / "meta.parquet" CFG_FP = DATA_DIR / "store.json" # --------- Schemas ---------- class EchoRequest(BaseModel): message: str class HealthResponse(BaseModel): status: str; version: str; index_size: int = 0; model: str = ""; spark: bool = False persisted: bool = False; rerank: bool = False; index_type: str = "flat" class EmbedRequest(BaseModel): texts: List[str] = Field(..., min_items=1); preview_n: int = Field(default=6, ge=0, le=32); normalize: bool = True class EmbedResponse(BaseModel): dim: int; count: int; preview: List[List[float]] class Doc(BaseModel): id: Optional[str] = None; text: str; meta: Dict[str, Any] = Field(default_factory=dict) class ChunkConfig(BaseModel): size: int = Field(default=800, gt=0); overlap: int = Field(default=120, ge=0) class IngestRequest(BaseModel): docs: List[Doc]; chunk: ChunkConfig = Field(default_factory=ChunkConfig); normalize: bool = True; use_spark: Optional[bool] = None class Match(BaseModel): id: str; score: float; text: Optional[str] = None; meta: Dict[str, Any] = Field(default_factory=dict) class QueryRequest(BaseModel): q: str; k: int = Field(default=5, ge=1, le=50); return_text: bool = True class QueryResponse(BaseModel): matches: List[Match] class ExplainMatch(Match): start: int; end: int; token_overlap: float class ExplainRequest(QueryRequest): pass class ExplainResponse(BaseModel): matches: List[ExplainMatch] class AnswerRequest(BaseModel): q: str; k: int = Field(default=5, ge=1, le=50); model: str = Field(default="mock") max_context_chars: int = Field(default=1600, ge=200, le=20000) return_contexts: bool = True; rerank: bool = False rerank_model: str = Field(default="cross-encoder/ms-marco-MiniLM-L-6-v2") class AnswerResponse(BaseModel): answer: str; contexts: List[Match] = [] class ReindexParams(BaseModel): index_type: str = Field(default="flat", pattern="^(flat|ivf|hnsw)$") nlist: int = Field(default=64, ge=1, le=65536); M: int = Field(default=32, ge=4, le=128) # --------- Embeddings ---------- class LazyEmbedder: def __init__(self, model_name: str = EMBED_MODEL_NAME): self.model_name = model_name; self._model: Optional[SentenceTransformer] = None; self._dim: Optional[int] = None def _ensure(self): if self._model is None: self._model = SentenceTransformer(self.model_name) self._dim = int(self._model.encode(["_probe_"], convert_to_numpy=True).shape[1]) # type: ignore @property def dim(self) -> int: self._ensure(); return int(self._dim) # type: ignore def encode(self, texts: List[str], normalize: bool = True) -> np.ndarray: self._ensure() vecs = self._model.encode(texts, batch_size=32, show_progress_bar=False, convert_to_numpy=True) # type: ignore if normalize: norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12 vecs = vecs / norms return vecs.astype("float32") _embedder = LazyEmbedder() # --------- Reranker ---------- class LazyReranker: def __init__(self): self._model=None; self._name=None def ensure(self, name: str): if not RERANK_AVAILABLE: return if self._model is None or self._name != name: self._model = CrossEncoder(name); self._name = name def score(self, q: str, texts: List[str]) -> List[float]: if not RERANK_AVAILABLE or self._model is None: return [0.0]*len(texts) return [float(s) for s in self._model.predict([(q,t) for t in texts])] # type: ignore _reranker = LazyReranker() # --------- Chunking ---------- def chunk_text_py(text: str, size: int, overlap: int): t = " ".join((text or "").split()); n=len(t); out=[]; s=0 if overlap >= size: overlap = max(size - 1, 0) while s0) sz,ov=int(size),int(overlap); if ov>=sz: ov=max(sz-1,0) @F.udf(returnType=StringType()) def chunk_udf(text: str, pid: str, meta_json: str) -> str: t=" ".join((text or "").split()); n=len(t); s=0; base=_j.loads(meta_json) if meta_json else {}; out=[] while s>")).alias("c")) out=exploded.select(F.col("c")["id"].alias("id"),F.col("c")["text"].alias("text"),F.col("c")["meta"].alias("meta_json")).collect() import json as _j2 return [{"id":r["id"],"text":r["text"],"meta":_j2.loads(r["meta_json"]) if r["meta_json"] else {}} for r in out] # --------- Vector index ---------- class VectorIndex: def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32): self.dim=dim; self.type=index_type; self.metric="ip"; self.nlist=nlist; self.M=M if index_type=="flat": self.index = faiss.IndexFlatIP(dim) elif index_type=="ivf": quant = faiss.IndexFlatIP(dim) self.index = faiss.IndexIVFFlat(quant, dim, max(1,nlist), faiss.METRIC_INNER_PRODUCT) elif index_type=="hnsw": self.index = faiss.IndexHNSWFlat(dim, max(4,M)); self.metric="l2" else: raise ValueError("bad index_type") def train(self, vecs: np.ndarray): if hasattr(self.index,"is_trained") and not self.index.is_trained: self.index.train(vecs) def add(self, vecs: np.ndarray): self.train(vecs); self.index.add(vecs) def search(self, qvec: np.ndarray, k: int): D,I = self.index.search(qvec,k) scores = (1.0 - 0.5*D[0]).tolist() if self.metric=="l2" else D[0].tolist() return I[0].tolist(), scores def save(self, fp: Path): faiss.write_index(self.index, str(fp)) @staticmethod def load(fp: Path) -> "VectorIndex": idx = faiss.read_index(str(fp)) vi = VectorIndex(idx.d, "flat"); vi.index = idx vi.metric = "ip" if isinstance(idx, faiss.IndexFlatIP) or "IVF" in str(type(idx)) else "l2" return vi # --------- Store ---------- class MemoryIndex: def __init__(self, dim: int, index_type: str = "flat", nlist: int = 64, M: int = 32): self.ids: List[str]=[]; self.texts: List[str]=[]; self.metas: List[Dict[str,Any]]=[] self.vindex = VectorIndex(dim, index_type=index_type, nlist=nlist, M=M) def add(self, vecs: np.ndarray, rows: List[Dict[str, Any]]): if vecs.shape[0]!=len(rows): raise ValueError("Vector count != row count") self.vindex.add(vecs) for r in rows: self.ids.append(r["id"]); self.texts.append(r["text"]); self.metas.append(r["meta"]) def size(self)->int: return self.vindex.index.ntotal def search(self, qvec: np.ndarray, k: int): return self.vindex.search(qvec,k) def save(self): self.vindex.save(INDEX_FP) with META_FP.open("w",encoding="utf-8") as f: for i in range(len(self.ids)): f.write(json.dumps({"id":self.ids[i],"text":self.texts[i],"meta":self.metas[i]})+"\n") try: df = pd.DataFrame({"id":self.ids,"text":self.texts,"meta_json":[json.dumps(m) for m in self.metas]}) df.to_parquet(PARQ_FP, index=False) except Exception: pass CFG_FP.write_text(json.dumps({"model":EMBED_MODEL_NAME,"dim":_embedder.dim,"index_type":self.vindex.type,"nlist":self.vindex.nlist,"M":self.vindex.M}),encoding="utf-8") @staticmethod def load_if_exists() -> Optional["MemoryIndex"]: if not INDEX_FP.exists() or not META_FP.exists(): return None cfg={"index_type":"flat","nlist":64,"M":32} if CFG_FP.exists(): try: cfg.update(json.loads(CFG_FP.read_text())) except Exception: pass vi = VectorIndex.load(INDEX_FP) store = MemoryIndex(dim=vi.dim, index_type=cfg.get("index_type","flat"), nlist=cfg.get("nlist",64), M=cfg.get("M",32)) store.vindex = vi ids,texts,metas=[],[],[] with META_FP.open("r",encoding="utf-8") as f: for line in f: rec=json.loads(line); ids.append(rec["id"]); texts.append(rec["text"]); metas.append(rec.get("meta",{})) store.ids,store.texts,store.metas=ids,texts,metas return store @staticmethod def reset_files(): for p in [INDEX_FP, META_FP, PARQ_FP, CFG_FP]: try: if p.exists(): p.unlink() except Exception: pass _mem_store: Optional[MemoryIndex] = MemoryIndex.load_if_exists() def require_store() -> MemoryIndex: if _mem_store is None or _mem_store.size()==0: raise HTTPException(status_code=400, detail="Index empty. Ingest documents first.") return _mem_store # --------- Helpers ---------- def _token_overlap(q: str, txt: str) -> float: qt={t for t in q.lower().split() if t}; tt={t for t in (txt or "").lower().split() if t} if not qt: return 0.0 return float(len(qt & tt))/float(len(qt)) def _topk(q: str, k: int) -> List[Match]: store=require_store(); qvec=_embedder.encode([q], normalize=True) idxs,scores=store.search(qvec,k); out=[] for i,s in zip(idxs,scores): if i==-1: continue out.append(Match(id=store.ids[i], score=float(s), text=store.texts[i], meta=store.metas[i])) return out def _compose_contexts(matches: List[Match], max_chars: int) -> str: buf=[]; total=0 for m in matches: t=m.text or ""; cut=min(len(t), max_chars-total) if cut<=0: break buf.append(t[:cut]); total+=cut if total>=max_chars: break return "\n\n".join(buf).strip() def _answer_with_mock(q: str, contexts: str) -> str: if not contexts: return "No indexed context available to answer the question." lines=[ln.strip() for ln in contexts.split("\n") if ln.strip()] hits=[ln for ln in lines if any(t in ln.lower() for t in q.lower().split())] if not hits: hits=lines[:2] return "Based on retrieved context, here’s a concise answer:\n- " + "\n- ".join(hits[:4]) def _maybe_rerank(q: str, matches: List[Match], enabled: bool, model_name: str) -> List[Match]: if not enabled: return matches try: _reranker.ensure(model_name) scores=_reranker.score(q, [m.text or "" for m in matches]) order=sorted(range(len(matches)), key=lambda i: scores[i], reverse=True) return [matches[i] for i in order] except Exception: return matches def _write_parquet_if_missing(): if not PARQ_FP.exists() and META_FP.exists(): try: rows=[json.loads(line) for line in META_FP.open("r",encoding="utf-8")] if rows: pd.DataFrame({"id":[r["id"] for r in rows], "text":[r["text"] for r in rows], "meta_json":[json.dumps(r.get("meta",{})) for r in rows]}).to_parquet(PARQ_FP,index=False) except Exception: pass # --------- Auth/limits/metrics ---------- API_KEY = os.getenv("API_KEY","") _rate = {"capacity":60,"refill_per_sec":1.0} _buckets: Dict[str, Dict[str, float]] = {} _metrics = {"requests":0,"by_endpoint":{}, "started": time.time()} def _allow(ip: str) -> bool: now=time.time(); b=_buckets.get(ip,{"tokens":_rate["capacity"],"ts":now}) tokens=min(b["tokens"]+(now-b["ts"])*_rate["refill_per_sec"], _rate["capacity"]) if tokens<1.0: _buckets[ip]={"tokens":tokens,"ts":now}; return False _buckets[ip]={"tokens":tokens-1.0,"ts":now}; return True async def guard(request: Request): if API_KEY and request.headers.get("x-api-key","")!=API_KEY: raise HTTPException(status_code=401, detail="invalid api key") ip=request.client.host if request.client else "local" if not _allow(ip): raise HTTPException(status_code=429, detail="rate limited") app = FastAPI(title="RAG-as-a-Service", version=APP_VERSION, description="Steps 10–13") @app.middleware("http") async def req_meta(request: Request, call_next): rid=str(uuid.uuid4()); _metrics["requests"]+=1 ep=f"{request.method} {request.url.path}"; _metrics["by_endpoint"][ep]=_metrics["by_endpoint"].get(ep,0)+1 resp=await call_next(request) try: resp.headers["x-request-id"]=rid except Exception: pass return resp # --------- API ---------- @app.get("/", response_class=HTMLResponse) def root(): return """RAG-as-a-Service

RAG-as-a-Service


"""

@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
    size=_mem_store.size() if _mem_store is not None else 0
    persisted=INDEX_FP.exists() and META_FP.exists()
    idx_type="flat"
    if CFG_FP.exists():
        try: idx_type=json.loads(CFG_FP.read_text()).get("index_type","flat")
        except Exception: pass
    return HealthResponse(status="ok", version=APP_VERSION, index_size=size, model=EMBED_MODEL_NAME, spark=SPARK_AVAILABLE, persisted=persisted, rerank=RERANK_AVAILABLE, index_type=idx_type)

@app.get("/metrics")
def metrics():
    up=time.time()-_metrics["started"]
    return {"requests":_metrics["requests"],"by_endpoint":_metrics["by_endpoint"],"uptime_sec":round(up,2)}

@app.post("/echo", dependencies=[Depends(guard)])
def echo(payload: EchoRequest) -> Dict[str, str]:
    return {"echo": payload.message, "length": str(len(payload.message))}

@app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(guard)])
def embed(payload: EmbedRequest) -> EmbedResponse:
    vecs=_embedder.encode(payload.texts, normalize=payload.normalize)
    preview=[[float(round(v,5)) for v in row[:payload.preview_n]] for row in vecs] if payload.preview_n>0 else []
    return EmbedResponse(dim=int(vecs.shape[1]), count=int(vecs.shape[0]), preview=preview)

@app.post("/ingest", dependencies=[Depends(guard)])
def ingest(req: IngestRequest) -> Dict[str, Any]:
    global _mem_store
    if _mem_store is None:
        cfg={"index_type":"flat","nlist":64,"M":32}
        if CFG_FP.exists():
            try: cfg.update(json.loads(CFG_FP.read_text()))
            except Exception: pass
        _mem_store=MemoryIndex(dim=_embedder.dim, index_type=cfg["index_type"], nlist=cfg["nlist"], M=cfg["M"])
    use_spark=SPARK_AVAILABLE if req.use_spark is None else bool(req.use_spark)
    rows=[]
    if use_spark:
        try: rows=spark_clean_and_chunk(req.docs, size=req.chunk.size, overlap=req.chunk.overlap)
        except Exception: rows=[]
    if not rows:
        for d in req.docs:
            pid=d.id or "doc"
            for ctext,(s,e) in chunk_text_py(d.text, size=req.chunk.size, overlap=req.chunk.overlap):
                meta=dict(d.meta); meta.update({"parent_id":pid,"start":s,"end":e})
                rows.append({"id":f"{pid}::offset:{s}-{e}","text":ctext,"meta":meta})
    if not rows: raise HTTPException(status_code=400, detail="No non-empty chunks produced")
    vecs=_embedder.encode([r["text"] for r in rows], normalize=req.normalize)
    _mem_store.add(vecs, rows); _mem_store.save(); 
    if not PARQ_FP.exists(): 
        try:
            pd.DataFrame({"id":[r["id"] for r in rows],"text":[r["text"] for r in rows],"meta_json":[json.dumps(r["meta"]) for r in rows]}).to_parquet(PARQ_FP,index=False)
        except Exception: pass
    return {"docs": len(req.docs), "chunks": len(rows), "index_size": _mem_store.size(), "engine": "spark" if use_spark else "python", "persisted": True}

@app.post("/query", response_model=QueryResponse, dependencies=[Depends(guard)])
def query(req: QueryRequest) -> QueryResponse:
    matches=_topk(req.q, req.k)
    if not req.return_text: matches=[Match(id=m.id, score=m.score, text=None, meta=m.meta) for m in matches]
    return QueryResponse(matches=matches)

@app.post("/explain", response_model=ExplainResponse, dependencies=[Depends(guard)])
def explain(req: ExplainRequest) -> ExplainResponse:
    matches=_topk(req.q, req.k); out=[]
    for m in matches:
        meta=m.meta; start=int(meta.get("start",0)); end=int(meta.get("end",0))
        out.append(ExplainMatch(id=m.id, score=m.score, text=m.text if req.return_text else None, meta=meta, start=start, end=end, token_overlap=float(round(_token_overlap(req.q, m.text or ""),4))))
    return ExplainResponse(matches=out)

@app.post("/answer", response_model=AnswerResponse, dependencies=[Depends(guard)])
def answer(req: AnswerRequest) -> AnswerResponse:
    matches=_topk(req.q, req.k)
    matches=_maybe_rerank(req.q, matches, enabled=req.rerank, model_name=req.rerank_model)
    ctx=_compose_contexts(matches, req.max_context_chars)
    out=_answer_with_mock(req.q, ctx) if req.model=="mock" else _answer_with_mock(req.q, ctx)
    return AnswerResponse(answer=out, contexts=matches if req.return_contexts else [])

@app.post("/reindex", dependencies=[Depends(guard)])
def reindex(params: ReindexParams) -> Dict[str, Any]:
    global _mem_store
    if not META_FP.exists():
        raise HTTPException(status_code=400, detail="no metadata on disk")

    rows = [json.loads(line) for line in META_FP.open("r", encoding="utf-8")]
    if not rows:
        raise HTTPException(status_code=400, detail="empty metadata")

    texts = [r["text"] for r in rows]
    vecs = _embedder.encode(texts, normalize=True)

    # Cap nlist to dataset size for IVF
    idx_type = params.index_type
    eff_nlist = params.nlist
    if idx_type == "ivf":
        eff_nlist = max(1, min(eff_nlist, len(rows)))

    try:
        _mem_store = MemoryIndex(dim=_embedder.dim, index_type=idx_type, nlist=eff_nlist, M=params.M)
        _mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows])
        _mem_store.save()
        return {
            "reindexed": True,
            "index_type": idx_type,
            "index_size": _mem_store.size(),
            "nlist": eff_nlist,
            "M": params.M
        }
    except Exception as e:
        # Fallback to flat if IVF/HNSW training/add fails for any reason
        _mem_store = MemoryIndex(dim=_embedder.dim, index_type="flat")
        _mem_store.add(vecs, [{"id": r["id"], "text": r["text"], "meta": r.get("meta", {})} for r in rows])
        _mem_store.save()
        return {
            "reindexed": True,
            "index_type": "flat",
            "index_size": _mem_store.size(),
            "note": f"fallback due to: {str(e)[:120]}"
        }
@app.post("/reset", dependencies=[Depends(guard)])
def reset() -> Dict[str, Any]:
    global _mem_store; _mem_store=None; MemoryIndex.reset_files(); return {"reset": True}

@app.post("/bulk_load_hf", dependencies=[Depends(guard)])
def bulk_load_hf(repo: str, split: str = "train", text_field: str = "text", id_field: Optional[str]=None, meta_fields: Optional[List[str]]=None, chunk_size:int=800, overlap:int=120):
    try:
        from datasets import load_dataset
        ds = load_dataset(repo, split=split)
        docs=[]
        for rec in ds:
            rid = str(rec[id_field]) if id_field and id_field in rec else None
            meta = {k: rec[k] for k in (meta_fields or []) if k in rec}
            docs.append(Doc(id=rid, text=str(rec[text_field]), meta=meta))
        return ingest(IngestRequest(docs=docs, chunk=ChunkConfig(size=chunk_size, overlap=overlap), normalize=True))
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"bulk_load_hf failed: {e}")