from fastapi import FastAPI, HTTPException, Depends from fastapi.security import HTTPBearer from pydantic import BaseModel from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config import torch import asyncio from contextlib import asynccontextmanager # FastAPI app instance app = FastAPI() # Global model and tokenizer variables model, tokenizer = None, None # HTTPBearer instance for security bearer_scheme = HTTPBearer() # Function to load model and tokenizer def load_model(): model_path = "./Ai-Text-Detector/model" weights_path = "./Ai-Text-Detector/model_weights.pth" try: tokenizer = GPT2TokenizerFast.from_pretrained(model_path) config = GPT2Config.from_pretrained(model_path) model = GPT2LMHeadModel(config) model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) model.eval() except Exception as e: raise RuntimeError(f"Error loading model: {str(e)}") return model, tokenizer # Load model on app startup @asynccontextmanager async def lifespan(app: FastAPI): global model, tokenizer model, tokenizer = load_model() yield # Attach startup loader app = FastAPI(lifespan=lifespan) # Input schema class TextInput(BaseModel): text: str # Sync text classification def classify_text(sentence: str): inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) loss = outputs.loss perplexity = torch.exp(loss).item() if perplexity < 60: result = "AI-generated" elif perplexity < 80: result = "Probably AI-generated" else: result = "Human-written" return result, perplexity # POST route to analyze text with Bearer token @app.post("/analyze") async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)): user_input = data.text.strip() if not user_input: raise HTTPException(status_code=400, detail="Text cannot be empty") # Check if there are at least two words word_count = len(user_input.split()) if word_count < 2: raise HTTPException(status_code=400, detail="Text must contain at least two words") result, perplexity = await asyncio.to_thread(classify_text, user_input) return { "result": result, "perplexity": round(perplexity, 2), } # Health check route @app.get("/health") async def health_check(): return {"status": "ok"} # Simple index route @app.get("/") def index(): return { "message": "FastAPI API is up.", "try": "/docs to test the API.", "status": "OK" }