File size: 7,537 Bytes
db897db
 
 
 
 
 
a72abd5
 
 
 
 
db897db
5a21ae7
db897db
 
3c6385c
db897db
 
d6c4fbd
5a21ae7
 
 
db897db
d6c4fbd
a72abd5
7ce0483
 
d6c4fbd
5a21ae7
 
a72abd5
5a21ae7
 
a72abd5
7ce0483
 
 
 
5a21ae7
db897db
d6c4fbd
5a21ae7
db897db
a72abd5
 
5a21ae7
db897db
a72abd5
 
 
db897db
a72abd5
db897db
a72abd5
db897db
a72abd5
 
d6c4fbd
a72abd5
 
5a21ae7
a72abd5
d6c4fbd
5a21ae7
 
a72abd5
 
 
 
d6c4fbd
a72abd5
 
d6c4fbd
5a21ae7
 
 
 
d6c4fbd
5a21ae7
7ce0483
d6c4fbd
a72abd5
 
 
 
3c6385c
db897db
7ce0483
a72abd5
db897db
a72abd5
7ce0483
 
a72abd5
5a21ae7
 
d6c4fbd
5a21ae7
a72abd5
5a21ae7
 
 
d6c4fbd
 
a72abd5
5a21ae7
a72abd5
 
 
5a21ae7
a72abd5
 
 
db897db
5a21ae7
 
a72abd5
 
 
5a21ae7
a72abd5
 
 
 
 
 
 
d6c4fbd
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import annotations
import os, re, html
from datetime import datetime, timedelta
from typing import List

import torch
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification,
)
from huggingface_hub import hf_hub_download
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.middleware.cors import CORSMiddleware
from jose import jwt, JWTError
from pydantic import BaseModel, Field

# ─────────────────────────  torch shim  ───────────────────────────────
if hasattr(torch, "compile"):
    torch.compile = (lambda m=None,*_,**__: m if callable(m) else (lambda f: f))  # type: ignore
    os.environ.setdefault("TORCHINDUCTOR_DISABLED", "1")

# ───────────────────────  remote‑code flag  ───────────────────────────
os.environ.setdefault("HF_ALLOW_CODE_IMPORT", "1")
TOKEN_KW = {"trust_remote_code": True}

# ───────────────────────────  config  ─────────────────────────────────
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHT_REPO = "Sleepyriizi/Orify-Text-Detection-Weights"
FILE_MAP    = {"ensamble_1":"ensamble_1","ensamble_2.bin":"ensamble_2.bin","ensamble_3":"ensamble_3"}
BASE_MODEL  = "answerdotai/ModernBERT-base"
NUM_LABELS  = 41
LABELS      = {i:n for i,n in enumerate([
    "13B","30B","65B","7B","GLM130B","bloom_7b","bloomz","cohere","davinci","dolly","dolly-v2-12b",
    "flan_t5_base","flan_t5_large","flan_t5_small","flan_t5_xl","flan_t5_xxl","gemma-7b-it","gemma2-9b-it",
    "gpt-3.5-turbo","gpt-35","gpt-4","gpt-4o","gpt-j","gpt-neox","human","llama3-70b","llama3-8b",
    "mixtral-8x7b","opt-1.3b","opt-125m","opt-13b","opt-2.7b","opt-30b","opt-350m","opt-6.7b",
    "opt-iml-30b","opt-iml-max-1.3b","t0-11b","t0-3b","text-davinci-002","text-davinci-003"])}

# ────────────────────────  JWT helpers  ──────────────────────────────
SECRET_KEY = os.getenv("SECRET_KEY")
if not SECRET_KEY:
    raise RuntimeError("SECRET_KEY env‑var not set – add it in Space settings → Secrets")
ALG="HS256"; EXP=24
oauth2 = OAuth2PasswordBearer(tokenUrl="token")

def _make_jwt(sub:str)->str:
    payload={"sub":sub,"exp":datetime.utcnow()+timedelta(hours=EXP)}
    return jwt.encode(payload,SECRET_KEY,algorithm=ALG)

def _verify_jwt(tok:str=Depends(oauth2)):
    try:
        return jwt.decode(tok,SECRET_KEY,algorithms=[ALG])["sub"]
    except JWTError:
        raise HTTPException(401,"Invalid or expired token")

# ───────────────────────  model bootstrap  ───────────────────────────
print("🔄 Fetching ensemble weights…", flush=True)
paths={k:hf_hub_download(WEIGHT_REPO,f,resume_download=True) for k,f in FILE_MAP.items()}

print("🧩 Building ModernBERT backbone…", flush=True)
_cfg = AutoConfig.from_pretrained(BASE_MODEL, **TOKEN_KW); _cfg.num_labels = NUM_LABELS
_tok = AutoTokenizer.from_pretrained(BASE_MODEL, **TOKEN_KW)
_models: List[AutoModelForSequenceClassification] = []
for p in paths.values():
    m = AutoModelForSequenceClassification.from_pretrained(
            BASE_MODEL,
            config=_cfg,
            ignore_mismatched_sizes=True,
            **TOKEN_KW,
        )
    m.load_state_dict(torch.load(p, map_location=DEVICE))
    m.to(DEVICE).eval()
    _models.append(m)
print(f"✅ Ensemble ready on {DEVICE}")

# ─────────────────────────  helpers  ─────────────────────────────────

def _tidy(t:str)->str:
    t=t.replace("\r\n","\n").replace("\r", "\n")
    t=re.sub(r"\n\s*\n+","\n\n",t)
    t=re.sub(r"[ \t]+"," ",t)
    t=re.sub(r"(\w+)-\n(\w+)",r"\1\2",t)
    t=re.sub(r"(?<!\n)\n(?!\n)"," ",t)
    return t.strip()

def _infer(seg:str):
    inp=_tok(seg,return_tensors="pt",truncation=True,padding=True).to(DEVICE)
    with torch.no_grad():
        probs=torch.stack([torch.softmax(m(**inp).logits,1) for m in _models]).mean(0)[0]
    ai_probs=probs.clone(); ai_probs[24]=0
    ai=ai_probs.sum().item()*100; human=100-ai
    top3=[LABELS[i] for i in torch.topk(ai_probs,3).indices.tolist()]
    return human, ai, top3

# ─────────────────────────  schemas  ─────────────────────────────────
class TokenOut(BaseModel): access_token:str; token_type:str="bearer"
class AnalyseIn(BaseModel): text:str=Field(...,min_length=1)
class Line(BaseModel): text:str; ai:float; human:float; top3:List[str]; reason:str
class AnalyseOut(BaseModel): verdict:str; confidence:float; ai_avg:float; human_avg:float; per_line:List[Line]; highlight_html:str

# ─────────────────────────  FastAPI app  ─────────────────────────────
app=FastAPI(title="Orify Text Detector API",version="1.2.0")
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_methods=["*"],allow_headers=["*"])

@app.post("/token",response_model=TokenOut)
async def token(form:OAuth2PasswordRequestForm=Depends()):
    return TokenOut(access_token=_make_jwt(form.username))

@app.post("/analyse",response_model=AnalyseOut)
async def analyse(body:AnalyseIn,_=Depends(_verify_jwt)):
    lines=_tidy(body.text).split("\n"); html_parts=[]; per=[]; h_sum=ai_sum=n=0.0
    for ln in lines:
        if not ln.strip():
            html_parts.append("<br>"); continue
        n+=1; human,ai,top3=_infer(ln); h_sum+=human; ai_sum+=ai
        cls="ai-line" if ai>human else "human-line"
        tip=f"AI {ai:.2f}% – Top-3: {', '.join(top3)}" if ai>human else f"Human {human:.2f}%"
        html_parts.append(f"<span class='{cls} prob-tooltip' title='{tip}'>{html.escape(ln)}</span>")
        reason=(f"High AI likelihood ({ai:.1f}%) – fingerprint ≈ {top3[0]}" if ai>human else f"Lexical variety suggests human ({human:.1f}%)")
        per.append(Line(text=ln,ai=ai,human=human,top3=top3,reason=reason))
    human_avg=h_sum/n if n else 0; ai_avg=ai_sum/n if n else 0
    verdict="AI-generated" if ai_avg>human_avg else "Human-written"; conf=max(human_avg,ai_avg)
    badge=(f"<span class='ai-line' style='padding:6px 10px;font-weight:bold'>AI-generated {ai_avg:.2f}%</span>" if verdict=="AI-generated" else f"<span class='human-line' style='padding:6px 10px;font-weight:bold'>Human-written {human_avg:.2f}%</span>")
    html_out=f"<h3>{badge}</h3><hr>"+"<br>".join(html_parts)
    return AnalyseOut(verdict=verdict,confidence=conf,ai_avg=ai_avg,human_avg=human_avg,per_line=per,highlight_html=html_out)

# ───────────────────────────  entrypoint  ────────────────────────────
if __name__ == "__main__":
    import uvicorn, sys
    port=int(os.environ.get("PORT", "7860"))
    uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info", reload=False)