Upload app.py
Browse files
app.py
CHANGED
@@ -195,21 +195,13 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
|
|
195 |
for tag, prob in predictions.get("meta", []):
|
196 |
if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
|
197 |
filtered_meta.append((tag, prob))
|
198 |
-
predictions["meta"] = filtered_meta
|
199 |
|
200 |
# --- Plotting Setup ---
|
201 |
-
plt.rcParams['font.family'] = 'DejaVu Sans'
|
202 |
-
fig = plt.figure(figsize=(
|
203 |
-
|
204 |
-
|
205 |
-
# Left side: Image
|
206 |
-
# ax_img = fig.add_subplot(gs[0, 0])
|
207 |
-
# ax_img.imshow(image)
|
208 |
-
# ax_img.set_title("Original Image")
|
209 |
-
# ax_img.axis('off')
|
210 |
-
|
211 |
-
# Right side: Tags
|
212 |
-
ax_tags = fig.add_subplot(gs[0, 1])
|
213 |
all_tags, all_probs, all_colors = [], [], []
|
214 |
color_map = {
|
215 |
'rating': 'red', 'character': 'blue', 'copyright': 'purple',
|
@@ -223,22 +215,18 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
|
|
223 |
('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
|
224 |
('meta', 'M', color_map['meta'])
|
225 |
]:
|
226 |
-
# Sort within category by probability before adding
|
227 |
sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
|
228 |
for tag, prob in sorted_tags:
|
229 |
-
|
230 |
-
all_tags.append(f"[{prefix}] {tag.replace('_', ' ')}") # Replace underscores for display
|
231 |
all_probs.append(prob)
|
232 |
all_colors.append(color)
|
233 |
|
234 |
if not all_tags:
|
235 |
ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
|
236 |
-
ax_tags.set_title(f"Tags (
|
237 |
ax_tags.axis('off')
|
238 |
else:
|
239 |
-
|
240 |
-
# Plotting from bottom up, so we want highest probability at the top
|
241 |
-
sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i]) # Sort ascending for barh
|
242 |
all_tags = [all_tags[i] for i in sorted_indices]
|
243 |
all_probs = [all_probs[i] for i in sorted_indices]
|
244 |
all_colors = [all_colors[i] for i in sorted_indices]
|
@@ -252,37 +240,31 @@ def visualize_predictions(image: Image.Image, predictions: Dict, threshold: floa
|
|
252 |
ax_tags.set_yticklabels(all_tags)
|
253 |
|
254 |
fontsize = 10 if num_tags <= 40 else 8 if num_tags <= 60 else 6
|
255 |
-
for
|
256 |
-
|
257 |
|
258 |
-
# Add probability text next to bars
|
259 |
for i, (bar, prob) in enumerate(zip(bars, all_probs)):
|
260 |
-
|
261 |
-
|
262 |
-
ax_tags.text(text_x, y_positions[i], f"{prob:.3f}", va='center', fontsize=fontsize)
|
263 |
|
264 |
ax_tags.set_xlim(0, 1)
|
265 |
-
ax_tags.set_title(f"Tags (
|
266 |
|
267 |
-
# Add legend
|
268 |
from matplotlib.patches import Patch
|
269 |
legend_elements = [
|
270 |
-
|
271 |
-
|
|
|
272 |
]
|
273 |
if legend_elements:
|
274 |
-
|
275 |
|
276 |
plt.tight_layout()
|
277 |
-
plt.subplots_adjust(bottom=0.05)
|
278 |
-
|
279 |
-
# Save plot to buffer
|
280 |
buf = io.BytesIO()
|
281 |
plt.savefig(buf, format='png', dpi=100)
|
282 |
plt.close(fig)
|
283 |
buf.seek(0)
|
284 |
-
|
285 |
-
return viz_image
|
286 |
|
287 |
# --- Constants ---
|
288 |
REPO_ID = "cella110n/cl_tagger"
|
|
|
195 |
for tag, prob in predictions.get("meta", []):
|
196 |
if not any(pattern in tag.lower() for pattern in excluded_meta_patterns):
|
197 |
filtered_meta.append((tag, prob))
|
198 |
+
predictions["meta"] = filtered_meta # Use filtered list for visualization
|
199 |
|
200 |
# --- Plotting Setup ---
|
201 |
+
plt.rcParams['font.family'] = 'DejaVu Sans'
|
202 |
+
fig = plt.figure(figsize=(8, 20), dpi=100)
|
203 |
+
ax_tags = fig.add_subplot(1, 1, 1)
|
204 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
all_tags, all_probs, all_colors = [], [], []
|
206 |
color_map = {
|
207 |
'rating': 'red', 'character': 'blue', 'copyright': 'purple',
|
|
|
215 |
('artist', 'A', color_map['artist']), ('general', 'G', color_map['general']),
|
216 |
('meta', 'M', color_map['meta'])
|
217 |
]:
|
|
|
218 |
sorted_tags = sorted(predictions.get(cat, []), key=lambda x: x[1], reverse=True)
|
219 |
for tag, prob in sorted_tags:
|
220 |
+
all_tags.append(f"[{prefix}] {tag.replace('_', ' ')}")
|
|
|
221 |
all_probs.append(prob)
|
222 |
all_colors.append(color)
|
223 |
|
224 |
if not all_tags:
|
225 |
ax_tags.text(0.5, 0.5, "No tags found above threshold", ha='center', va='center')
|
226 |
+
ax_tags.set_title(f"Tags (Threshold ≳ {threshold:.2f})")
|
227 |
ax_tags.axis('off')
|
228 |
else:
|
229 |
+
sorted_indices = sorted(range(len(all_probs)), key=lambda i: all_probs[i])
|
|
|
|
|
230 |
all_tags = [all_tags[i] for i in sorted_indices]
|
231 |
all_probs = [all_probs[i] for i in sorted_indices]
|
232 |
all_colors = [all_colors[i] for i in sorted_indices]
|
|
|
240 |
ax_tags.set_yticklabels(all_tags)
|
241 |
|
242 |
fontsize = 10 if num_tags <= 40 else 8 if num_tags <= 60 else 6
|
243 |
+
for lbl in ax_tags.get_yticklabels():
|
244 |
+
lbl.set_fontsize(fontsize)
|
245 |
|
|
|
246 |
for i, (bar, prob) in enumerate(zip(bars, all_probs)):
|
247 |
+
text_x = min(prob + 0.02, 0.98)
|
248 |
+
ax_tags.text(text_x, y_positions[i], f"{prob:.3f}", va='center', fontsize=fontsize)
|
|
|
249 |
|
250 |
ax_tags.set_xlim(0, 1)
|
251 |
+
ax_tags.set_title(f"Tags (Threshold ≳ {threshold:.2f})")
|
252 |
|
|
|
253 |
from matplotlib.patches import Patch
|
254 |
legend_elements = [
|
255 |
+
Patch(facecolor=color, label=cat.capitalize())
|
256 |
+
for cat, color in color_map.items()
|
257 |
+
if any(t.startswith(f"[{cat[0].upper() if cat!='copyright' else '©'}]") for t in all_tags)
|
258 |
]
|
259 |
if legend_elements:
|
260 |
+
ax_tags.legend(handles=legend_elements, loc='lower right', fontsize=8)
|
261 |
|
262 |
plt.tight_layout()
|
|
|
|
|
|
|
263 |
buf = io.BytesIO()
|
264 |
plt.savefig(buf, format='png', dpi=100)
|
265 |
plt.close(fig)
|
266 |
buf.seek(0)
|
267 |
+
return Image.open(buf)
|
|
|
268 |
|
269 |
# --- Constants ---
|
270 |
REPO_ID = "cella110n/cl_tagger"
|