File size: 5,127 Bytes
ad126c1
 
 
 
 
329b20b
 
 
 
 
 
d798d30
 
329b20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface/transformers"
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.makedirs("/tmp/.cache/huggingface/transformers", exist_ok=True)

import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import re
import asyncio
from typing import Dict, Any
import os 

class AttributeExtractor:
    def __init__(self):
        self.model_name = "Salesforce/blip-image-captioning-base"
        self.processor = None
        self.model = None
        self._load_model()
        
        # Define attribute patterns for text analysis
        self.style_patterns = {
            "formal": ["suit", "blazer", "dress shirt", "tie", "formal", "business", "elegant"],
            "casual": ["t-shirt", "jeans", "sneakers", "hoodie", "casual", "relaxed", "comfortable", "leggings"],
            "sports": ["athletic", "sports", "gym", "workout", "running", "training"]
        }
        
        self.texture_patterns = {
            "cotton": ["cotton", "soft", "comfortable"],
            "denim": ["denim", "jeans", "rugged"],
            "silk": ["silk", "smooth", "shiny", "lustrous", "leggings", "velvet"],
            "wool": ["wool", "warm", "thick"],
            "leather": ["leather", "tough", "durable"],
            "synthetic": ["polyester", "synthetic", "artificial"]
        }
    
    def _load_model(self):
        """Load the BLIP model for image captioning"""
        try:
            print("Loading BLIP model for attribute extraction...")
            self.processor = BlipProcessor.from_pretrained(self.model_name)
            self.model = BlipForConditionalGeneration.from_pretrained(self.model_name)
            self.model.eval()
            print("BLIP model loaded successfully!")
        except Exception as e:
            print(f"Error loading BLIP model: {e}")
            raise
    
    async def extract_attributes(self, image: Image.Image) -> Dict[str, Any]:
        """Extract clothing attributes from image"""
        try:
            loop = asyncio.get_event_loop()
            
            # Generate multiple captions with different prompts
            tasks = [
                loop.run_in_executor(None, self._generate_caption, image, "a photo of"),
                loop.run_in_executor(None, self._generate_caption, image, "clothing style:"),
                loop.run_in_executor(None, self._generate_caption, image, "fabric texture:")
            ]
            
            captions = await asyncio.gather(*tasks)
            
            # Analyze captions to extract attributes
            attributes = self._analyze_captions(captions)
            return attributes
            
        except Exception as e:
            print(f"Attribute extraction error: {e}")
            return {"style": "unknown", "formality": "unknown", "texture": "unknown"}
    
    def _generate_caption(self, image: Image.Image, prompt: str = "") -> str:
        """Generate caption for the image"""
        try:
            if prompt:
                inputs = self.processor(image, prompt, return_tensors="pt")
            else:
                inputs = self.processor(image, return_tensors="pt")
            
            with torch.no_grad():
                out = self.model.generate(**inputs, max_length=50, num_beams=4)
                caption = self.processor.decode(out[0], skip_special_tokens=True)
            
            return caption.lower()
        except Exception as e:
            print(f"Caption generation error: {e}")
            return ""
    
    def _analyze_captions(self, captions: list) -> Dict[str, Any]:
        """Analyze captions to extract structured attributes"""
        combined_text = " ".join(captions).lower()
        
        # Determine style/formality
        formal_score = sum(1 for word in self.style_patterns["formal"] if word in combined_text)
        casual_score = sum(1 for word in self.style_patterns["casual"] if word in combined_text)
        sports_score = sum(1 for word in self.style_patterns["sports"] if word in combined_text)
        
        if formal_score > casual_score and formal_score > sports_score:
            style = "formal"
            formality = "formal"
        elif sports_score > casual_score:
            style = "athletic"
            formality = "casual"
        else:
            style = "casual"
            formality = "casual"
        
        # Determine texture
        texture_scores = {}
        for texture, patterns in self.texture_patterns.items():
            texture_scores[texture] = sum(1 for word in patterns if word in combined_text)
        
        detected_texture = max(texture_scores, key=texture_scores.get) if max(texture_scores.values()) > 0 else "unknown"
        
        return {
            "style": style,
            "formality": formality,
            "texture": detected_texture,
            "confidence": 0.8,
            "raw_captions": captions,
            "detected_keywords": [word for word in combined_text.split() if any(word in patterns for patterns in self.style_patterns.values())]
        }