Spaces:
Sleeping
Sleeping
File size: 7,537 Bytes
db897db a72abd5 db897db 5a21ae7 db897db 3c6385c db897db d6c4fbd 5a21ae7 db897db d6c4fbd a72abd5 7ce0483 d6c4fbd 5a21ae7 a72abd5 5a21ae7 a72abd5 7ce0483 5a21ae7 db897db d6c4fbd 5a21ae7 db897db a72abd5 5a21ae7 db897db a72abd5 db897db a72abd5 db897db a72abd5 db897db a72abd5 d6c4fbd a72abd5 5a21ae7 a72abd5 d6c4fbd 5a21ae7 a72abd5 d6c4fbd a72abd5 d6c4fbd 5a21ae7 d6c4fbd 5a21ae7 7ce0483 d6c4fbd a72abd5 3c6385c db897db 7ce0483 a72abd5 db897db a72abd5 7ce0483 a72abd5 5a21ae7 d6c4fbd 5a21ae7 a72abd5 5a21ae7 d6c4fbd a72abd5 5a21ae7 a72abd5 5a21ae7 a72abd5 db897db 5a21ae7 a72abd5 5a21ae7 a72abd5 d6c4fbd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from __future__ import annotations
import os, re, html
from datetime import datetime, timedelta
from typing import List
import torch
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForSequenceClassification,
)
from huggingface_hub import hf_hub_download
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.middleware.cors import CORSMiddleware
from jose import jwt, JWTError
from pydantic import BaseModel, Field
# ───────────────────────── torch shim ───────────────────────────────
if hasattr(torch, "compile"):
torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f)) # type: ignore
os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")
# ─────────────────────── remote‑code flag ───────────────────────────
os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1")
TOKEN_KW = {"trust_remote_code": True}
# ─────────────────────────── config ─────────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
FILE_MAP = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
BASE_MODEL = "answerdotai/ModernBERT-base"
NUM_LABELS = 41
LABELS = {i:n for i,n in enumerate([
"13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b",
"flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it",
"gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b",
"mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b",
"opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"])}
# ──────────────────────── JWT helpers ──────────────────────────────
SECRET_KEY = os.getenv("SECRET_KEY")
if not SECRET_KEY:
raise RuntimeError("SECRET_KEY env‑var not set – add it in Space settings → Secrets")
ALG="HS256"; EXP=24
oauth2 = OAuth2PasswordBearer(tokenUrl="token")
def _make_jwt(sub:str)->str:
payload={"sub":sub,"exp":datetime.utcnow()+timedelta(hours=EXP)}
return jwt.encode(payload,SECRET_KEY,algorithm=ALG)
def _verify_jwt(tok:str=Depends(oauth2)):
try:
return jwt.decode(tok,SECRET_KEY,algorithms=[ALG])["sub"]
except JWTError:
raise HTTPException(401,"Invalid or expired token")
# ─────────────────────── model bootstrap ───────────────────────────
print("🔄 Fetching ensemble weights…", flush=True)
paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}
print("🧩 Building ModernBERT backbone…", flush=True)
_cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW); _cfg.num_labels = NUM_LABELS
_tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
_models: List[AutoModelForSequenceClassification] = []
for p in paths.values():
m = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL,
config=_cfg,
ignore_mismatched_sizes=True,
**TOKEN_KW,
)
m.load_state_dict(torch.load(p, map_location=DEVICE))
m.to(DEVICE).eval()
_models.append(m)
print(f"✅ Ensemble ready on {DEVICE}")
# ───────────────────────── helpers ─────────────────────────────────
def _tidy(t:str)->str:
t=t.replace("\r\n","\n").replace("\r", "\n")
t=re.sub(r"\n\s*\n+","\n\n",t)
t=re.sub(r"[ \t]+"," ",t)
t=re.sub(r"(\w+)-\n(\w+)",r"\1\2",t)
t=re.sub(r"(?<!\n)\n(?!\n)"," ",t)
return t.strip()
def _infer(seg:str):
inp=_tok(seg,return_tensors="pt",truncation=True,padding=True).to(DEVICE)
with torch.no_grad():
probs=torch.stack([torch.softmax(m(**inp).logits,1) for m in _models]).mean(0)[0]
ai_probs=probs.clone(); ai_probs[24]=0
ai=ai_probs.sum().item()*100; human=100-ai
top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
return human, ai, top3
# ───────────────────────── schemas ─────────────────────────────────
class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str
# ───────────────────────── FastAPI app ─────────────────────────────
app=FastAPI(title="Orify Text Detector API",version="1.2.0")
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])
@app.post("/token",response_model=TokenOut)
async def token(form:OAuth2PasswordRequestForm=Depends()):
return TokenOut(access_token=_make_jwt(form.username))
@app.post("/analyse",response_model=AnalyseOut)
async def analyse(body:AnalyseIn,_=Depends(_verify_jwt)):
lines=_tidy(body.text).split("\n"); html_parts=[]; per=[]; h_sum=ai_sum=n=0.0
for ln in lines:
if not ln.strip():
html_parts.append("<br>"); continue
n+=1; human,ai,top3=_infer(ln); h_sum+=human; ai_sum+=ai
cls="ai-line" if ai>human else "human-line"
tip=f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai>human else f"Human {human:.2f}%"
html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
reason=(f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai>human else f"Lexical variety suggests human ({human:.1f}%)")
per.append(Line(text=ln,ai=ai,human=human,top3=top3,reason=reason))
human_avg=h_sum/n if n else 0; ai_avg=ai_sum/n if n else 0
verdict="AI-generated" if ai_avg>human_avg else "Human-written"; conf=max(human_avg,ai_avg)
badge=(f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
html_out=f"<h3>{badge}</h3><hr>"+"<br>".join(html_parts)
return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,per_line=per,highlight_html=html_out)
# ─────────────────────────── entrypoint ────────────────────────────
if __name__ == "__main__":
import uvicorn, sys
port=int(os.environ.get("PORT", "7860"))
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info", reload=False)
|