File size: 2,815 Bytes
b59d3a6 e9f0d54 b59d3a6 e9f0d54 b59d3a6 e9f0d54 b59d3a6 e9f0d54 b59d3a6 e9f0d54 b59d3a6 e9f0d54 b59d3a6 e9f0d54 b59d3a6 e9f0d54 b59d3a6 e9f0d54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import HTTPBearer
from pydantic import BaseModel
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
import torch
import asyncio
from contextlib import asynccontextmanager
# FastAPI app instance
app = FastAPI()
# Global model and tokenizer variables
model, tokenizer = None, None
# HTTPBearer instance for security
bearer_scheme = HTTPBearer()
# Function to load model and tokenizer
def load_model():
model_path = "./Ai-Text-Detector/model"
weights_path = "./Ai-Text-Detector/model_weights.pth"
try:
tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
config = GPT2Config.from_pretrained(model_path)
model = GPT2LMHeadModel(config)
model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
model.eval()
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
return model, tokenizer
# Load model on app startup
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, tokenizer
model, tokenizer = load_model()
yield
# Attach startup loader
app = FastAPI(lifespan=lifespan)
# Input schema
class TextInput(BaseModel):
text: str
# Sync text classification
def classify_text(sentence: str):
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
if perplexity < 60:
result = "AI-generated"
elif perplexity < 80:
result = "Probably AI-generated"
else:
result = "Human-written"
return result, perplexity
# POST route to analyze text with Bearer token
@app.post("/analyze")
async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
user_input = data.text.strip()
if not user_input:
raise HTTPException(status_code=400, detail="Text cannot be empty")
# Check if there are at least two words
word_count = len(user_input.split())
if word_count < 2:
raise HTTPException(status_code=400, detail="Text must contain at least two words")
result, perplexity = await asyncio.to_thread(classify_text, user_input)
return {
"result": result,
"perplexity": round(perplexity, 2),
}
# Health check route
@app.get("/health")
async def health_check():
return {"status": "ok"}
# Simple index route
@app.get("/")
def index():
return {
"message": "FastAPI API is up.",
"try": "/docs to test the API.",
"status": "OK"
}
|