File size: 3,982 Bytes
56b0bc9
b1caa99
98632cb
6671403
 
56b0bc9
5a36ad5
b1caa99
1b20ee8
 
 
 
 
 
 
 
5a36ad5
1b20ee8
 
 
 
 
 
 
 
56b0bc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b20ee8
56b0bc9
1b20ee8
 
98632cb
568c509
1b20ee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23d6bd9
 
 
 
 
 
 
 
 
 
0b18f6e
23d6bd9
0b18f6e
 
 
568c509
1b20ee8
0b18f6e
1b20ee8
0b18f6e
 
98632cb
6671403
98632cb
 
 
 
 
 
 
 
6671403
 
98632cb
 
b1caa99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import gradio as gr
import glob
import time
import random
import requests
import numpy as np

# Import necessary libraries
from torchvision import models, transforms
from PIL import Image
import torch

# Load pre-trained ResNet model once
model = models.resnet50(pretrained=True)
model.eval()
# 
# Define image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to download imagenet_classes.txt
def download_imagenet_classes():
    url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
    response = requests.get(url)
    if response.status_code == 200:
        with open("imagenet_classes.txt", "wb") as f:
            f.write(response.content)
        print("imagenet_classes.txt downloaded successfully.")
    else:
        print("Failed to download imagenet_classes.txt")

# Check if imagenet_classes.txt exists, if not, download it
if not os.path.exists("imagenet_classes.txt"):
    download_imagenet_classes()

# Load class labels
with open('imagenet_classes.txt', 'r') as f:
    labels = [line.strip() for line in f.readlines()]

def classify_image(image):
    # Wait for a random interval between 0.5 and 1.5 seconds to look useful
    # time.sleep(random.uniform(0.5, 1.5))
    print("Classifying image...")
    
    # Preprocess the image
    img = Image.fromarray(image).convert('RGB')
    img_t = transform(img)
    batch_t = torch.unsqueeze(img_t, 0)

    # Make prediction
    with torch.no_grad():
        output = model(batch_t)

    # Get the predicted class
    _, predicted = torch.max(output, 1)
    classification = labels[predicted.item()]

    # Check if the predicted class is a bird
    bird_categories = [
        'cock', 'hen', 'ostrich', 'brambling', 'goldfinch', 'house finch', 'junco', 'indigo bunting', 'robin',
        'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel', 'kite', 'bald eagle', 'vulture', 'great grey owl',
        'European fire salamander', 'ptarmigan', 'ruffed grouse', 'prairie chicken', 'peacock', 'quail', 'partridge',
        'African grey', 'macaw', 'sulphur-crested cockatoo', 'lorikeet', 'coucal', 'bee eater', 'hornbill',
        'hummingbird', 'jacamar', 'toucan', 'drake', 'red-breasted merganser', 'goose', 'black swan', 'white stork',
        'black stork', 'spoonbill', 'flamingo', 'little blue heron', 'American egret', 'bittern', 'crane', 'limpkin',
        'European gallinule', 'American coot', 'bustard', 'ruddy turnstone', 'red-backed sandpiper', 'redshank',
        'dowitcher', 'oystercatcher', 'pelican', 'king penguin', 'albatross'
    ]
    is_bird = ('bird' in classification.lower()) or any(category in classification.lower() for category in bird_categories)
# 
    # Get the confidence score
    confidence_score = torch.nn.functional.softmax(output[0], dim=0)[predicted].item()
    confidence_percentage = f"{confidence_score:.2%}"

    if is_bird:
        return f"This is a bird! Specifically, it looks like a {classification}. Model confidence: {confidence_percentage}"
    else:
        return f"This is not a bird. It appears to be a {classification}. Model confidence: {confidence_percentage}"
#
# Dynamically create the list of example images
example_files = sorted(glob.glob("examples/*.png"))
examples = [[file] for file in example_files]

# Create the Gradio interface
demo = gr.Interface(
    fn=classify_image,  # The function to run
    inputs="image",     # The input type is an image
    outputs="text",     # The output type is text
    examples=examples   # Add example images
    ,title="Is this a picture of a bird?"  # Title of the app
    ,description="Uses the latest in machine learning LLM Diffusion models to analyzes every pixel (twice) and to determine conclusively if it is a picture of a bird"  # Description of the app
)
# Launch the app
demo.launch()