soiz1's picture
Update app.py
ab0c110 verified
import os
import sys
import cv2
import numpy as np
import gradio as gr
import tempfile
import torch
from torchvision import transforms
import tensorflow as tf
import tensorflow_hub as hub
import mediapipe as mp
def setup_environment():
"""環境セットアップ関数"""
print("環境セットアップを開始します...")
# 必要なパッケージをインストール
os.system("pip install torch torchvision opencv-python numpy gradio tensorflow tensorflow-hub mediapipe")
print("環境セットアップが完了しました!")
# セットアップ実行
setup_environment()
# モデル初期化
def initialize_models():
"""各モデルを初期化"""
models = {}
# MediaPipe Pose
mp_pose = mp.solutions.pose
models['mediapipe'] = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5)
# MoveNet MultiPose
movenet_model = hub.load("https://tfhub.dev/google/movenet/multipose/lightning/1")
models['movenet'] = movenet_model.signatures['serving_default']
return models
# 共通設定
KEYPOINT_EDGES = [
(0, 1), (1, 3), (0, 2), (2, 4),
(5, 7), (7, 9), (6, 8), (8, 10),
(5, 6), (5, 11), (6, 12),
(11, 13), (13, 15), (12, 14), (14, 16),
(11, 12)
]
# 各モデル用のポーズ検出関数
def detect_pose_mediapipe(frame, pose):
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = pose.process(rgb)
return results.pose_landmarks if results.pose_landmarks else None
def detect_pose_movenet(frame, movenet):
image = tf.image.resize_with_pad(tf.expand_dims(frame, axis=0), 256, 256)
input_image = tf.cast(image, dtype=tf.int32)
outputs = movenet(input_image)
return outputs['output_0'].numpy()[0]
# 棒人間描画関数
def draw_stick_figure(frame, landmarks, model_type, width, height, head_size_factor, line_thickness, fixed_size=None):
blank = np.ones((height, width, 3), dtype=np.uint8) * 255
black = (0, 0, 0)
if model_type == 'mediapipe':
def get_point(index):
lm = landmarks.landmark[index]
return int(lm.x * width), int(lm.y * height)
head_x, head_y = get_point(0)
if fixed_size is not None:
shoulder_hip_dist = fixed_size
else:
def get_distance(a, b):
xa, ya = get_point(a)
xb, yb = get_point(b)
return ((xa - xb)**2 + (ya - yb)**2)**0.5
shoulder_hip_dist = get_distance(11, 13)
radius = int(shoulder_hip_dist * head_size_factor)
cv2.circle(blank, (head_x, head_y), radius, black, thickness=line_thickness)
connections = [
(11, 12), (11, 13), (13, 15), (12, 14), (14, 16),
(11, 23), (12, 24), (23, 24), (23, 25), (24, 26), (25, 27), (26, 28)
]
for a, b in connections:
pt1 = get_point(a)
pt2 = get_point(b)
cv2.line(blank, pt1, pt2, black, line_thickness)
elif model_type == 'movenet':
for person in landmarks:
overall_score = person[-1]
if overall_score < 0.2:
continue
keypoints = []
for i in range(17):
y, x, score = person[i * 3:(i + 1) * 3]
if score < 0.2:
keypoints.append(None)
else:
px, py = int(x * width), int(y * height)
keypoints.append((px, py))
if keypoints[0] and keypoints[5] and keypoints[6]:
head_x, head_y = keypoints[0]
d1 = np.linalg.norm(np.array(keypoints[5]) - np.array(keypoints[11])) if keypoints[11] else 0
d2 = np.linalg.norm(np.array(keypoints[6]) - np.array(keypoints[12])) if keypoints[12] else 0
shoulder_hip_dist = (d1 + d2) / 2 if d1 and d2 else 50
radius = int(shoulder_hip_dist * head_size_factor)
cv2.circle(blank, (head_x, head_y), radius, black, thickness=line_thickness)
for idx, pt in enumerate(keypoints):
if pt and idx not in [0, 1, 2, 3, 4]:
cv2.circle(blank, pt, line_thickness, black, -1)
for a, b in KEYPOINT_EDGES:
if keypoints[a] and keypoints[b]:
cv2.line(blank, keypoints[a], keypoints[b], black, line_thickness)
return blank
# 動画処理メイン関数
def process_video(video_path, model_type, head_size_factor, line_thickness, use_average_head_size):
models = initialize_models()
model = models[model_type]
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("動画ファイルを開けません")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
average_shoulder_hip_dist = None
if use_average_head_size and model_type == 'mediapipe':
distances = []
while True:
ret, frame = cap.read()
if not ret:
break
landmarks = detect_pose_mediapipe(frame, model)
if landmarks:
def get_point(index):
lm = landmarks.landmark[index]
return int(lm.x * width), int(lm.y * height)
def get_distance(a, b):
xa, ya = get_point(a)
xb, yb = get_point(b)
return ((xa - xb)**2 + (ya - yb)**2)**0.5
dist = get_distance(11, 13)
distances.append(dist)
average_shoulder_hip_dist = np.mean(distances) if distances else 50
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
output_path = tempfile.mktemp(suffix=".mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
frame_idx += 1
print(f"処理中: フレーム {frame_idx}/{total_frames} ({(frame_idx/total_frames)*100:.1f}%)")
if model_type == 'mediapipe':
landmarks = detect_pose_mediapipe(frame, model)
if landmarks:
drawn = draw_stick_figure(
frame, landmarks, model_type, width, height,
head_size_factor, line_thickness, average_shoulder_hip_dist
)
else:
drawn = np.ones((height, width, 3), dtype=np.uint8) * 255
elif model_type == 'movenet':
landmarks = detect_pose_movenet(frame, model)
drawn = draw_stick_figure(
frame, landmarks, model_type, width, height,
head_size_factor, line_thickness, None
)
out.write(drawn)
cap.release()
out.release()
print("処理完了!")
return output_path
# Gradioインターフェース
def gradio_interface(video_file, model_type, head_size_factor, line_thickness, use_avg):
return process_video(
video_file, model_type, head_size_factor, line_thickness,
use_avg == "全体平均で固定"
)
model_info = """
- **MediaPipe Pose**: 単独人物向け、高精度
- **MoveNet MultiPose**: 複数人物検出可能
"""
demo = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Video(label="動画をアップロード"),
gr.Radio(
["mediapipe", "movenet"],
label="モデル選択",
value="mediapipe",
info=model_info
),
gr.Slider(minimum=0.1, maximum=1.0, value=0.4, label="頭の大きさ(肩〜腰の比率)"),
gr.Slider(minimum=1, maximum=10, step=1, value=2, label="線の太さ"),
gr.Radio(
["フレームごとに計算", "全体平均で固定"],
value="フレームごとに計算",
label="頭サイズの計算方法 (MediaPipeのみ有効)"
)
],
outputs=gr.Video(label="棒人間動画"),
title="統合版 棒人間モーショントラッキング",
description="MediaPipe または MoveNet による棒人間動画生成ツール"
)
if __name__ == "__main__":
demo.launch()