Spaces:
Sleeping
Sleeping
File size: 5,363 Bytes
cd112fe c10d8f3 cd112fe 6f9b46a 826cc00 607eca2 c3fca84 053d849 c3fca84 053d849 c3fca84 c10d8f3 c3fca84 c10d8f3 c3fca84 c10d8f3 c3fca84 6f9b46a cd112fe 7176315 cd112fe 7176315 cd112fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
from transformers import AutoModel, AutoProcessor
import torch
import requests
from PIL import Image
from io import BytesIO
fashion_items = ['top', 'trousers', 'jumper']
# Load model and processor with proper meta tensor handling
model_name = 'Marqo/marqo-fashionSigLIP'
# Force CPU usage to avoid device mapping issues
device = torch.device('cpu')
# Set environment variables to prevent meta tensor issues
import os
os.environ['HF_HOME'] = '/tmp/hf_cache' # Use temporary cache directory
# Targeted patching of open_clip to prevent meta tensor issues
try:
import open_clip
import torch.nn as nn
# Store original methods
original_to = nn.Module.to
original_set_model_device_and_precision = open_clip.factory._set_model_device_and_precision
# Patch the problematic _set_model_device_and_precision function
def patched_set_model_device_and_precision(model, device, precision, is_timm_model):
# Force device to CPU and use to_empty instead of to
cpu_device = torch.device('cpu')
if hasattr(model, 'to_empty'):
model.to_empty(device=cpu_device)
else:
# Fallback to original method but with CPU device
try:
original_to(model, device=cpu_device)
except:
# If that fails, try to move parameters individually
for param in model.parameters():
if param.device != cpu_device:
param.data = param.data.to(cpu_device)
if param.grad is not None:
param.grad.data = param.grad.data.to(cpu_device)
# Apply the patch
open_clip.factory._set_model_device_and_precision = patched_set_model_device_and_precision
# Also patch the Module.to method to handle meta tensors
def patched_to(self, *args, **kwargs):
# Check if we're moving from meta device
if hasattr(self, 'parameters'):
for param in self.parameters():
if param.device.type == 'meta':
# Use to_empty instead of to for meta tensors
if hasattr(self, 'to_empty'):
return self.to_empty(device=torch.device('cpu'))
else:
# Create new tensors with the same shape
cpu_device = torch.device('cpu')
for name, param in self.named_parameters(recurse=False):
if param.device.type == 'meta':
new_param = torch.empty_like(param, device=cpu_device)
setattr(self, name, torch.nn.Parameter(new_param))
for name, buffer in self.named_buffers(recurse=False):
if buffer.device.type == 'meta':
new_buffer = torch.empty_like(buffer, device=cpu_device)
setattr(self, name, new_buffer)
return self
# Fallback to original method
return original_to(self, *args, **kwargs)
# Apply the patch
nn.Module.to = patched_to
except Exception as e:
print(f"Could not patch open_clip: {e}")
# Load model with patched open_clip to prevent meta tensor issues
try:
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32
)
model = model.to(device)
except Exception as e:
print(f"Model loading failed: {e}")
# Fallback - try loading with different configuration
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True
)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
# Preprocess and normalize text data
with torch.no_grad():
# Ensure truncation and padding are activated
processed_texts = processor(
text=fashion_items,
return_tensors="pt",
truncation=True, # Ensure text is truncated to fit model input size
padding=True # Pad shorter sequences so that all are the same length
)['input_ids']
text_features = model.get_text_features(processed_texts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Prediction function
def predict_from_url(url):
# Check if the URL is empty
if not url:
return {"Error": "Please input a URL"}
try:
image = Image.open(BytesIO(requests.get(url).content))
except Exception as e:
return {"Error": f"Failed to load image: {str(e)}"}
processed_image = processor(images=image, return_tensors="pt")['pixel_values']
with torch.no_grad():
image_features = model.get_image_features(processed_image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))}
# Gradio interface
demo = gr.Interface(
fn=predict_from_url,
inputs=gr.Textbox(label="Enter Image URL"),
outputs=gr.Label(label="Classification Results"),
title="Fashion Item Classifier"
)
# Launch the interface
demo.launch()
|