Pujan-Dev commited on
Commit
5c8f5dc
·
verified ·
1 Parent(s): 3d128ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -118
app.py CHANGED
@@ -1,73 +1,55 @@
1
- from fastapi import FastAPI, HTTPException, Depends, UploadFile, File
2
  from fastapi.security import HTTPBearer
3
  from pydantic import BaseModel
4
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
5
  import torch
6
- import os
7
  import asyncio
8
  from contextlib import asynccontextmanager
9
- import logging
10
- from io import BytesIO
11
- import docx
12
- import fitz # PyMuPDF
13
-
14
- # Load environment variables
15
- from dotenv import load_dotenv
16
- load_dotenv()
17
-
18
- SECRET_TOKEN = os.getenv("SECRET_TOKEN")
19
- bearer_scheme = HTTPBearer()
20
- # Ai-Text-Detector
21
- MODEL_PATH = "./Ai-Text-Detector/model"
22
- WEIGHTS_PATH = "./Ai-Text-Detector/model_weights.pth"
23
 
24
  # FastAPI app instance
25
  app = FastAPI()
26
 
27
  # Global model and tokenizer variables
28
  model, tokenizer = None, None
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
- # Logging setup
32
- logging.basicConfig(level=logging.DEBUG)
33
 
34
- # Load model and tokenizer function
35
  def load_model():
36
- global model, tokenizer
 
 
37
  try:
38
- tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
39
- config = GPT2Config.from_pretrained(MODEL_PATH)
40
- model_instance = GPT2LMHeadModel(config)
41
- model_instance.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
42
- model_instance.to(device)
43
- model_instance.eval()
44
- model, tokenizer = model_instance, tokenizer
45
- logging.info("Model loaded successfully.")
46
  except Exception as e:
47
- logging.error(f"Error loading model: {str(e)}")
48
  raise RuntimeError(f"Error loading model: {str(e)}")
49
 
 
 
50
  # Load model on app startup
51
  @asynccontextmanager
52
  async def lifespan(app: FastAPI):
53
- load_model() # Load model when FastAPI app starts
 
54
  yield
55
 
56
- # Attach the lifespan to the app instance
57
  app = FastAPI(lifespan=lifespan)
58
 
59
- # Input schema for text analysis
60
  class TextInput(BaseModel):
61
  text: str
62
 
63
- # Function to classify text using the model
64
- def classify_text(text: str):
65
- if not model or not tokenizer:
66
- raise RuntimeError("Model or tokenizer not loaded.")
67
-
68
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
69
- input_ids = inputs["input_ids"].to(device)
70
- attention_mask = inputs["attention_mask"].to(device)
71
 
72
  with torch.no_grad():
73
  outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
@@ -75,91 +57,34 @@ def classify_text(text: str):
75
  perplexity = torch.exp(loss).item()
76
 
77
  if perplexity < 60:
78
- return "AI-generated", perplexity
79
  elif perplexity < 80:
80
- return "Probably AI-generated", perplexity
81
  else:
82
- return "Human-written", perplexity
 
 
83
 
84
  # POST route to analyze text with Bearer token
85
  @app.post("/analyze")
86
  async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
87
- # Verify token
88
- if token.credentials != SECRET_TOKEN:
89
- raise HTTPException(status_code=401, detail="Invalid token")
90
-
91
- text = data.text.strip()
92
 
93
- # Input validation
94
- if not text:
95
  raise HTTPException(status_code=400, detail="Text cannot be empty")
96
 
97
- if len(text.split()) < 2:
 
 
98
  raise HTTPException(status_code=400, detail="Text must contain at least two words")
99
 
100
- try:
101
- # Classify text
102
- label, perplexity = await asyncio.to_thread(classify_text, text)
103
- return {"result": label, "perplexity": round(perplexity, 2)}
104
- except Exception as e:
105
- logging.error(f"Error processing text: {str(e)}")
106
- raise HTTPException(status_code=500, detail="Model processing error")
107
-
108
- # Function to parse .docx files
109
- def parse_docx(file: BytesIO):
110
- doc = docx.Document(file)
111
- text = ""
112
- for para in doc.paragraphs:
113
- text += para.text + "\n"
114
- return text
115
-
116
- # Function to parse .pdf files
117
- def parse_pdf(file: BytesIO):
118
- try:
119
- doc = fitz.open(stream=file, filetype="pdf")
120
- text = ""
121
- for page_num in range(doc.page_count):
122
- page = doc.load_page(page_num)
123
- text += page.get_text()
124
- return text
125
- except Exception as e:
126
- logging.error(f"Error while processing PDF: {str(e)}")
127
- raise HTTPException(status_code=500, detail="Error processing PDF file")
128
-
129
- # Function to parse .txt files
130
- def parse_txt(file: BytesIO):
131
- return file.read().decode("utf-8")
132
-
133
- # POST route to upload files and analyze content
134
- @app.post("/upload/")
135
- async def upload_file(file: UploadFile = File(...), token: str = Depends(bearer_scheme)):
136
- file_contents = None
137
- try:
138
- if file.content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
139
- file_contents = parse_docx(BytesIO(await file.read()))
140
- elif file.content_type == 'application/pdf':
141
- file_contents = parse_pdf(BytesIO(await file.read()))
142
- elif file.content_type == 'text/plain':
143
- file_contents = parse_txt(BytesIO(await file.read()))
144
- else:
145
- raise HTTPException(status_code=400, detail="Invalid file type. Only .docx, .pdf, and .txt are allowed.")
146
-
147
- logging.debug(f"Extracted Text from {file.filename}:\n{file_contents}")
148
-
149
- # Check if the text length exceeds 10,000 characters
150
- if len(file_contents) > 10000:
151
- return {"message": "File contains more than 10,000 characters."}
152
-
153
- # Clean the text by removing newline and tab characters
154
- cleaned_text = file_contents.replace("\n", "").replace("\t", "")
155
-
156
- # Analyze the cleaned text
157
- label, perplexity = await asyncio.to_thread(classify_text, cleaned_text)
158
- return {"result": label, "perplexity": round(perplexity, 2)}
159
 
160
- except Exception as e:
161
- logging.error(f"Error processing file: {str(e)}")
162
- raise HTTPException(status_code=500, detail="Error processing the file")
 
 
 
163
 
164
  # Health check route
165
  @app.get("/health")
@@ -170,7 +95,7 @@ async def health_check():
170
  @app.get("/")
171
  def index():
172
  return {
173
- "message": "FastAPI AI Text Detector is running.",
174
- "usage": "Use /docs or /analyze to test the API."
175
- }
176
-
 
1
+ from fastapi import FastAPI, HTTPException, Depends
2
  from fastapi.security import HTTPBearer
3
  from pydantic import BaseModel
4
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
5
  import torch
 
6
  import asyncio
7
  from contextlib import asynccontextmanager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # FastAPI app instance
10
  app = FastAPI()
11
 
12
  # Global model and tokenizer variables
13
  model, tokenizer = None, None
 
14
 
15
+ # HTTPBearer instance for security
16
+ bearer_scheme = HTTPBearer()
17
 
18
+ # Function to load model and tokenizer
19
  def load_model():
20
+ model_path = "./Ai-Text-Detector/model"
21
+ weights_path = "./Ai-Text-Detector/model_weights.pth"
22
+
23
  try:
24
+ tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
25
+ config = GPT2Config.from_pretrained(model_path)
26
+ model = GPT2LMHeadModel(config)
27
+ model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
28
+ model.eval()
 
 
 
29
  except Exception as e:
 
30
  raise RuntimeError(f"Error loading model: {str(e)}")
31
 
32
+ return model, tokenizer
33
+
34
  # Load model on app startup
35
  @asynccontextmanager
36
  async def lifespan(app: FastAPI):
37
+ global model, tokenizer
38
+ model, tokenizer = load_model()
39
  yield
40
 
41
+ # Attach startup loader
42
  app = FastAPI(lifespan=lifespan)
43
 
44
+ # Input schema
45
  class TextInput(BaseModel):
46
  text: str
47
 
48
+ # Sync text classification
49
+ def classify_text(sentence: str):
50
+ inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
51
+ input_ids = inputs["input_ids"]
52
+ attention_mask = inputs["attention_mask"]
 
 
 
53
 
54
  with torch.no_grad():
55
  outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
 
57
  perplexity = torch.exp(loss).item()
58
 
59
  if perplexity < 60:
60
+ result = "AI-generated"
61
  elif perplexity < 80:
62
+ result = "Probably AI-generated"
63
  else:
64
+ result = "Human-written"
65
+
66
+ return result, perplexity
67
 
68
  # POST route to analyze text with Bearer token
69
  @app.post("/analyze")
70
  async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
71
+ user_input = data.text.strip()
 
 
 
 
72
 
73
+ if not user_input:
 
74
  raise HTTPException(status_code=400, detail="Text cannot be empty")
75
 
76
+ # Check if there are at least two words
77
+ word_count = len(user_input.split())
78
+ if word_count < 2:
79
  raise HTTPException(status_code=400, detail="Text must contain at least two words")
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ result, perplexity = await asyncio.to_thread(classify_text, user_input)
83
+
84
+ return {
85
+ "result": result,
86
+ "perplexity": round(perplexity, 2),
87
+ }
88
 
89
  # Health check route
90
  @app.get("/health")
 
95
  @app.get("/")
96
  def index():
97
  return {
98
+ "message": "FastAPI API is up.",
99
+ "try": "/docs to test the API.",
100
+ "status": "OK"
101
+ }