Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
# fashionclip_segmented_app.py
|
2 |
import gradio as gr
|
3 |
from PIL import Image
|
4 |
import torch
|
@@ -16,17 +15,9 @@ color_prompts = ["a red garment", "a blue garment", "a black garment", "a white
|
|
16 |
pattern_prompts = ["a plain shirt", "a striped shirt", "a floral shirt", "a checked shirt", "a dotted shirt", "an abstract patterned shirt"]
|
17 |
fit_prompts = ["a slim fit shirt", "an oversized top", "a regular fit shirt", "a cropped shirt", "a shirt with a crew neck", "a shirt with a v-neck", "a shirt with a round neckline"]
|
18 |
|
19 |
-
# Funktion zur Hintergrundentfernung
|
20 |
-
def remove_background(image: Image.Image) -> Image.Image:
|
21 |
-
image_np = np.array(image)
|
22 |
-
image_no_bg_np = remove(image_np)
|
23 |
-
return Image.fromarray(image_no_bg_np)
|
24 |
-
|
25 |
# Hilfsfunktion: finde das passendste Prompt für eine Gruppe
|
26 |
def predict_best_prompt(image, prompts):
|
27 |
-
|
28 |
-
image_clean = remove_background(image)
|
29 |
-
inputs = processor(text=prompts, images=[image_clean], return_tensors="pt", padding=True)
|
30 |
with torch.no_grad():
|
31 |
outputs = model(**inputs)
|
32 |
logits_per_image = outputs.logits_per_image
|
@@ -37,30 +28,36 @@ def predict_best_prompt(image, prompts):
|
|
37 |
# Hauptfunktion für die App
|
38 |
def analyze_image(image):
|
39 |
if image is None:
|
40 |
-
return "⚠️ Please upload or take a picture first."
|
|
|
|
|
|
|
|
|
41 |
|
|
|
42 |
results = {}
|
43 |
-
results["Category"], cat_score = predict_best_prompt(
|
44 |
-
results["Color"], color_score = predict_best_prompt(
|
45 |
-
results["Pattern"], pattern_score = predict_best_prompt(
|
46 |
-
results["Fit"], fit_score = predict_best_prompt(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
return
|
49 |
-
Category: {results['Category']} ({cat_score:.2f})\n
|
50 |
-
Color: {results['Color']} ({color_score:.2f})\n
|
51 |
-
Pattern: {results['Pattern']} ({pattern_score:.2f})\n
|
52 |
-
Fit: {results['Fit']} ({fit_score:.2f})
|
53 |
-
"""
|
54 |
|
55 |
# Gradio UI erstellen
|
56 |
iface = gr.Interface(
|
57 |
fn=analyze_image,
|
58 |
inputs=gr.Image(type="pil", label="Upload or take a picture", sources=["upload", "webcam"]),
|
59 |
-
outputs="
|
60 |
title="Fashion Attribute Predictor (Prototype 2 + Segmentation)",
|
61 |
-
description="
|
62 |
)
|
63 |
|
64 |
-
# App starten
|
65 |
if __name__ == "__main__":
|
66 |
iface.launch()
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from PIL import Image
|
3 |
import torch
|
|
|
15 |
pattern_prompts = ["a plain shirt", "a striped shirt", "a floral shirt", "a checked shirt", "a dotted shirt", "an abstract patterned shirt"]
|
16 |
fit_prompts = ["a slim fit shirt", "an oversized top", "a regular fit shirt", "a cropped shirt", "a shirt with a crew neck", "a shirt with a v-neck", "a shirt with a round neckline"]
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# Hilfsfunktion: finde das passendste Prompt für eine Gruppe
|
19 |
def predict_best_prompt(image, prompts):
|
20 |
+
inputs = processor(text=prompts, images=[image], return_tensors="pt", padding=True)
|
|
|
|
|
21 |
with torch.no_grad():
|
22 |
outputs = model(**inputs)
|
23 |
logits_per_image = outputs.logits_per_image
|
|
|
28 |
# Hauptfunktion für die App
|
29 |
def analyze_image(image):
|
30 |
if image is None:
|
31 |
+
return None, "⚠️ Please upload or take a picture first."
|
32 |
+
# Hintergrund einmal entfernen!
|
33 |
+
image_np = np.array(image)
|
34 |
+
image_no_bg_np = remove(image_np)
|
35 |
+
segmented_image = Image.fromarray(image_no_bg_np)
|
36 |
|
37 |
+
# Vorhersage mit CLIP auf dem segmentierten Bild
|
38 |
results = {}
|
39 |
+
results["Category"], cat_score = predict_best_prompt(segmented_image, category_prompts)
|
40 |
+
results["Color"], color_score = predict_best_prompt(segmented_image, color_prompts)
|
41 |
+
results["Pattern"], pattern_score = predict_best_prompt(segmented_image, pattern_prompts)
|
42 |
+
results["Fit"], fit_score = predict_best_prompt(segmented_image, fit_prompts)
|
43 |
+
|
44 |
+
text_result = (
|
45 |
+
f"Category: {results['Category']} ({cat_score:.2f})\n"
|
46 |
+
f"Color: {results['Color']} ({color_score:.2f})\n"
|
47 |
+
f"Pattern: {results['Pattern']} ({pattern_score:.2f})\n"
|
48 |
+
f"Fit: {results['Fit']} ({fit_score:.2f})"
|
49 |
+
)
|
50 |
|
51 |
+
return segmented_image, text_result
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# Gradio UI erstellen
|
54 |
iface = gr.Interface(
|
55 |
fn=analyze_image,
|
56 |
inputs=gr.Image(type="pil", label="Upload or take a picture", sources=["upload", "webcam"]),
|
57 |
+
outputs=[gr.Image(label="Segmentiertes Bild"), gr.Textbox(label="Vorhersage")],
|
58 |
title="Fashion Attribute Predictor (Prototype 2 + Segmentation)",
|
59 |
+
description="Das Modell entfernt zuerst den Hintergrund (Segmentierung) und erkennt dann Attribute mit FashionCLIP."
|
60 |
)
|
61 |
|
|
|
62 |
if __name__ == "__main__":
|
63 |
iface.launch()
|