File size: 4,908 Bytes
8afdc67
cb8d367
 
 
 
 
 
 
8afdc67
a7f24f5
cb8d367
 
8afdc67
cb8d367
 
 
 
a7f24f5
8afdc67
a7f24f5
 
 
 
 
 
 
 
 
 
cb8d367
8afdc67
cb8d367
 
8afdc67
cb8d367
8afdc67
 
cb8d367
 
 
 
 
 
 
 
 
 
 
 
 
8afdc67
 
 
 
a7f24f5
cb8d367
8afdc67
cb8d367
 
 
8afdc67
a7f24f5
cb8d367
 
8afdc67
 
cb8d367
8afdc67
cb8d367
 
 
 
8afdc67
 
 
a7f24f5
8afdc67
 
5f21563
8afdc67
 
cb8d367
 
8afdc67
 
cb8d367
a7f24f5
8afdc67
a7f24f5
 
 
 
8afdc67
 
a7f24f5
8afdc67
a7f24f5
 
8afdc67
 
a7f24f5
 
 
8afdc67
 
 
cb8d367
8afdc67
 
cb8d367
 
 
8afdc67
 
cb8d367
8afdc67
 
 
a7f24f5
 
8afdc67
5f21563
cb8d367
 
8afdc67
a7f24f5
8afdc67
 
cb8d367
 
8afdc67
a7f24f5
8afdc67
 
a7f24f5
 
8afdc67
cb8d367
8afdc67
 
cb8d367
 
8afdc67
 
 
 
 
 
a7f24f5
8afdc67
a7f24f5
8afdc67
a7f24f5
 
 
 
cb8d367
8afdc67
 
 
 
 
cb8d367
 
 
8afdc67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# app_fast.py - Vintern-1B Fast Version
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
import time
import json
import traceback

# Setup
device = "cpu"
model = None
tokenizer = None
transform = None

def build_transform(input_size=448):
    """Optimized transform"""
    IMAGENET_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_STD = (0.229, 0.224, 0.225)
    
    return T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if hasattr(img, 'mode') and img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])

def load_model():
    """Load Vintern-1B (faster version)"""
    global model, tokenizer, transform
    try:
        print("🚀 Loading Vintern-1B (Fast Version)...")
        
        # Sử dụng model nhẹ hơn
        model_name = "5CD-AI/Vintern-1B-v2"  # Thay vì v3.5
        
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True
        )
        
        model = AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )
        
        # Optimize model for inference
        model.eval()
        model = torch.jit.optimize_for_inference(model)
        
        transform = build_transform()
        
        print("✅ Fast model loaded!")
        return True
        
    except Exception as e:
        print(f"❌ Error: {e}")
        traceback.print_exc()
        return False

def fast_analyze(image):
    """Optimized analysis function"""
    if model is None:
        return "❌ Model chưa sẵn sàng"
    
    try:
        start_time = time.time()
        
        # Quick image processing
        if image is None:
            return "❌ Không có ảnh"
            
        if hasattr(image, 'mode') and image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Fast transform
        image_tensor = transform(image).unsqueeze(0).to(device)
        
        with torch.no_grad():
            # Shorter, faster generation
            query = "Mô tả ngắn gọn:"
            
            try:
                result = model.chat(
                    tokenizer,
                    image_tensor,
                    query,
                    generation_config=dict(
                        max_new_tokens=100,  # Ngắn hơn → nhanh hơn
                        do_sample=False,     # Greedy → nhanh hơn
                        temperature=0.7,
                        num_beams=1         # No beam search → nhanh hơn
                    )
                )
            except:
                # Fallback nhanh
                inputs = tokenizer(query, return_tensors="pt").to(device)
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=80,
                    do_sample=False,
                    num_beams=1
                )
                result = tokenizer.decode(outputs[0], skip_special_tokens=True)
                result = result.replace(query, "").strip()
            
            processing_time = time.time() - start_time
            
            return f"""**📝 Mô tả nhanh:**
{result}

**⚡ Thời gian:** {processing_time:.1f}s
**🤖 Model:** Vintern-1B-v2 (Optimized)
**💨 Tốc độ:** {1/processing_time:.1f} FPS

---
*Model được tối ưu cho tốc độ - phù hợp real-time*
"""
    
    except Exception as e:
        return f"❌ Lỗi: {str(e)}"

# Load model
print("🚀 Starting Fast Vintern Server...")
model_loaded = load_model()

# Lightweight Gradio interface
with gr.Blocks(
    title="Vintern-1B Fast",
    theme=gr.themes.Base(),
) as demo:
    
    gr.Markdown("# ⚡ Vintern-1B Fast - Tốc Độ Cao")
    
    if model_loaded:
        gr.Markdown("✅ **Model sẵn sàng!** Tối ưu cho tốc độ và real-time.")
    
    with gr.Row():
        image_input = gr.Image(type="pil", label="📤 Upload Ảnh")
        result_output = gr.Textbox(
            label="📋 Kết Quả", 
            lines=8,
            show_copy_button=True
        )
    
    # Auto-analyze on upload
    image_input.change(
        fn=fast_analyze,
        inputs=image_input,
        outputs=result_output
    )
    
    gr.Markdown("""
    ### ⚡ Tối ưu cho tốc độ:
    - **Model nhẹ**: Vintern-1B-v2 (~1.5GB)
    - **Fast generation**: Greedy decode, short output
    - **Optimized**: JIT compilation, no beam search
    - **Real-time ready**: ~2-5 giây/ảnh
    """)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)