dejanseo commited on
Commit
68d1553
·
verified ·
1 Parent(s): 4f473c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -150
app.py CHANGED
@@ -3,201 +3,126 @@ import torch
3
  import torch.nn.functional as F
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import re
6
- import logging # Optional: Add logging for better debugging
7
 
8
- # Set up logging (optional but helpful)
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- # Set the page configuration
13
  st.set_page_config(
14
  page_title="AI Article Detection by DEJAN",
15
  page_icon="🧠",
16
  layout="wide"
17
  )
18
 
19
- # Logo as provided
20
- st.logo(
21
- image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
22
- link="https://dejan.ai/",
23
- # size="large" # 'size' is not a valid argument for st.logo as of Streamlit 1.34 - remove or adjust if needed
 
 
 
24
  )
25
 
26
- # Font styling
27
  st.markdown("""
28
  <link href="https://fonts.googleapis.com/css2?family=Roboto&display=swap" rel="stylesheet">
29
  <style>
30
- html, body, [class*="css"] {
31
- font-family: 'Roboto', sans-serif;
32
- }
33
  </style>
34
  """, unsafe_allow_html=True)
35
 
36
- @st.cache_resource # Cache the model and tokenizer to avoid reloading on every interaction
37
  def load_model_and_tokenizer(model_name):
38
- """Loads the model and tokenizer."""
39
- logger.info(f"Loading tokenizer: {model_name}")
40
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
41
 
42
- # Determine device
43
- device_type = "cuda" if torch.cuda.is_available() else "cpu"
44
- # Use bfloat16 if available on CUDA for potential speedup/memory saving, else float32
45
- dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
46
- logger.info(f"Using device: {device_type} with dtype: {dtype}")
47
-
48
- logger.info(f"Loading model: {model_name}")
49
- # Load model onto CPU first, then move to target device
50
- model = AutoModelForSequenceClassification.from_pretrained(
51
- model_name,
52
- torch_dtype=dtype # Use the determined dtype
53
- # Removed device_map="auto"
54
- )
55
- logger.info("Moving model to target device...")
56
- model.to(torch.device(device_type)) # Move the entire model to the target device
57
- model.eval() # Set model to evaluation mode
58
- logger.info("Model loaded successfully.")
59
- return tokenizer, model, torch.device(device_type)
60
-
61
- # Load model and tokenizer using the cached function
62
  MODEL_NAME = "dejanseo/ai-detection-small"
63
  try:
64
  tokenizer, model, device = load_model_and_tokenizer(MODEL_NAME)
65
  except Exception as e:
66
  st.error(f"Error loading model: {e}")
67
- logger.error(f"Failed to load model or tokenizer: {e}", exc_info=True)
68
- st.stop() # Stop execution if model loading fails
69
 
70
-
71
- # Static settings
72
  LABELS = ["AI Content", "Human Content"]
73
- COLORS = ["#ffe5e5", "#e6ffe6"] # light red, light green
74
 
75
- # Regex-based sentence splitter (improved slightly for robustness)
76
  def sent_tokenize(text):
77
- # Split by '.', '!', '?' followed by space(s) or end of string
78
- sentences = re.split(r'(?<=[.!?])\s+', text.strip())
79
- # Filter out empty strings that might result from splitting
80
  return [s for s in sentences if s]
81
 
82
- def split_into_chunks(text, tokenizer, max_length=512):
83
- sentences = sent_tokenize(text)
84
- if not sentences:
85
- return [] # Handle empty input after tokenization
86
-
87
- chunks, current_chunk_sentences, current_len = [], [], 0
88
- max_tokens = max_length - 2 # Account for [CLS] and [SEP] tokens
89
-
90
- for sent in sentences:
91
- # Use tokenizer.encode to get accurate token count (more reliable than tokenize)
92
- token_ids = tokenizer.encode(sent, add_special_tokens=False)
93
- token_len = len(token_ids)
94
-
95
- if token_len > max_tokens:
96
- # Sentence is too long even by itself, handle appropriately
97
- # Option 1: Truncate the sentence (simplest)
98
- logger.warning(f"Sentence truncated as it exceeds max_length: '{sent[:100]}...'")
99
- truncated_sent = tokenizer.decode(token_ids[:max_tokens])
100
- # If there was a previous chunk, add it first
101
- if current_chunk_sentences:
102
- chunks.append(" ".join(current_chunk_sentences))
103
- chunks.append(truncated_sent) # Add the single truncated sentence as its own chunk
104
- current_chunk_sentences, current_len = [], 0 # Reset chunk
105
- continue # Move to the next sentence
106
-
107
- if current_len + token_len <= max_tokens:
108
- current_chunk_sentences.append(sent)
109
- current_len += token_len
110
- else:
111
- # Current chunk is full, finalize it
112
- if current_chunk_sentences:
113
- chunks.append(" ".join(current_chunk_sentences))
114
- # Start a new chunk with the current sentence
115
- current_chunk_sentences = [sent]
116
- current_len = token_len
117
-
118
- # Add the last remaining chunk
119
- if current_chunk_sentences:
120
- chunks.append(" ".join(current_chunk_sentences))
121
-
122
- return chunks
123
-
124
- # --- UI ---
125
  st.title("AI Article Detection")
126
- text = st.text_area("Enter text to classify", height=150, placeholder="Paste your text here...")
127
 
128
  if st.button("Classify", type="primary"):
129
- if not text or not text.strip():
130
  st.warning("Please enter some text.")
131
  else:
132
- with st.spinner("Analyzing... Please wait."):
133
  try:
134
- # Split text using the tokenizer reference
135
- chunks = split_into_chunks(text, tokenizer, max_length=model.config.max_position_embeddings)
136
- logger.info(f"Split text into {len(chunks)} chunks.")
137
-
138
- if not chunks:
139
- st.warning("Could not process the input text (perhaps it's too short or contains only delimiters?).")
140
- st.stop()
141
 
142
- # Tokenize chunks and move tensors to the correct device
143
  inputs = tokenizer(
144
- chunks,
145
  return_tensors="pt",
146
- padding=True, # Pad sequences to the longest in the batch
147
- truncation=True, # Truncate sequences longer than max_length
148
- max_length=model.config.max_position_embeddings # Use model's max length
149
- ).to(device) # Move inputs to the same device as the model
150
 
151
- # Perform inference
152
  with torch.no_grad():
153
  outputs = model(**inputs)
154
  logits = outputs.logits
155
- # Ensure probabilities are calculated on CPU if needed for aggregation later
156
- probs = F.softmax(logits, dim=-1).cpu() # Move probs to CPU
157
- preds = torch.argmax(probs, dim=-1) # Argmax on CPU probabilities
158
-
159
- # Process results
160
- chunk_results = []
161
- for i, chunk in enumerate(chunks):
162
- pred_index = preds[i].item() # Get prediction index for this chunk
163
- chunk_results.append({
164
- "text": chunk,
165
- "label": LABELS[pred_index],
166
- "color": COLORS[pred_index],
167
- "conf": probs[i, pred_index].item() * 100, # Get confidence for the predicted class
168
- })
169
-
170
- # Calculate overall prediction based on average probability across chunks
171
- if probs.numel() > 0: # Check if probs tensor is not empty
172
- avg_probs = torch.mean(probs, dim=0) # Average probabilities across the batch dimension
173
- final_class_index = torch.argmax(avg_probs).item()
174
- final_label = LABELS[final_class_index]
175
- final_conf = avg_probs[final_class_index].item() * 100
176
-
177
- # Display final prediction
178
- st.subheader("📊 Final Prediction")
179
- st.markdown(
180
- f"<div style='background-color:{COLORS[final_class_index]}; padding:1rem; border-radius:0.5rem; border: 1px solid #ccc;'>"
181
- f"Based on the analysis, the text is most likely: <b>{final_label}</b> (Confidence: {final_conf:.1f}%)</div>",
182
- unsafe_allow_html=True
183
- )
184
- else:
185
- st.warning("Could not generate predictions for the provided text.")
186
-
187
-
188
- # Display per-chunk predictions in an expander
189
- with st.expander("See per-chunk predictions and confidence"):
190
- if chunk_results:
191
- for result in chunk_results:
192
- st.markdown(
193
- f"<div title='Confidence: {result['conf']:.1f}%' "
194
- f"style='background-color:{result['color']}; padding:0.75rem; margin-bottom:0.5rem; border-radius:0.5rem; border: 1px solid #ddd;'>"
195
- f"<i>({result['label']} - {result['conf']:.1f}%)</i><br>{result['text']}</div>",
196
- unsafe_allow_html=True
197
- )
198
  else:
199
- st.write("No chunk predictions were generated.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  except Exception as e:
202
- st.error(f"An error occurred during analysis: {e}")
203
- logger.error(f"Analysis failed: {e}", exc_info=True)
 
3
  import torch.nn.functional as F
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import re
6
+ import logging
7
 
8
+ # Set up logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ # Streamlit page config
13
  st.set_page_config(
14
  page_title="AI Article Detection by DEJAN",
15
  page_icon="🧠",
16
  layout="wide"
17
  )
18
 
19
+ # Logo
20
+ st.markdown(
21
+ """
22
+ <a href="https://dejan.ai/" target="_blank">
23
+ <img src="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png" alt="DEJAN logo">
24
+ </a>
25
+ """,
26
+ unsafe_allow_html=True
27
  )
28
 
29
+ # Custom font
30
  st.markdown("""
31
  <link href="https://fonts.googleapis.com/css2?family=Roboto&display=swap" rel="stylesheet">
32
  <style>
33
+ html, body, [class*="css"] {
34
+ font-family: 'Roboto', sans-serif;
35
+ }
36
  </style>
37
  """, unsafe_allow_html=True)
38
 
39
+ @st.cache_resource
40
  def load_model_and_tokenizer(model_name):
 
 
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ dtype = torch.bfloat16 if (device.type == "cuda" and torch.cuda.is_bf16_supported()) else torch.float32
44
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, torch_dtype=dtype)
45
+ model.to(device)
46
+ model.eval()
47
+ return tokenizer, model, device
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  MODEL_NAME = "dejanseo/ai-detection-small"
50
  try:
51
  tokenizer, model, device = load_model_and_tokenizer(MODEL_NAME)
52
  except Exception as e:
53
  st.error(f"Error loading model: {e}")
54
+ logger.error("Failed to load model or tokenizer", exc_info=True)
55
+ st.stop()
56
 
57
+ # Labels
 
58
  LABELS = ["AI Content", "Human Content"]
 
59
 
60
+ # Sentence splitter
61
  def sent_tokenize(text):
62
+ sentences = re.split(r'(?<=[\.!?])\s+', text.strip())
 
 
63
  return [s for s in sentences if s]
64
 
65
+ # UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  st.title("AI Article Detection")
67
+ text = st.text_area("Enter text to classify", height=200)
68
 
69
  if st.button("Classify", type="primary"):
70
+ if not text.strip():
71
  st.warning("Please enter some text.")
72
  else:
73
+ with st.spinner("Analyzing..."):
74
  try:
75
+ sentences = sent_tokenize(text)
76
+ if not sentences:
77
+ st.warning("No sentences detected.")
78
+ st.stop()
 
 
 
79
 
80
+ # Tokenize each sentence
81
  inputs = tokenizer(
82
+ sentences,
83
  return_tensors="pt",
84
+ padding=True,
85
+ truncation=True,
86
+ max_length=model.config.max_position_embeddings
87
+ ).to(device)
88
 
89
+ # Inference
90
  with torch.no_grad():
91
  outputs = model(**inputs)
92
  logits = outputs.logits
93
+ probs = F.softmax(logits, dim=-1).cpu() # shape [n_sentences, 2]
94
+ preds = torch.argmax(probs, dim=-1).cpu()
95
+
96
+ # Build inline styled text
97
+ styled_chunks = []
98
+ for i, sent in enumerate(sentences):
99
+ pred = preds[i].item()
100
+ # select color channel
101
+ if pred == 0:
102
+ r, g = 255, 0 # red for AI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  else:
104
+ r, g = 0, 255 # green for Human
105
+ confidence = probs[i, pred].item() # between 0 and 1
106
+ alpha = confidence # drive opacity directly
107
+ # wrap sentence in span
108
+ span = (
109
+ f"<span "
110
+ f"style='background-color: rgba({r},{g},0,{alpha:.2f}); "
111
+ f"padding:2px; margin:0 2px; border-radius:3px;'>"
112
+ f"{sent}"
113
+ f"</span>"
114
+ )
115
+ styled_chunks.append(span)
116
+
117
+ # join all sentences inline
118
+ full_text_html = "".join(styled_chunks)
119
+ st.markdown(full_text_html, unsafe_allow_html=True)
120
+
121
+ # Overall AI likelihood
122
+ avg_probs = torch.mean(probs, dim=0)
123
+ ai_likelihood = avg_probs[0].item() * 100 # class 0 is AI
124
+ st.subheader(f"🤖 AI Likelihood: {ai_likelihood:.1f}%")
125
 
126
  except Exception as e:
127
+ st.error(f"Analysis error: {e}")
128
+ logger.error("Classification failed", exc_info=True)