Pujan-Dev commited on
Commit
bc07cfe
·
verified ·
1 Parent(s): 751cbd9

feat: add file parsing support and enforce 10,000-character limit

Browse files

- Added support for analyzing .docx, .pdf, and .txt files via the /upload endpoint
- Implemented file parsing logic using python-docx, PyMuPDF (fitz), and basic text decoding
- Enforced a 10,000-character limit to prevent model overload on large files
- Cleaned and sanitized file contents before analysis to improve model accuracy
- Added logging to debug file parsing and analysis results

Files changed (1) hide show
  1. app.py +116 -42
app.py CHANGED
@@ -1,55 +1,73 @@
1
- from fastapi import FastAPI, HTTPException, Depends
2
  from fastapi.security import HTTPBearer
3
  from pydantic import BaseModel
4
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
5
  import torch
 
6
  import asyncio
7
  from contextlib import asynccontextmanager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # FastAPI app instance
10
  app = FastAPI()
11
 
12
  # Global model and tokenizer variables
13
  model, tokenizer = None, None
 
14
 
15
- # HTTPBearer instance for security
16
- bearer_scheme = HTTPBearer()
17
 
18
- # Function to load model and tokenizer
19
  def load_model():
20
- model_path = "./Ai-Text-Detector/model"
21
- weights_path = "./Ai-Text-Detector/model_weights.pth"
22
-
23
  try:
24
- tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
25
- config = GPT2Config.from_pretrained(model_path)
26
- model = GPT2LMHeadModel(config)
27
- model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
28
- model.eval()
 
 
 
29
  except Exception as e:
 
30
  raise RuntimeError(f"Error loading model: {str(e)}")
31
 
32
- return model, tokenizer
33
-
34
  # Load model on app startup
35
  @asynccontextmanager
36
  async def lifespan(app: FastAPI):
37
- global model, tokenizer
38
- model, tokenizer = load_model()
39
  yield
40
 
41
- # Attach startup loader
42
  app = FastAPI(lifespan=lifespan)
43
 
44
- # Input schema
45
  class TextInput(BaseModel):
46
  text: str
47
 
48
- # Sync text classification
49
- def classify_text(sentence: str):
50
- inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
51
- input_ids = inputs["input_ids"]
52
- attention_mask = inputs["attention_mask"]
 
 
 
53
 
54
  with torch.no_grad():
55
  outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
@@ -57,34 +75,91 @@ def classify_text(sentence: str):
57
  perplexity = torch.exp(loss).item()
58
 
59
  if perplexity < 60:
60
- result = "AI-generated"
61
  elif perplexity < 80:
62
- result = "Probably AI-generated"
63
  else:
64
- result = "Human-written"
65
-
66
- return result, perplexity
67
 
68
  # POST route to analyze text with Bearer token
69
  @app.post("/analyze")
70
  async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
71
- user_input = data.text.strip()
 
 
 
 
72
 
73
- if not user_input:
 
74
  raise HTTPException(status_code=400, detail="Text cannot be empty")
75
 
76
- # Check if there are at least two words
77
- word_count = len(user_input.split())
78
- if word_count < 2:
79
  raise HTTPException(status_code=400, detail="Text must contain at least two words")
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- result, perplexity = await asyncio.to_thread(classify_text, user_input)
83
-
84
- return {
85
- "result": result,
86
- "perplexity": round(perplexity, 2),
87
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # Health check route
90
  @app.get("/health")
@@ -95,7 +170,6 @@ async def health_check():
95
  @app.get("/")
96
  def index():
97
  return {
98
- "message": "FastAPI API is up.",
99
- "try": "/docs to test the API.",
100
- "status": "OK"
101
  }
 
1
+ from fastapi import FastAPI, HTTPException, Depends, UploadFile, File
2
  from fastapi.security import HTTPBearer
3
  from pydantic import BaseModel
4
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
5
  import torch
6
+ import os
7
  import asyncio
8
  from contextlib import asynccontextmanager
9
+ import logging
10
+ from io import BytesIO
11
+ import docx
12
+ import fitz # PyMuPDF
13
+
14
+ # Load environment variables
15
+ from dotenv import load_dotenv
16
+ load_dotenv()
17
+
18
+ SECRET_TOKEN = os.getenv("SECRET_TOKEN")
19
+ bearer_scheme = HTTPBearer()
20
+ # Ai-Text-Detector
21
+ MODEL_PATH = "./Ai-Text-Detector/model"
22
+ WEIGHTS_PATH = "./Ai-Text-Detector/model_weights.pth"
23
 
24
  # FastAPI app instance
25
  app = FastAPI()
26
 
27
  # Global model and tokenizer variables
28
  model, tokenizer = None, None
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
+ # Logging setup
32
+ logging.basicConfig(level=logging.DEBUG)
33
 
34
+ # Load model and tokenizer function
35
  def load_model():
36
+ global model, tokenizer
 
 
37
  try:
38
+ tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
39
+ config = GPT2Config.from_pretrained(MODEL_PATH)
40
+ model_instance = GPT2LMHeadModel(config)
41
+ model_instance.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
42
+ model_instance.to(device)
43
+ model_instance.eval()
44
+ model, tokenizer = model_instance, tokenizer
45
+ logging.info("Model loaded successfully.")
46
  except Exception as e:
47
+ logging.error(f"Error loading model: {str(e)}")
48
  raise RuntimeError(f"Error loading model: {str(e)}")
49
 
 
 
50
  # Load model on app startup
51
  @asynccontextmanager
52
  async def lifespan(app: FastAPI):
53
+ load_model() # Load model when FastAPI app starts
 
54
  yield
55
 
56
+ # Attach the lifespan to the app instance
57
  app = FastAPI(lifespan=lifespan)
58
 
59
+ # Input schema for text analysis
60
  class TextInput(BaseModel):
61
  text: str
62
 
63
+ # Function to classify text using the model
64
+ def classify_text(text: str):
65
+ if not model or not tokenizer:
66
+ raise RuntimeError("Model or tokenizer not loaded.")
67
+
68
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
69
+ input_ids = inputs["input_ids"].to(device)
70
+ attention_mask = inputs["attention_mask"].to(device)
71
 
72
  with torch.no_grad():
73
  outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
 
75
  perplexity = torch.exp(loss).item()
76
 
77
  if perplexity < 60:
78
+ return "AI-generated", perplexity
79
  elif perplexity < 80:
80
+ return "Probably AI-generated", perplexity
81
  else:
82
+ return "Human-written", perplexity
 
 
83
 
84
  # POST route to analyze text with Bearer token
85
  @app.post("/analyze")
86
  async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
87
+ # Verify token
88
+ if token.credentials != SECRET_TOKEN:
89
+ raise HTTPException(status_code=401, detail="Invalid token")
90
+
91
+ text = data.text.strip()
92
 
93
+ # Input validation
94
+ if not text:
95
  raise HTTPException(status_code=400, detail="Text cannot be empty")
96
 
97
+ if len(text.split()) < 2:
 
 
98
  raise HTTPException(status_code=400, detail="Text must contain at least two words")
99
 
100
+ try:
101
+ # Classify text
102
+ label, perplexity = await asyncio.to_thread(classify_text, text)
103
+ return {"result": label, "perplexity": round(perplexity, 2)}
104
+ except Exception as e:
105
+ logging.error(f"Error processing text: {str(e)}")
106
+ raise HTTPException(status_code=500, detail="Model processing error")
107
+
108
+ # Function to parse .docx files
109
+ def parse_docx(file: BytesIO):
110
+ doc = docx.Document(file)
111
+ text = ""
112
+ for para in doc.paragraphs:
113
+ text += para.text + "\n"
114
+ return text
115
+
116
+ # Function to parse .pdf files
117
+ def parse_pdf(file: BytesIO):
118
+ try:
119
+ doc = fitz.open(stream=file, filetype="pdf")
120
+ text = ""
121
+ for page_num in range(doc.page_count):
122
+ page = doc.load_page(page_num)
123
+ text += page.get_text()
124
+ return text
125
+ except Exception as e:
126
+ logging.error(f"Error while processing PDF: {str(e)}")
127
+ raise HTTPException(status_code=500, detail="Error processing PDF file")
128
 
129
+ # Function to parse .txt files
130
+ def parse_txt(file: BytesIO):
131
+ return file.read().decode("utf-8")
132
+
133
+ # POST route to upload files and analyze content
134
+ @app.post("/upload/")
135
+ async def upload_file(file: UploadFile = File(...), token: str = Depends(bearer_scheme)):
136
+ file_contents = None
137
+ try:
138
+ if file.content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
139
+ file_contents = parse_docx(BytesIO(await file.read()))
140
+ elif file.content_type == 'application/pdf':
141
+ file_contents = parse_pdf(BytesIO(await file.read()))
142
+ elif file.content_type == 'text/plain':
143
+ file_contents = parse_txt(BytesIO(await file.read()))
144
+ else:
145
+ raise HTTPException(status_code=400, detail="Invalid file type. Only .docx, .pdf, and .txt are allowed.")
146
+
147
+ logging.debug(f"Extracted Text from {file.filename}:\n{file_contents}")
148
+
149
+ # Check if the text length exceeds 10,000 characters
150
+ if len(file_contents) > 10000:
151
+ return {"message": "File contains more than 10,000 characters."}
152
+
153
+ # Clean the text by removing newline and tab characters
154
+ cleaned_text = file_contents.replace("\n", "").replace("\t", "")
155
+
156
+ # Analyze the cleaned text
157
+ label, perplexity = await asyncio.to_thread(classify_text, cleaned_text)
158
+ return {"result": label, "perplexity": round(perplexity, 2)}
159
+
160
+ except Exception as e:
161
+ logging.error(f"Error processing file: {str(e)}")
162
+ raise HTTPException(status_code=500, detail="Error processing the file")
163
 
164
  # Health check route
165
  @app.get("/health")
 
170
  @app.get("/")
171
  def index():
172
  return {
173
+ "message": "FastAPI AI Text Detector is running.",
174
+ "usage": "Use /docs or /analyze to test the API."
 
175
  }