Spaces:
Sleeping
Sleeping
| import subprocess | |
| import sys | |
| import os | |
| def install_requirements(): | |
| packages = [ | |
| "numpy==1.24.3", | |
| "torch==2.0.1", | |
| "torchvision==0.15.2", | |
| "Pillow==9.5.0", | |
| "gradio==3.50.2" | |
| ] | |
| for package in packages: | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package]) | |
| install_requirements() | |
| import traceback | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import gradio as gr | |
| class ModifiedLargeNet(nn.Module): | |
| def __init__(self): | |
| super(ModifiedLargeNet, self).__init__() | |
| self.name = "modified_large" | |
| self.conv1 = nn.Conv2d(3, 5, 5) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.conv2 = nn.Conv2d(5, 10, 5) | |
| self.fc1 = nn.Linear(10 * 29 * 29, 32) | |
| self.fc2 = nn.Linear(32, 3) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = x.view(-1, 10 * 29 * 29) | |
| x = F.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| return x | |
| try: | |
| model = ModifiedLargeNet() | |
| state_dict = torch.load("modified_large_net.pt", map_location=torch.device("cpu")) | |
| model.load_state_dict(state_dict) | |
| print("Model loaded successfully") | |
| model.eval() | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| traceback.print_exc() | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def custom_transform(pil_image): | |
| np_image = np.array(pil_image) | |
| tensor_image = torch.from_numpy(np_image.transpose((2, 0, 1))).float() | |
| tensor_image = tensor_image / 255.0 | |
| normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| tensor_image = normalize(tensor_image) | |
| return tensor_image | |
| def process_image(image): | |
| if image is None: | |
| return None | |
| try: | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image.astype('uint8')) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image = image.resize((128, 128), Image.Resampling.LANCZOS) | |
| print(f"Processed image size: {image.size}") | |
| print(f"Processed image mode: {image.mode}") | |
| print(f"Image type: {type(image)}") | |
| return image | |
| except Exception as e: | |
| print(f"Error in process_image: {str(e)}") | |
| traceback.print_exc() | |
| return None | |
| def predict(image): | |
| if image is None: | |
| return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]} | |
| try: | |
| processed_image = process_image(image) | |
| if processed_image is None: | |
| return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]} | |
| try: | |
| tensor_image = custom_transform(processed_image) | |
| tensor_image = tensor_image.unsqueeze(0) | |
| print(f"Input tensor shape: {tensor_image.shape}") | |
| print(f"Tensor dtype: {tensor_image.dtype}") | |
| print(f"Tensor device: {tensor_image.device}") | |
| except Exception as e: | |
| print(f"Error in tensor conversion: {str(e)}") | |
| traceback.print_exc() | |
| return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]} | |
| try: | |
| with torch.no_grad(): | |
| outputs = model(tensor_image) | |
| print(f"Raw outputs: {outputs}") | |
| probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy() | |
| print(f"Probabilities: {probabilities}") | |
| classes = ["Rope", "Hammer", "Other"] | |
| results = {cls: float(prob) for cls, prob in zip(classes, probabilities)} | |
| print(f"Final results: {results}") | |
| return results | |
| except Exception as e: | |
| print(f"Error in prediction: {str(e)}") | |
| traceback.print_exc() | |
| return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]} | |
| except Exception as e: | |
| print(f"Prediction error: {str(e)}") | |
| traceback.print_exc() | |
| return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]} | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="Mechanical Tools Classifier", | |
| description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.", | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| interface.launch() |