import gradio as gr from transformers import AutoModel, AutoProcessor import torch import requests from PIL import Image from io import BytesIO fashion_items = ['top', 'trousers', 'jumper'] # Load model and processor with proper meta tensor handling model_name = 'Marqo/marqo-fashionSigLIP' # Force CPU usage to avoid device mapping issues device = torch.device('cpu') # Set environment variables to prevent meta tensor issues import os os.environ['HF_HOME'] = '/tmp/hf_cache' # Use temporary cache directory # Targeted patching of open_clip to prevent meta tensor issues try: import open_clip import torch.nn as nn # Store original methods original_to = nn.Module.to original_set_model_device_and_precision = open_clip.factory._set_model_device_and_precision # Patch the problematic _set_model_device_and_precision function def patched_set_model_device_and_precision(model, device, precision, is_timm_model): # Force device to CPU and use to_empty instead of to cpu_device = torch.device('cpu') if hasattr(model, 'to_empty'): model.to_empty(device=cpu_device) else: # Fallback to original method but with CPU device try: original_to(model, device=cpu_device) except: # If that fails, try to move parameters individually for param in model.parameters(): if param.device != cpu_device: param.data = param.data.to(cpu_device) if param.grad is not None: param.grad.data = param.grad.data.to(cpu_device) # Apply the patch open_clip.factory._set_model_device_and_precision = patched_set_model_device_and_precision # Also patch the Module.to method to handle meta tensors def patched_to(self, *args, **kwargs): # Check if we're moving from meta device if hasattr(self, 'parameters'): for param in self.parameters(): if param.device.type == 'meta': # Use to_empty instead of to for meta tensors if hasattr(self, 'to_empty'): return self.to_empty(device=torch.device('cpu')) else: # Create new tensors with the same shape cpu_device = torch.device('cpu') for name, param in self.named_parameters(recurse=False): if param.device.type == 'meta': new_param = torch.empty_like(param, device=cpu_device) setattr(self, name, torch.nn.Parameter(new_param)) for name, buffer in self.named_buffers(recurse=False): if buffer.device.type == 'meta': new_buffer = torch.empty_like(buffer, device=cpu_device) setattr(self, name, new_buffer) return self # Fallback to original method return original_to(self, *args, **kwargs) # Apply the patch nn.Module.to = patched_to except Exception as e: print(f"Could not patch open_clip: {e}") # Load model with patched open_clip to prevent meta tensor issues try: model = AutoModel.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.float32 ) model = model.to(device) except Exception as e: print(f"Model loading failed: {e}") # Fallback - try loading with different configuration model = AutoModel.from_pretrained( model_name, trust_remote_code=True ) model = model.to(device) processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) # Preprocess and normalize text data with torch.no_grad(): # Ensure truncation and padding are activated processed_texts = processor( text=fashion_items, return_tensors="pt", truncation=True, # Ensure text is truncated to fit model input size padding=True # Pad shorter sequences so that all are the same length )['input_ids'] text_features = model.get_text_features(processed_texts) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # Prediction function def predict_from_url(url): # Check if the URL is empty if not url: return {"Error": "Please input a URL"} try: image = Image.open(BytesIO(requests.get(url).content)) except Exception as e: return {"Error": f"Failed to load image: {str(e)}"} processed_image = processor(images=image, return_tensors="pt")['pixel_values'] with torch.no_grad(): image_features = model.get_image_features(processed_image) image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_probs = (100 * image_features @ text_features.T).softmax(dim=-1) return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))} # Gradio interface demo = gr.Interface( fn=predict_from_url, inputs=gr.Textbox(label="Enter Image URL"), outputs=gr.Label(label="Classification Results"), title="Fashion Item Classifier" ) # Launch the interface demo.launch()