add progress bar and draw human keypoints option
Browse files- app.py +14 -16
- tools/inferencer.py +7 -1
- tools/visualizer.py +2 -2
app.py
CHANGED
|
@@ -3,7 +3,6 @@ from tools.inferencer import PoseInferencerV2
|
|
| 3 |
from tools.dtw import DTWForKeypoints
|
| 4 |
from tools.visualizer import FastVisualizer
|
| 5 |
from tools.utils import convert_video_to_playable_mp4
|
| 6 |
-
from argparse import ArgumentParser
|
| 7 |
from pathlib import Path
|
| 8 |
from tqdm import tqdm
|
| 9 |
import mmengine
|
|
@@ -12,13 +11,6 @@ import mmcv
|
|
| 12 |
import cv2
|
| 13 |
import gradio as gr
|
| 14 |
|
| 15 |
-
def parse_args():
|
| 16 |
-
parser = ArgumentParser()
|
| 17 |
-
parser.add_argument('--config', type=str, default='configs/mark2.py')
|
| 18 |
-
parser.add_argument('--video1', type=str, default='assets/tennis1.mp4')
|
| 19 |
-
parser.add_argument('--video2', type=str, default='assets/tennis2.mp4')
|
| 20 |
-
return parser.parse_args()
|
| 21 |
-
|
| 22 |
def concat(img1, img2, height=1080):
|
| 23 |
w1, h1, _ = img1.shape
|
| 24 |
w2, h2, _ = img2.shape
|
|
@@ -35,15 +27,19 @@ def concat(img1, img2, height=1080):
|
|
| 35 |
image = cv2.hconcat([img1, img2])
|
| 36 |
return image
|
| 37 |
|
| 38 |
-
def draw(vis: FastVisualizer, img, keypoint, box, oks, oks_unnorm,
|
|
|
|
|
|
|
| 39 |
vis.set_image(img)
|
| 40 |
vis.draw_non_transparent_area(box)
|
| 41 |
if draw_score_bar:
|
| 42 |
vis.draw_score_bar(oks)
|
| 43 |
-
|
|
|
|
| 44 |
return vis.get_image()
|
| 45 |
|
| 46 |
-
def main(video1, video2
|
|
|
|
| 47 |
# build PoseInferencerV2
|
| 48 |
config = 'configs/mark2.py'
|
| 49 |
cfg = mmengine.Config.fromfile(config)
|
|
@@ -68,14 +64,14 @@ def main(video1, video2):
|
|
| 68 |
|
| 69 |
vis = FastVisualizer()
|
| 70 |
|
| 71 |
-
for i, j in tqdm(dtw_path):
|
| 72 |
frame1 = v1[i]
|
| 73 |
frame2 = v2[j]
|
| 74 |
|
| 75 |
frame1_ = draw(vis, frame1.copy(), keypoints1[i], boxes1[i],
|
| 76 |
-
oks[i, j], oks_unnorm[i, j])
|
| 77 |
frame2_ = draw(vis, frame2.copy(), keypoints2[j], boxes2[j],
|
| 78 |
-
oks[i, j], oks_unnorm[i, j], draw_score_bar=False)
|
| 79 |
# concate two frames
|
| 80 |
frame = concat(frame1_, frame2_)
|
| 81 |
# draw logo
|
|
@@ -100,10 +96,12 @@ if __name__ == '__main__':
|
|
| 100 |
|
| 101 |
inputs = [
|
| 102 |
gr.Video(label="Input video 1"),
|
| 103 |
-
gr.Video(label="Input video 2")
|
|
|
|
| 104 |
]
|
| 105 |
|
| 106 |
output = gr.Video(label="Output video")
|
| 107 |
|
| 108 |
-
demo = gr.Interface(fn=main, inputs=inputs, outputs=output
|
|
|
|
| 109 |
demo.launch()
|
|
|
|
| 3 |
from tools.dtw import DTWForKeypoints
|
| 4 |
from tools.visualizer import FastVisualizer
|
| 5 |
from tools.utils import convert_video_to_playable_mp4
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
from tqdm import tqdm
|
| 8 |
import mmengine
|
|
|
|
| 11 |
import cv2
|
| 12 |
import gradio as gr
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def concat(img1, img2, height=1080):
|
| 15 |
w1, h1, _ = img1.shape
|
| 16 |
w2, h2, _ = img2.shape
|
|
|
|
| 27 |
image = cv2.hconcat([img1, img2])
|
| 28 |
return image
|
| 29 |
|
| 30 |
+
def draw(vis: FastVisualizer, img, keypoint, box, oks, oks_unnorm,
|
| 31 |
+
draw_human_keypoints=True,
|
| 32 |
+
draw_score_bar=True):
|
| 33 |
vis.set_image(img)
|
| 34 |
vis.draw_non_transparent_area(box)
|
| 35 |
if draw_score_bar:
|
| 36 |
vis.draw_score_bar(oks)
|
| 37 |
+
if draw_human_keypoints:
|
| 38 |
+
vis.draw_human_keypoints(keypoint, oks_unnorm)
|
| 39 |
return vis.get_image()
|
| 40 |
|
| 41 |
+
def main(video1, video2, draw_human_keypoints,
|
| 42 |
+
progress=gr.Progress(track_tqdm=True)):
|
| 43 |
# build PoseInferencerV2
|
| 44 |
config = 'configs/mark2.py'
|
| 45 |
cfg = mmengine.Config.fromfile(config)
|
|
|
|
| 64 |
|
| 65 |
vis = FastVisualizer()
|
| 66 |
|
| 67 |
+
for i, j in tqdm(dtw_path, desc='Visualizing'):
|
| 68 |
frame1 = v1[i]
|
| 69 |
frame2 = v2[j]
|
| 70 |
|
| 71 |
frame1_ = draw(vis, frame1.copy(), keypoints1[i], boxes1[i],
|
| 72 |
+
oks[i, j], oks_unnorm[i, j], draw_human_keypoints)
|
| 73 |
frame2_ = draw(vis, frame2.copy(), keypoints2[j], boxes2[j],
|
| 74 |
+
oks[i, j], oks_unnorm[i, j], draw_human_keypoints, draw_score_bar=False)
|
| 75 |
# concate two frames
|
| 76 |
frame = concat(frame1_, frame2_)
|
| 77 |
# draw logo
|
|
|
|
| 96 |
|
| 97 |
inputs = [
|
| 98 |
gr.Video(label="Input video 1"),
|
| 99 |
+
gr.Video(label="Input video 2"),
|
| 100 |
+
"checkbox"
|
| 101 |
]
|
| 102 |
|
| 103 |
output = gr.Video(label="Output video")
|
| 104 |
|
| 105 |
+
demo = gr.Interface(fn=main, inputs=inputs, outputs=output,
|
| 106 |
+
allow_flagging='never').queue()
|
| 107 |
demo.launch()
|
tools/inferencer.py
CHANGED
|
@@ -29,6 +29,8 @@ class PoseInferencer:
|
|
| 29 |
self.pose_model = init_model(self.pose_model_cfg,
|
| 30 |
self.pose_model_ckpt,
|
| 31 |
device=device)
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def process_one_image(self, img):
|
| 34 |
init_default_scope('mmdet')
|
|
@@ -101,6 +103,8 @@ class PoseInferencerV2:
|
|
| 101 |
self.pose_model = init_model(self.pose_model_cfg,
|
| 102 |
self.pose_model_ckpt,
|
| 103 |
device)
|
|
|
|
|
|
|
| 104 |
|
| 105 |
def process_one_image(self, img):
|
| 106 |
init_default_scope('mmdet')
|
|
@@ -145,10 +149,12 @@ class PoseInferencerV2:
|
|
| 145 |
video_reader = mmcv.VideoReader(video_path)
|
| 146 |
all_pose, all_det = [], []
|
| 147 |
|
| 148 |
-
|
|
|
|
| 149 |
# inference with detector
|
| 150 |
det, pose = self.process_one_image(frame)
|
| 151 |
all_pose.append(pose)
|
| 152 |
all_det.append(det)
|
|
|
|
| 153 |
|
| 154 |
return all_det, all_pose
|
|
|
|
| 29 |
self.pose_model = init_model(self.pose_model_cfg,
|
| 30 |
self.pose_model_ckpt,
|
| 31 |
device=device)
|
| 32 |
+
# use this count to tell the progress
|
| 33 |
+
self.video_count = 0
|
| 34 |
|
| 35 |
def process_one_image(self, img):
|
| 36 |
init_default_scope('mmdet')
|
|
|
|
| 103 |
self.pose_model = init_model(self.pose_model_cfg,
|
| 104 |
self.pose_model_ckpt,
|
| 105 |
device)
|
| 106 |
+
# use this count to tell the progress
|
| 107 |
+
self.video_count = 0
|
| 108 |
|
| 109 |
def process_one_image(self, img):
|
| 110 |
init_default_scope('mmdet')
|
|
|
|
| 149 |
video_reader = mmcv.VideoReader(video_path)
|
| 150 |
all_pose, all_det = [], []
|
| 151 |
|
| 152 |
+
count = self.video_count + 1
|
| 153 |
+
for frame in tqdm(video_reader, desc=f'Inference video {count}'):
|
| 154 |
# inference with detector
|
| 155 |
det, pose = self.process_one_image(frame)
|
| 156 |
all_pose.append(pose)
|
| 157 |
all_det.append(det)
|
| 158 |
+
self.video_count += 1
|
| 159 |
|
| 160 |
return all_det, all_pose
|
tools/visualizer.py
CHANGED
|
@@ -157,8 +157,8 @@ class FastVisualizer:
|
|
| 157 |
else: lvl_names = self.score_level_names(scores)
|
| 158 |
|
| 159 |
for idx, (point, lvl_name) in enumerate(zip(keypoints, lvl_names)):
|
| 160 |
-
if idx in set((1, 2, 3, 4)):
|
| 161 |
-
continue # do not draw
|
| 162 |
rectangle_xyhw = np.array((point[0], point[1], cube_size, cube_size))
|
| 163 |
rectangle_xyxy = self.xyhw_to_xyxy(rectangle_xyhw)
|
| 164 |
self.draw_rectangle(rectangle_xyxy,
|
|
|
|
| 157 |
else: lvl_names = self.score_level_names(scores)
|
| 158 |
|
| 159 |
for idx, (point, lvl_name) in enumerate(zip(keypoints, lvl_names)):
|
| 160 |
+
if idx in set((0, 1, 2, 3, 4)):
|
| 161 |
+
continue # do not draw head
|
| 162 |
rectangle_xyhw = np.array((point[0], point[1], cube_size, cube_size))
|
| 163 |
rectangle_xyxy = self.xyhw_to_xyxy(rectangle_xyhw)
|
| 164 |
self.draw_rectangle(rectangle_xyxy,
|