File size: 5,363 Bytes
cd112fe
 
 
 
 
 
 
 
 
c10d8f3
cd112fe
6f9b46a
 
 
826cc00
607eca2
 
 
 
c3fca84
053d849
c3fca84
 
053d849
c3fca84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c10d8f3
 
 
c3fca84
c10d8f3
c3fca84
c10d8f3
 
c3fca84
 
 
 
 
 
 
6f9b46a
cd112fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7176315
cd112fe
7176315
 
cd112fe
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
from transformers import AutoModel, AutoProcessor
import torch
import requests
from PIL import Image
from io import BytesIO

fashion_items = ['top', 'trousers', 'jumper']

# Load model and processor with proper meta tensor handling
model_name = 'Marqo/marqo-fashionSigLIP'

# Force CPU usage to avoid device mapping issues
device = torch.device('cpu')

# Set environment variables to prevent meta tensor issues
import os
os.environ['HF_HOME'] = '/tmp/hf_cache'  # Use temporary cache directory

# Targeted patching of open_clip to prevent meta tensor issues
try:
    import open_clip
    import torch.nn as nn
    
    # Store original methods
    original_to = nn.Module.to
    original_set_model_device_and_precision = open_clip.factory._set_model_device_and_precision
    
    # Patch the problematic _set_model_device_and_precision function
    def patched_set_model_device_and_precision(model, device, precision, is_timm_model):
        # Force device to CPU and use to_empty instead of to
        cpu_device = torch.device('cpu')
        if hasattr(model, 'to_empty'):
            model.to_empty(device=cpu_device)
        else:
            # Fallback to original method but with CPU device
            try:
                original_to(model, device=cpu_device)
            except:
                # If that fails, try to move parameters individually
                for param in model.parameters():
                    if param.device != cpu_device:
                        param.data = param.data.to(cpu_device)
                        if param.grad is not None:
                            param.grad.data = param.grad.data.to(cpu_device)
    
    # Apply the patch
    open_clip.factory._set_model_device_and_precision = patched_set_model_device_and_precision
    
    # Also patch the Module.to method to handle meta tensors
    def patched_to(self, *args, **kwargs):
        # Check if we're moving from meta device
        if hasattr(self, 'parameters'):
            for param in self.parameters():
                if param.device.type == 'meta':
                    # Use to_empty instead of to for meta tensors
                    if hasattr(self, 'to_empty'):
                        return self.to_empty(device=torch.device('cpu'))
                    else:
                        # Create new tensors with the same shape
                        cpu_device = torch.device('cpu')
                        for name, param in self.named_parameters(recurse=False):
                            if param.device.type == 'meta':
                                new_param = torch.empty_like(param, device=cpu_device)
                                setattr(self, name, torch.nn.Parameter(new_param))
                        for name, buffer in self.named_buffers(recurse=False):
                            if buffer.device.type == 'meta':
                                new_buffer = torch.empty_like(buffer, device=cpu_device)
                                setattr(self, name, new_buffer)
                        return self
    
        # Fallback to original method
        return original_to(self, *args, **kwargs)
    
    # Apply the patch
    nn.Module.to = patched_to
    
except Exception as e:
    print(f"Could not patch open_clip: {e}")

# Load model with patched open_clip to prevent meta tensor issues
try:
    model = AutoModel.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float32
    )
    model = model.to(device)
    
except Exception as e:
    print(f"Model loading failed: {e}")
    # Fallback - try loading with different configuration
    model = AutoModel.from_pretrained(
        model_name,
        trust_remote_code=True
    )
    model = model.to(device)

processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

# Preprocess and normalize text data
with torch.no_grad():
    # Ensure truncation and padding are activated
    processed_texts = processor(
        text=fashion_items,
        return_tensors="pt",
        truncation=True,  # Ensure text is truncated to fit model input size
        padding=True      # Pad shorter sequences so that all are the same length
    )['input_ids']
    
    text_features = model.get_text_features(processed_texts)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

# Prediction function
def predict_from_url(url):
    # Check if the URL is empty
    if not url:
        return {"Error": "Please input a URL"}
    
    try:
        image = Image.open(BytesIO(requests.get(url).content))
    except Exception as e:
        return {"Error": f"Failed to load image: {str(e)}"}
    
    processed_image = processor(images=image, return_tensors="pt")['pixel_values']
    
    with torch.no_grad():
        image_features = model.get_image_features(processed_image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
        
    return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))}

# Gradio interface
demo = gr.Interface(
    fn=predict_from_url,
    inputs=gr.Textbox(label="Enter Image URL"),
    outputs=gr.Label(label="Classification Results"),
    title="Fashion Item Classifier"
)

# Launch the interface
demo.launch()