rajkumarrawal commited on
Commit
cd112fe
·
1 Parent(s): 7decba4

Initial commit

Browse files
Files changed (5) hide show
  1. .gitattributes +0 -35
  2. .gitignore +1 -0
  3. app.py +58 -0
  4. memory.json +4 -0
  5. requirements.txt +9 -0
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModel, AutoProcessor
3
+ import torch
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+
8
+ fashion_items = ['top', 'trousers', 'jumper']
9
+
10
+ # Load model and processor
11
+ model_name = 'Marqo/marqo-fashionSigLIP'
12
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
13
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
14
+
15
+ # Preprocess and normalize text data
16
+ with torch.no_grad():
17
+ # Ensure truncation and padding are activated
18
+ processed_texts = processor(
19
+ text=fashion_items,
20
+ return_tensors="pt",
21
+ truncation=True, # Ensure text is truncated to fit model input size
22
+ padding=True # Pad shorter sequences so that all are the same length
23
+ )['input_ids']
24
+
25
+ text_features = model.get_text_features(processed_texts)
26
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
27
+
28
+ # Prediction function
29
+ def predict_from_url(url):
30
+ # Check if the URL is empty
31
+ if not url:
32
+ return {"Error": "Please input a URL"}
33
+
34
+ try:
35
+ image = Image.open(BytesIO(requests.get(url).content))
36
+ except Exception as e:
37
+ return {"Error": f"Failed to load image: {str(e)}"}
38
+
39
+ processed_image = processor(images=image, return_tensors="pt")['pixel_values']
40
+
41
+ with torch.no_grad():
42
+ image_features = model.get_image_features(processed_image)
43
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
44
+ text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
45
+
46
+ return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))}
47
+
48
+ # Gradio interface
49
+ demo = gr.Interface(
50
+ fn=predict_from_url,
51
+ inputs=gr.Textbox(label="Enter Image URL"),
52
+ outputs=gr.Label(label="Classification Results"),
53
+ title="Fashion Item Classifier",
54
+ allow_flagging="never"
55
+ )
56
+
57
+ # Launch the interface
58
+ demo.launch()
memory.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "entities": [],
3
+ "relations": []
4
+ }
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ requests
4
+ Pillow
5
+ open_clip_torch
6
+ ftfy
7
+
8
+ # This is only needed for local deployment
9
+ gradio