feat: add file parsing support and enforce 10,000-character limit
Browse files- Added support for analyzing .docx, .pdf, and .txt files via the /upload endpoint
- Implemented file parsing logic using python-docx, PyMuPDF (fitz), and basic text decoding
- Enforced a 10,000-character limit to prevent model overload on large files
- Cleaned and sanitized file contents before analysis to improve model accuracy
- Added logging to debug file parsing and analysis results
app.py
CHANGED
@@ -1,55 +1,73 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException, Depends
|
2 |
from fastapi.security import HTTPBearer
|
3 |
from pydantic import BaseModel
|
4 |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
|
5 |
import torch
|
|
|
6 |
import asyncio
|
7 |
from contextlib import asynccontextmanager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
# FastAPI app instance
|
10 |
app = FastAPI()
|
11 |
|
12 |
# Global model and tokenizer variables
|
13 |
model, tokenizer = None, None
|
|
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
|
18 |
-
#
|
19 |
def load_model():
|
20 |
-
|
21 |
-
weights_path = "./Ai-Text-Detector/model_weights.pth"
|
22 |
-
|
23 |
try:
|
24 |
-
tokenizer = GPT2TokenizerFast.from_pretrained(
|
25 |
-
config = GPT2Config.from_pretrained(
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
29 |
except Exception as e:
|
|
|
30 |
raise RuntimeError(f"Error loading model: {str(e)}")
|
31 |
|
32 |
-
return model, tokenizer
|
33 |
-
|
34 |
# Load model on app startup
|
35 |
@asynccontextmanager
|
36 |
async def lifespan(app: FastAPI):
|
37 |
-
|
38 |
-
model, tokenizer = load_model()
|
39 |
yield
|
40 |
|
41 |
-
# Attach
|
42 |
app = FastAPI(lifespan=lifespan)
|
43 |
|
44 |
-
# Input schema
|
45 |
class TextInput(BaseModel):
|
46 |
text: str
|
47 |
|
48 |
-
#
|
49 |
-
def classify_text(
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
53 |
|
54 |
with torch.no_grad():
|
55 |
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
|
@@ -57,34 +75,91 @@ def classify_text(sentence: str):
|
|
57 |
perplexity = torch.exp(loss).item()
|
58 |
|
59 |
if perplexity < 60:
|
60 |
-
|
61 |
elif perplexity < 80:
|
62 |
-
|
63 |
else:
|
64 |
-
|
65 |
-
|
66 |
-
return result, perplexity
|
67 |
|
68 |
# POST route to analyze text with Bearer token
|
69 |
@app.post("/analyze")
|
70 |
async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
|
|
74 |
raise HTTPException(status_code=400, detail="Text cannot be empty")
|
75 |
|
76 |
-
|
77 |
-
word_count = len(user_input.split())
|
78 |
-
if word_count < 2:
|
79 |
raise HTTPException(status_code=400, detail="Text must contain at least two words")
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
return
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# Health check route
|
90 |
@app.get("/health")
|
@@ -95,7 +170,6 @@ async def health_check():
|
|
95 |
@app.get("/")
|
96 |
def index():
|
97 |
return {
|
98 |
-
"message": "FastAPI
|
99 |
-
"
|
100 |
-
"status": "OK"
|
101 |
}
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Depends, UploadFile, File
|
2 |
from fastapi.security import HTTPBearer
|
3 |
from pydantic import BaseModel
|
4 |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
|
5 |
import torch
|
6 |
+
import os
|
7 |
import asyncio
|
8 |
from contextlib import asynccontextmanager
|
9 |
+
import logging
|
10 |
+
from io import BytesIO
|
11 |
+
import docx
|
12 |
+
import fitz # PyMuPDF
|
13 |
+
|
14 |
+
# Load environment variables
|
15 |
+
from dotenv import load_dotenv
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
SECRET_TOKEN = os.getenv("SECRET_TOKEN")
|
19 |
+
bearer_scheme = HTTPBearer()
|
20 |
+
# Ai-Text-Detector
|
21 |
+
MODEL_PATH = "./Ai-Text-Detector/model"
|
22 |
+
WEIGHTS_PATH = "./Ai-Text-Detector/model_weights.pth"
|
23 |
|
24 |
# FastAPI app instance
|
25 |
app = FastAPI()
|
26 |
|
27 |
# Global model and tokenizer variables
|
28 |
model, tokenizer = None, None
|
29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
|
31 |
+
# Logging setup
|
32 |
+
logging.basicConfig(level=logging.DEBUG)
|
33 |
|
34 |
+
# Load model and tokenizer function
|
35 |
def load_model():
|
36 |
+
global model, tokenizer
|
|
|
|
|
37 |
try:
|
38 |
+
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
|
39 |
+
config = GPT2Config.from_pretrained(MODEL_PATH)
|
40 |
+
model_instance = GPT2LMHeadModel(config)
|
41 |
+
model_instance.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
|
42 |
+
model_instance.to(device)
|
43 |
+
model_instance.eval()
|
44 |
+
model, tokenizer = model_instance, tokenizer
|
45 |
+
logging.info("Model loaded successfully.")
|
46 |
except Exception as e:
|
47 |
+
logging.error(f"Error loading model: {str(e)}")
|
48 |
raise RuntimeError(f"Error loading model: {str(e)}")
|
49 |
|
|
|
|
|
50 |
# Load model on app startup
|
51 |
@asynccontextmanager
|
52 |
async def lifespan(app: FastAPI):
|
53 |
+
load_model() # Load model when FastAPI app starts
|
|
|
54 |
yield
|
55 |
|
56 |
+
# Attach the lifespan to the app instance
|
57 |
app = FastAPI(lifespan=lifespan)
|
58 |
|
59 |
+
# Input schema for text analysis
|
60 |
class TextInput(BaseModel):
|
61 |
text: str
|
62 |
|
63 |
+
# Function to classify text using the model
|
64 |
+
def classify_text(text: str):
|
65 |
+
if not model or not tokenizer:
|
66 |
+
raise RuntimeError("Model or tokenizer not loaded.")
|
67 |
+
|
68 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
69 |
+
input_ids = inputs["input_ids"].to(device)
|
70 |
+
attention_mask = inputs["attention_mask"].to(device)
|
71 |
|
72 |
with torch.no_grad():
|
73 |
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
|
|
|
75 |
perplexity = torch.exp(loss).item()
|
76 |
|
77 |
if perplexity < 60:
|
78 |
+
return "AI-generated", perplexity
|
79 |
elif perplexity < 80:
|
80 |
+
return "Probably AI-generated", perplexity
|
81 |
else:
|
82 |
+
return "Human-written", perplexity
|
|
|
|
|
83 |
|
84 |
# POST route to analyze text with Bearer token
|
85 |
@app.post("/analyze")
|
86 |
async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
|
87 |
+
# Verify token
|
88 |
+
if token.credentials != SECRET_TOKEN:
|
89 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
90 |
+
|
91 |
+
text = data.text.strip()
|
92 |
|
93 |
+
# Input validation
|
94 |
+
if not text:
|
95 |
raise HTTPException(status_code=400, detail="Text cannot be empty")
|
96 |
|
97 |
+
if len(text.split()) < 2:
|
|
|
|
|
98 |
raise HTTPException(status_code=400, detail="Text must contain at least two words")
|
99 |
|
100 |
+
try:
|
101 |
+
# Classify text
|
102 |
+
label, perplexity = await asyncio.to_thread(classify_text, text)
|
103 |
+
return {"result": label, "perplexity": round(perplexity, 2)}
|
104 |
+
except Exception as e:
|
105 |
+
logging.error(f"Error processing text: {str(e)}")
|
106 |
+
raise HTTPException(status_code=500, detail="Model processing error")
|
107 |
+
|
108 |
+
# Function to parse .docx files
|
109 |
+
def parse_docx(file: BytesIO):
|
110 |
+
doc = docx.Document(file)
|
111 |
+
text = ""
|
112 |
+
for para in doc.paragraphs:
|
113 |
+
text += para.text + "\n"
|
114 |
+
return text
|
115 |
+
|
116 |
+
# Function to parse .pdf files
|
117 |
+
def parse_pdf(file: BytesIO):
|
118 |
+
try:
|
119 |
+
doc = fitz.open(stream=file, filetype="pdf")
|
120 |
+
text = ""
|
121 |
+
for page_num in range(doc.page_count):
|
122 |
+
page = doc.load_page(page_num)
|
123 |
+
text += page.get_text()
|
124 |
+
return text
|
125 |
+
except Exception as e:
|
126 |
+
logging.error(f"Error while processing PDF: {str(e)}")
|
127 |
+
raise HTTPException(status_code=500, detail="Error processing PDF file")
|
128 |
|
129 |
+
# Function to parse .txt files
|
130 |
+
def parse_txt(file: BytesIO):
|
131 |
+
return file.read().decode("utf-8")
|
132 |
+
|
133 |
+
# POST route to upload files and analyze content
|
134 |
+
@app.post("/upload/")
|
135 |
+
async def upload_file(file: UploadFile = File(...), token: str = Depends(bearer_scheme)):
|
136 |
+
file_contents = None
|
137 |
+
try:
|
138 |
+
if file.content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
|
139 |
+
file_contents = parse_docx(BytesIO(await file.read()))
|
140 |
+
elif file.content_type == 'application/pdf':
|
141 |
+
file_contents = parse_pdf(BytesIO(await file.read()))
|
142 |
+
elif file.content_type == 'text/plain':
|
143 |
+
file_contents = parse_txt(BytesIO(await file.read()))
|
144 |
+
else:
|
145 |
+
raise HTTPException(status_code=400, detail="Invalid file type. Only .docx, .pdf, and .txt are allowed.")
|
146 |
+
|
147 |
+
logging.debug(f"Extracted Text from {file.filename}:\n{file_contents}")
|
148 |
+
|
149 |
+
# Check if the text length exceeds 10,000 characters
|
150 |
+
if len(file_contents) > 10000:
|
151 |
+
return {"message": "File contains more than 10,000 characters."}
|
152 |
+
|
153 |
+
# Clean the text by removing newline and tab characters
|
154 |
+
cleaned_text = file_contents.replace("\n", "").replace("\t", "")
|
155 |
+
|
156 |
+
# Analyze the cleaned text
|
157 |
+
label, perplexity = await asyncio.to_thread(classify_text, cleaned_text)
|
158 |
+
return {"result": label, "perplexity": round(perplexity, 2)}
|
159 |
+
|
160 |
+
except Exception as e:
|
161 |
+
logging.error(f"Error processing file: {str(e)}")
|
162 |
+
raise HTTPException(status_code=500, detail="Error processing the file")
|
163 |
|
164 |
# Health check route
|
165 |
@app.get("/health")
|
|
|
170 |
@app.get("/")
|
171 |
def index():
|
172 |
return {
|
173 |
+
"message": "FastAPI AI Text Detector is running.",
|
174 |
+
"usage": "Use /docs or /analyze to test the API."
|
|
|
175 |
}
|