|
|
from flask import Flask, render_template, request, send_from_directory |
|
|
from PIL import Image |
|
|
import os, torch, cv2, mediapipe as mp |
|
|
from transformers import SamModel, SamProcessor, logging as hf_logging |
|
|
from torchvision import transforms |
|
|
from diffusers.utils import load_image |
|
|
from flask_cors import CORS |
|
|
|
|
|
app= Flask(__name__) |
|
|
CORS(app) |
|
|
|
|
|
|
|
|
hf_logging.set_verbosity_info() |
|
|
|
|
|
|
|
|
UPLOAD_FOLDER = '/tmp/uploads' |
|
|
OUTPUT_FOLDER = '/tmp/outputs' |
|
|
|
|
|
if not os.path.exists(UPLOAD_FOLDER): |
|
|
print(f"[WARN] {UPLOAD_FOLDER} does not exist. Creating...") |
|
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True) |
|
|
|
|
|
if not os.path.exists(OUTPUT_FOLDER): |
|
|
print(f"[WARN] {OUTPUT_FOLDER} does not exist. Creating...") |
|
|
os.makedirs(OUTPUT_FOLDER, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
model, processor = None, None |
|
|
|
|
|
def load_model(): |
|
|
global model, processor |
|
|
if model is None or processor is None: |
|
|
print("[INFO] Loading SAM model and processor...") |
|
|
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-50", cache_dir="/app/.cache") |
|
|
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-50", cache_dir="/app/.cache") |
|
|
print("[INFO] Model and processor loaded successfully!") |
|
|
|
|
|
@app.before_request |
|
|
def log_request_info(): |
|
|
print(f"[INFO] Incoming request: {request.method} {request.path}") |
|
|
|
|
|
@app.route('/health') |
|
|
def health(): |
|
|
return "OK", 200 |
|
|
|
|
|
|
|
|
@app.route('/outputs/<filename>') |
|
|
def serve_output(filename): |
|
|
return send_from_directory(OUTPUT_FOLDER, filename) |
|
|
|
|
|
@app.route('/', methods=['GET', 'POST']) |
|
|
def index(): |
|
|
print(f"[INFO] Handling {request.method} on /") |
|
|
if request.method == 'POST': |
|
|
try: |
|
|
load_model() |
|
|
|
|
|
|
|
|
person_file = request.files['person_image'] |
|
|
tshirt_file = request.files['tshirt_image'] |
|
|
person_path = os.path.join(UPLOAD_FOLDER, 'person.jpg') |
|
|
tshirt_path = os.path.join(UPLOAD_FOLDER, 'tshirt.png') |
|
|
person_file.save(person_path) |
|
|
tshirt_file.save(tshirt_path) |
|
|
print(f"[INFO] Saved files to {UPLOAD_FOLDER}") |
|
|
|
|
|
|
|
|
mp_pose = mp.solutions.pose |
|
|
pose = mp_pose.Pose() |
|
|
image = cv2.imread(person_path) |
|
|
if image is None: |
|
|
return "No image detected." |
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
results = pose.process(image_rgb) |
|
|
if not results.pose_landmarks: |
|
|
return "No pose detected." |
|
|
height, width, _ = image.shape |
|
|
landmarks = results.pose_landmarks.landmark |
|
|
left_shoulder = (int(landmarks[11].x * width), int(landmarks[11].y * height)) |
|
|
right_shoulder = (int(landmarks[12].x * width), int(landmarks[12].y * height)) |
|
|
print(f"[INFO] Shoulder coordinates: {left_shoulder}, {right_shoulder}") |
|
|
|
|
|
|
|
|
img = load_image(person_path) |
|
|
new_tshirt = load_image(tshirt_path) |
|
|
input_points = [[[left_shoulder[0], left_shoulder[1]], [right_shoulder[0], right_shoulder[1]]]] |
|
|
inputs = processor(img, input_points=input_points, return_tensors="pt") |
|
|
outputs = model(**inputs) |
|
|
masks = processor.image_processor.post_process_masks( |
|
|
outputs.pred_masks.cpu(), |
|
|
inputs["original_sizes"].cpu(), |
|
|
inputs["reshaped_input_sizes"].cpu() |
|
|
) |
|
|
mask_tensor = masks[0][0][2].to(dtype=torch.uint8) |
|
|
mask = transforms.ToPILImage()(mask_tensor * 255) |
|
|
|
|
|
|
|
|
new_tshirt = new_tshirt.resize(img.size, Image.LANCZOS) |
|
|
img_with_new_tshirt = Image.composite(new_tshirt, img, mask) |
|
|
result_path = os.path.join(OUTPUT_FOLDER, 'result.jpg') |
|
|
img_with_new_tshirt.save(result_path) |
|
|
print(f"[INFO] Result saved to {result_path}") |
|
|
|
|
|
|
|
|
return render_template('index.html', result_img='/outputs/result.jpg') |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR] {e}") |
|
|
return f"Error: {e}" |
|
|
|
|
|
return render_template('index.html') |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
print("[INFO] Starting Flask server...") |
|
|
app.run(debug=True, host='0.0.0.0') |
|
|
|