Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| import os | |
| model_names = [ | |
| "0-ma/beit-geometric-shapes-base", | |
| "0-ma/vit-geometric-shapes-tiny", | |
| "0-ma/vit-geometric-shapes-base", | |
| "0-ma/swin-geometric-shapes-tiny", | |
| "0-ma/mobilenet-v2-geometric-shapes", | |
| "0-ma/focalnet-geometric-shapes-tiny", | |
| "0-ma/efficientnet-b2-geometric-shapes", | |
| "0-ma/mit-b0-geometric-shapes", | |
| "0-ma/resnet-geometric-shapes", | |
| ] | |
| labels = [ | |
| 'None', | |
| 'Circle', | |
| 'Triangle', | |
| 'Square', | |
| 'Pentagone', | |
| 'Hexagone' | |
| ] | |
| example_dir = "./example" | |
| example_images = [os.path.join(example_dir,example_image) for example_image in os.listdir(example_dir)] | |
| feature_extractors = {model_name: AutoImageProcessor.from_pretrained(model_name) for model_name in model_names} | |
| classification_models = {model_name: AutoModelForImageClassification.from_pretrained(model_name) for model_name in model_names} | |
| def predict(image, selected_model): | |
| if image is None: | |
| return None | |
| feature_extractor = feature_extractors[selected_model] | |
| model = classification_models[selected_model] | |
| inputs = feature_extractor(images=[image], return_tensors="pt") | |
| logits = model(**inputs)['logits'].cpu().detach().numpy()[0] | |
| logits_positive = logits | |
| logits_positive[logits < 0] = 0 | |
| logits_positive = logits_positive/np.sum(logits_positive) | |
| confidences = {} | |
| for i in range(len(labels)): | |
| if logits[i] > 0: | |
| confidences[labels[i]] = float(logits_positive[i]) | |
| return confidences | |
| title = "Geometric Shape Classifier" | |
| description = "Select a model and upload an image to classify geometric shapes." | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| model_dropdown = gr.Dropdown(choices=model_names, label="Select Model", value=model_names[0]) | |
| image_input = gr.Image(type="pil") | |
| gr.Examples( | |
| examples=example_images, | |
| inputs=image_input, | |
| label="Click on an example image to test", | |
| ) | |
| output = gr.Label(label="Classification Result") | |
| image_input.change(fn=predict, inputs=[image_input, model_dropdown], outputs=output) | |
| model_dropdown.change(fn=predict, inputs=[image_input, model_dropdown], outputs=output) | |
| demo.launch() |