|
from fastapi import FastAPI, HTTPException, Depends |
|
from fastapi.security import HTTPBearer |
|
from pydantic import BaseModel |
|
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config |
|
import torch |
|
import asyncio |
|
from contextlib import asynccontextmanager |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
model, tokenizer = None, None |
|
|
|
|
|
bearer_scheme = HTTPBearer() |
|
|
|
|
|
def load_model(): |
|
model_path = "./Ai-Text-Detector/model" |
|
weights_path = "./Ai-Text-Detector/model_weights.pth" |
|
|
|
try: |
|
tokenizer = GPT2TokenizerFast.from_pretrained(model_path) |
|
config = GPT2Config.from_pretrained(model_path) |
|
model = GPT2LMHeadModel(config) |
|
model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) |
|
model.eval() |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading model: {str(e)}") |
|
|
|
return model, tokenizer |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
global model, tokenizer |
|
model, tokenizer = load_model() |
|
yield |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
class TextInput(BaseModel): |
|
text: str |
|
|
|
|
|
def classify_text(sentence: str): |
|
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) |
|
input_ids = inputs["input_ids"] |
|
attention_mask = inputs["attention_mask"] |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) |
|
loss = outputs.loss |
|
perplexity = torch.exp(loss).item() |
|
|
|
if perplexity < 60: |
|
result = "AI-generated" |
|
elif perplexity < 80: |
|
result = "Probably AI-generated" |
|
else: |
|
result = "Human-written" |
|
|
|
return result, perplexity |
|
|
|
|
|
@app.post("/analyze") |
|
async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)): |
|
user_input = data.text.strip() |
|
|
|
if not user_input: |
|
raise HTTPException(status_code=400, detail="Text cannot be empty") |
|
|
|
|
|
word_count = len(user_input.split()) |
|
if word_count < 2: |
|
raise HTTPException(status_code=400, detail="Text must contain at least two words") |
|
|
|
|
|
result, perplexity = await asyncio.to_thread(classify_text, user_input) |
|
|
|
return { |
|
"result": result, |
|
"perplexity": round(perplexity, 2), |
|
} |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "ok"} |
|
|
|
|
|
@app.get("/") |
|
def index(): |
|
return { |
|
"message": "FastAPI API is up.", |
|
"try": "/docs to test the API.", |
|
"status": "OK" |
|
} |