Kuberwastaken commited on
Commit
474b075
·
1 Parent(s): 2194060

Improved model and functioning loading bar

Browse files
Files changed (2) hide show
  1. gradio_app.py +14 -29
  2. model/analyzer.py +100 -35
gradio_app.py CHANGED
@@ -131,38 +131,24 @@ label {
131
  }
132
  """
133
 
134
- def analyze_with_loading(text, progress=gr.Progress()):
135
  """
136
- Synchronous wrapper for the async analyze_content function
137
  """
138
- # Initialize progress
139
- progress(0, desc="Starting analysis...")
140
-
141
- # Initial setup phase
142
- for i in range(30):
143
- time.sleep(0.02) # Reduced sleep time
144
- progress((i + 1) / 100)
145
-
146
- # Perform analysis
147
- progress(0.3, desc="Processing text...")
148
  try:
149
- # Use asyncio.run to handle the async function call
150
- result = asyncio.run(analyze_content(text))
 
 
 
 
 
 
 
 
 
151
  except Exception as e:
152
  return f"Error during analysis: {str(e)}"
153
-
154
- # Final processing
155
- for i in range(70, 100):
156
- time.sleep(0.02) # Reduced sleep time
157
- progress((i + 1) / 100)
158
-
159
- # Format the results
160
- triggers = result["detected_triggers"]
161
- if triggers == ["None"]:
162
- return "✓ No triggers detected in the content."
163
- else:
164
- trigger_list = "\n".join([f"• {trigger}" for trigger in triggers])
165
- return f"⚠ Triggers Detected:\n{trigger_list}"
166
 
167
  # Create the Gradio interface
168
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface:
@@ -220,9 +206,8 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface:
220
  """)
221
 
222
  if __name__ == "__main__":
223
- # Launch without the 'ssr' argument
224
  iface.launch(
225
  share=False,
226
  debug=True,
227
  show_error=True
228
- )
 
131
  }
132
  """
133
 
134
+ async def analyze_with_loading(text, progress=gr.Progress()):
135
  """
136
+ Asynchronous wrapper for analyze_content that properly tracks progress
137
  """
 
 
 
 
 
 
 
 
 
 
138
  try:
139
+ # Call analyze_content directly with the progress object
140
+ result = await analyze_content(text, progress)
141
+
142
+ # Format the results
143
+ triggers = result["detected_triggers"]
144
+ if triggers == ["None"]:
145
+ return "✓ No concerns detected in the content."
146
+ else:
147
+ trigger_list = "\n".join([f"• {trigger}" for trigger in triggers])
148
+ return f"⚠ Triggers Detected:\n{trigger_list}"
149
+
150
  except Exception as e:
151
  return f"Error during analysis: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  # Create the Gradio interface
154
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as iface:
 
206
  """)
207
 
208
  if __name__ == "__main__":
 
209
  iface.launch(
210
  share=False,
211
  debug=True,
212
  show_error=True
213
+ )
model/analyzer.py CHANGED
@@ -73,10 +73,10 @@ class ContentAnalyzer:
73
  "mapped_name": "Sexual Abuse",
74
  "description": (
75
  "Any form of non-consensual sexual act, behavior, or interaction, involving coercion, manipulation, or physical force. "
76
- "This includes incidents of sexual assault, molestation, exploitation, harassment, and any acts where an individual is subjected to sexual acts against their will or without their consent. "
77
- "It also covers discussions or depictions of the aftermath of such abuse, such as trauma, emotional distress, legal proceedings, or therapy. "
78
- "References to inappropriate sexual advances, groping, or any other form of sexual misconduct are also included, as well as the psychological and emotional impact on survivors. "
79
- "Scenes where individuals are placed in sexually compromising situations, even if not directly acted upon, may also fall under this category."
80
  )
81
  },
82
  "Self-Harm": {
@@ -122,7 +122,7 @@ class ContentAnalyzer:
122
  )
123
 
124
  if progress:
125
- progress(0.3, "Loading model...")
126
 
127
  self.model = AutoModelForCausalLM.from_pretrained(
128
  "meta-llama/Llama-3.2-1B",
@@ -132,16 +132,55 @@ class ContentAnalyzer:
132
  )
133
 
134
  if progress:
135
- progress(0.5, "Model loaded successfully")
136
 
137
  logger.info(f"Model loaded successfully on {self.device}")
138
  except Exception as e:
139
  logger.error(f"Error loading model: {str(e)}")
140
  raise
141
 
142
- def _chunk_text(self, text: str, chunk_size: int = 128, overlap: int = 5) -> List[str]:
143
  """Split text into overlapping chunks for processing."""
144
- return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  async def analyze_chunk(
147
  self,
@@ -152,16 +191,24 @@ class ContentAnalyzer:
152
  ) -> Dict[str, float]:
153
  """Analyze a single chunk of text for triggers."""
154
  chunk_triggers = {}
 
155
 
156
  for category, info in self.trigger_categories.items():
157
  mapped_name = info["mapped_name"]
158
  description = info["description"]
159
 
160
  prompt = f"""
161
- Check this text for any indication of {mapped_name} ({description}).
162
- Be sensitive to subtle references or implications, make sure the text is not metaphorical.
163
- Respond concisely with: YES, NO, or MAYBE.
164
- Text: {chunk}
 
 
 
 
 
 
 
165
  Answer:
166
  """
167
 
@@ -172,24 +219,25 @@ class ContentAnalyzer:
172
  with torch.no_grad():
173
  outputs = self.model.generate(
174
  **inputs,
175
- max_new_tokens=5,
 
176
  do_sample=True,
177
- temperature=0.5,
178
- top_p=0.9,
 
 
179
  pad_token_id=self.tokenizer.eos_token_id
180
  )
181
 
182
- response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip().upper()
183
- first_word = response_text.split("\n")[-1].split()[0] if response_text else "NO"
184
-
185
- if first_word == "YES":
186
- chunk_triggers[mapped_name] = chunk_triggers.get(mapped_name, 0) + 1
187
- elif first_word == "MAYBE":
188
- chunk_triggers[mapped_name] = chunk_triggers.get(mapped_name, 0) + 0.5
189
 
190
  if progress:
191
- current_progress += progress_step
192
- progress(min(current_progress, 0.9), f"Analyzing {mapped_name}...")
193
 
194
  except Exception as e:
195
  logger.error(f"Error analyzing chunk for {mapped_name}: {str(e)}")
@@ -202,27 +250,41 @@ class ContentAnalyzer:
202
  await self.load_model(progress)
203
 
204
  chunks = self._chunk_text(script)
205
- identified_triggers = {}
206
- progress_step = 0.4 / (len(chunks) * len(self.trigger_categories))
207
- current_progress = 0.5 # Starting after model loading
 
 
 
 
 
 
208
 
209
- for chunk_idx, chunk in enumerate(chunks, 1):
210
  chunk_triggers = await self.analyze_chunk(
211
  chunk,
212
  progress,
213
  current_progress,
214
- progress_step
215
  )
216
 
217
- for trigger, count in chunk_triggers.items():
218
- identified_triggers[trigger] = identified_triggers.get(trigger, 0) + count
 
 
 
 
 
 
219
 
220
  if progress:
221
- progress(0.95, "Finalizing results...")
222
 
 
 
223
  final_triggers = [
224
- trigger for trigger, count in identified_triggers.items()
225
- if count > 0.5
226
  ]
227
 
228
  return final_triggers if final_triggers else ["None"]
@@ -235,6 +297,9 @@ async def analyze_content(
235
  analyzer = ContentAnalyzer()
236
 
237
  try:
 
 
 
238
  triggers = await analyzer.analyze_script(script, progress)
239
 
240
  if progress:
@@ -260,7 +325,7 @@ async def analyze_content(
260
  }
261
 
262
  if __name__ == "__main__":
263
- # This section is mainly for testing the analyzer directly
264
  iface = gr.Interface(
265
  fn=analyze_content,
266
  inputs=gr.Textbox(lines=8, label="Input Text"),
 
73
  "mapped_name": "Sexual Abuse",
74
  "description": (
75
  "Any form of non-consensual sexual act, behavior, or interaction, involving coercion, manipulation, or physical force. "
76
+ "This includes incidents of sexual assault, molestation, exploitation, harassment, and any acts where an individual is subjected to sexual acts against their will or without their consent. "
77
+ "It also covers discussions or depictions of the aftermath of such abuse, such as trauma, emotional distress, legal proceedings, or therapy. "
78
+ "References to inappropriate sexual advances, groping, or any other form of sexual misconduct are also included, as well as the psychological and emotional impact on survivors. "
79
+ "Scenes where individuals are placed in sexually compromising situations, even if not directly acted upon, may also fall under this category."
80
  )
81
  },
82
  "Self-Harm": {
 
122
  )
123
 
124
  if progress:
125
+ progress(0.15, "Loading model...")
126
 
127
  self.model = AutoModelForCausalLM.from_pretrained(
128
  "meta-llama/Llama-3.2-1B",
 
132
  )
133
 
134
  if progress:
135
+ progress(0.2, "Model loaded successfully")
136
 
137
  logger.info(f"Model loaded successfully on {self.device}")
138
  except Exception as e:
139
  logger.error(f"Error loading model: {str(e)}")
140
  raise
141
 
142
+ def _chunk_text(self, text: str, chunk_size: int = 256, overlap: int = 32) -> List[str]:
143
  """Split text into overlapping chunks for processing."""
144
+ chunks = []
145
+ start = 0
146
+ text_len = len(text)
147
+
148
+ while start < text_len:
149
+ end = min(start + chunk_size, text_len)
150
+ if end < text_len:
151
+ last_period = max(
152
+ text.rfind('. ', start, end),
153
+ text.rfind('\n', start, end)
154
+ )
155
+ if last_period > start:
156
+ end = last_period + 1
157
+
158
+ chunks.append(text[start:end])
159
+ start = end - overlap
160
+
161
+ return chunks
162
+
163
+ def _process_model_response(self, response_text: str) -> float:
164
+ """Process model response and return a confidence score."""
165
+ response = response_text.strip().upper()
166
+
167
+ if "YES" in response:
168
+ evidence_words = ["CLEAR", "DEFINITELY", "EXPLICIT", "STRONG"]
169
+ return 1.0 if any(word in response for word in evidence_words) else 0.8
170
+ elif "MAYBE" in response or "POSSIBLE" in response:
171
+ return 0.5
172
+ elif "NO" in response:
173
+ return 0.0
174
+
175
+ positive_indicators = ["PRESENT", "FOUND", "CONTAINS", "SHOWS", "INDICATES"]
176
+ negative_indicators = ["ABSENT", "NONE", "NOTHING", "LACKS"]
177
+
178
+ if any(indicator in response for indicator in positive_indicators):
179
+ return 0.7
180
+ elif any(indicator in response for indicator in negative_indicators):
181
+ return 0.0
182
+
183
+ return 0.0
184
 
185
  async def analyze_chunk(
186
  self,
 
191
  ) -> Dict[str, float]:
192
  """Analyze a single chunk of text for triggers."""
193
  chunk_triggers = {}
194
+ progress_increment = progress_step / len(self.trigger_categories)
195
 
196
  for category, info in self.trigger_categories.items():
197
  mapped_name = info["mapped_name"]
198
  description = info["description"]
199
 
200
  prompt = f"""
201
+ Analyze this text carefully for any indication of {mapped_name}.
202
+ Context: {description}
203
+
204
+ Guidelines:
205
+ - Consider both explicit and implicit references
206
+ - Ignore metaphorical or figurative language
207
+ - Look for concrete evidence in the text
208
+
209
+ Text to analyze: {chunk}
210
+
211
+ Is there evidence of {mapped_name}? Respond with YES, NO, or MAYBE and briefly explain why.
212
  Answer:
213
  """
214
 
 
219
  with torch.no_grad():
220
  outputs = self.model.generate(
221
  **inputs,
222
+ max_new_tokens=32,
223
+ num_return_sequences=1,
224
  do_sample=True,
225
+ temperature=0.7,
226
+ top_p=0.92,
227
+ top_k=50,
228
+ repetition_penalty=1.1,
229
  pad_token_id=self.tokenizer.eos_token_id
230
  )
231
 
232
+ response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
233
+ confidence = self._process_model_response(response_text)
234
+
235
+ if confidence > 0.5:
236
+ chunk_triggers[mapped_name] = chunk_triggers.get(mapped_name, 0) + confidence
 
 
237
 
238
  if progress:
239
+ current_progress += progress_increment
240
+ progress(min(current_progress, 0.9), f"Analyzing for {mapped_name}...")
241
 
242
  except Exception as e:
243
  logger.error(f"Error analyzing chunk for {mapped_name}: {str(e)}")
 
250
  await self.load_model(progress)
251
 
252
  chunks = self._chunk_text(script)
253
+ trigger_scores = {}
254
+
255
+ # Calculate progress allocation
256
+ analysis_progress = 0.7 # 70% of progress for analysis
257
+ progress_per_chunk = analysis_progress / len(chunks)
258
+ current_progress = 0.2 # Starting after model loading
259
+
260
+ if progress:
261
+ progress(current_progress, "Beginning content analysis...")
262
 
263
+ for i, chunk in enumerate(chunks):
264
  chunk_triggers = await self.analyze_chunk(
265
  chunk,
266
  progress,
267
  current_progress,
268
+ progress_per_chunk
269
  )
270
 
271
+ for trigger, score in chunk_triggers.items():
272
+ trigger_scores[trigger] = trigger_scores.get(trigger, 0) + score
273
+
274
+ current_progress += progress_per_chunk
275
+ if progress:
276
+ chunk_number = i + 1
277
+ progress(min(0.9, current_progress),
278
+ f"Processing chunk {chunk_number}/{len(chunks)}...")
279
 
280
  if progress:
281
+ progress(0.95, "Finalizing analysis...")
282
 
283
+ # Normalize scores by number of chunks and apply threshold
284
+ chunk_count = len(chunks)
285
  final_triggers = [
286
+ trigger for trigger, score in trigger_scores.items()
287
+ if score / chunk_count > 0.3
288
  ]
289
 
290
  return final_triggers if final_triggers else ["None"]
 
297
  analyzer = ContentAnalyzer()
298
 
299
  try:
300
+ if progress:
301
+ progress(0.0, "Initializing analyzer...")
302
+
303
  triggers = await analyzer.analyze_script(script, progress)
304
 
305
  if progress:
 
325
  }
326
 
327
  if __name__ == "__main__":
328
+ # Gradio interface
329
  iface = gr.Interface(
330
  fn=analyze_content,
331
  inputs=gr.Textbox(lines=8, label="Input Text"),