ritz26 commited on
Commit
7ec1256
Β·
1 Parent(s): 965e09e

Improve model perfomance by stored body coordinates...

Browse files
Files changed (4) hide show
  1. Dockerfile +12 -5
  2. app.py +144 -35
  3. requirements.txt +2 -1
  4. templates/index.html +64 -3
Dockerfile CHANGED
@@ -2,12 +2,19 @@ FROM python:3.10.12-slim
2
 
3
  WORKDIR /app
4
 
5
- # Set cache directories
6
  ENV HF_HOME=/app/.cache
7
  ENV MPLCONFIGDIR=/app/.cache
 
 
8
 
9
- # Install dependencies
10
- RUN apt-get update && apt-get install -y libgl1-mesa-glx libglib2.0-0 && apt-get clean
 
 
 
 
 
11
 
12
  # Create cache directory
13
  RUN mkdir -p /app/.cache && chmod -R 777 /app/.cache
@@ -27,5 +34,5 @@ COPY . .
27
  # Set port
28
  ENV PORT=7860
29
 
30
- # Run with Gunicorn
31
- CMD ["gunicorn", "app:app", "--bind", "0.0.0.0:7860", "--workers", "2"]
 
2
 
3
  WORKDIR /app
4
 
5
+ # Set cache directories and performance optimizations
6
  ENV HF_HOME=/app/.cache
7
  ENV MPLCONFIGDIR=/app/.cache
8
+ ENV OMP_NUM_THREADS=4
9
+ ENV TOKENIZERS_PARALLELISM=false
10
 
11
+ # Install dependencies including CUDA support
12
+ RUN apt-get update && apt-get install -y \
13
+ libgl1-mesa-glx \
14
+ libglib2.0-0 \
15
+ libgomp1 \
16
+ && apt-get clean \
17
+ && rm -rf /var/lib/apt/lists/*
18
 
19
  # Create cache directory
20
  RUN mkdir -p /app/.cache && chmod -R 777 /app/.cache
 
34
  # Set port
35
  ENV PORT=7860
36
 
37
+ # Run with Gunicorn with optimized settings
38
+ CMD ["gunicorn", "app:app", "--bind", "0.0.0.0:7860", "--workers", "1", "--threads", "4", "--worker-class", "gthread", "--timeout", "300"]
app.py CHANGED
@@ -1,12 +1,15 @@
1
- from flask import Flask, render_template, request, send_from_directory
2
  from PIL import Image
3
  import os, torch, cv2, mediapipe as mp
4
  from transformers import SamModel, SamProcessor, logging as hf_logging
5
  from torchvision import transforms
6
  from diffusers.utils import load_image
7
  from flask_cors import CORS
 
 
8
 
9
  app= Flask(__name__)
 
10
  CORS(app)
11
 
12
  # Enable Hugging Face detailed logs (shows model download progress)
@@ -25,16 +28,52 @@ if not os.path.exists(OUTPUT_FOLDER):
25
  os.makedirs(OUTPUT_FOLDER, exist_ok=True)
26
 
27
 
28
- # Lazy-load model
29
  model, processor = None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def load_model():
 
32
  global model, processor
33
  if model is None or processor is None:
34
- print("[INFO] Loading SAM model and processor...")
35
- model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-50", cache_dir="/app/.cache")
36
- processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-50", cache_dir="/app/.cache")
37
- print("[INFO] Model and processor loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  @app.before_request
40
  def log_request_info():
@@ -49,44 +88,92 @@ def health():
49
  def serve_output(filename):
50
  return send_from_directory(OUTPUT_FOLDER, filename)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  @app.route('/', methods=['GET', 'POST'])
53
  def index():
 
54
  print(f"[INFO] Handling {request.method} on /")
55
  if request.method == 'POST':
56
  try:
57
  load_model()
58
-
59
- # Save uploaded images
60
- person_file = request.files['person_image']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  tshirt_file = request.files['tshirt_image']
62
- person_path = os.path.join(UPLOAD_FOLDER, 'person.jpg')
63
  tshirt_path = os.path.join(UPLOAD_FOLDER, 'tshirt.png')
64
- person_file.save(person_path)
65
  tshirt_file.save(tshirt_path)
66
- print(f"[INFO] Saved files to {UPLOAD_FOLDER}")
67
-
68
- # Pose detection
69
- mp_pose = mp.solutions.pose
70
- pose = mp_pose.Pose()
71
- image = cv2.imread(person_path)
72
- if image is None:
73
- return "No image detected."
74
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
75
- results = pose.process(image_rgb)
76
- if not results.pose_landmarks:
77
- return "No pose detected."
78
- height, width, _ = image.shape
79
- landmarks = results.pose_landmarks.landmark
80
- left_shoulder = (int(landmarks[11].x * width), int(landmarks[11].y * height))
81
- right_shoulder = (int(landmarks[12].x * width), int(landmarks[12].y * height))
82
- print(f"[INFO] Shoulder coordinates: {left_shoulder}, {right_shoulder}")
83
-
84
- # SAM model inference
85
  img = load_image(person_path)
86
  new_tshirt = load_image(tshirt_path)
87
- input_points = [[[left_shoulder[0], left_shoulder[1]], [right_shoulder[0], right_shoulder[1]]]]
 
88
  inputs = processor(img, input_points=input_points, return_tensors="pt")
89
- outputs = model(**inputs)
 
 
 
 
 
 
 
90
  masks = processor.image_processor.post_process_masks(
91
  outputs.pred_masks.cpu(),
92
  inputs["original_sizes"].cpu(),
@@ -102,8 +189,16 @@ def index():
102
  img_with_new_tshirt.save(result_path)
103
  print(f"[INFO] Result saved to {result_path}")
104
 
105
- # Serve via dynamic route
106
- return render_template('index.html', result_img='/outputs/result.jpg')
 
 
 
 
 
 
 
 
107
 
108
  except Exception as e:
109
  print(f"[ERROR] {e}")
@@ -111,7 +206,21 @@ def index():
111
 
112
  return render_template('index.html')
113
 
 
 
 
 
 
 
 
 
114
  if __name__ == '__main__':
115
-
 
 
 
 
 
 
116
  print("[INFO] Starting Flask server...")
117
  app.run(debug=True, host='0.0.0.0')
 
1
+ from flask import Flask, render_template, request, send_from_directory, session
2
  from PIL import Image
3
  import os, torch, cv2, mediapipe as mp
4
  from transformers import SamModel, SamProcessor, logging as hf_logging
5
  from torchvision import transforms
6
  from diffusers.utils import load_image
7
  from flask_cors import CORS
8
+ import json
9
+ import time
10
 
11
  app= Flask(__name__)
12
+ app.secret_key = os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production') # Change this to a random secret key
13
  CORS(app)
14
 
15
  # Enable Hugging Face detailed logs (shows model download progress)
 
28
  os.makedirs(OUTPUT_FOLDER, exist_ok=True)
29
 
30
 
31
+ # Global model variables
32
  model, processor = None, None
33
+ device = None
34
+
35
+ def initialize_model():
36
+ """Initialize model once at startup"""
37
+ global model, processor, device
38
+
39
+ # Determine device
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ print(f"[INFO] Using device: {device}")
42
+
43
+ print("[INFO] Loading SAM model and processor...")
44
+ model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-50", cache_dir="/app/.cache")
45
+ processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-50", cache_dir="/app/.cache")
46
+
47
+ # Move model to device
48
+ model = model.to(device)
49
+ print(f"[INFO] Model and processor loaded successfully on {device}!")
50
 
51
  def load_model():
52
+ """Ensure model is loaded (should already be loaded at startup)"""
53
  global model, processor
54
  if model is None or processor is None:
55
+ print("[WARNING] Model not loaded, initializing now...")
56
+ initialize_model()
57
+
58
+ def warmup_model():
59
+ """Warm up the model with a dummy inference"""
60
+ global model, processor, device
61
+ if model is None or processor is None:
62
+ return
63
+
64
+ print("[INFO] Warming up model...")
65
+ try:
66
+ # Create a dummy image and points for warmup
67
+ dummy_img = Image.new('RGB', (512, 512), color='white')
68
+ dummy_points = [[[256, 256], [300, 300]]]
69
+ inputs = processor(dummy_img, input_points=dummy_points, return_tensors="pt")
70
+ inputs = {k: v.to(device) for k, v in inputs.items()}
71
+
72
+ with torch.no_grad():
73
+ _ = model(**inputs)
74
+ print("[INFO] Model warmup completed!")
75
+ except Exception as e:
76
+ print(f"[WARNING] Model warmup failed: {e}")
77
 
78
  @app.before_request
79
  def log_request_info():
 
88
  def serve_output(filename):
89
  return send_from_directory(OUTPUT_FOLDER, filename)
90
 
91
+ # Route to serve cached person images
92
+ @app.route('/uploads/<filename>')
93
+ def serve_upload(filename):
94
+ return send_from_directory(UPLOAD_FOLDER, filename)
95
+
96
+ def detect_pose_and_get_coordinates(person_path):
97
+ """Extract pose coordinates from person image"""
98
+ mp_pose = mp.solutions.pose
99
+ pose = mp_pose.Pose()
100
+ image = cv2.imread(person_path)
101
+ if image is None:
102
+ raise Exception("No image detected.")
103
+
104
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
105
+ results = pose.process(image_rgb)
106
+ if not results.pose_landmarks:
107
+ raise Exception("No pose detected.")
108
+
109
+ height, width, _ = image.shape
110
+ landmarks = results.pose_landmarks.landmark
111
+ left_shoulder = (int(landmarks[11].x * width), int(landmarks[11].y * height))
112
+ right_shoulder = (int(landmarks[12].x * width), int(landmarks[12].y * height))
113
+
114
+ return left_shoulder, right_shoulder
115
+
116
  @app.route('/', methods=['GET', 'POST'])
117
  def index():
118
+ start_time = time.time()
119
  print(f"[INFO] Handling {request.method} on /")
120
  if request.method == 'POST':
121
  try:
122
  load_model()
123
+
124
+ # Check if we have a cached person image and coordinates
125
+ use_cached_person = 'person_coordinates' in session and 'person_image_path' in session
126
+ person_coordinates = None
127
+ person_path = None
128
+
129
+ if use_cached_person:
130
+ # Use cached person image and coordinates
131
+ person_path = session['person_image_path']
132
+ person_coordinates = session['person_coordinates']
133
+ print(f"[INFO] Using cached person image: {person_path}")
134
+ print(f"[INFO] Using cached coordinates: {person_coordinates}")
135
+ else:
136
+ # Process new person image
137
+ person_file = request.files.get('person_image')
138
+ if not person_file or person_file.filename == '':
139
+ return "No person image provided. Please upload a person image first."
140
+
141
+ person_path = os.path.join(UPLOAD_FOLDER, 'person.jpg')
142
+ person_file.save(person_path)
143
+ print(f"[INFO] Saved new person image to {person_path}")
144
+
145
+ # Detect pose and get coordinates
146
+ left_shoulder, right_shoulder = detect_pose_and_get_coordinates(person_path)
147
+ person_coordinates = {
148
+ 'left_shoulder': left_shoulder,
149
+ 'right_shoulder': right_shoulder
150
+ }
151
+
152
+ # Cache the person image and coordinates
153
+ session['person_image_path'] = person_path
154
+ session['person_coordinates'] = person_coordinates
155
+ print(f"[INFO] Cached person coordinates: {person_coordinates}")
156
+
157
+ # Process garment image
158
  tshirt_file = request.files['tshirt_image']
 
159
  tshirt_path = os.path.join(UPLOAD_FOLDER, 'tshirt.png')
 
160
  tshirt_file.save(tshirt_path)
161
+ print(f"[INFO] Saved garment image to {tshirt_path}")
162
+
163
+ # SAM model inference using cached or new coordinates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  img = load_image(person_path)
165
  new_tshirt = load_image(tshirt_path)
166
+ input_points = [[[person_coordinates['left_shoulder'][0], person_coordinates['left_shoulder'][1]],
167
+ [person_coordinates['right_shoulder'][0], person_coordinates['right_shoulder'][1]]]]
168
  inputs = processor(img, input_points=input_points, return_tensors="pt")
169
+
170
+ # Move inputs to device
171
+ inputs = {k: v.to(device) for k, v in inputs.items()}
172
+
173
+ # Run inference
174
+ with torch.no_grad(): # Disable gradient computation for inference
175
+ outputs = model(**inputs)
176
+
177
  masks = processor.image_processor.post_process_masks(
178
  outputs.pred_masks.cpu(),
179
  inputs["original_sizes"].cpu(),
 
189
  img_with_new_tshirt.save(result_path)
190
  print(f"[INFO] Result saved to {result_path}")
191
 
192
+ # Calculate processing time
193
+ processing_time = time.time() - start_time
194
+ print(f"[PERF] Total processing time: {processing_time:.2f}s")
195
+
196
+ # Serve via dynamic route with cached person info
197
+ return render_template('index.html',
198
+ result_img='/outputs/result.jpg',
199
+ cached_person=use_cached_person,
200
+ person_image_path=person_path,
201
+ processing_time=f"{processing_time:.2f}s")
202
 
203
  except Exception as e:
204
  print(f"[ERROR] {e}")
 
206
 
207
  return render_template('index.html')
208
 
209
+ @app.route('/change_person', methods=['POST'])
210
+ def change_person():
211
+ """Clear cached person data to allow new person upload"""
212
+ session.pop('person_coordinates', None)
213
+ session.pop('person_image_path', None)
214
+ print("[INFO] Cleared cached person data")
215
+ return render_template('index.html')
216
+
217
  if __name__ == '__main__':
218
+ # Initialize model at startup
219
+ print("[INFO] Initializing model...")
220
+ initialize_model()
221
+
222
+ # Warm up the model
223
+ warmup_model()
224
+
225
  print("[INFO] Starting Flask server...")
226
  app.run(debug=True, host='0.0.0.0')
requirements.txt CHANGED
@@ -8,4 +8,5 @@ mediapipe
8
  transformers
9
  diffusers
10
  safetensors
11
- flask-cors
 
 
8
  transformers
9
  diffusers
10
  safetensors
11
+ flask-cors
12
+ accelerate
templates/index.html CHANGED
@@ -21,7 +21,31 @@
21
  <div class="grid grid-cols-1 gap-8">
22
  <!-- Person Image Upload -->
23
  <div>
24
- <h2 class="text-lg font-semibold mb-2">Upload your photo</h2>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  <label for="person_image" class="flex flex-col items-center justify-center border-2 border-dashed border-gray-600 rounded-xl p-6 hover:bg-gray-700 cursor-pointer">
26
  <svg xmlns="http://www.w3.org/2000/svg" class="h-10 w-10 text-gray-400 mb-2" fill="none"
27
  viewBox="0 0 24 24" stroke="currentColor">
@@ -29,18 +53,29 @@
29
  d="M7 16v4m0 0h10m-10 0v-4m0 0h10m-10 0V5m0 0h10m-10 0H5m14 0h-2" />
30
  </svg>
31
  <p class="text-gray-400">Drag & drop or click to upload</p>
32
- <input id="person_image" type="file" name="person_image" class="hidden" required
33
  onchange="showFileName('person_image', 'person_filename', 'person_preview')">
34
  </label>
35
  <p id="person_filename" class="text-green-400 text-sm mt-2 text-center"></p>
36
  <div class="mt-3 flex justify-center">
37
  <img id="person_preview" class="hidden max-h-32 rounded-lg border border-gray-600">
38
  </div>
 
39
  </div>
40
 
41
  <!-- Garment Image Upload with Cropper -->
42
  <div>
43
- <h2 class="text-lg font-semibold mb-2">Upload garment image</h2>
 
 
 
 
 
 
 
 
 
 
44
  <label for="tshirt_image" class="flex flex-col items-center justify-center border-2 border-dashed border-gray-600 rounded-xl p-6 hover:bg-gray-700 cursor-pointer">
45
  <svg xmlns="http://www.w3.org/2000/svg" class="h-10 w-10 text-gray-400 mb-2" fill="none"
46
  viewBox="0 0 24 24" stroke="currentColor">
@@ -73,6 +108,13 @@
73
  {% if result_img %}
74
  <div>
75
  <h2 class="text-2xl font-bold mb-6 text-center">πŸŽ‰ Your Virtual Try-On Result</h2>
 
 
 
 
 
 
 
76
  <div class="flex justify-center mb-6">
77
  <img id="result-image" src="{{ result_img }}" alt="Result Image" class="rounded-xl">
78
  </div>
@@ -139,6 +181,25 @@
139
  document.getElementById('loading-spinner').style.display = 'flex';
140
  });
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  // Show file name + preview (person only)
143
  function showFileName(inputId, filenameId, previewId) {
144
  const input = document.getElementById(inputId);
 
21
  <div class="grid grid-cols-1 gap-8">
22
  <!-- Person Image Upload -->
23
  <div>
24
+ <div class="flex justify-between items-center mb-2">
25
+ <h2 class="text-lg font-semibold">Upload your photo</h2>
26
+ {% if cached_person %}
27
+ <button type="button" onclick="changePerson()" class="bg-yellow-500 hover:bg-yellow-600 text-white text-sm px-3 py-1 rounded-lg transition">
28
+ Change Person
29
+ </button>
30
+ {% endif %}
31
+ </div>
32
+
33
+ {% if cached_person %}
34
+ <!-- Show cached person image -->
35
+ <div class="border-2 border-green-500 rounded-xl p-4 bg-green-900/20">
36
+ <div class="flex flex-col items-center">
37
+ <div class="flex items-center gap-2 mb-2">
38
+ <svg xmlns="http://www.w3.org/2000/svg" class="h-5 w-5 text-green-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
39
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7" />
40
+ </svg>
41
+ <p class="text-green-400 text-sm font-medium">Person image cached</p>
42
+ </div>
43
+ <img src="/uploads/person.jpg" alt="Cached Person" class="max-h-32 rounded-lg border border-gray-600">
44
+ <p class="text-gray-400 text-xs mt-2">Coordinates saved - no need to re-upload!</p>
45
+ </div>
46
+ </div>
47
+ {% else %}
48
+ <!-- Upload new person image -->
49
  <label for="person_image" class="flex flex-col items-center justify-center border-2 border-dashed border-gray-600 rounded-xl p-6 hover:bg-gray-700 cursor-pointer">
50
  <svg xmlns="http://www.w3.org/2000/svg" class="h-10 w-10 text-gray-400 mb-2" fill="none"
51
  viewBox="0 0 24 24" stroke="currentColor">
 
53
  d="M7 16v4m0 0h10m-10 0v-4m0 0h10m-10 0V5m0 0h10m-10 0H5m14 0h-2" />
54
  </svg>
55
  <p class="text-gray-400">Drag & drop or click to upload</p>
56
+ <input id="person_image" type="file" name="person_image" class="hidden"
57
  onchange="showFileName('person_image', 'person_filename', 'person_preview')">
58
  </label>
59
  <p id="person_filename" class="text-green-400 text-sm mt-2 text-center"></p>
60
  <div class="mt-3 flex justify-center">
61
  <img id="person_preview" class="hidden max-h-32 rounded-lg border border-gray-600">
62
  </div>
63
+ {% endif %}
64
  </div>
65
 
66
  <!-- Garment Image Upload with Cropper -->
67
  <div>
68
+ <div class="flex justify-between items-center mb-2">
69
+ <h2 class="text-lg font-semibold">Upload garment image</h2>
70
+ {% if cached_person %}
71
+ <div class="flex items-center gap-1 text-xs text-green-400">
72
+ <svg xmlns="http://www.w3.org/2000/svg" class="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
73
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z" />
74
+ </svg>
75
+ <span>Fast mode</span>
76
+ </div>
77
+ {% endif %}
78
+ </div>
79
  <label for="tshirt_image" class="flex flex-col items-center justify-center border-2 border-dashed border-gray-600 rounded-xl p-6 hover:bg-gray-700 cursor-pointer">
80
  <svg xmlns="http://www.w3.org/2000/svg" class="h-10 w-10 text-gray-400 mb-2" fill="none"
81
  viewBox="0 0 24 24" stroke="currentColor">
 
108
  {% if result_img %}
109
  <div>
110
  <h2 class="text-2xl font-bold mb-6 text-center">πŸŽ‰ Your Virtual Try-On Result</h2>
111
+ {% if processing_time %}
112
+ <div class="text-center mb-4">
113
+ <span class="bg-blue-500 text-white px-3 py-1 rounded-full text-sm">
114
+ ⚑ Processed in {{ processing_time }}
115
+ </span>
116
+ </div>
117
+ {% endif %}
118
  <div class="flex justify-center mb-6">
119
  <img id="result-image" src="{{ result_img }}" alt="Result Image" class="rounded-xl">
120
  </div>
 
181
  document.getElementById('loading-spinner').style.display = 'flex';
182
  });
183
 
184
+ // Change Person function
185
+ function changePerson() {
186
+ fetch('/change_person', {
187
+ method: 'POST',
188
+ headers: {
189
+ 'Content-Type': 'application/x-www-form-urlencoded',
190
+ }
191
+ })
192
+ .then(response => response.text())
193
+ .then(html => {
194
+ // Replace the current page content
195
+ document.documentElement.innerHTML = html;
196
+ })
197
+ .catch(error => {
198
+ console.error('Error changing person:', error);
199
+ alert('Error changing person. Please refresh the page.');
200
+ });
201
+ }
202
+
203
  // Show file name + preview (person only)
204
  function showFileName(inputId, filenameId, previewId) {
205
  const input = document.getElementById(inputId);