Spaces:
Running
Running
# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import matplotlib.pyplot as plt | |
from flask import Flask, request, jsonify, render_template | |
import os | |
import io | |
import numpy as np | |
import torch | |
import yaml | |
import matplotlib | |
import argparse | |
matplotlib.use('Agg') | |
app = Flask(__name__, static_folder='static', template_folder='templates') | |
# βββ Arguments βββββββββββββββββββββββββββββββββββ | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--save_dir', type=str, default='videos_example') | |
args = parser.parse_args() | |
# βββ Configuration βββββββββββββββββββββββββββββ | |
BASE_DIR = args.save_dir | |
STATIC_BASE = os.path.join('static', BASE_DIR) | |
IMAGES_DIR = os.path.join(STATIC_BASE, 'images') | |
OVERLAY_DIR = os.path.join(STATIC_BASE, 'images_tracks') | |
TRACKS_DIR = os.path.join(BASE_DIR, 'tracks') | |
YAML_PATH = os.path.join(BASE_DIR, 'test.yaml') | |
IMAGES_DIR_OUT = os.path.join(BASE_DIR, 'images') | |
FIXED_LENGTH = 121 | |
COLOR_CYCLE = ['r', 'g', 'b', 'c', 'm', 'y', 'k'] | |
QUANT_MULTI = 8 | |
for d in (IMAGES_DIR, TRACKS_DIR, OVERLAY_DIR, IMAGES_DIR_OUT): | |
os.makedirs(d, exist_ok=True) | |
# βββ Helpers βββββββββββββββββββββββββββββββββββββββ | |
def array_to_npz_bytes(arr, path, compressed=True, quant_multi=QUANT_MULTI): | |
# pack into uint16 as before | |
arr_q = (quant_multi * arr).astype(np.float32) | |
bio = io.BytesIO() | |
if compressed: | |
np.savez_compressed(bio, array=arr_q) | |
else: | |
np.savez(bio, array=arr_q) | |
torch.save(bio.getvalue(), path) | |
def load_existing_tracks(path): | |
raw = torch.load(path) | |
bio = io.BytesIO(raw) | |
with np.load(bio) as npz: | |
return npz['array'] | |
# βββ Routes βββββββββββββββββββββββββββββββββββββββ | |
def index(): | |
return render_template('index.html') | |
def upload_image(): | |
f = request.files['image'] | |
from PIL import Image | |
img = Image.open(f.stream) | |
orig_w, orig_h = img.size | |
idx = len(os.listdir(IMAGES_DIR)) + 1 | |
ext = f.filename.rsplit('.', 1)[-1] | |
fname = f"{idx:02d}.{ext}" | |
img.save(os.path.join(IMAGES_DIR, fname)) | |
img.save(os.path.join(IMAGES_DIR_OUT, fname)) | |
return jsonify({ | |
'image_url': f"{STATIC_BASE}/images/{fname}", | |
'image_id': idx, | |
'ext': ext, | |
'orig_width': orig_w, | |
'orig_height': orig_h | |
}) | |
def store_tracks(): | |
data = request.get_json() | |
image_id = data['image_id'] | |
ext = data['ext'] | |
free_tracks = data.get('tracks', []) | |
circ_trajs = data.get('circle_trajectories', []) | |
# Debug lengths | |
for i, tr in enumerate(free_tracks, 1): | |
print(f"Freehand Track {i}: {len(tr)} points") | |
for i, tr in enumerate(circ_trajs, 1): | |
print(f"Circle/Static Traj {i}: {len(tr)} points") | |
def pad_pts(tr): | |
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating.""" | |
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32) | |
n = pts.shape[0] | |
if n < FIXED_LENGTH: | |
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32) | |
pts = np.vstack((pts, pad)) | |
else: | |
pts = pts[:FIXED_LENGTH] | |
return pts.reshape(FIXED_LENGTH, 1, 3) | |
arrs = [] | |
# 1) Freehand tracks | |
for i, tr in enumerate(free_tracks): | |
pts = pad_pts(tr) | |
arrs.append(pts,) | |
# 2) Circle + Static combined | |
for i, tr in enumerate(circ_trajs): | |
pts = pad_pts(tr) | |
arrs.append(pts) | |
print(arrs) | |
# Nothing to save? | |
if not arrs: | |
overlay_file = f"{image_id:02d}.png" | |
return jsonify({ | |
'status': 'ok', | |
'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}" | |
}) | |
new_tracks = np.stack(arrs, axis=0) # (T_new, FIXED_LENGTH,1,4) | |
# Load existing .pth and pad old channels to 4 if needed | |
track_path = os.path.join(TRACKS_DIR, f"{image_id:02d}.pth") | |
if os.path.exists(track_path): | |
# shape (T_old, FIXED_LENGTH,1,3) or (...,4) | |
old = load_existing_tracks(track_path) | |
if old.ndim == 4 and old.shape[-1] == 3: | |
pad = np.zeros( | |
(old.shape[0], old.shape[1], old.shape[2], 1), dtype=np.float32) | |
old = np.concatenate((old, pad), axis=-1) | |
all_tracks = np.concatenate([old, new_tracks], axis=0) | |
else: | |
all_tracks = new_tracks | |
# Save updated track file | |
array_to_npz_bytes(all_tracks, track_path, compressed=True) | |
# Build overlay PNG | |
img_path = os.path.join(IMAGES_DIR, f"{image_id:02d}.{ext}") | |
img = plt.imread(img_path) | |
fig, ax = plt.subplots(figsize=(12, 8)) | |
ax.imshow(img) | |
for t in all_tracks: | |
coords = t[:, 0, :] # (FIXED_LENGTH,4) | |
ax.plot(coords[:, 0][coords[:, 2] > 0.5], coords[:, 1] | |
[coords[:, 2] > 0.5], marker='o', color=COLOR_CYCLE[0]) | |
ax.axis('off') | |
overlay_file = f"{image_id:02d}.png" | |
fig.savefig(os.path.join(OVERLAY_DIR, overlay_file), | |
bbox_inches='tight', pad_inches=0) | |
plt.close(fig) | |
# Update YAML (unchanged) | |
entry = { | |
"image": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/images/{image_id:02d}.{ext}"), | |
"text": None, | |
"track": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/tracks/{image_id:02d}.pth") | |
} | |
if os.path.exists(YAML_PATH): | |
with open(YAML_PATH) as yf: | |
docs = yaml.safe_load(yf) or [] | |
else: | |
docs = [] | |
for e in docs: | |
if e.get("image", "").endswith(f"{image_id:02d}.{ext}"): | |
e.update(entry) | |
break | |
else: | |
docs.append(entry) | |
with open(YAML_PATH, 'w') as yf: | |
yaml.dump(docs, yf, default_flow_style=False) | |
return jsonify({ | |
'status': 'ok', | |
'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}" | |
}) | |
if __name__ == '__main__': | |
app.run(debug=True) | |