CP_Project / app.py
ritz26's picture
update the backend to fully clear session and uploaded files on Change Person.
d721dd6
from flask import Flask, render_template, request, send_from_directory, session, redirect, url_for
from PIL import Image
import os, torch, cv2, mediapipe as mp
from transformers import SamModel, SamProcessor, logging as hf_logging
from torchvision import transforms
from diffusers.utils import load_image
from flask_cors import CORS
import json
import time
app= Flask(__name__)
app.secret_key = os.environ.get('SECRET_KEY', 'dev-secret-key-change-in-production') # Change this to a random secret key
CORS(app)
# Enable Hugging Face detailed logs (shows model download progress)
hf_logging.set_verbosity_info()
UPLOAD_FOLDER = '/tmp/uploads'
OUTPUT_FOLDER = '/tmp/outputs'
if not os.path.exists(UPLOAD_FOLDER):
print(f"[WARN] {UPLOAD_FOLDER} does not exist. Creating...")
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
if not os.path.exists(OUTPUT_FOLDER):
print(f"[WARN] {OUTPUT_FOLDER} does not exist. Creating...")
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# Global model variables
model, processor = None, None
device = None
def load_model():
"""Load model on demand (CPU-only to avoid meta tensor/device issues on Spaces)."""
global model, processor, device
# Force CPU on Spaces to avoid meta tensor errors when moving devices
device = "cpu"
print(f"[INFO] Using device: {device}")
print("[INFO] Loading SAM model and processor...")
model = SamModel.from_pretrained(
"Zigeng/SlimSAM-uniform-50",
cache_dir="/tmp/.cache",
torch_dtype=torch.float32,
)
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-50", cache_dir="/tmp/.cache")
# Do NOT move model with .to(); keep it on CPU to prevent meta tensor errors
print("[INFO] Model and processor loaded successfully on CPU!")
def cleanup_temp_files():
"""Clean up temporary files to save storage"""
try:
import shutil
if os.path.exists("/tmp/.cache"):
shutil.rmtree("/tmp/.cache")
print("[INFO] Cleaned up temporary cache files")
except Exception as e:
print(f"[WARNING] Could not clean up temp files: {e}")
def cleanup_old_outputs():
"""Clean up old output files to save storage"""
try:
if os.path.exists(OUTPUT_FOLDER):
for file in os.listdir(OUTPUT_FOLDER):
file_path = os.path.join(OUTPUT_FOLDER, file)
if os.path.isfile(file_path):
# Remove files older than 1 hour
if time.time() - os.path.getctime(file_path) > 3600:
os.remove(file_path)
print(f"[INFO] Removed old output file: {file}")
except Exception as e:
print(f"[WARNING] Could not clean up old outputs: {e}")
@app.before_request
def log_request_info():
print(f"[INFO] Incoming request: {request.method} {request.path}")
@app.route('/health')
def health():
return "OK", 200
# Route to serve outputs dynamically
@app.route('/outputs/<filename>')
def serve_output(filename):
print(f"[DEBUG] Serving file: {filename} from {OUTPUT_FOLDER}")
if not os.path.exists(OUTPUT_FOLDER):
print(f"[ERROR] Output folder does not exist: {OUTPUT_FOLDER}")
return "Output folder not found", 404
file_path = os.path.join(OUTPUT_FOLDER, filename)
if not os.path.exists(file_path):
print(f"[ERROR] File does not exist: {file_path}")
return "File not found", 404
print(f"[DEBUG] File exists, serving: {file_path}")
# Set proper MIME type for images
from flask import Response
if filename.lower().endswith(('.jpg', '.jpeg')):
mimetype = 'image/jpeg'
elif filename.lower().endswith('.png'):
mimetype = 'image/png'
else:
mimetype = 'application/octet-stream'
return send_from_directory(OUTPUT_FOLDER, filename, mimetype=mimetype)
# Route to serve cached person images
@app.route('/uploads/<filename>')
def serve_upload(filename):
return send_from_directory(UPLOAD_FOLDER, filename)
def detect_pose_and_get_coordinates(person_path):
"""Extract pose coordinates from person image"""
mp_pose = mp.solutions.pose
pose = mp_pose.Pose()
image = cv2.imread(person_path)
if image is None:
raise Exception("No image detected.")
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
if not results.pose_landmarks:
raise Exception("No pose detected.")
height, width, _ = image.shape
landmarks = results.pose_landmarks.landmark
left_shoulder = (int(landmarks[11].x * width), int(landmarks[11].y * height))
right_shoulder = (int(landmarks[12].x * width), int(landmarks[12].y * height))
return left_shoulder, right_shoulder
@app.route('/', methods=['GET', 'POST'])
def index():
start_time = time.time()
print(f"[INFO] Handling {request.method} on /")
if request.method == 'POST':
try:
load_model()
# Check if we have a cached person image and coordinates
use_cached_person = 'person_coordinates' in session and 'person_image_path' in session
cached_person_flag = use_cached_person
person_coordinates = None
person_path = None
person_disk_path = os.path.join(UPLOAD_FOLDER, 'person.jpg')
if use_cached_person:
# Use cached person image and coordinates
person_path = session['person_image_path']
person_coordinates = session['person_coordinates']
print(f"[INFO] Using cached person image: {person_path}")
print(f"[INFO] Using cached coordinates: {person_coordinates}")
else:
# Process new person image, or reuse existing person on disk if session missing
person_file = request.files.get('person_image')
if person_file and person_file.filename != '':
# New person uploaded
person_path = person_disk_path
person_file.save(person_path)
print(f"[INFO] Saved new person image to {person_path}")
elif os.path.exists(person_disk_path):
# No upload this time, but previous person still on disk
person_path = person_disk_path
print(f"[INFO] Reusing existing person image on disk: {person_path}")
else:
return "No person image provided. Please upload a person image first."
# Detect pose and get coordinates (regenerate if session missing)
left_shoulder, right_shoulder = detect_pose_and_get_coordinates(person_path)
person_coordinates = {
'left_shoulder': left_shoulder,
'right_shoulder': right_shoulder
}
# Cache the person image and coordinates
session['person_image_path'] = person_path
session['person_coordinates'] = person_coordinates
print(f"[INFO] Cached person coordinates: {person_coordinates}")
cached_person_flag = True
# Process garment image
tshirt_file = request.files['tshirt_image']
tshirt_path = os.path.join(UPLOAD_FOLDER, 'tshirt.png')
tshirt_file.save(tshirt_path)
print(f"[INFO] Saved garment image to {tshirt_path}")
# SAM model inference using cached or new coordinates
img = load_image(person_path)
new_tshirt = load_image(tshirt_path)
input_points = [[[person_coordinates['left_shoulder'][0], person_coordinates['left_shoulder'][1]],
[person_coordinates['right_shoulder'][0], person_coordinates['right_shoulder'][1]]]]
inputs = processor(img, input_points=input_points, return_tensors="pt")
# Move inputs to device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run inference
with torch.no_grad(): # Disable gradient computation for inference
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
mask_tensor = masks[0][0][2].to(dtype=torch.uint8)
mask = transforms.ToPILImage()(mask_tensor * 255)
# Combine images
new_tshirt = new_tshirt.resize(img.size, Image.LANCZOS)
img_with_new_tshirt = Image.composite(new_tshirt, img, mask)
result_path = os.path.join(OUTPUT_FOLDER, 'result.jpg')
# Ensure output directory exists
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# Save the result image
img_with_new_tshirt.save(result_path)
print(f"[INFO] Result saved to {result_path}")
# Verify file was saved
if os.path.exists(result_path):
file_size = os.path.getsize(result_path)
print(f"[DEBUG] File saved successfully, size: {file_size} bytes")
else:
print(f"[ERROR] File was not saved to {result_path}")
# Calculate processing time
processing_time = time.time() - start_time
print(f"[PERF] Total processing time: {processing_time:.2f}s")
# Clean up old files to save storage
cleanup_old_outputs()
# Generate a unique filename to avoid caching issues
import uuid
unique_filename = f"result_{uuid.uuid4().hex[:8]}.jpg"
unique_result_path = os.path.join(OUTPUT_FOLDER, unique_filename)
# Copy the result to a unique filename
import shutil
shutil.copy2(result_path, unique_result_path)
# Serve via dynamic route with cached person info
return render_template('index.html',
result_img=f'/outputs/{unique_filename}',
cached_person=cached_person_flag,
person_image_path=person_path,
processing_time=f"{processing_time:.2f}s")
except Exception as e:
print(f"[ERROR] {e}")
return f"Error: {e}"
# GET request: keep person image visible if available in session
has_cached = 'person_coordinates' in session and 'person_image_path' in session
return render_template(
'index.html',
cached_person=has_cached,
person_image_path=session.get('person_image_path') if has_cached else None
)
@app.route('/change_person', methods=['POST'])
def change_person():
"""Clear cached person data to allow new person upload"""
session.pop('person_coordinates', None)
session.pop('person_image_path', None)
# Remove uploaded and output files to reset state
try:
person_disk_path = os.path.join(UPLOAD_FOLDER, 'person.jpg')
tshirt_disk_path = os.path.join(UPLOAD_FOLDER, 'tshirt.png')
if os.path.exists(person_disk_path):
os.remove(person_disk_path)
if os.path.exists(tshirt_disk_path):
os.remove(tshirt_disk_path)
if os.path.exists(OUTPUT_FOLDER):
for file in os.listdir(OUTPUT_FOLDER):
file_path = os.path.join(OUTPUT_FOLDER, file)
if os.path.isfile(file_path):
os.remove(file_path)
print("[INFO] Cleared cached person data and temp files")
except Exception as e:
print(f"[WARNING] Failed to clear files: {e}")
# Redirect to GET / so the app reloads fresh
return redirect(url_for('index'))
@app.route('/cleanup', methods=['POST'])
def cleanup():
"""Manual cleanup of temporary files"""
cleanup_temp_files()
cleanup_old_outputs()
return "Cleanup completed", 200
@app.route('/test-image')
def test_image():
"""Test route to check if image serving works"""
# Create a simple test image
from PIL import Image, ImageDraw
img = Image.new('RGB', (200, 200), color='red')
draw = ImageDraw.Draw(img)
draw.text((50, 100), "TEST IMAGE", fill='white')
test_path = os.path.join(OUTPUT_FOLDER, 'test.jpg')
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
img.save(test_path)
return f'<img src="/outputs/test.jpg" alt="Test Image">'
if __name__ == '__main__':
print("[INFO] Starting Flask server...")
print("[INFO] Model will be loaded on first request to save memory...")
app.run(debug=True, host='0.0.0.0')