yolox / app.py
tidalove's picture
Update app.py
336d245 verified
import gradio as gr
import os
import tempfile
import json
import zipfile
from tools.demo_api import build_predictor, run_detection
import shutil
from square_crop import run_square_crop
PREDICTOR = build_predictor(
exp_file = "exps/yolox_s.py",
model_name = "yolox_s",
ckpt_path = "best_ckpt.pth",
device="cpu",
fp16=False,
fuse=False,
trt=False,
conf=0.3,
nms=0.3,
tsize=640,
)
def process_yolox_api(files):
'''API endpoint for YOLOX processing.'''
if not files:
return [], "No files uploaded"
# Create temporary directories
temp_dir = tempfile.mkdtemp()
input_dir = os.path.join(temp_dir, "input")
output_dir = os.path.join(temp_dir, "output")
cropped_dir = os.path.join(temp_dir, "cropped")
os.makedirs(input_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(cropped_dir, exist_ok=True)
# Save uploaded files
for file in files:
if file is not None:
shutil.copy(file.name, input_dir)
try:
coco_json_path = run_detection(PREDICTOR, input_dir)
print("Beginning to crop")
cropped_paths = run_square_crop(input_dir, coco_json_path, cropped_dir)
# Create zip file
zip_path = os.path.join(temp_dir, "cropped_results.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
for img_path in cropped_paths:
if os.path.exists(img_path):
zipf.write(img_path, os.path.basename(img_path))
return zip_path, cropped_paths, f"Processed {len(cropped_paths)} images"
except Exception as e:
return None, [], f"Error: {str(e)}"
# Create interface with API endpoint
with gr.Blocks() as yolox_demo:
gr.Markdown("# YOLOX Auto-Cropping Service")
with gr.Row():
with gr.Column():
files_input = gr.File(label="Upload Images", file_count="multiple", file_types=["image"])
process_btn = gr.Button("Process", variant="primary")
with gr.Column():
gallery_output = gr.Gallery(label="Cropped Images", columns=3, rows=2)
download_output = gr.File(label="Download Results")
status_output = gr.Textbox(label="Status", interactive=False)
process_btn.click(
fn=process_yolox_api,
inputs=[files_input],
outputs=[download_output, gallery_output, status_output],
api_name="yolox_process" # This creates the API endpoint
)
yolox_demo.launch()