Spaces:
Running
on
Zero
Running
on
Zero
feat: π new supervision vitpose support and annotators improvement (#3)
Browse files- -feat: π new supervision vitpose support and annotators improvement and doccs and gradio UI updates for more functionally (acb84529b192c3ec3d63c9abc055a1a8337dd91c)
Co-authored-by: Onuralp SEZER <onuralpszr@users.noreply.huggingface.co>
- app.py +122 -39
- pyproject.toml +1 -1
- requirements.txt +2 -8
app.py
CHANGED
@@ -13,10 +13,62 @@ import torch
|
|
13 |
import tqdm
|
14 |
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
|
15 |
|
16 |
-
DESCRIPTION = "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
MAX_NUM_FRAMES = 300
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
person_detector_name = "PekingU/rtdetr_r50vd_coco_o365"
|
@@ -30,11 +82,19 @@ pose_model = VitPoseForPoseEstimation.from_pretrained(pose_model_name, device_ma
|
|
30 |
|
31 |
@spaces.GPU(duration=5)
|
32 |
@torch.inference_mode()
|
33 |
-
def detect_pose_image(
|
|
|
|
|
|
|
|
|
|
|
34 |
"""Detects persons and estimates their poses in a single image.
|
35 |
|
36 |
Args:
|
37 |
image (PIL.Image.Image): Input image in which to detect persons and estimate poses.
|
|
|
|
|
|
|
38 |
|
39 |
Returns:
|
40 |
tuple[PIL.Image.Image, list[dict]]:
|
@@ -44,20 +104,14 @@ def detect_pose_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dic
|
|
44 |
inputs = person_image_processor(images=image, return_tensors="pt").to(device)
|
45 |
outputs = person_model(**inputs)
|
46 |
results = person_image_processor.post_process_object_detection(
|
47 |
-
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=
|
48 |
)
|
49 |
result = results[0] # take first image results
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
person_boxes_xyxy = person_boxes_xyxy.cpu().numpy()
|
54 |
-
|
55 |
-
# Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
|
56 |
-
person_boxes = person_boxes_xyxy.copy()
|
57 |
-
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
|
58 |
-
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
|
59 |
|
60 |
-
inputs = pose_image_processor(image, boxes=[
|
61 |
|
62 |
# for vitpose-plus-base checkpoint we should additionally provide dataset_index
|
63 |
# to specify which MOE experts to use for inference
|
@@ -68,11 +122,12 @@ def detect_pose_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dic
|
|
68 |
|
69 |
outputs = pose_model(**inputs)
|
70 |
|
71 |
-
pose_results = pose_image_processor.post_process_pose_estimation(outputs, boxes=[
|
72 |
image_pose_result = pose_results[0] # results for first image
|
73 |
|
74 |
# make results more human-readable
|
75 |
human_readable_results = []
|
|
|
76 |
for i, person_pose in enumerate(image_pose_result):
|
77 |
data = {
|
78 |
"person_id": i,
|
@@ -83,43 +138,55 @@ def detect_pose_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dic
|
|
83 |
person_pose["keypoints"], person_pose["labels"], person_pose["scores"], strict=True
|
84 |
):
|
85 |
keypoint_name = pose_model.config.id2label[label.item()]
|
|
|
86 |
x, y = keypoint
|
87 |
data["keypoints"].append({"name": keypoint_name, "x": x.item(), "y": y.item(), "score": score.item()})
|
88 |
human_readable_results.append(data)
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
xy = torch.stack(xy).cpu().numpy()
|
93 |
-
|
94 |
-
scores = [pose_result["scores"] for pose_result in image_pose_result]
|
95 |
-
scores = torch.stack(scores).cpu().numpy()
|
96 |
-
|
97 |
-
keypoints = sv.KeyPoints(xy=xy, confidence=scores)
|
98 |
-
detections = sv.Detections(xyxy=person_boxes_xyxy)
|
99 |
|
100 |
-
edge_annotator = sv.EdgeAnnotator(color=sv.Color.
|
101 |
-
vertex_annotator = sv.VertexAnnotator(color=sv.Color.
|
102 |
-
|
103 |
|
104 |
-
|
|
|
|
|
105 |
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
annotated_frame = edge_annotator.annotate(scene=annotated_frame, key_points=keypoints)
|
111 |
-
return vertex_annotator.annotate(scene=annotated_frame, key_points=keypoints), human_readable_results
|
112 |
|
113 |
|
114 |
-
@spaces.GPU(duration=90)
|
115 |
def detect_pose_video(
|
116 |
video_path: str,
|
|
|
|
|
|
|
117 |
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
|
118 |
) -> str:
|
119 |
"""Detects persons and estimates their poses for each frame in a video, saving the annotated video.
|
120 |
|
121 |
Args:
|
122 |
video_path (str): Path to the input video file.
|
|
|
|
|
|
|
123 |
progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
|
124 |
|
125 |
Returns:
|
@@ -140,7 +207,12 @@ def detect_pose_video(
|
|
140 |
if not ok:
|
141 |
break
|
142 |
rgb_frame = frame[:, :, ::-1]
|
143 |
-
annotated_frame, _ = detect_pose_image(
|
|
|
|
|
|
|
|
|
|
|
144 |
writer.write(np.asarray(annotated_frame)[:, :, ::-1])
|
145 |
writer.release()
|
146 |
cap.release()
|
@@ -150,6 +222,17 @@ def detect_pose_video(
|
|
150 |
with gr.Blocks(css_paths="style.css") as demo:
|
151 |
gr.Markdown(DESCRIPTION)
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
with gr.Tabs():
|
154 |
with gr.Tab("Image"):
|
155 |
with gr.Row():
|
@@ -160,15 +243,15 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
160 |
output_image = gr.Image(label="Output Image")
|
161 |
output_json = gr.JSON(label="Output JSON")
|
162 |
gr.Examples(
|
163 |
-
examples=sorted(pathlib.Path("images").glob("*.jpg")),
|
164 |
-
inputs=input_image,
|
165 |
outputs=[output_image, output_json],
|
166 |
fn=detect_pose_image,
|
167 |
)
|
168 |
|
169 |
run_button_image.click(
|
170 |
fn=detect_pose_image,
|
171 |
-
inputs=input_image,
|
172 |
outputs=[output_image, output_json],
|
173 |
)
|
174 |
|
@@ -183,15 +266,15 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
183 |
output_video = gr.Video(label="Output Video")
|
184 |
|
185 |
gr.Examples(
|
186 |
-
examples=sorted(pathlib.Path("videos").glob("*.mp4")),
|
187 |
-
inputs=input_video,
|
188 |
outputs=output_video,
|
189 |
fn=detect_pose_video,
|
190 |
cache_examples=False,
|
191 |
)
|
192 |
run_button_video.click(
|
193 |
fn=detect_pose_video,
|
194 |
-
inputs=input_video,
|
195 |
outputs=output_video,
|
196 |
)
|
197 |
|
|
|
13 |
import tqdm
|
14 |
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
|
15 |
|
16 |
+
DESCRIPTION = """
|
17 |
+
# ViTPose
|
18 |
+
|
19 |
+
<div style="display: flex; gap: 10px;">
|
20 |
+
<a href="https://huggingface.co/docs/transformers/en/model_doc/vitpose">
|
21 |
+
<img src="https://img.shields.io/badge/Huggingface-FFD21E?style=flat&logo=Huggingface&logoColor=black" alt="Huggingface">
|
22 |
+
</a>
|
23 |
+
<a href="https://arxiv.org/abs/2204.12484">
|
24 |
+
<img src="https://img.shields.io/badge/Arvix-B31B1B?style=flat&logo=arXiv&logoColor=white" alt="Paper">
|
25 |
+
</a>
|
26 |
+
<a href="https://github.com/ViTAE-Transformer/ViTPose">
|
27 |
+
<img src="https://img.shields.io/badge/Github-100000?style=flat&logo=github&logoColor=white" alt="Github">
|
28 |
+
</a>
|
29 |
+
</div>
|
30 |
+
|
31 |
+
ViTPose is a state-of-the-art human pose estimation model based on Vision Transformers (ViT). It employs a standard, non-hierarchical ViT backbone and a simple decoder head to predict keypoint heatmaps from images. Despite its simplicity, ViTPose achieves top results on the MS COCO Keypoint Detection benchmark.
|
32 |
+
|
33 |
+
ViTPose++ further improves performance with a mixture-of-experts (MoE) module and extensive pre-training. The model is scalable, flexible, and demonstrates strong transferability across pose estimation tasks.
|
34 |
+
|
35 |
+
**Key features:**
|
36 |
+
- PyTorch implementation
|
37 |
+
- Scalable model size (100M to 1B parameters)
|
38 |
+
- Flexible training and inference
|
39 |
+
- State-of-the-art accuracy on challenging benchmarks
|
40 |
+
|
41 |
+
"""
|
42 |
+
|
43 |
+
|
44 |
+
COLORS = [
|
45 |
+
"#A351FB",
|
46 |
+
"#FF4040",
|
47 |
+
"#FFA1A0",
|
48 |
+
"#FF7633",
|
49 |
+
"#FFB633",
|
50 |
+
"#D1D435",
|
51 |
+
"#4CFB12",
|
52 |
+
"#94CF1A",
|
53 |
+
"#40DE8A",
|
54 |
+
"#1B9640",
|
55 |
+
"#00D6C1",
|
56 |
+
"#2E9CAA",
|
57 |
+
"#00C4FF",
|
58 |
+
"#364797",
|
59 |
+
"#6675FF",
|
60 |
+
"#0019EF",
|
61 |
+
"#863AFF",
|
62 |
+
]
|
63 |
+
COLORS = [sv.Color.from_hex(color_hex=c) for c in COLORS]
|
64 |
|
65 |
MAX_NUM_FRAMES = 300
|
66 |
|
67 |
+
keypoint_score = 0.3
|
68 |
+
enable_labels_annotator = True
|
69 |
+
enable_vertices_annotator = True
|
70 |
+
|
71 |
+
|
72 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
73 |
|
74 |
person_detector_name = "PekingU/rtdetr_r50vd_coco_o365"
|
|
|
82 |
|
83 |
@spaces.GPU(duration=5)
|
84 |
@torch.inference_mode()
|
85 |
+
def detect_pose_image(
|
86 |
+
image: PIL.Image.Image,
|
87 |
+
threshold: float = 0.3,
|
88 |
+
enable_labels_annotator: bool = True,
|
89 |
+
enable_vertices_annotator: bool = True,
|
90 |
+
) -> tuple[PIL.Image.Image, list[dict]]:
|
91 |
"""Detects persons and estimates their poses in a single image.
|
92 |
|
93 |
Args:
|
94 |
image (PIL.Image.Image): Input image in which to detect persons and estimate poses.
|
95 |
+
threshold (Float): Confidence threshold for pose keypoints.
|
96 |
+
enable_labels_annotator (bool): Whether to enable annotating labels for pose keypoints.
|
97 |
+
enable_vertices_annotator (bool): Whether to enable annotating vertices for pose keypoints
|
98 |
|
99 |
Returns:
|
100 |
tuple[PIL.Image.Image, list[dict]]:
|
|
|
104 |
inputs = person_image_processor(images=image, return_tensors="pt").to(device)
|
105 |
outputs = person_model(**inputs)
|
106 |
results = person_image_processor.post_process_object_detection(
|
107 |
+
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=threshold
|
108 |
)
|
109 |
result = results[0] # take first image results
|
110 |
|
111 |
+
detections = sv.Detections.from_transformers(result)
|
112 |
+
person_detections_xywh = sv.xyxy_to_xywh(detections[detections.class_id == 0].xyxy)
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
+
inputs = pose_image_processor(image, boxes=[person_detections_xywh], return_tensors="pt").to(device)
|
115 |
|
116 |
# for vitpose-plus-base checkpoint we should additionally provide dataset_index
|
117 |
# to specify which MOE experts to use for inference
|
|
|
122 |
|
123 |
outputs = pose_model(**inputs)
|
124 |
|
125 |
+
pose_results = pose_image_processor.post_process_pose_estimation(outputs, boxes=[person_detections_xywh])
|
126 |
image_pose_result = pose_results[0] # results for first image
|
127 |
|
128 |
# make results more human-readable
|
129 |
human_readable_results = []
|
130 |
+
person_pose_labels = []
|
131 |
for i, person_pose in enumerate(image_pose_result):
|
132 |
data = {
|
133 |
"person_id": i,
|
|
|
138 |
person_pose["keypoints"], person_pose["labels"], person_pose["scores"], strict=True
|
139 |
):
|
140 |
keypoint_name = pose_model.config.id2label[label.item()]
|
141 |
+
person_pose_labels.append(keypoint_name)
|
142 |
x, y = keypoint
|
143 |
data["keypoints"].append({"name": keypoint_name, "x": x.item(), "y": y.item(), "score": score.item()})
|
144 |
human_readable_results.append(data)
|
145 |
|
146 |
+
line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=(image.width, image.height))
|
147 |
+
text_scale = sv.calculate_optimal_text_scale(resolution_wh=(image.width, image.height))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
edge_annotator = sv.EdgeAnnotator(color=sv.Color.WHITE, thickness=line_thickness)
|
150 |
+
vertex_annotator = sv.VertexAnnotator(color=sv.Color.BLUE, radius=3)
|
151 |
+
box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE, color_lookup=sv.ColorLookup.INDEX, thickness=3)
|
152 |
|
153 |
+
vertex_label_annotator = sv.VertexLabelAnnotator(
|
154 |
+
color=COLORS, smart_position=True, border_radius=3, text_thickness=2, text_scale=text_scale
|
155 |
+
)
|
156 |
|
157 |
+
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections)
|
158 |
+
|
159 |
+
for _, person_pose in enumerate(image_pose_result):
|
160 |
+
person_keypoints = sv.KeyPoints.from_transformers([person_pose])
|
161 |
+
person_labels = [pose_model.config.id2label[label.item()] for label in person_pose["labels"]]
|
162 |
+
# annotate edges and vertices for this person
|
163 |
+
annotated_frame = edge_annotator.annotate(scene=annotated_frame, key_points=person_keypoints)
|
164 |
+
# annotate labels for this person
|
165 |
+
if enable_labels_annotator:
|
166 |
+
annotated_frame = vertex_label_annotator.annotate(
|
167 |
+
scene=np.array(annotated_frame), key_points=person_keypoints, labels=person_labels
|
168 |
+
)
|
169 |
+
# annotate vertices for this person
|
170 |
+
if enable_vertices_annotator:
|
171 |
+
annotated_frame = vertex_annotator.annotate(scene=annotated_frame, key_points=person_keypoints)
|
172 |
|
173 |
+
return annotated_frame, human_readable_results
|
|
|
|
|
174 |
|
175 |
|
|
|
176 |
def detect_pose_video(
|
177 |
video_path: str,
|
178 |
+
threshold: float,
|
179 |
+
enable_labels_annotator: bool = True,
|
180 |
+
enable_vertices_annotator: bool = True,
|
181 |
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
|
182 |
) -> str:
|
183 |
"""Detects persons and estimates their poses for each frame in a video, saving the annotated video.
|
184 |
|
185 |
Args:
|
186 |
video_path (str): Path to the input video file.
|
187 |
+
threshold (Float): Confidence threshold for pose keypoints.
|
188 |
+
enable_labels_annotator (bool): Whether to enable annotating labels for pose keypoints.
|
189 |
+
enable_vertices_annotator (bool): Whether to enable annotating vertices for pose keypoints.
|
190 |
progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
|
191 |
|
192 |
Returns:
|
|
|
207 |
if not ok:
|
208 |
break
|
209 |
rgb_frame = frame[:, :, ::-1]
|
210 |
+
annotated_frame, _ = detect_pose_image(
|
211 |
+
PIL.Image.fromarray(rgb_frame),
|
212 |
+
threshold=threshold,
|
213 |
+
enable_labels_annotator=enable_labels_annotator,
|
214 |
+
enable_vertices_annotator=enable_vertices_annotator,
|
215 |
+
)
|
216 |
writer.write(np.asarray(annotated_frame)[:, :, ::-1])
|
217 |
writer.release()
|
218 |
cap.release()
|
|
|
222 |
with gr.Blocks(css_paths="style.css") as demo:
|
223 |
gr.Markdown(DESCRIPTION)
|
224 |
|
225 |
+
keypoint_score = gr.Slider(
|
226 |
+
minimum=0.0,
|
227 |
+
maximum=1.0,
|
228 |
+
value=0.6,
|
229 |
+
step=0.01,
|
230 |
+
info="Adjust the confidence threshold for keypoint detection.",
|
231 |
+
label="Keypoint Score Threshold",
|
232 |
+
)
|
233 |
+
enable_labels_annotator = gr.Checkbox(interactive=True, value=True, label="Enable Labels")
|
234 |
+
enable_vertices_annotator = gr.Checkbox(interactive=True, value=True, label="Enable Vertices")
|
235 |
+
|
236 |
with gr.Tabs():
|
237 |
with gr.Tab("Image"):
|
238 |
with gr.Row():
|
|
|
243 |
output_image = gr.Image(label="Output Image")
|
244 |
output_json = gr.JSON(label="Output JSON")
|
245 |
gr.Examples(
|
246 |
+
examples=[[str(img), 0.5, True, True] for img in sorted(pathlib.Path("images").glob("*.jpg"))],
|
247 |
+
inputs=[input_image, keypoint_score, enable_labels_annotator, enable_vertices_annotator],
|
248 |
outputs=[output_image, output_json],
|
249 |
fn=detect_pose_image,
|
250 |
)
|
251 |
|
252 |
run_button_image.click(
|
253 |
fn=detect_pose_image,
|
254 |
+
inputs=[input_image, keypoint_score, enable_labels_annotator, enable_vertices_annotator],
|
255 |
outputs=[output_image, output_json],
|
256 |
)
|
257 |
|
|
|
266 |
output_video = gr.Video(label="Output Video")
|
267 |
|
268 |
gr.Examples(
|
269 |
+
examples=[[str(video), 0.5, True, True] for video in sorted(pathlib.Path("videos").glob("*.mp4"))],
|
270 |
+
inputs=[input_video, keypoint_score, enable_labels_annotator, enable_vertices_annotator],
|
271 |
outputs=output_video,
|
272 |
fn=detect_pose_video,
|
273 |
cache_examples=False,
|
274 |
)
|
275 |
run_button_video.click(
|
276 |
fn=detect_pose_video,
|
277 |
+
inputs=[input_video, keypoint_score, enable_labels_annotator, enable_vertices_annotator],
|
278 |
outputs=output_video,
|
279 |
)
|
280 |
|
pyproject.toml
CHANGED
@@ -10,7 +10,7 @@ dependencies = [
|
|
10 |
"hf-transfer>=0.1.9",
|
11 |
"setuptools>=80.9.0",
|
12 |
"spaces>=0.37.1",
|
13 |
-
"supervision>=0.
|
14 |
"torch==2.5.1",
|
15 |
"transformers>=4.53.0",
|
16 |
]
|
|
|
10 |
"hf-transfer>=0.1.9",
|
11 |
"setuptools>=80.9.0",
|
12 |
"spaces>=0.37.1",
|
13 |
+
"supervision>=0.26.0",
|
14 |
"torch==2.5.1",
|
15 |
"transformers>=4.53.0",
|
16 |
]
|
requirements.txt
CHANGED
@@ -25,15 +25,11 @@ click==8.1.8
|
|
25 |
# typer
|
26 |
# uvicorn
|
27 |
contourpy==1.3.1
|
28 |
-
# via
|
29 |
-
# matplotlib
|
30 |
-
# supervision
|
31 |
cycler==0.12.1
|
32 |
# via matplotlib
|
33 |
defusedxml==0.7.1
|
34 |
# via supervision
|
35 |
-
exceptiongroup==1.2.2
|
36 |
-
# via anyio
|
37 |
fastapi==0.115.7
|
38 |
# via gradio
|
39 |
ffmpy==0.5.0
|
@@ -254,7 +250,7 @@ starlette==0.45.3
|
|
254 |
# fastapi
|
255 |
# gradio
|
256 |
# mcp
|
257 |
-
supervision==0.
|
258 |
# via vitpose-transformers (pyproject.toml)
|
259 |
sympy==1.13.1
|
260 |
# via torch
|
@@ -286,12 +282,10 @@ typing-extensions==4.12.2
|
|
286 |
# huggingface-hub
|
287 |
# pydantic
|
288 |
# pydantic-core
|
289 |
-
# rich
|
290 |
# spaces
|
291 |
# torch
|
292 |
# typer
|
293 |
# typing-inspection
|
294 |
-
# uvicorn
|
295 |
typing-inspection==0.4.1
|
296 |
# via
|
297 |
# pydantic
|
|
|
25 |
# typer
|
26 |
# uvicorn
|
27 |
contourpy==1.3.1
|
28 |
+
# via matplotlib
|
|
|
|
|
29 |
cycler==0.12.1
|
30 |
# via matplotlib
|
31 |
defusedxml==0.7.1
|
32 |
# via supervision
|
|
|
|
|
33 |
fastapi==0.115.7
|
34 |
# via gradio
|
35 |
ffmpy==0.5.0
|
|
|
250 |
# fastapi
|
251 |
# gradio
|
252 |
# mcp
|
253 |
+
supervision==0.26.0
|
254 |
# via vitpose-transformers (pyproject.toml)
|
255 |
sympy==1.13.1
|
256 |
# via torch
|
|
|
282 |
# huggingface-hub
|
283 |
# pydantic
|
284 |
# pydantic-core
|
|
|
285 |
# spaces
|
286 |
# torch
|
287 |
# typer
|
288 |
# typing-inspection
|
|
|
289 |
typing-inspection==0.4.1
|
290 |
# via
|
291 |
# pydantic
|