DeclK commited on
Commit
a4b0fcb
·
1 Parent(s): 841e1c5

add progress bar and draw human keypoints option

Browse files
Files changed (3) hide show
  1. app.py +14 -16
  2. tools/inferencer.py +7 -1
  3. tools/visualizer.py +2 -2
app.py CHANGED
@@ -3,7 +3,6 @@ from tools.inferencer import PoseInferencerV2
3
  from tools.dtw import DTWForKeypoints
4
  from tools.visualizer import FastVisualizer
5
  from tools.utils import convert_video_to_playable_mp4
6
- from argparse import ArgumentParser
7
  from pathlib import Path
8
  from tqdm import tqdm
9
  import mmengine
@@ -12,13 +11,6 @@ import mmcv
12
  import cv2
13
  import gradio as gr
14
 
15
- def parse_args():
16
- parser = ArgumentParser()
17
- parser.add_argument('--config', type=str, default='configs/mark2.py')
18
- parser.add_argument('--video1', type=str, default='assets/tennis1.mp4')
19
- parser.add_argument('--video2', type=str, default='assets/tennis2.mp4')
20
- return parser.parse_args()
21
-
22
  def concat(img1, img2, height=1080):
23
  w1, h1, _ = img1.shape
24
  w2, h2, _ = img2.shape
@@ -35,15 +27,19 @@ def concat(img1, img2, height=1080):
35
  image = cv2.hconcat([img1, img2])
36
  return image
37
 
38
- def draw(vis: FastVisualizer, img, keypoint, box, oks, oks_unnorm, draw_score_bar=True):
 
 
39
  vis.set_image(img)
40
  vis.draw_non_transparent_area(box)
41
  if draw_score_bar:
42
  vis.draw_score_bar(oks)
43
- vis.draw_human_keypoints(keypoint, oks_unnorm)
 
44
  return vis.get_image()
45
 
46
- def main(video1, video2):
 
47
  # build PoseInferencerV2
48
  config = 'configs/mark2.py'
49
  cfg = mmengine.Config.fromfile(config)
@@ -68,14 +64,14 @@ def main(video1, video2):
68
 
69
  vis = FastVisualizer()
70
 
71
- for i, j in tqdm(dtw_path):
72
  frame1 = v1[i]
73
  frame2 = v2[j]
74
 
75
  frame1_ = draw(vis, frame1.copy(), keypoints1[i], boxes1[i],
76
- oks[i, j], oks_unnorm[i, j])
77
  frame2_ = draw(vis, frame2.copy(), keypoints2[j], boxes2[j],
78
- oks[i, j], oks_unnorm[i, j], draw_score_bar=False)
79
  # concate two frames
80
  frame = concat(frame1_, frame2_)
81
  # draw logo
@@ -100,10 +96,12 @@ if __name__ == '__main__':
100
 
101
  inputs = [
102
  gr.Video(label="Input video 1"),
103
- gr.Video(label="Input video 2")
 
104
  ]
105
 
106
  output = gr.Video(label="Output video")
107
 
108
- demo = gr.Interface(fn=main, inputs=inputs, outputs=output).queue()
 
109
  demo.launch()
 
3
  from tools.dtw import DTWForKeypoints
4
  from tools.visualizer import FastVisualizer
5
  from tools.utils import convert_video_to_playable_mp4
 
6
  from pathlib import Path
7
  from tqdm import tqdm
8
  import mmengine
 
11
  import cv2
12
  import gradio as gr
13
 
 
 
 
 
 
 
 
14
  def concat(img1, img2, height=1080):
15
  w1, h1, _ = img1.shape
16
  w2, h2, _ = img2.shape
 
27
  image = cv2.hconcat([img1, img2])
28
  return image
29
 
30
+ def draw(vis: FastVisualizer, img, keypoint, box, oks, oks_unnorm,
31
+ draw_human_keypoints=True,
32
+ draw_score_bar=True):
33
  vis.set_image(img)
34
  vis.draw_non_transparent_area(box)
35
  if draw_score_bar:
36
  vis.draw_score_bar(oks)
37
+ if draw_human_keypoints:
38
+ vis.draw_human_keypoints(keypoint, oks_unnorm)
39
  return vis.get_image()
40
 
41
+ def main(video1, video2, draw_human_keypoints,
42
+ progress=gr.Progress(track_tqdm=True)):
43
  # build PoseInferencerV2
44
  config = 'configs/mark2.py'
45
  cfg = mmengine.Config.fromfile(config)
 
64
 
65
  vis = FastVisualizer()
66
 
67
+ for i, j in tqdm(dtw_path, desc='Visualizing'):
68
  frame1 = v1[i]
69
  frame2 = v2[j]
70
 
71
  frame1_ = draw(vis, frame1.copy(), keypoints1[i], boxes1[i],
72
+ oks[i, j], oks_unnorm[i, j], draw_human_keypoints)
73
  frame2_ = draw(vis, frame2.copy(), keypoints2[j], boxes2[j],
74
+ oks[i, j], oks_unnorm[i, j], draw_human_keypoints, draw_score_bar=False)
75
  # concate two frames
76
  frame = concat(frame1_, frame2_)
77
  # draw logo
 
96
 
97
  inputs = [
98
  gr.Video(label="Input video 1"),
99
+ gr.Video(label="Input video 2"),
100
+ "checkbox"
101
  ]
102
 
103
  output = gr.Video(label="Output video")
104
 
105
+ demo = gr.Interface(fn=main, inputs=inputs, outputs=output,
106
+ allow_flagging='never').queue()
107
  demo.launch()
tools/inferencer.py CHANGED
@@ -29,6 +29,8 @@ class PoseInferencer:
29
  self.pose_model = init_model(self.pose_model_cfg,
30
  self.pose_model_ckpt,
31
  device=device)
 
 
32
 
33
  def process_one_image(self, img):
34
  init_default_scope('mmdet')
@@ -101,6 +103,8 @@ class PoseInferencerV2:
101
  self.pose_model = init_model(self.pose_model_cfg,
102
  self.pose_model_ckpt,
103
  device)
 
 
104
 
105
  def process_one_image(self, img):
106
  init_default_scope('mmdet')
@@ -145,10 +149,12 @@ class PoseInferencerV2:
145
  video_reader = mmcv.VideoReader(video_path)
146
  all_pose, all_det = [], []
147
 
148
- for frame in tqdm(video_reader):
 
149
  # inference with detector
150
  det, pose = self.process_one_image(frame)
151
  all_pose.append(pose)
152
  all_det.append(det)
 
153
 
154
  return all_det, all_pose
 
29
  self.pose_model = init_model(self.pose_model_cfg,
30
  self.pose_model_ckpt,
31
  device=device)
32
+ # use this count to tell the progress
33
+ self.video_count = 0
34
 
35
  def process_one_image(self, img):
36
  init_default_scope('mmdet')
 
103
  self.pose_model = init_model(self.pose_model_cfg,
104
  self.pose_model_ckpt,
105
  device)
106
+ # use this count to tell the progress
107
+ self.video_count = 0
108
 
109
  def process_one_image(self, img):
110
  init_default_scope('mmdet')
 
149
  video_reader = mmcv.VideoReader(video_path)
150
  all_pose, all_det = [], []
151
 
152
+ count = self.video_count + 1
153
+ for frame in tqdm(video_reader, desc=f'Inference video {count}'):
154
  # inference with detector
155
  det, pose = self.process_one_image(frame)
156
  all_pose.append(pose)
157
  all_det.append(det)
158
+ self.video_count += 1
159
 
160
  return all_det, all_pose
tools/visualizer.py CHANGED
@@ -157,8 +157,8 @@ class FastVisualizer:
157
  else: lvl_names = self.score_level_names(scores)
158
 
159
  for idx, (point, lvl_name) in enumerate(zip(keypoints, lvl_names)):
160
- if idx in set((1, 2, 3, 4)):
161
- continue # do not draw eyes and years
162
  rectangle_xyhw = np.array((point[0], point[1], cube_size, cube_size))
163
  rectangle_xyxy = self.xyhw_to_xyxy(rectangle_xyhw)
164
  self.draw_rectangle(rectangle_xyxy,
 
157
  else: lvl_names = self.score_level_names(scores)
158
 
159
  for idx, (point, lvl_name) in enumerate(zip(keypoints, lvl_names)):
160
+ if idx in set((0, 1, 2, 3, 4)):
161
+ continue # do not draw head
162
  rectangle_xyhw = np.array((point[0], point[1], cube_size, cube_size))
163
  rectangle_xyxy = self.xyhw_to_xyxy(rectangle_xyhw)
164
  self.draw_rectangle(rectangle_xyxy,