Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import os, re, html | |
from datetime import datetime, timedelta | |
from typing import List | |
import torch | |
from transformers import ( | |
AutoConfig, | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
) | |
from huggingface_hub import hf_hub_download | |
from fastapi import FastAPI, HTTPException, Depends | |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
from fastapi.middleware.cors import CORSMiddleware | |
from jose import jwt, JWTError | |
from pydantic import BaseModel, Field | |
# ───────────────────────── torch shim ─────────────────────────────── | |
if hasattr(torch, "compile"): | |
torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore | |
os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1") | |
# ─────────────────────── remote‑code flag ─────────────────────────── | |
os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1") | |
TOKEN_KW = {"trust_remote_code": True} | |
# ─────────────────────────── config ───────────────────────────────── | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights" | |
FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"} | |
BASE_MODEL = "answerdotai/ModernBERT-base" | |
NUM_LABELS = 41 | |
LABELS = {i:n for i,n in enumerate([ | |
"13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b", | |
"flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it", | |
"gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b", | |
"mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b", | |
"opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"])} | |
# ──────────────────────── JWT helpers ────────────────────────────── | |
SECRET_KEY = os.getenv("SECRET_KEY") | |
if not SECRET_KEY: | |
raise RuntimeError("SECRET_KEY env‑var not set – add it in Space settings → Secrets") | |
ALG="HS256"; EXP=24 | |
oauth2 = OAuth2PasswordBearer(tokenUrl="token") | |
def _make_jwt(sub:str)->str: | |
payload={"sub":sub,"exp":datetime.utcnow()+timedelta(hours=EXP)} | |
return jwt.encode(payload,SECRET_KEY,algorithm=ALG) | |
def _verify_jwt(tok:str=Depends(oauth2)): | |
try: | |
return jwt.decode(tok,SECRET_KEY,algorithms=[ALG])["sub"] | |
except JWTError: | |
raise HTTPException(401,"Invalid or expired token") | |
# ─────────────────────── model bootstrap ─────────────────────────── | |
print("🔄 Fetching ensemble weights…", flush=True) | |
paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()} | |
print("🧩 Building ModernBERT backbone…", flush=True) | |
_cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW); _cfg.num_labels = NUM_LABELS | |
_tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW) | |
_models: List[AutoModelForSequenceClassification] = [] | |
for p in paths.values(): | |
m = AutoModelForSequenceClassification.from_pretrained( | |
BASE_MODEL, | |
config=_cfg, | |
ignore_mismatched_sizes=True, | |
**TOKEN_KW, | |
) | |
m.load_state_dict(torch.load(p, map_location=DEVICE)) | |
m.to(DEVICE).eval() | |
_models.append(m) | |
print(f"✅ Ensemble ready on {DEVICE}") | |
# ───────────────────────── helpers ───────────────────────────────── | |
def _tidy(t:str)->str: | |
t=t.replace("\r\n","\n").replace("\r", "\n") | |
t=re.sub(r"\n\s*\n+","\n\n",t) | |
t=re.sub(r"[ \t]+"," ",t) | |
t=re.sub(r"(\w+)-\n(\w+)",r"\1\2",t) | |
t=re.sub(r"(?<!\n)\n(?!\n)"," ",t) | |
return t.strip() | |
def _infer(seg:str): | |
inp=_tok(seg,return_tensors="pt",truncation=True,padding=True).to(DEVICE) | |
with torch.no_grad(): | |
probs=torch.stack([torch.softmax(m(**inp).logits,1) for m in _models]).mean(0)[0] | |
ai_probs=probs.clone(); ai_probs[24]=0 | |
ai=ai_probs.sum().item()*100; human=100-ai | |
top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()] | |
return human, ai, top3 | |
# ───────────────────────── schemas ───────────────────────────────── | |
class TokenOut(BaseModel): access_token:str; token_type:str="bearer" | |
class AnalyseIn(BaseModel): text:str=Field(...,min_length=1) | |
class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str | |
class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str | |
# ───────────────────────── FastAPI app ───────────────────────────── | |
app=FastAPI(title="Orify Text Detector API",version="1.2.0") | |
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"]) | |
async def token(form:OAuth2PasswordRequestForm=Depends()): | |
return TokenOut(access_token=_make_jwt(form.username)) | |
async def analyse(body:AnalyseIn,_=Depends(_verify_jwt)): | |
lines=_tidy(body.text).split("\n"); html_parts=[]; per=[]; h_sum=ai_sum=n=0.0 | |
for ln in lines: | |
if not ln.strip(): | |
html_parts.append("<br>"); continue | |
n+=1; human,ai,top3=_infer(ln); h_sum+=human; ai_sum+=ai | |
cls="ai-line" if ai>human else "human-line" | |
tip=f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai>human else f"Human {human:.2f}%" | |
html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>") | |
reason=(f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai>human else f"Lexical variety suggests human ({human:.1f}%)") | |
per.append(Line(text=ln,ai=ai,human=human,top3=top3,reason=reason)) | |
human_avg=h_sum/n if n else 0; ai_avg=ai_sum/n if n else 0 | |
verdict="AI-generated" if ai_avg>human_avg else "Human-written"; conf=max(human_avg,ai_avg) | |
badge=(f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>") | |
html_out=f"<h3>{badge}</h3><hr>"+"<br>".join(html_parts) | |
return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,per_line=per,highlight_html=html_out) | |
# ─────────────────────────── entrypoint ──────────────────────────── | |
if __name__ == "__main__": | |
import uvicorn, sys | |
port=int(os.environ.get("PORT", "7860")) | |
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info", reload=False) | |