Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Depends, UploadFile, File | |
| from fastapi.security import HTTPBearer | |
| from pydantic import BaseModel | |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config | |
| import torch | |
| import os | |
| import asyncio | |
| from contextlib import asynccontextmanager | |
| import logging | |
| from io import BytesIO | |
| import docx | |
| import fitz # PyMuPDF | |
| # Load environment variables | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| SECRET_TOKEN = os.getenv("SECRET_TOKEN") | |
| bearer_scheme = HTTPBearer() | |
| # Ai-Text-Detector | |
| MODEL_PATH = "./Ai-Text-Detector/model" | |
| WEIGHTS_PATH = "./Ai-Text-Detector/model_weights.pth" | |
| # FastAPI app instance | |
| app = FastAPI() | |
| # Global model and tokenizer variables | |
| model, tokenizer = None, None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Logging setup | |
| logging.basicConfig(level=logging.DEBUG) | |
| # Load model and tokenizer function | |
| def load_model(): | |
| global model, tokenizer | |
| try: | |
| tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH) | |
| config = GPT2Config.from_pretrained(MODEL_PATH) | |
| model_instance = GPT2LMHeadModel(config) | |
| model_instance.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device)) | |
| model_instance.to(device) | |
| model_instance.eval() | |
| model, tokenizer = model_instance, tokenizer | |
| logging.info("Model loaded successfully.") | |
| except Exception as e: | |
| logging.error(f"Error loading model: {str(e)}") | |
| raise RuntimeError(f"Error loading model: {str(e)}") | |
| # Load model on app startup | |
| async def lifespan(app: FastAPI): | |
| load_model() # Load model when FastAPI app starts | |
| yield | |
| # Attach the lifespan to the app instance | |
| app = FastAPI(lifespan=lifespan) | |
| # Input schema for text analysis | |
| class TextInput(BaseModel): | |
| text: str | |
| # Function to classify text using the model | |
| def classify_text(text: str): | |
| if not model or not tokenizer: | |
| raise RuntimeError("Model or tokenizer not loaded.") | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| input_ids = inputs["input_ids"].to(device) | |
| attention_mask = inputs["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) | |
| loss = outputs.loss | |
| perplexity = torch.exp(loss).item() | |
| if perplexity < 60: | |
| return "AI-generated", perplexity | |
| elif perplexity < 80: | |
| return "Probably AI-generated", perplexity | |
| else: | |
| return "Human-written", perplexity | |
| # POST route to analyze text with Bearer token | |
| async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)): | |
| # Verify token | |
| if token.credentials != SECRET_TOKEN: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| text = data.text.strip() | |
| # Input validation | |
| if not text: | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| if len(text.split()) < 2: | |
| raise HTTPException(status_code=400, detail="Text must contain at least two words") | |
| try: | |
| # Classify text | |
| label, perplexity = await asyncio.to_thread(classify_text, text) | |
| return {"result": label, "perplexity": round(perplexity, 2)} | |
| except Exception as e: | |
| logging.error(f"Error processing text: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Model processing error") | |
| # Function to parse .docx files | |
| def parse_docx(file: BytesIO): | |
| doc = docx.Document(file) | |
| text = "" | |
| for para in doc.paragraphs: | |
| text += para.text + "\n" | |
| return text | |
| # Function to parse .pdf files | |
| def parse_pdf(file: BytesIO): | |
| try: | |
| doc = fitz.open(stream=file, filetype="pdf") | |
| text = "" | |
| for page_num in range(doc.page_count): | |
| page = doc.load_page(page_num) | |
| text += page.get_text() | |
| return text | |
| except Exception as e: | |
| logging.error(f"Error while processing PDF: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Error processing PDF file") | |
| # Function to parse .txt files | |
| def parse_txt(file: BytesIO): | |
| return file.read().decode("utf-8") | |
| # POST route to upload files and analyze content | |
| async def upload_file(file: UploadFile = File(...), token: str = Depends(bearer_scheme)): | |
| file_contents = None | |
| try: | |
| if file.content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': | |
| file_contents = parse_docx(BytesIO(await file.read())) | |
| elif file.content_type == 'application/pdf': | |
| file_contents = parse_pdf(BytesIO(await file.read())) | |
| elif file.content_type == 'text/plain': | |
| file_contents = parse_txt(BytesIO(await file.read())) | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid file type. Only .docx, .pdf, and .txt are allowed.") | |
| logging.debug(f"Extracted Text from {file.filename}:\n{file_contents}") | |
| # Check if the text length exceeds 10,000 characters | |
| if len(file_contents) > 10000: | |
| return {"message": "File contains more than 10,000 characters."} | |
| # Clean the text by removing newline and tab characters | |
| cleaned_text = file_contents.replace("\n", "").replace("\t", "") | |
| # Analyze the cleaned text | |
| label, perplexity = await asyncio.to_thread(classify_text, cleaned_text) | |
| return {"result": label, "perplexity": round(perplexity, 2)} | |
| except Exception as e: | |
| logging.error(f"Error processing file: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Error processing the file") | |
| # Health check route | |
| async def health_check(): | |
| return {"status": "ok"} | |
| # Simple index route | |
| def index(): | |
| return { | |
| "message": "FastAPI AI Text Detector is running.", | |
| "usage": "Use /docs or /analyze to test the API." | |
| } | |