cella110n commited on
Commit
424eab0
·
verified ·
1 Parent(s): e6e216d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -36
app.py CHANGED
@@ -195,21 +195,13 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
195
  for tag, prob in predictions.get("meta", []):
196
  if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
197
  filtered_meta.append((tag, prob))
198
- predictions["meta"] = filtered_meta # Use filtered list for visualization
199
 
200
  # --- Plotting Setup ---
201
- plt.rcParams['font.family'] = 'DejaVu Sans' # Ensure font compatibility
202
- fig = plt.figure(figsize=(12, 20), dpi=100)
203
- gs = fig.add_gridspec(1, 2, width_ratios=[1.2, 1])
204
-
205
- # Left side: Image
206
- # ax_img = fig.add_subplot(gs[0, 0])
207
- # ax_img.imshow(image)
208
- # ax_img.set_title("Original Image")
209
- # ax_img.axis('off')
210
-
211
- # Right side: Tags
212
- ax_tags = fig.add_subplot(gs[0, 1])
213
  all_tags, all_probs, all_colors = [], [], []
214
  color_map = {
215
  'rating': 'red', 'character': 'blue', 'copyright': 'purple',
@@ -223,22 +215,18 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
223
  ('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
224
  ('meta', 'M', color_map['meta'])
225
  ]:
226
- # Sort within category by probability before adding
227
  sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
228
  for tag, prob in sorted_tags:
229
- # Add prefix to tag name for display
230
- all_tags.append(f"[{prefix}] {tag.replace('_', ' ')}") # Replace underscores for display
231
  all_probs.append(prob)
232
  all_colors.append(color)
233
 
234
  if not all_tags:
235
  ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
236
- ax_tags.set_title(f"Tags (Thresholds: Gen/Meta={threshold:.2f}, Char/Art/Copy={threshold:.2f})") # Assuming same threshold for now
237
  ax_tags.axis('off')
238
  else:
239
- # Sort all aggregated tags by probability (descending) for plotting order
240
- # Plotting from bottom up, so we want highest probability at the top
241
- sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i]) # Sort ascending for barh
242
  all_tags = [all_tags[i] for i in sorted_indices]
243
  all_probs = [all_probs[i] for i in sorted_indices]
244
  all_colors = [all_colors[i] for i in sorted_indices]
@@ -252,37 +240,31 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
252
  ax_tags.set_yticklabels(all_tags)
253
 
254
  fontsize = 10 if num_tags <= 40 else 8 if num_tags <= 60 else 6
255
- for label in ax_tags.get_yticklabels():
256
- label.set_fontsize(fontsize)
257
 
258
- # Add probability text next to bars
259
  for i, (bar, prob) in enumerate(zip(bars, all_probs)):
260
- # Position text slightly outside the bar, ensuring it stays within plot bounds
261
- text_x = min(prob + 0.02, 0.98) # Adjust x position
262
- ax_tags.text(text_x, y_positions[i], f"{prob:.3f}", va='center', fontsize=fontsize)
263
 
264
  ax_tags.set_xlim(0, 1)
265
- ax_tags.set_title(f"Tags (Thresholds approx: {threshold:.2f})") # Indicate threshold used
266
 
267
- # Add legend
268
  from matplotlib.patches import Patch
269
  legend_elements = [
270
- Patch(facecolor=color, label=cat.capitalize()) for cat, color in color_map.items()
271
- if any(t.startswith(f"[{cat[0].upper() if cat != 'copyright' else '©'}]") for t in all_tags)
 
272
  ]
273
  if legend_elements:
274
- ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
275
 
276
  plt.tight_layout()
277
- plt.subplots_adjust(bottom=0.05)
278
-
279
- # Save plot to buffer
280
  buf = io.BytesIO()
281
  plt.savefig(buf, format='png', dpi=100)
282
  plt.close(fig)
283
  buf.seek(0)
284
- viz_image = Image.open(buf)
285
- return viz_image
286
 
287
  # --- Constants ---
288
  REPO_ID = "cella110n/cl_tagger"
 
195
  for tag, prob in predictions.get("meta", []):
196
  if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
197
  filtered_meta.append((tag, prob))
198
+ predictions["meta"] = filtered_meta # Use filtered list for visualization
199
 
200
  # --- Plotting Setup ---
201
+ plt.rcParams['font.family'] = 'DejaVu Sans'
202
+ fig = plt.figure(figsize=(8, 20), dpi=100)
203
+ ax_tags = fig.add_subplot(1, 1, 1)
204
+
 
 
 
 
 
 
 
 
205
  all_tags, all_probs, all_colors = [], [], []
206
  color_map = {
207
  'rating': 'red', 'character': 'blue', 'copyright': 'purple',
 
215
  ('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
216
  ('meta', 'M', color_map['meta'])
217
  ]:
 
218
  sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
219
  for tag, prob in sorted_tags:
220
+ all_tags.append(f"[{prefix}] {tag.replace('_', ' ')}")
 
221
  all_probs.append(prob)
222
  all_colors.append(color)
223
 
224
  if not all_tags:
225
  ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
226
+ ax_tags.set_title(f"Tags (Threshold {threshold:.2f})")
227
  ax_tags.axis('off')
228
  else:
229
+ sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i])
 
 
230
  all_tags = [all_tags[i] for i in sorted_indices]
231
  all_probs = [all_probs[i] for i in sorted_indices]
232
  all_colors = [all_colors[i] for i in sorted_indices]
 
240
  ax_tags.set_yticklabels(all_tags)
241
 
242
  fontsize = 10 if num_tags <= 40 else 8 if num_tags <= 60 else 6
243
+ for lbl in ax_tags.get_yticklabels():
244
+ lbl.set_fontsize(fontsize)
245
 
 
246
  for i, (bar, prob) in enumerate(zip(bars, all_probs)):
247
+ text_x = min(prob + 0.02, 0.98)
248
+ ax_tags.text(text_x, y_positions[i], f"{prob:.3f}", va='center', fontsize=fontsize)
 
249
 
250
  ax_tags.set_xlim(0, 1)
251
+ ax_tags.set_title(f"Tags (Threshold {threshold:.2f})")
252
 
 
253
  from matplotlib.patches import Patch
254
  legend_elements = [
255
+ Patch(facecolor=color, label=cat.capitalize())
256
+ for cat, color in color_map.items()
257
+ if any(t.startswith(f"[{cat[0].upper() if cat!='copyright' else '©'}]") for t in all_tags)
258
  ]
259
  if legend_elements:
260
+ ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
261
 
262
  plt.tight_layout()
 
 
 
263
  buf = io.BytesIO()
264
  plt.savefig(buf, format='png', dpi=100)
265
  plt.close(fig)
266
  buf.seek(0)
267
+ return Image.open(buf)
 
268
 
269
  # --- Constants ---
270
  REPO_ID = "cella110n/cl_tagger"