File size: 3,982 Bytes
56b0bc9 b1caa99 98632cb 6671403 56b0bc9 5a36ad5 b1caa99 1b20ee8 5a36ad5 1b20ee8 56b0bc9 1b20ee8 56b0bc9 1b20ee8 98632cb 568c509 1b20ee8 23d6bd9 0b18f6e 23d6bd9 0b18f6e 568c509 1b20ee8 0b18f6e 1b20ee8 0b18f6e 98632cb 6671403 98632cb 6671403 98632cb b1caa99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import gradio as gr
import glob
import time
import random
import requests
import numpy as np
# Import necessary libraries
from torchvision import models, transforms
from PIL import Image
import torch
# Load pre-trained ResNet model once
model = models.resnet50(pretrained=True)
model.eval()
#
# Define image transformations
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Function to download imagenet_classes.txt
def download_imagenet_classes():
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
response = requests.get(url)
if response.status_code == 200:
with open("imagenet_classes.txt", "wb") as f:
f.write(response.content)
print("imagenet_classes.txt downloaded successfully.")
else:
print("Failed to download imagenet_classes.txt")
# Check if imagenet_classes.txt exists, if not, download it
if not os.path.exists("imagenet_classes.txt"):
download_imagenet_classes()
# Load class labels
with open('imagenet_classes.txt', 'r') as f:
labels = [line.strip() for line in f.readlines()]
def classify_image(image):
# Wait for a random interval between 0.5 and 1.5 seconds to look useful
# time.sleep(random.uniform(0.5, 1.5))
print("Classifying image...")
# Preprocess the image
img = Image.fromarray(image).convert('RGB')
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
# Make prediction
with torch.no_grad():
output = model(batch_t)
# Get the predicted class
_, predicted = torch.max(output, 1)
classification = labels[predicted.item()]
# Check if the predicted class is a bird
bird_categories = [
'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'robin',
'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture', 'great grey owl',
'European fire salamander', 'ptarmigan', 'ruffed grouse', 'prairie chicken', 'peacock', 'quail', 'partridge',
'African grey', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill',
'hummingbird', 'jacamar', 'toucan', 'drake', 'red-breasted merganser', 'goose', 'black swan', 'white stork',
'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'American egret', 'bittern', 'crane', 'limpkin',
'European gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'red-backed sandpiper', 'redshank',
'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross'
]
is_bird = ('bird' in classification.lower()) or any(category in classification.lower() for category in bird_categories)
#
# Get the confidence score
confidence_score = torch.nn.functional.softmax(output[0], dim=0)[predicted].item()
confidence_percentage = f"{confidence_score:.2%}"
if is_bird:
return f"This is a bird! Specifically, it looks like a {classification}. Model confidence: {confidence_percentage}"
else:
return f"This is not a bird. It appears to be a {classification}. Model confidence: {confidence_percentage}"
#
# Dynamically create the list of example images
example_files = sorted(glob.glob("examples/*.png"))
examples = [[file] for file in example_files]
# Create the Gradio interface
demo = gr.Interface(
fn=classify_image, # The function to run
inputs="image", # The input type is an image
outputs="text", # The output type is text
examples=examples # Add example images
,title="Is this a picture of a bird?" # Title of the app
,description="Uses the latest in machine learning LLM Diffusion models to analyzes every pixel (twice) and to determine conclusively if it is a picture of a bird" # Description of the app
)
# Launch the app
demo.launch() |