Kuberwastaken commited on
Commit
1331506
·
1 Parent(s): 8ea6312

Reverting to classic model

Browse files
Files changed (1) hide show
  1. model/analyzer.py +28 -80
model/analyzer.py CHANGED
@@ -73,10 +73,10 @@ class ContentAnalyzer:
73
  "mapped_name": "Sexual Abuse",
74
  "description": (
75
  "Any form of non-consensual sexual act, behavior, or interaction, involving coercion, manipulation, or physical force. "
76
- "This includes incidents of sexual assault, molestation, exploitation, harassment, and any acts where an individual is subjected to sexual acts against their will or without their consent. "
77
- "It also covers discussions or depictions of the aftermath of such abuse, such as trauma, emotional distress, legal proceedings, or therapy. "
78
- "References to inappropriate sexual advances, groping, or any other form of sexual misconduct are also included, as well as the psychological and emotional impact on survivors. "
79
- "Scenes where individuals are placed in sexually compromising situations, even if not directly acted upon, may also fall under this category."
80
  )
81
  },
82
  "Self-Harm": {
@@ -139,51 +139,9 @@ class ContentAnalyzer:
139
  logger.error(f"Error loading model: {str(e)}")
140
  raise
141
 
142
- def _chunk_text(self, text: str, chunk_size: int = 256, overlap: int = 32) -> List[str]:
143
  """Split text into overlapping chunks for processing."""
144
- chunks = []
145
- start = 0
146
- text_len = len(text)
147
-
148
- while start < text_len:
149
- end = min(start + chunk_size, text_len)
150
- # Find the last period or newline in the chunk to avoid cutting sentences
151
- if end < text_len:
152
- last_period = max(
153
- text.rfind('. ', start, end),
154
- text.rfind('\n', start, end)
155
- )
156
- if last_period > start:
157
- end = last_period + 1
158
-
159
- chunks.append(text[start:end])
160
- start = end - overlap
161
-
162
- return chunks
163
-
164
- def _process_model_response(self, response_text: str) -> float:
165
- """Process model response and return a confidence score."""
166
- response = response_text.strip().upper()
167
-
168
- # Check for explicit YES/NO/MAYBE
169
- if "YES" in response:
170
- evidence_words = ["CLEAR", "DEFINITELY", "EXPLICIT", "STRONG"]
171
- return 1.0 if any(word in response for word in evidence_words) else 0.8
172
- elif "MAYBE" in response or "POSSIBLE" in response:
173
- return 0.5
174
- elif "NO" in response:
175
- return 0.0
176
-
177
- # Fallback analysis for unclear responses
178
- positive_indicators = ["PRESENT", "FOUND", "CONTAINS", "SHOWS", "INDICATES"]
179
- negative_indicators = ["ABSENT", "NONE", "NOTHING", "LACKS"]
180
-
181
- if any(indicator in response for indicator in positive_indicators):
182
- return 0.7
183
- elif any(indicator in response for indicator in negative_indicators):
184
- return 0.0
185
-
186
- return 0.0
187
 
188
  async def analyze_chunk(
189
  self,
@@ -200,17 +158,10 @@ class ContentAnalyzer:
200
  description = info["description"]
201
 
202
  prompt = f"""
203
- Analyze this text carefully for any indication of {mapped_name}.
204
- Context: {description}
205
-
206
- Guidelines:
207
- - Consider both explicit and implicit references
208
- - Ignore metaphorical or figurative language
209
- - Look for concrete evidence in the text
210
-
211
- Text to analyze: {chunk}
212
-
213
- Is there evidence of {mapped_name}? Respond with YES, NO, or MAYBE and briefly explain why.
214
  Answer:
215
  """
216
 
@@ -221,21 +172,20 @@ class ContentAnalyzer:
221
  with torch.no_grad():
222
  outputs = self.model.generate(
223
  **inputs,
224
- max_new_tokens=32,
225
- num_return_sequences=1,
226
  do_sample=True,
227
- temperature=0.7,
228
- top_p=0.92,
229
- top_k=50,
230
- repetition_penalty=1.1,
231
  pad_token_id=self.tokenizer.eos_token_id
232
  )
233
 
234
- response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
235
- confidence = self._process_model_response(response_text)
236
-
237
- if confidence > 0.5:
238
- chunk_triggers[mapped_name] = chunk_triggers.get(mapped_name, 0) + confidence
 
 
239
 
240
  if progress:
241
  current_progress += progress_step
@@ -252,11 +202,11 @@ class ContentAnalyzer:
252
  await self.load_model(progress)
253
 
254
  chunks = self._chunk_text(script)
255
- trigger_scores = {}
256
  progress_step = 0.4 / (len(chunks) * len(self.trigger_categories))
257
- current_progress = 0.5
258
 
259
- for chunk in chunks:
260
  chunk_triggers = await self.analyze_chunk(
261
  chunk,
262
  progress,
@@ -264,17 +214,15 @@ class ContentAnalyzer:
264
  progress_step
265
  )
266
 
267
- for trigger, score in chunk_triggers.items():
268
- trigger_scores[trigger] = trigger_scores.get(trigger, 0) + score
269
 
270
  if progress:
271
  progress(0.95, "Finalizing results...")
272
 
273
- # Normalize scores by number of chunks and apply threshold
274
- chunk_count = len(chunks)
275
  final_triggers = [
276
- trigger for trigger, score in trigger_scores.items()
277
- if score / chunk_count > 0.3 # Adjusted threshold for better balance
278
  ]
279
 
280
  return final_triggers if final_triggers else ["None"]
@@ -312,7 +260,7 @@ async def analyze_content(
312
  }
313
 
314
  if __name__ == "__main__":
315
- # Gradio interface
316
  iface = gr.Interface(
317
  fn=analyze_content,
318
  inputs=gr.Textbox(lines=8, label="Input Text"),
 
73
  "mapped_name": "Sexual Abuse",
74
  "description": (
75
  "Any form of non-consensual sexual act, behavior, or interaction, involving coercion, manipulation, or physical force. "
76
+ "This includes incidents of sexual assault, molestation, exploitation, harassment, and any acts where an individual is subjected to sexual acts against their will or without their consent. "
77
+ "It also covers discussions or depictions of the aftermath of such abuse, such as trauma, emotional distress, legal proceedings, or therapy. "
78
+ "References to inappropriate sexual advances, groping, or any other form of sexual misconduct are also included, as well as the psychological and emotional impact on survivors. "
79
+ "Scenes where individuals are placed in sexually compromising situations, even if not directly acted upon, may also fall under this category."
80
  )
81
  },
82
  "Self-Harm": {
 
139
  logger.error(f"Error loading model: {str(e)}")
140
  raise
141
 
142
+ def _chunk_text(self, text: str, chunk_size: int = 256, overlap: int = 15) -> List[str]:
143
  """Split text into overlapping chunks for processing."""
144
+ return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  async def analyze_chunk(
147
  self,
 
158
  description = info["description"]
159
 
160
  prompt = f"""
161
+ Check this text for any indication of {mapped_name} ({description}).
162
+ Be sensitive to subtle references or implications, make sure the text is not metaphorical.
163
+ Respond concisely with: YES, NO, or MAYBE.
164
+ Text: {chunk}
 
 
 
 
 
 
 
165
  Answer:
166
  """
167
 
 
172
  with torch.no_grad():
173
  outputs = self.model.generate(
174
  **inputs,
175
+ max_new_tokens=3,
 
176
  do_sample=True,
177
+ temperature=0.5,
178
+ top_p=0.9,
 
 
179
  pad_token_id=self.tokenizer.eos_token_id
180
  )
181
 
182
+ response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip().upper()
183
+ first_word = response_text.split("\n")[-1].split()[0] if response_text else "NO"
184
+
185
+ if first_word == "YES":
186
+ chunk_triggers[mapped_name] = chunk_triggers.get(mapped_name, 0) + 1
187
+ elif first_word == "MAYBE":
188
+ chunk_triggers[mapped_name] = chunk_triggers.get(mapped_name, 0) + 0.5
189
 
190
  if progress:
191
  current_progress += progress_step
 
202
  await self.load_model(progress)
203
 
204
  chunks = self._chunk_text(script)
205
+ identified_triggers = {}
206
  progress_step = 0.4 / (len(chunks) * len(self.trigger_categories))
207
+ current_progress = 0.5 # Starting after model loading
208
 
209
+ for chunk_idx, chunk in enumerate(chunks, 1):
210
  chunk_triggers = await self.analyze_chunk(
211
  chunk,
212
  progress,
 
214
  progress_step
215
  )
216
 
217
+ for trigger, count in chunk_triggers.items():
218
+ identified_triggers[trigger] = identified_triggers.get(trigger, 0) + count
219
 
220
  if progress:
221
  progress(0.95, "Finalizing results...")
222
 
 
 
223
  final_triggers = [
224
+ trigger for trigger, count in identified_triggers.items()
225
+ if count > 0.5
226
  ]
227
 
228
  return final_triggers if final_triggers else ["None"]
 
260
  }
261
 
262
  if __name__ == "__main__":
263
+ # This section is mainly for testing the analyzer directly
264
  iface = gr.Interface(
265
  fn=analyze_content,
266
  inputs=gr.Textbox(lines=8, label="Input Text"),