import os import sys import cv2 import numpy as np import gradio as gr import tempfile import torch from torchvision import transforms import tensorflow as tf import tensorflow_hub as hub import mediapipe as mp def setup_environment(): """環境セットアップ関数""" print("環境セットアップを開始します...") # 必要なパッケージをインストール os.system("pip install torch torchvision opencv-python numpy gradio tensorflow tensorflow-hub mediapipe") print("環境セットアップが完了しました!") # セットアップ実行 setup_environment() # モデル初期化 def initialize_models(): """各モデルを初期化""" models = {} # MediaPipe Pose mp_pose = mp.solutions.pose models['mediapipe'] = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) # MoveNet MultiPose movenet_model = hub.load("https://tfhub.dev/google/movenet/multipose/lightning/1") models['movenet'] = movenet_model.signatures['serving_default'] return models # 共通設定 KEYPOINT_EDGES = [ (0, 1), (1, 3), (0, 2), (2, 4), (5, 7), (7, 9), (6, 8), (8, 10), (5, 6), (5, 11), (6, 12), (11, 13), (13, 15), (12, 14), (14, 16), (11, 12) ] # 各モデル用のポーズ検出関数 def detect_pose_mediapipe(frame, pose): rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = pose.process(rgb) return results.pose_landmarks if results.pose_landmarks else None def detect_pose_movenet(frame, movenet): image = tf.image.resize_with_pad(tf.expand_dims(frame, axis=0), 256, 256) input_image = tf.cast(image, dtype=tf.int32) outputs = movenet(input_image) return outputs['output_0'].numpy()[0] # 棒人間描画関数 def draw_stick_figure(frame, landmarks, model_type, width, height, head_size_factor, line_thickness, fixed_size=None): blank = np.ones((height, width, 3), dtype=np.uint8) * 255 black = (0, 0, 0) if model_type == 'mediapipe': def get_point(index): lm = landmarks.landmark[index] return int(lm.x * width), int(lm.y * height) head_x, head_y = get_point(0) if fixed_size is not None: shoulder_hip_dist = fixed_size else: def get_distance(a, b): xa, ya = get_point(a) xb, yb = get_point(b) return ((xa - xb)**2 + (ya - yb)**2)**0.5 shoulder_hip_dist = get_distance(11, 13) radius = int(shoulder_hip_dist * head_size_factor) cv2.circle(blank, (head_x, head_y), radius, black, thickness=line_thickness) connections = [ (11, 12), (11, 13), (13, 15), (12, 14), (14, 16), (11, 23), (12, 24), (23, 24), (23, 25), (24, 26), (25, 27), (26, 28) ] for a, b in connections: pt1 = get_point(a) pt2 = get_point(b) cv2.line(blank, pt1, pt2, black, line_thickness) elif model_type == 'movenet': for person in landmarks: overall_score = person[-1] if overall_score < 0.2: continue keypoints = [] for i in range(17): y, x, score = person[i * 3:(i + 1) * 3] if score < 0.2: keypoints.append(None) else: px, py = int(x * width), int(y * height) keypoints.append((px, py)) if keypoints[0] and keypoints[5] and keypoints[6]: head_x, head_y = keypoints[0] d1 = np.linalg.norm(np.array(keypoints[5]) - np.array(keypoints[11])) if keypoints[11] else 0 d2 = np.linalg.norm(np.array(keypoints[6]) - np.array(keypoints[12])) if keypoints[12] else 0 shoulder_hip_dist = (d1 + d2) / 2 if d1 and d2 else 50 radius = int(shoulder_hip_dist * head_size_factor) cv2.circle(blank, (head_x, head_y), radius, black, thickness=line_thickness) for idx, pt in enumerate(keypoints): if pt and idx not in [0, 1, 2, 3, 4]: cv2.circle(blank, pt, line_thickness, black, -1) for a, b in KEYPOINT_EDGES: if keypoints[a] and keypoints[b]: cv2.line(blank, keypoints[a], keypoints[b], black, line_thickness) return blank # 動画処理メイン関数 def process_video(video_path, model_type, head_size_factor, line_thickness, use_average_head_size): models = initialize_models() model = models[model_type] cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError("動画ファイルを開けません") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) average_shoulder_hip_dist = None if use_average_head_size and model_type == 'mediapipe': distances = [] while True: ret, frame = cap.read() if not ret: break landmarks = detect_pose_mediapipe(frame, model) if landmarks: def get_point(index): lm = landmarks.landmark[index] return int(lm.x * width), int(lm.y * height) def get_distance(a, b): xa, ya = get_point(a) xb, yb = get_point(b) return ((xa - xb)**2 + (ya - yb)**2)**0.5 dist = get_distance(11, 13) distances.append(dist) average_shoulder_hip_dist = np.mean(distances) if distances else 50 cap.set(cv2.CAP_PROP_POS_FRAMES, 0) output_path = tempfile.mktemp(suffix=".mp4") fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) frame_idx = 0 while True: ret, frame = cap.read() if not ret: break frame_idx += 1 print(f"処理中: フレーム {frame_idx}/{total_frames} ({(frame_idx/total_frames)*100:.1f}%)") if model_type == 'mediapipe': landmarks = detect_pose_mediapipe(frame, model) if landmarks: drawn = draw_stick_figure( frame, landmarks, model_type, width, height, head_size_factor, line_thickness, average_shoulder_hip_dist ) else: drawn = np.ones((height, width, 3), dtype=np.uint8) * 255 elif model_type == 'movenet': landmarks = detect_pose_movenet(frame, model) drawn = draw_stick_figure( frame, landmarks, model_type, width, height, head_size_factor, line_thickness, None ) out.write(drawn) cap.release() out.release() print("処理完了!") return output_path # Gradioインターフェース def gradio_interface(video_file, model_type, head_size_factor, line_thickness, use_avg): return process_video( video_file, model_type, head_size_factor, line_thickness, use_avg == "全体平均で固定" ) model_info = """ - **MediaPipe Pose**: 単独人物向け、高精度 - **MoveNet MultiPose**: 複数人物検出可能 """ demo = gr.Interface( fn=gradio_interface, inputs=[ gr.Video(label="動画をアップロード"), gr.Radio( ["mediapipe", "movenet"], label="モデル選択", value="mediapipe", info=model_info ), gr.Slider(minimum=0.1, maximum=1.0, value=0.4, label="頭の大きさ(肩〜腰の比率)"), gr.Slider(minimum=1, maximum=10, step=1, value=2, label="線の太さ"), gr.Radio( ["フレームごとに計算", "全体平均で固定"], value="フレームごとに計算", label="頭サイズの計算方法 (MediaPipeのみ有効)" ) ], outputs=gr.Video(label="棒人間動画"), title="統合版 棒人間モーショントラッキング", description="MediaPipe または MoveNet による棒人間動画生成ツール" ) if __name__ == "__main__": demo.launch()