mguven61 commited on
Commit
1dc7b14
·
verified ·
1 Parent(s): db980df

Upload 6 files

Browse files
Files changed (6) hide show
  1. LICENSE +21 -0
  2. app.py +173 -0
  3. pyproject.toml +59 -0
  4. requirements.txt +91 -0
  5. style.css +11 -0
  6. uv.lock +0 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import pathlib
4
+ import tempfile
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import PIL.Image
10
+ import spaces
11
+ import supervision as sv
12
+ import torch
13
+ import tqdm
14
+ from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
15
+
16
+ DESCRIPTION = "# ViTPose"
17
+
18
+ MAX_NUM_FRAMES = 300
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ person_detector_name = "PekingU/rtdetr_r50vd_coco_o365"
23
+ person_image_processor = AutoProcessor.from_pretrained(person_detector_name)
24
+ person_model = RTDetrForObjectDetection.from_pretrained(person_detector_name, device_map=device)
25
+
26
+ pose_model_name = "usyd-community/vitpose-base-simple"
27
+ pose_image_processor = AutoProcessor.from_pretrained(pose_model_name)
28
+ pose_model = VitPoseForPoseEstimation.from_pretrained(pose_model_name, device_map=device)
29
+
30
+
31
+ @spaces.GPU(duration=5)
32
+ @torch.inference_mode()
33
+ def process_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dict]]:
34
+ inputs = person_image_processor(images=image, return_tensors="pt").to(device)
35
+ outputs = person_model(**inputs)
36
+ results = person_image_processor.post_process_object_detection(
37
+ outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
38
+ )
39
+ result = results[0]
40
+
41
+ person_boxes_xyxy = result["boxes"][result["labels"] == 0]
42
+ person_boxes_xyxy = person_boxes_xyxy.cpu().numpy()
43
+
44
+ person_boxes = person_boxes_xyxy.copy()
45
+ person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
46
+ person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
47
+
48
+ inputs = pose_image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)
49
+
50
+ if pose_model.config.backbone_config.num_experts > 1:
51
+ dataset_index = torch.tensor([0] * len(inputs["pixel_values"]))
52
+ dataset_index = dataset_index.to(inputs["pixel_values"].device)
53
+ inputs["dataset_index"] = dataset_index
54
+
55
+ outputs = pose_model(**inputs)
56
+
57
+ pose_results = pose_image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
58
+ image_pose_result = pose_results[0]
59
+
60
+ human_readable_results = []
61
+ for i, person_pose in enumerate(image_pose_result):
62
+ data = {
63
+ "person_id": i,
64
+ "bbox": person_pose["bbox"].numpy().tolist(),
65
+ "keypoints": [],
66
+ }
67
+ for keypoint, label, score in zip(
68
+ person_pose["keypoints"], person_pose["labels"], person_pose["scores"], strict=True
69
+ ):
70
+ keypoint_name = pose_model.config.id2label[label.item()]
71
+ x, y = keypoint
72
+ data["keypoints"].append({"name": keypoint_name, "x": x.item(), "y": y.item(), "score": score.item()})
73
+ human_readable_results.append(data)
74
+
75
+ xy = [pose_result["keypoints"] for pose_result in image_pose_result]
76
+ xy = torch.stack(xy).cpu().numpy()
77
+
78
+ scores = [pose_result["scores"] for pose_result in image_pose_result]
79
+ scores = torch.stack(scores).cpu().numpy()
80
+
81
+ keypoints = sv.KeyPoints(xy=xy, confidence=scores)
82
+ detections = sv.Detections(xyxy=person_boxes_xyxy)
83
+
84
+ edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=1)
85
+ vertex_annotator = sv.VertexAnnotator(color=sv.Color.RED, radius=2)
86
+ bounding_box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE, color_lookup=sv.ColorLookup.INDEX, thickness=1)
87
+
88
+ annotated_frame = image.copy()
89
+
90
+ annotated_frame = bounding_box_annotator.annotate(scene=image.copy(), detections=detections)
91
+
92
+ annotated_frame = edge_annotator.annotate(scene=annotated_frame, key_points=keypoints)
93
+ return vertex_annotator.annotate(scene=annotated_frame, key_points=keypoints), human_readable_results
94
+
95
+
96
+ @spaces.GPU(duration=90)
97
+ def process_video(
98
+ video_path: str,
99
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
100
+ ) -> str:
101
+ cap = cv2.VideoCapture(video_path)
102
+
103
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
104
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
105
+ fps = cap.get(cv2.CAP_PROP_FPS)
106
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
107
+
108
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
109
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as out_file:
110
+ writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
111
+ for _ in tqdm.auto.tqdm(range(min(MAX_NUM_FRAMES, num_frames))):
112
+ ok, frame = cap.read()
113
+ if not ok:
114
+ break
115
+ rgb_frame = frame[:, :, ::-1]
116
+ annotated_frame, _ = process_image(PIL.Image.fromarray(rgb_frame))
117
+ writer.write(np.asarray(annotated_frame)[:, :, ::-1])
118
+ writer.release()
119
+ cap.release()
120
+ return out_file.name
121
+
122
+
123
+ with gr.Blocks(css_paths="style.css") as demo:
124
+ gr.Markdown(DESCRIPTION)
125
+
126
+ with gr.Tabs():
127
+ with gr.Tab("Image"):
128
+ with gr.Row():
129
+ with gr.Column():
130
+ input_image = gr.Image(label="Input Image", type="pil")
131
+ run_button_image = gr.Button()
132
+ with gr.Column():
133
+ output_image = gr.Image(label="Output Image")
134
+ output_json = gr.JSON(label="Output JSON")
135
+ gr.Examples(
136
+ examples=sorted(pathlib.Path("images").glob("*.jpg")),
137
+ inputs=input_image,
138
+ outputs=[output_image, output_json],
139
+ fn=process_image,
140
+ )
141
+
142
+ run_button_image.click(
143
+ fn=process_image,
144
+ inputs=input_image,
145
+ outputs=[output_image, output_json],
146
+ )
147
+
148
+ with gr.Tab("Video"):
149
+ gr.Markdown(f"The input video will be truncated to {MAX_NUM_FRAMES} frames.")
150
+
151
+ with gr.Row():
152
+ with gr.Column():
153
+ input_video = gr.Video(label="Input Video")
154
+ run_button_video = gr.Button()
155
+ with gr.Column():
156
+ output_video = gr.Video(label="Output Video")
157
+
158
+ gr.Examples(
159
+ examples=sorted(pathlib.Path("videos").glob("*.mp4")),
160
+ inputs=input_video,
161
+ outputs=output_video,
162
+ fn=process_video,
163
+ cache_examples=False,
164
+ )
165
+ run_button_video.click(
166
+ fn=process_video,
167
+ inputs=input_video,
168
+ outputs=output_video,
169
+ )
170
+
171
+
172
+ if __name__ == "__main__":
173
+ demo.queue(max_size=20).launch()
pyproject.toml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "pose-detect"
3
+ version = "0.1.0"
4
+ description = ""
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "accelerate>=1.3.0",
9
+ "gradio>=5.13.2",
10
+ "hf-transfer>=0.1.9",
11
+ "opencv-python-headless>=4.11.0.86",
12
+ "setuptools>=75.8.0",
13
+ "spaces>=0.32.0",
14
+ "supervision>=0.25.1",
15
+ "torch==2.4.0",
16
+ "transformers>=4.48.1",
17
+ ]
18
+
19
+ [tool.ruff]
20
+ line-length = 119
21
+
22
+ [tool.ruff.lint]
23
+ select = ["ALL"]
24
+ ignore = [
25
+ "COM812", # missing-trailing-comma
26
+ "D203", # one-blank-line-before-class
27
+ "D213", # multi-line-summary-second-line
28
+ "E501", # line-too-long
29
+ "SIM117", # multiple-with-statements
30
+ ]
31
+ extend-ignore = [
32
+ "D100", # undocumented-public-module
33
+ "D101", # undocumented-public-class
34
+ "D102", # undocumented-public-method
35
+ "D103", # undocumented-public-function
36
+ "D104", # undocumented-public-package
37
+ "D105", # undocumented-magic-method
38
+ "D107", # undocumented-public-init
39
+ "EM101", # raw-string-in-exception
40
+ "FBT001", # boolean-type-hint-positional-argument
41
+ "FBT002", # boolean-default-value-positional-argument
42
+ "PD901", # pandas-df-variable-name
43
+ "PGH003", # blanket-type-ignore
44
+ "PLR0913", # too-many-arguments
45
+ "PLR0915", # too-many-statements
46
+ "TRY003", # raise-vanilla-args
47
+ ]
48
+ unfixable = [
49
+ "F401", # unused-import
50
+ ]
51
+
52
+ [tool.ruff.lint.pydocstyle]
53
+ convention = "google"
54
+
55
+ [tool.ruff.lint.per-file-ignores]
56
+ "*.ipynb" = ["T201"]
57
+
58
+ [tool.ruff.format]
59
+ docstring-code-format = true
requirements.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.3.0
4
+ aiofiles==23.2.1
5
+ annotated-types==0.7.0
6
+ anyio==4.8.0
7
+ certifi==2024.12.14
8
+ charset-normalizer==3.4.1
9
+ click==8.1.8
10
+ contourpy==1.3.1
11
+ cycler==0.12.1
12
+ defusedxml==0.7.1
13
+ exceptiongroup==1.2.2
14
+ fastapi==0.115.7
15
+ ffmpy==0.5.0
16
+ filelock==3.17.0
17
+ fonttools==4.55.7
18
+ fsspec==2024.12.0
19
+ gradio==5.13.2
20
+ gradio-client==1.6.0
21
+ h11==0.14.0
22
+ hf-transfer==0.1.9
23
+ httpcore==1.0.7
24
+ httpx==0.28.1
25
+ huggingface-hub==0.28.0
26
+ idna==3.10
27
+ jinja2==3.1.5
28
+ kiwisolver==1.4.8
29
+ markdown-it-py==3.0.0
30
+ markupsafe==2.1.5
31
+ matplotlib==3.10.0
32
+ mdurl==0.1.2
33
+ mpmath==1.3.0
34
+ networkx==3.4.2
35
+ numpy==2.2.2
36
+ nvidia-cublas-cu12==12.1.3.1
37
+ nvidia-cuda-cupti-cu12==12.1.105
38
+ nvidia-cuda-nvrtc-cu12==12.1.105
39
+ nvidia-cuda-runtime-cu12==12.1.105
40
+ nvidia-cudnn-cu12==9.1.0.70
41
+ nvidia-cufft-cu12==11.0.2.54
42
+ nvidia-curand-cu12==10.3.2.106
43
+ nvidia-cusolver-cu12==11.4.5.107
44
+ nvidia-cusparse-cu12==12.1.0.106
45
+ nvidia-nccl-cu12==2.20.5
46
+ nvidia-nvjitlink-cu12==12.8.61
47
+ nvidia-nvtx-cu12==12.1.105
48
+ opencv-python==4.11.0.86
49
+ opencv-python-headless==4.11.0.86
50
+ orjson==3.10.15
51
+ packaging==24.2
52
+ pandas==2.2.3
53
+ pillow==11.1.0
54
+ psutil==5.9.8
55
+ pydantic==2.10.6
56
+ pydantic-core==2.27.2
57
+ pydub==0.25.1
58
+ pygments==2.19.1
59
+ pyparsing==3.2.1
60
+ python-dateutil==2.9.0.post0
61
+ python-multipart==0.0.20
62
+ pytz==2024.2
63
+ pyyaml==6.0.2
64
+ regex==2024.11.6
65
+ requests==2.32.3
66
+ rich==13.9.4
67
+ ruff==0.9.3
68
+ safehttpx==0.1.6
69
+ safetensors==0.5.2
70
+ scipy==1.15.1
71
+ semantic-version==2.10.0
72
+ setuptools==75.8.0
73
+ shellingham==1.5.4
74
+ six==1.17.0
75
+ sniffio==1.3.1
76
+ spaces==0.32.0
77
+ starlette==0.45.3
78
+ supervision==0.25.1
79
+ sympy==1.13.3
80
+ tokenizers==0.21.0
81
+ tomlkit==0.13.2
82
+ torch==2.4.0
83
+ tqdm==4.67.1
84
+ transformers==4.48.1
85
+ triton==3.0.0
86
+ typer==0.15.1
87
+ typing-extensions==4.12.2
88
+ tzdata==2025.1
89
+ urllib3==2.3.0
90
+ uvicorn==0.34.0
91
+ websockets==14.2
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff